Deep tensor networks with the tensor-network attention layer

Cover Image for Deep tensor networks with the tensor-network attention layer
Bojan Žunkovič
Bojan Žunkovič

Tensor networks are ubiquitous in the field of many-body quantum physics. Since 2014 they have increasingly been applied also in machine learning. So far, two main applications of tensor networks in machine learning are the compression of large weight layers and a new type of multilinear model. Tensor networks are also a stepping stone toward applicable quantum machine learning. However, we typically use tensor networks as single-layer linear models in exponentially large Hilbert space. In this post, I will describe a non-linear tensor-network layer, which we can use to build deep tensor networks. I will discuss applications of deep tensor networks to specific datasets and tasks in the following posts.

This post is related to two papers:


A standard tensor network model consists of two parts. The first part is the embedding, which transforms the raw input (categorical or numerical) into vectors of size dd. The second part is an exponentially large linear model with weights given/compressed in a tensor-network form. Although the first embedding part is important, I will focus on the second part. We will assume that we have an embedding layer that transforms the input of length NN into a Hilbert space vector of size dNd^N via a local embedding input features xix_i, i.e. xjϕ(xj)x_j\rightarrow \phi(x_j).

We then write the embedding of a single input vector as a tensor product of the local embeddings, Φ(x)=ϕ(x1)ϕ(x2)ϕ(xN)\Phi(x)=\phi(x_1)\otimes\phi(x_2)\ldots\phi(x_N).

Tensor network attention

In the following, we will transform the embedded inputs with the tensor-network (TN) attention layer determined by an attention tensor ARD×D×dA\in R^{D\times D\times d} and a transformation tensor BRD×D×d×dB\in R^{D\times D\times d \times d}.

We will keep the tensors AA and BB constant for any input position. The generalization of uniform to non-uniform TN attention is straightforward.

First, we construct matrices A(j)\mathcal{A}(j) by contracting the attention tensor AA with the local embedding vectors ϕ(xj)\phi(x_j)

Then, we use the matrices A(j)\mathcal{A}(j) to construct the left and right context matrices HL,R(j)H^{\rm L,R}(j),

The matrix GG determines the boundary conditions. In the case of closed/cyclic boundary conditions, we use G=1DG=1_{D}. In the case of open boundary conditions, we use G=vLvRG=v^{\rm L}\otimes v^{\rm R}, where the boundary vectors vL,RRdv^{\rm L,R}\in R^d are determined as left and right eigenvectors of the matrix A0A_0 corresponding to the largest eigenvalue.

In the next step we calculate tensor B(j)\mathcal{B}(j) by contracting the embedding at position jj, i.e. ϕ(xj)\phi(x_j) with the transformation tensor BB

where Bj,kB_{j,k} denotes the matrix with elements [Bj,k]μ,ν=Bμ,ν,j,k[B_{j,k}]_{\mu,\nu}=B_{\mu,\nu,j,k}. Finally, we contract the context matrices with the tensor B(j)\mathcal{B}(j) to obtain the transformation of the local embedding ϕ(xj)\phi(x_j)

Optionally, we can add a bias and a final non-linearity. In some variants of the TN attention, we use four-dimensional attention tensors ARD×D×d×dA\in R^{D\times D\times d\times d} and construct A(i)\mathcal{A}(i) by contracting AA with two copies of the local embeddings.

The entire embedding is then transformed by the TN attention as Φ(x)y(1)y(2)y(N)\Phi(x)\rightarrow y(1)\otimes y(2)\otimes\ldots\otimes y(N)

Short diagrammatic overview

Finally, let us summarize the TN attention layer in a diagrammatic notation by using the definitions

We summarize the action of a TN attention on an embedding ϕ(j)\phi(j) in three steps

Local weight matrix

We can also represent the TN attention layer as a local weight matrix W(j)W(j) acting on the embedding ϕ(xj)\phi(x_j). The diagrammatic definition of the local weight matrix is shwon in next figure.

Generalised attention tensor

We can add the attention tensors one more dimension, i.e., ARD×D×d×dA\in R^{D\times D\times d\times d}. In this case, we calculate the matrices A(j)\mathcal{A}(j) by contracting the attention tensors AA with two copies of the embedding vectors.

Alternatively, we could introduce a new local embedding map as ϕ~(xj)=ϕ(xj)ϕ(xj)\tilde{\phi}(x_j)=\phi(x_j)\otimes \phi(x_j) and keep the definition of AA. This transformation would lead to the same left and right context matrices but a slightly more general weight matrix WW.

Deep tensor networks

The main benefit of the TN attention layer is that the bond dimension of the input matrix product state (the embedding vectors) does not increase with the layer application. Therefore, we can apply the TN attention layer repeatedly. We call the resulting model a deep tensor network. The final architecture is shown in the figure below and consists of an embedding layer followed by the TN attention layers. We also add a skip connection and normalization, which we can turn off. The final layer depends on the task. In the classification case, we use the MPS layer described in a previous post on MPS baselines. In the sequence prediction case, the final layer is a decoder layer. The decoder is a simple inverse operation of the embedding layer. We will discuss the sequence prediction case in one of the following posts on grokking.

TN attention as generalized linear-dot attention

I will not explain how we can rewrite a linear dot-attention as a TN attention layer with the bond dimension D=d2+1D = d^2+1. In particular, we show that a slight modification of the proposed TN attention layer implements the linear dot-attention transformation yj=(qjkl)ql,y_j=(q_j\cdot k_l)q_l, where qj=WQϕ(j)q_j=W^{\rm Q}\phi(j), and kj=WKϕ(j)k_j=W^{\rm K}\phi(j). For simplicity we removed the factor 1d\frac{1}{\sqrt{d}} and used a trivial transformation of values (vj=ϕ(j)v_j=\phi(j)). We obtain a more general transformation by applying the linear transformation WV(WQ)1W^{\rm V}(W^{\rm Q})^{-1} on the transformed values yjy_j. Additionally, we assume that the embedding vectors are L2L_2 normalized.

The embedding tensor can be interpreted as an NN-fold tensor product of embedding vectors Φ=ϕ(1)ϕ(2)ϕ(N)\Phi=\phi(1)\otimes\phi(2)\otimes\ldots\otimes\phi(N). We consider the action of the TN attention on the vector Φ\Phi. First, we decompose TN attention tensors ARD×D×d×dA\in R^{D\times D\times d\times d} into three components

Aa,at,s(j)=t,s=1dWt,tQA~a,at,s(j)Ws,sKA^{t,s}_{a,a'}(j)=\sum_{t',s'=1}^{d}W^{\rm Q}_{t',t}\tilde{A}^{t',s'}_{a,a'}(j)W^{\rm K}_{s',s}

The matrices WQW^{\rm Q} and WKW^{\rm K} transform the local vectors ϕ(j)\phi(j) to queries and keys. The remaining TN attention implements a permutation operator

Tr(G~A~t1,s1A~t2,s2A~tN,sN)=i<j=1NPij, {\mathrm Tr}\left(\tilde{G} \tilde{A}^{t_1,s_1}\tilde{A}^{t_2,s_2}\ldots\tilde{A}^{t_N,s_N}\right) = \sum_{i<j=1}^N P_{ij},


Pi,jϕ(1)ϕ(2)ϕ(i)ϕ(j)ϕ(N)=ϕ(1)ϕ(2)ϕ(j)ϕ(i)ϕ(N) P_{i,j}\phi(1)\otimes\phi(2)\otimes\ldots\phi(i)\otimes\ldots\phi(j)\otimes\ldots\phi(N) = \phi(1)\otimes\phi(2)\otimes\ldots\phi(j)\otimes\ldots\phi(i)\otimes\ldots\phi(N)

The local weight matrix W(j)W(j), without the normalization of the left and right context matrices, is then given by

W(j)=i,lj(qikl)(kiql)1d+ijqikiT. W(j) = \sum_{i,l\neq j}(q_i\cdot k_l)(k_i\cdot q_l) 1_d + \sum_{i\neq j} q_i k_i^T.

The final transformation of the embedding ϕ(j)\phi(j) with the described TN attention layer is then

ϕ(j)cqj+ijqi(kiqj), \phi(j)\rightarrow c q_j + \sum_{i\neq j}q_i (k_i\cdot q_j),

where c=i,lj(qikl)(kiql)c=\sum_{i,l\neq j}(q_i\cdot k_l)(k_i\cdot q_l). We can correct the difference between the linear attention and the TN attention result by a simple local residual connection. The final difference is that we normalize the left and the right context matrices. This normalization translates to rescaling the final output and has no effect if we normalize the output after each application of the attention mechanism.

Finally, we provide the TN attention tensors A~\tilde{A} that implement the permutation transformation

A~1,dt+s+1t,s=1,t,s=1,2d, \tilde{A}^{t,s}_{1,dt+s+1}=1,\quad t,s=1,2\ldots d,

A~dt+s+1,1s,t=1,t,s=1,2d, \tilde{A}^{s,t}_{dt+s+1,1}=1,\quad t,s=1,2\ldots d,

A~a,as,s=1,s=1,2d,a=2,3d2+1, \tilde{A}^{s,s}_{a,a}=1,\quad s=1,2\ldots d,\quad a=2,3\ldots d^2+1,

G~1,1=1. \tilde{G}_{1,1}=1.

The remaining elements of the tensors A~\tilde{A} and G~\tilde{G} are zero.

As we have just seen, the TN attention layer can implement (though in an inefficient way) linear dot attention. In addition, the TN attention can model higher-order correlations/interactions between embeddings and explicitly incorporates the embedding position. Therefore, we interpret the tensor-network attention as a generalized linear dot attention mechanism. On the other hand, we can not straightforwardly introduce nonlinearities to the presented formalism.

I will discuss the application of the TN attention layer/formalism to the image classification problem and grokking in the following posts.