Tensor Product Attention (TPA) 导读

最近Deepseek比较出圈,连带着里面用到的MLA也被讨论得更多了。MLA无疑是一个非常出色的attention改进,但是由于它的KV cache设计,不能很好地兼容RoPE,因此作者们使用了decoupled RoPE这样的「补丁」来引入位置关系,这无疑也增加了实现的复杂度。 最近,修改注意力KV Cache这一线工作又增添了TPA这个新成员,笔者觉得这篇文章比较有趣,因此希望写一篇简短的导读。之所以叫「导读」,是因为本文不打算太深入文章的formulation和实验,而是从笔者自己的视角出发介绍文章的一些重点贡献。。 MHA的拆解 最标准的多头注意力(Vaswani的版本),大致可以拆解成3个步骤(这里默认讨论自注意力)1: $$ \begin{aligned} \text{step 1 }& \begin{cases} \boldsymbol{q}_i^{(h)} = \boldsymbol{x}_i\boldsymbol{W}_q^{(h)}\in\mathbb{R}^{d_k},\quad \boldsymbol{W}_q^{(h)}\in\mathbb{R}^{d\times d_k} \\ \boldsymbol{k}_i^{(h)} = \boldsymbol{x}_i\boldsymbol{W}_k^{(h)}\in\mathbb{R}^{d_k},\quad \boldsymbol{W}_k^{(h)}\in\mathbb{R}^{d\times d_k}\\ \boldsymbol{v}_i^{(h)} = \boldsymbol{x}_i\boldsymbol{W}_v^{(h)}\in\mathbb{R}^{d_v},\quad \boldsymbol{W}_v^{(h)}\in\mathbb{R}^{d\times d_v} \end{cases} \\ \text{step 2 }& \begin{cases} \boldsymbol{o}_t^{(h)} = \text{Attention}\left(\boldsymbol{q}_t^{(h)}, \boldsymbol{k}_{\leq t}^{(h)}, \boldsymbol{v}_{\leq t}^{(h)}\right)\triangleq\frac{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(h)} \boldsymbol{k}_i^{(h)}{}^{\top}\right)\boldsymbol{v}_i^{(h)}}{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(h)} \boldsymbol{k}_i^{(h)}{}^{\top}\right)} \\ \end{cases} \\ \text{step 3 }& \begin{cases} \boldsymbol{o}_t = \left[\boldsymbol{o}_t^{(1)}, \boldsymbol{o}_t^{(2)}, \cdots, \boldsymbol{o}_t^{(H)}\right] \end{cases} \\ \end{aligned}\\ $$ 这里$\boldsymbol{x}_1,\boldsymbol{x}_2,\cdots,\boldsymbol{x}_l$是输入向量,上标$h\in \{1,\ldots,H\}$表示注意力头,$d,d_k,d_v$分别表示输入维度、key维度、value维度。在第一步中,我们将每个token的表征独立地线性投射到不同的自空间上,第二步是标准的点积注意力,第三步是拼接(后续再做一步线性变换)。 我们重点来审视一下第一个步骤,这里由于每个token的表征是独立操作的(类似FFN),因此我们可以将每个token的表征看做一个样本点,这里做的事情就是将一个$d$维度的向量,转化成3个矩阵$\tilde{\boldsymbol{Q}}\in\mathbb{R}^{H\times d_k},\tilde{\boldsymbol{K}}\in\mathbb{R}^{H\times d_k},\tilde{\boldsymbol{V}}\in\mathbb{R}^{H\times d_v}$,每一个头对应这个矩阵中的一行。接着我们逐行计算每行对应的三组向量的点积注意力,就构成了标准的多头注意力。在标准实现中,这个变换是通过参数矩阵的线性变换+ reshape实现的。 ...

2025-02-13 · Tianyang Lin