a linear transformer based on attention pooling
Abstract
Currently, transformer-based neural networks have dominated the field of artificial intelligence, achieving widespread applications in natural language processing, computer vision, image generation, and multimodal domains. Especially since OpenAI (reference) released ChatGPT in 2023, GPT has become the undisputed leader in the language domain. However, from the perspective of first principles, there are significant differences between GPT and human thinking. This is mainly reflected in the following aspects: First, GPT needs to store all historical information, while humans have the ability to forget; second, the complexity of GPT is O(N2)O(N^2)O(N2), while human reading complexity is O(N)O(N)O(N). Given these differences, applying GPT to lifelong learning domains still faces several challenges. This paper proposes a linear transformer model based on attention pooling, aiming to achieve linear complexity and forgetting capabilities while retaining the core attention mechanism of GPT. In summary, this paper introduces a linear transformer model based on attention pooling to achieve the aforementioned goals.
Introduction
Due to its ability to capture long-range dependencies between words, the GPT architecture has achieved great success in natural language processing in recent years. From the mainstream solutions adopted by major companies, it is rare to see models like RNNs and LSTMs being used. RNNs and LSTMs have significant drawbacks in training and deployment, mainly due to their inability to parallelize training and their catastrophic forgetting issues. Although the GPT architecture has achieved such great success, it still has several limitations. First, the principles of GPT are fundamentally different from how human intelligence works. Human intelligence includes a forgetting mechanism, while GPT does not. While the lack of a forgetting mechanism might make GPT more powerful, it poses a fatal problem in lifelong learning, mainly due to the increasing size of stored historical information and the growing computational complexity of retrieving this information. Both factors are positively correlated with inference time.
To address these issues, compressing historical information is an unavoidable step. This paper proposes a attention pooling method to compress historical information for querying by the current token. This method supports both parallel training and causal inference.
In summary, the contributions of this paper are as follows:
- Proposes a linear transformer named WQKV, which uses attention pooling to compress historical information.
- Presents parallel training and recursive causal inference methods for WQKV.
- Conducts ablation experiments on WQKV.
Related Work
There has been significant work in reducing the complexity of transformers from O(N2)O(N^2)O(N2) to O(N)O(N)O(N).
RWKV is inspired by Apple’s Attention Free Transformer. The RWKV architecture is carefully simplified and optimized to be converted into an RNN. Additionally, many techniques, such as TokenShift and SmallInitEmb, are used to make RWKV perform comparably to GPT.
In “Were RNNs All We Needed?”, minLSTM and minGRU are proposed by removing nonlinear factors that hinder parallel computation. These RNNs can perform recursive inference and parallel training using the parallel scan algorithm.
However, these methods have some drawbacks. To support large contexts (state variables), they often require a very large number of parameters. Our method does not have this requirement. Moreover, these methods lack a clear connection to the attention mechanism of transformers, and their effectiveness has not been widely validated. In contrast, our method fully inherits the attention mechanism of transformers, with the addition of a attention pooling layer. Therefore, the scalability of our method is guaranteed.
Attention pooling.
Method
A straightforward approach to reducing the complexity of GPT from O(N2)O(N^2)O(N2)to O(N)O(N)O(N) is to compress the queried historical information into a fixed-length vector, so that the current token queries a fixed-size context. Clearly, this process is similar to pooling. Therefore, we designed a softmax-weighted pooling method to compress historical information into a fixed-length vector, followed by the original GPT architecture for attention operations. Below, we describe how this is achieved.
For simplicity, this paper assumes a batch size of 1 and does not use multi-head attention.
The attention pooling used in this paper is described as follow pictures.

Consistent with GPT, for the current sequence XXX with shape [L,C][L, C][L,C], we obtain Q,K,VQ, K, VQ,K,V with shape [L,C][L, C][L,C] using matrices Wq,Wk,WvW_q, W_k, W_vWq,Wk,Wv. The specific calculations are as follows:
K=WkXV=WvX K = W_kX \\ V = W_vX K=WkXV=WvX
In addition, we use an extra weight matrix WwW_wWw to obtain W=WwXW = W_wXW=WwX, with shape [L,D][L, D][L,D]. This weight matrix is used to compress historical information. The attention pooling results are:
K′=softmax(WT)KV′=softmax(WT)V K' = \text{softmax}(W^T) K \\ V' = \text{softmax}(W^T) V K′=softmax(WT)KV′=softmax(WT)V
Thus, the shape of K,VK, VK,V is compressed from [L,C][L, C][L,C] to [D,C][D, C][D,C].
After the sequence information is compressed, the subsequent operations are similar to GPT, using QQQ to perform attention on [K′,V′][K', V'][K′,V′], i.e.,
X′=softmax(QTK′)V′ X' = \text{softmax}(Q^TK')V' X′=softmax(QTK′)V′
From the above formulas, the new WQKV model is highly similar to GPT. The difference is that the new model adds a pooling module to compress information. Clearly, the WQKV model includes the transformer model, as long as the historical information is consistent with the full sequence information. When the memory dimension DDD is reduced, the model gains the ability to forget. Compared to the O(N2)O(N^2)O(N2) complexity of GPT, the new model has a complexity of O(ND)O(ND)O(ND).
The above derivation assumes that each token can see all other tokens, equivalent to the encoder layer of a transformer. For infinitely long sequences, it is necessary to derive the parallel training method and recursive inference method for the causal form. Fortunately, the WQKV model also has a corresponding causal inference form. Below, we present its derivation.
From the previous derivation, we note that the computation of K′,V′K', V'K′,V′ has a recursive form. For convenience, we denote exp(WT)\text{exp}(W^T)exp(WT) as (w1w2...wl)\begin{pmatrix} w_1 & w_2 & ... & w_l\end{pmatrix}(w1w2...wl), and KKK as (k1k2...kl)\begin{pmatrix} k_1 \\ k_2 \\ ... \\ k_l \end{pmatrix}k1k2...kl. We rewrite Kl′K'_lKl′ as follows:
Kl′=(w1w2...wl)∑ilwi(k1k2...kl)=(w1w2...wl)(k1k2...kl)∑ilwi=∑ilwiki∑ilwi=∑il−1wiki+wlkl∑il−1wi+wl K'_l = \dfrac{\begin{pmatrix} w_1 & w_2 & ... & w_l\end{pmatrix}}{\sum_i^{l}w_i}\begin{pmatrix} k_1 \\ k_2 \\ ... \\ k_l \end{pmatrix} \\ = \dfrac{\begin{pmatrix} w_1 & w_2 & ... & w_l\end{pmatrix} \begin{pmatrix} k_1 \\ k_2 \\ ... \\ k_l \end{pmatrix}}{\sum_i^{l}w_i} \\ = \dfrac{\sum_i^l w_ik_i}{\sum_i^{l}w_i} \\ = \dfrac{\sum_i^{l-1}w_ik_i + w_lk_l}{\sum_i^{l-1}w_i + w_l} Kl′=∑ilwi(w1w2...wl)k1k2...kl=∑ilwi(w1w2...wl)k1k2...kl=∑ilwi∑ilwiki=∑il−1wi+wl∑il−1wiki+wlkl
Let Ml=∑ilwiki,Nl=∑ilwiM_{l} = \sum_i^{l}w_ik_i, N_{l} = \sum_i^{l}w_iMl=∑ilwiki,Nl=∑ilwi, then the following recurrence relation exists:
Kl′=Ml−1+wlklNl−1+wlKl−1′=Ml−1Nl−1 K'_{l} = \dfrac{M_{l-1} + w_lk_l}{N_{l-1} + w_l}\\ K'_{l-1} = \dfrac{M_{l-1}}{N_{l-1}} Kl′=Nl−1+wlMl−1+wlklKl−1′=Nl−1Ml−1
From the above recursive formula, if parallel training is desired, cumulative sums must be used, which is similar to existing methods.
Experiments
There is a clear linear relationship between the number of tokens memorized in the context and performance.
Future Work
This linear transformer network based on attention pooling can be reinterpreted from the perspective of cross-attention.
For a text sequence L1L1L1 with shape [L,C][L, C][L,C], we add a random weight sequence WWW with shape [D,C][D, C][D,C] to form a new sequence LLL with shape [L+D,C][L + D, C][L+D,C].
The basic principle of the cross-attention-based linear transformer is to perform cross-attention repeatedly on L1L1L1 and L2L2L2.
The formula for L2L2L2 performing cross-attention on L1L1L1 is:
W′=softmax(WL1T)L1 W' = \text{softmax}(W L1^T) L1 W′=softmax(WL1T)L1
The formula for L1′L1'L1′ performing cross-attention on L2L2L2 is:
L1′=softmax(L1W′T)W′ L1' = \text{softmax}(L1 W'^T) W' L1′=softmax(L1W′T)W′
The overall formula is:
L1′=softmax(L1(softmax(WL1T)L1)T)∗softmax(WL1T)L1 L1' = \text{softmax}(L1 (\text{softmax}(WL1^T)L1)^T) * \text{softmax}(WL1^T)L1 L1′=softmax(L1(softmax(WL1T)L1)T)∗softmax(WL1T)L1
From the above simple transformation, it can be seen that this is essentially equivalent to the attention pooling method. The connection between these two methods is worth further exploration.
Limitations:
To achieve parallel training, GPU memory usage is excessively high When use our method.

430

被折叠的 条评论
为什么被折叠?



