最近Deepseek比较出圈,连带着里面用到的MLA也被讨论得更多了。MLA无疑是一个非常出色的attention改进,但是由于它的KV cache设计,不能很好地兼容RoPE,因此作者们使用了decoupled RoPE这样的「补丁」来引入位置关系,这无疑也增加了实现的复杂度。
最近,修改注意力KV Cache这一线工作又增添了TPA这个新成员,笔者觉得这篇文章比较有趣,因此希望写一篇简短的导读。之所以叫「导读」,是因为本文不打算太深入文章的formulation和实验,而是从笔者自己的视角出发介绍文章的一些重点贡献。。
MHA的拆解
最标准的多头注意力(Vaswani的版本),大致可以拆解成3个步骤(这里默认讨论自注意力)1:
$$ \begin{aligned} \text{step 1 }& \begin{cases} \boldsymbol{q}_i^{(h)} = \boldsymbol{x}_i\boldsymbol{W}_q^{(h)}\in\mathbb{R}^{d_k},\quad \boldsymbol{W}_q^{(h)}\in\mathbb{R}^{d\times d_k} \\ \boldsymbol{k}_i^{(h)} = \boldsymbol{x}_i\boldsymbol{W}_k^{(h)}\in\mathbb{R}^{d_k},\quad \boldsymbol{W}_k^{(h)}\in\mathbb{R}^{d\times d_k}\\ \boldsymbol{v}_i^{(h)} = \boldsymbol{x}_i\boldsymbol{W}_v^{(h)}\in\mathbb{R}^{d_v},\quad \boldsymbol{W}_v^{(h)}\in\mathbb{R}^{d\times d_v} \end{cases} \\ \text{step 2 }& \begin{cases} \boldsymbol{o}_t^{(h)} = \text{Attention}\left(\boldsymbol{q}_t^{(h)}, \boldsymbol{k}_{\leq t}^{(h)}, \boldsymbol{v}_{\leq t}^{(h)}\right)\triangleq\frac{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(h)} \boldsymbol{k}_i^{(h)}{}^{\top}\right)\boldsymbol{v}_i^{(h)}}{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(h)} \boldsymbol{k}_i^{(h)}{}^{\top}\right)} \\ \end{cases} \\ \text{step 3 }& \begin{cases} \boldsymbol{o}_t = \left[\boldsymbol{o}_t^{(1)}, \boldsymbol{o}_t^{(2)}, \cdots, \boldsymbol{o}_t^{(H)}\right] \end{cases} \\ \end{aligned}\\ $$
这里$\boldsymbol{x}_1,\boldsymbol{x}_2,\cdots,\boldsymbol{x}_l$
是输入向量,上标$h\in \{1,\ldots,H\}$
表示注意力头,$d,d_k,d_v$
分别表示输入维度、key维度、value维度。在第一步中,我们将每个token的表征独立地线性投射到不同的自空间上,第二步是标准的点积注意力,第三步是拼接(后续再做一步线性变换)。
我们重点来审视一下第一个步骤,这里由于每个token的表征是独立操作的(类似FFN),因此我们可以将每个token的表征看做一个样本点,这里做的事情就是将一个$d$
维度的向量,转化成3个矩阵$\tilde{\boldsymbol{Q}}\in\mathbb{R}^{H\times d_k},\tilde{\boldsymbol{K}}\in\mathbb{R}^{H\times d_k},\tilde{\boldsymbol{V}}\in\mathbb{R}^{H\times d_v}$
,每一个头对应这个矩阵中的一行。接着我们逐行计算每行对应的三组向量的点积注意力,就构成了标准的多头注意力。在标准实现中,这个变换是通过参数矩阵的线性变换+ reshape实现的。
在自回归模型推理的阶段,这里涉及到的$\boldsymbol{k}_i^{(h)},\boldsymbol{v}_i^{(h)}$
会被后续的token使用到,因此可以将其缓存起来,避免重复计算,这就是KV cache的思想。序列中每个位置则需要缓存$2Hd_{kv}$
个值(实现中一般$d_k=d_v$
)。
TPA:一种低秩重参化技巧
上面我们描述了在标准注意力中,在计算点积注意力之前,需要通过一系列形如$f:\mathbb{R}^d\to\mathbb{R}^{H\times d'}$
的映射,把输入投射到query-key-value子空间。在TPA这个文章中,作者介绍了一Contextual Factorization (CF)技巧来构造这个映射。笔者将这个技巧描述为「低秩重参化」。
以$\tilde{\boldsymbol{Q}}$
的构造为例,对于输入$\boldsymbol{x}_i$
,引入两个参数矩阵$\boldsymbol{W}^A\in\mathbb{R}^{d\times (r\cdot H)},\boldsymbol{W}^B\in\mathbb{R}^{d\times (r\cdot d_k)}$
,得到两个向量2
$$ \begin{align} \boldsymbol{a}_i &= \boldsymbol{x}_i \boldsymbol{W}^A\in\mathbb{R}^{r\cdot H}\\ \boldsymbol{b}_i &= \boldsymbol{x}_i \boldsymbol{W}^B\in\mathbb{R}^{r\cdot d_k}\\ \end{align}\\ $$
这里$R$
是我们指定的最大的秩,是一个超参数,将这两个向量reshape成矩阵形式:
$$ \begin{align} \tilde{\boldsymbol{A}}_i &\in\mathbb{R}^{r\times H}\\ \tilde{\boldsymbol{B}}_i &\in\mathbb{R}^{r\times d_k}\\ \end{align}\tag{1}\\ $$
这样我们就可以构造出一个$\tilde{\boldsymbol{Q}}$
$$ \tilde{\boldsymbol{Q}}_i = \frac{1}{R}\tilde{\boldsymbol{A}}_i^\top \tilde{\boldsymbol{B}}_i\in\mathbb{R}^{H\times d_k}\\ $$
同样的方法可以构造出$\tilde{\boldsymbol{K}}_i\in\mathbb{R}^{H\times d_k}$
和$\tilde{\boldsymbol{V}}_i\in\mathbb{R}^{H\times d_v}$
。按照上一节中的介绍,后面需要做的就是将三组矩阵的每一行分别当做每个注意力头的query-key-value做点积注意力。与标准MHA不同的是,这里的映射是带有非线性的。
顺便一提,原作中这个重参化的引入是用的外积和的形式,笔者觉得有点冗余了,因为外积和与矩阵的乘法是等价的,相信多数读者对于矩阵的乘法是更加熟悉的。
更少的KV Cache
使用这种重参化的形式的一个好处是,在推理的时候,只需要缓存$\{\tilde{\boldsymbol{A}}^K_j,\tilde{\boldsymbol{B}}^K_j,\tilde{\boldsymbol{A}}^V_j,\tilde{\boldsymbol{B}}^V_j\}_{j\le t}$
即可,对应每个token位置的KV Cache量在
$$ r_k(H+d_k)+r_v(H+d_v)\\ $$
如果代入原作的设定$r_k=r_v=2$
,TPA的KV Cache可以大致计算为$4(H+d_{kv})$
,比起标准MHA的$2Hd_{kv}$
要低不少。根据知乎@寒月灼华的计算,Medium大小的模型上,TPA每token的KV Cache为444,相比MHA的2048和MLA的1056,都是更有优势的。
兼容旋转位置编码
众所周知,现在最广泛使用的位置编码方式RoPE,可以通过在点积注意力的query-key上分别乘上分块对角旋转矩阵来实现高效的相对位置表征3。而在MLA中,由于KV Cache保存的压缩向量并不是点积注意力最终的key,因此不能直接兼容RoPE。而TPA的形式恰好可以直接兼容RoPE,原作中有比较完整的证明过程,这里笔者按照上一节的符号做一个简短的sketch proof。
RoPE的基本思想是,在第$i$
个token位置引入旋转编码矩阵$\boldsymbol{\mathcal{R}}_i$
,从而
$$ \left(\boldsymbol{q}_i\boldsymbol{\mathcal{R}}_i\right)\left(\boldsymbol{k}_j\boldsymbol{\mathcal{R}}_j\right)^\top = \boldsymbol{q}_i\boldsymbol{\mathcal{R}}_{j-i}\boldsymbol{k}_j^\top\tag{2}\\ $$
在KV Cache的框架下,问题的关键在于令key的旋转位置编码被包含在KV Cache中,刚好TPA能够满足这个要求。考虑上一节构造的$\tilde{\boldsymbol{Q}},\tilde{\boldsymbol{K}}$
$$ \begin{align} \tilde{\boldsymbol{Q}}_i &= \frac{1}{R}(\tilde{\boldsymbol{A}}_i^Q)^\top \tilde{\boldsymbol{B}}_i^Q\in\mathbb{R}^{H\times d_k}\\ \tilde{\boldsymbol{K}}_j &= \frac{1}{R}(\tilde{\boldsymbol{A}}_j^K)^\top \tilde{\boldsymbol{B}}_j^K\in\mathbb{R}^{H\times d_k}\\ \end{align}\\ $$
我们已经提到,第$h$
个注意力头就是在$\tilde{\boldsymbol{Q}},\tilde{\boldsymbol{K}},\tilde{\boldsymbol{V}}$
矩阵的第$h$
行向量基础上做点积注意力。我们假设$\boldsymbol{a}_i^Q$
是$\tilde{\boldsymbol{A}}_i^Q$
的第$h$
列,$\boldsymbol{a}_j^K$
是$\tilde{\boldsymbol{A}}_j^K$
的第$h$
列,则在TPA中,第$h$
个注意力头的点积注意力的输入分别是
$$ \begin{align} \boldsymbol{q}_i &= \frac{1}{R}\boldsymbol{a}_j^Q \tilde{\boldsymbol{B}}_i^Q\in\mathbb{R}^{d_k}\\ \boldsymbol{k}_j &= \frac{1}{R}\boldsymbol{a}_j^K \tilde{\boldsymbol{B}}_j^K\in\mathbb{R}^{d_k}\\ \end{align}\\ $$
按照公式$(2)$
的原理,只需要将旋转位置编码乘在上述两项的右侧即可
$$ \begin{align} \boldsymbol{q}_i{\color[rgb]{0, 0.5, 0.8}{\boldsymbol{\mathcal{R}}_i}} &= \frac{1}{R}\boldsymbol{a}_j^Q \tilde{\boldsymbol{B}}_i^Q{\color[rgb]{0, 0.5, 0.8}{\boldsymbol{\mathcal{R}}_i}}=\frac{1}{R}\boldsymbol{a}_j^Q \left(\tilde{\boldsymbol{B}}_i^Q{\color[rgb]{0, 0.5, 0.8}{\boldsymbol{\mathcal{R}}_i}}\right)\\ \boldsymbol{k}_j{\color[rgb]{0, 0.5, 0.8}{\boldsymbol{\mathcal{R}}_j}} &= \frac{1}{R}\boldsymbol{a}_j^K \tilde{\boldsymbol{B}}_j^K{\color[rgb]{0, 0.5, 0.8}{\boldsymbol{\mathcal{R}}_j}}=\frac{1}{R}\boldsymbol{a}_j^K \left(\tilde{\boldsymbol{B}}_j^K{\color[rgb]{0, 0.5, 0.8}{\boldsymbol{\mathcal{R}}_j}}\right)\\ \end{align}\\ $$
在实现中,这相当于在公式$(1)$
的变换之后,分别对$\tilde{\boldsymbol{B}}^Q, \tilde{\boldsymbol{B}}^K$
应用RoPE编码4
B_q, B_k = apply_rotary_emb(B_q, cos, sin), apply_rotary_emb(B_k, cos, sin)
在推理的时候,将$\{\tilde{\boldsymbol{A}}^K_j,\tilde{\boldsymbol{B}}^K_j{\color[rgb]{0, 0.5, 0.8}{\boldsymbol{\mathcal{R}}_j}},\tilde{\boldsymbol{A}}^V_j,\tilde{\boldsymbol{B}}^V_j\}_{j\le t}$
缓存,剩余部分正常计算。
总结
本文简单介绍了Tensor-Product Attention(TPA)的基本方法和两个性质:较少的KV Cache缓存量和RoPE的兼容性。更细节的描述和详尽的实验请阅读Zhang 2025. Tensor Product Attention Is All You Need以及作者维护的仓库tensorgi/T6。
参考阅读
- Zhang 2025. Tensor Product Attention Is All You Need
- 缓存与效果的极限拉扯:从MHA、MQA、GQA到MLA
- Transformer升级之路:2、博采众长的旋转式位置编码
本篇中的向量是行向量。 ↩︎
为了显示简洁,公式中去掉了Q的标注。 ↩︎
这是逻辑上的描述,实现上不会实例化一整个矩阵。 ↩︎
https://github.com/tensorgi/T6/blob/bd6dd4ab682a9955d256d395fa9bf0d5da8a804b/model/T6.py#L122C9-L122C84 ↩︎