Deep tensor networks with the tensor-network attention layer
Tensor networks are ubiquitous in the field of many-body quantum physics. Since 2014 they have increasingly been applied also in machine learning. So far, two main applications of tensor networks in machine learning are the compression of large weight layers and a new type of multilinear model. Tensor networks are also a stepping stone toward applicable quantum machine learning. However, we typically use tensor networks as single-layer linear models in exponentially large Hilbert space. In this post, I will describe a non-linear tensor-network layer, which we can use to build deep tensor networks. I will discuss applications of deep tensor networks to specific datasets and tasks in the following posts.
This post is related to two papers:
- Deep tensor networks with matrix product operators
- [Grokking phase transitions in learning local rules with gradient descent] (https://arxiv.org/abs/2210.15435)
Architecture
A standard tensor network model consists of two parts. The first part is the embedding, which transforms the raw input (categorical or numerical) into vectors of size . The second part is an exponentially large linear model with weights given/compressed in a tensor-network form. Although the first embedding part is important, I will focus on the second part. We will assume that we have an embedding layer that transforms the input of length into a Hilbert space vector of size via a local embedding input features , i.e. .
We then write the embedding of a single input vector as a tensor product of the local embeddings, .
Tensor network attention
In the following, we will transform the embedded inputs with the tensor-network (TN) attention layer determined by an attention tensor and a transformation tensor .
We will keep the tensors and constant for any input position. The generalization of uniform to non-uniform TN attention is straightforward.
First, we construct matrices by contracting the attention tensor with the local embedding vectors
Then, we use the matrices to construct the left and right context matrices ,
The matrix determines the boundary conditions. In the case of closed/cyclic boundary conditions, we use . In the case of open boundary conditions, we use , where the boundary vectors are determined as left and right eigenvectors of the matrix corresponding to the largest eigenvalue.
In the next step we calculate tensor by contracting the embedding at position , i.e. with the transformation tensor
where denotes the matrix with elements . Finally, we contract the context matrices with the tensor to obtain the transformation of the local embedding
Optionally, we can add a bias and a final non-linearity. In some variants of the TN attention, we use four-dimensional attention tensors and construct by contracting with two copies of the local embeddings.
The entire embedding is then transformed by the TN attention as
Short diagrammatic overview
Finally, let us summarize the TN attention layer in a diagrammatic notation by using the definitions
We summarize the action of a TN attention on an embedding in three steps
Local weight matrix
We can also represent the TN attention layer as a local weight matrix acting on the embedding . The diagrammatic definition of the local weight matrix is shwon in next figure.
Generalised attention tensor
We can add the attention tensors one more dimension, i.e., . In this case, we calculate the matrices by contracting the attention tensors with two copies of the embedding vectors.
Alternatively, we could introduce a new local embedding map as and keep the definition of . This transformation would lead to the same left and right context matrices but a slightly more general weight matrix .
Deep tensor networks
The main benefit of the TN attention layer is that the bond dimension of the input matrix product state (the embedding vectors) does not increase with the layer application. Therefore, we can apply the TN attention layer repeatedly. We call the resulting model a deep tensor network. The final architecture is shown in the figure below and consists of an embedding layer followed by the TN attention layers. We also add a skip connection and normalization, which we can turn off. The final layer depends on the task. In the classification case, we use the MPS layer described in a previous post on MPS baselines. In the sequence prediction case, the final layer is a decoder layer. The decoder is a simple inverse operation of the embedding layer. We will discuss the sequence prediction case in one of the following posts on grokking.
TN attention as generalized linear-dot attention
I will not explain how we can rewrite a linear dot-attention as a TN attention layer with the bond dimension . In particular, we show that a slight modification of the proposed TN attention layer implements the linear dot-attention transformation where , and . For simplicity we removed the factor and used a trivial transformation of values (). We obtain a more general transformation by applying the linear transformation on the transformed values . Additionally, we assume that the embedding vectors are normalized.
The embedding tensor can be interpreted as an -fold tensor product of embedding vectors . We consider the action of the TN attention on the vector . First, we decompose TN attention tensors into three components
The matrices and transform the local vectors to queries and keys. The remaining TN attention implements a permutation operator
where
The local weight matrix , without the normalization of the left and right context matrices, is then given by
The final transformation of the embedding with the described TN attention layer is then
where . We can correct the difference between the linear attention and the TN attention result by a simple local residual connection. The final difference is that we normalize the left and the right context matrices. This normalization translates to rescaling the final output and has no effect if we normalize the output after each application of the attention mechanism.
Finally, we provide the TN attention tensors that implement the permutation transformation
The remaining elements of the tensors and are zero.
As we have just seen, the TN attention layer can implement (though in an inefficient way) linear dot attention. In addition, the TN attention can model higher-order correlations/interactions between embeddings and explicitly incorporates the embedding position. Therefore, we interpret the tensor-network attention as a generalized linear dot attention mechanism. On the other hand, we can not straightforwardly introduce nonlinearities to the presented formalism.
I will discuss the application of the TN attention layer/formalism to the image classification problem and grokking in the following posts.