自定义CUDA kernel加速Muon优化器

TL;DR 笔者通过自定义$XX^\top$算子,加速Muon优化器中的核心操作——Newton-Schulz迭代,算子单测在8192维度上相比原生实现计算时间降低约一半,端到端的Newton-Schulz迭代运行时间降低至原来的0.71倍。相关代码已经发布在: https://github.com/nil0x9/flash-muon 读者可以通过如下方式来试用优化版本的Muon实现或者核心的CUDA算子。 git clone --recurse-submodules https://github.com/nil0x9/flash-muon.git pip install -e ./ 具体做法 Muon优化器的核心机制是通过Newton-Schulz迭代法来代替SVD分解,实现GPU友好的$\boldsymbol{U}\boldsymbol{V}^T$计算(之前笔者写过这篇blog介绍Muon背后的数学原理)。 Newton-Schulz迭代法的核心公式如下所示,通过指定合适的多项式系数$a,b,c$,将SVD分解的 $\boldsymbol{U}\boldsymbol{\Sigma}\boldsymbol{V}^T$中的$\boldsymbol{\Sigma}$在迭代中消除 $$ \boldsymbol{X}\leftarrow a\boldsymbol{X} + b(\boldsymbol{X}\boldsymbol{X}^\top)\boldsymbol{X} + c(\boldsymbol{X}\boldsymbol{X}^\top)(\boldsymbol{X}\boldsymbol{X}^\top)\boldsymbol{X} $$ 其中$\boldsymbol{X}$是梯度矩阵或梯度的动量矩阵。在Jordan的实现 中这个迭代过程由如下的PyTorch代码实现: for _ in range(steps): A = X @ X.T B = b * A + c * A @ A X = a * X + B @ X 注意到这里提前计算了$\boldsymbol{A}=\boldsymbol{X}\boldsymbol{X}^\top$来减少重复计算。由于$\boldsymbol{A}$是一个对称矩阵,则循环内第二行的$\boldsymbol{AA} = \boldsymbol{AA}^\top$。这里我们可以看到一个被使用了两次的运算lambda x: matmul(x, x.T),如果这个算子可以有比matmul更高效的实现,就可以实现更快的Newton-Schulz迭代。 幸运的是$\boldsymbol{X}\boldsymbol{X}^\top$是一个对称矩阵,理论上我们可以只计算上三角部分,下三角部分可以直接复用下三角部分的结果。 这一点在Laker Newhouse的这篇文章中有提及,但是可惜的是,他们并没有实现一个可用的kernel版本。 笔者沿着这个思路实现了一个可用的CUDA kernel。它的设计思路其实非常直观(也可能比较clumsy,欢迎批评指正)——在一般的GEMM算子中,每个线程会从全局内存load计算一定分块的矩阵算元到自己的寄存器内,在计算完对应分块的结果矩阵后,将这部分的结果store到输出算元对应的全局内存空间。GEMM的相关分析教程实在是太多了,笔者就不赘述了。——而对于$\boldsymbol{X}\boldsymbol{X}^\top$这种对称运算,笔者的让下三角对应的线程块在kernel一开始直接退出,上三角部分的线程在计算结束后做完一般的store操作后,根据线程和线程块的id,找到对应的下三角部分的全局内存空间,将当前线程对应的结果分块转置后store到下三角的分块。 用图像表达如下: matmul_transpose的kernel设计 ...

2025-04-25 · Tianyang Lin

谱条件:如何衡量神经网络参数空间中的距离?

Prologue 如何衡量神经网络参数空间中的距离(i.e., 选取合适的范数)?这个问题我们之前已经在这篇文章中有所涉及,本文是更完整的拓展。 首先应该明确为什么这个问题很重要?因为在梯度下降这样的优化算法框架下,总是依赖目标参数的某种距离的衡量。例如,当目标参数可以对应到欧几里得空间的坐标时,就可以得到局部目标函数下降最快的方向是梯度负方向。 尽管在深度学习优化中,直接使用欧几里得范数(衡量参数空间的距离)非常吸引人,并且实际上也有效。但是,这种做法会丢失结构信息,因为它实质上是将参数看做一个扁平的向量。那么得到的优化方向就可能是相对低效的。 对于常见的神经网络而言,对参数矩阵使用谱范数似乎是个更好的选择[1][2]。使用谱范数的相关约束,可以引出在现在的大规模模型训练中非常重要的两个工作: $\mu P$ (Maximal Update Parameterization):使用特定的初始化与学习率的参数化,可以做到在小模型上调节部分超参数,迁移到同构的大模型。 Muon优化器的更新规则,即对梯度矩阵做SVD分解,消除奇异值矩阵后将$-\boldsymbol{U}\boldsymbol{V}^\top$作为对应参数的更新方向。 从而涵盖了给定神经网络结构下的参数初始化、参数更新的步长和方向三个方面(模型成功训练的三个基本要素)。 feature learning的谱条件 为什么需要使用谱范数来衡量参数的大小和参数更新步长呢?在Greg Yang的系列工作中,考虑的主要是feature learning的问题。Yang在他的TP4中发现,在标准参数化或者NTK参数化的设定下,无穷宽的神经网络没有办法学习特征(一次更新后特征与初始化状态无异),这就与传统的linear transfer learning的智慧形成了悖论,因为如果没有feature learning,那么预训练就对后续的transfer没有任何增益。 于是为了达成非平凡的feature learning,就需要对神经网络的特征$\boldsymbol{h}_\ell(\boldsymbol{x})\in\mathbb{R}^{n_\ell}$的范数与一步更新的特征范数变化量$\boldsymbol{h}_\ell(\boldsymbol{x})$做如下的约束[1]: $$ \|\boldsymbol{h}_\ell\|_2=\Theta(\sqrt{n_\ell})\text{ and } \|\Delta\boldsymbol{h}_\ell\|_2=\Theta(\sqrt{n_\ell}), \text{ for }\ell=1,\ldots,L-1\tag{1} $$ 这个条件意味着,对于中间的任何特征向量而言,每个元素平均下来的范数是常数阶的,一次更新后的变化量也是常数阶的(不会随着宽度$n_\ell$趋近无穷而爆炸或者弥散)。 对于常见的由一系列矩阵参数构成的神经网络而言,实现这样的约束,需要对参数矩阵做如下的谱条件约束: $$ \|\boldsymbol{W}_\ell\|_2=\Theta\left(\sqrt{\frac{n_\ell}{n_{\ell-1}}}\right)\text{ and } \|\Delta\boldsymbol{W}_\ell\|_2=\Theta\left(\sqrt{\frac{n_\ell}{n_{\ell-1}}}\right), \text{ for }\ell=1,\ldots,L-1\tag{2} $$ 这里的$\|\cdot\|_2$是矩阵的谱范数,即从向量$\ell_2$空间映射到向量$\ell_2$空间的算子诱导范数 $$ \|\boldsymbol{A}\|_2=\underset{\boldsymbol{x}\in\mathbb{R}^n}{\max} \frac{\|\boldsymbol{A}\boldsymbol{x}\|_2}{\|\boldsymbol{x}\|_2}\text{ for } \boldsymbol{A}\in\mathbb{R}^{m\times n} $$ 这个谱条件能够保证公式$(1)$成立,证明思路主要利用到谱范数的性质$\|\boldsymbol{Av}\|_2\le\|\boldsymbol{A}\|_2\|\boldsymbol{v}\|_2$,证明出上述谱条件诱导的$\|\boldsymbol{h}_\ell\|_2,\|\Delta\boldsymbol{h}_\ell\|_2$的上界均为$\Theta(\sqrt{n_\ell})$,接着证明这个上界依概率总是取得。具体可以看这篇论文的第三节(相比Greg的TP系列,还算是很容易理解的)。 公式$(2)$中的第二个条件是很直观的,对于SGD优化器而言,直接按照每个参数矩阵的维度设定对应的学习率即可 $$ \eta_\ell = \Theta\left(\sqrt{\frac{n_\ell}{n_{\ell-1}}}\right) $$ 对于第一个条件$\boldsymbol{W}_\ell = \Theta\left(\sqrt{\frac{n_\ell}{n_{\ell-1}}}\right)$,需要修改初始化的方式——我们可以从一个标准正态分布中i.i.d.采样一个权重矩阵$\boldsymbol{W}'\in\mathbb{R}^{n_\ell\times n_{\ell-1}}$,接着做缩放$\boldsymbol{W} = \sigma\boldsymbol{W}'$(这里略去下标$\ell$)得到最终的初始化参数矩阵。 这个$\sigma$怎么确定呢?由于矩阵的stable rank按定义可以将矩阵的谱范数和Frobenius范数联系起来: $$ \text{stable-rank}(\boldsymbol{W}) = \frac{\|\boldsymbol{W}\|_F^2}{\|\boldsymbol{W}\|_2^2} $$ ...

2025-03-31 · Tianyang Lin

简记:Muon中设计Newton-Schulz迭代的系数?

上篇文章介绍了Muon等新兴深度学习优化器背后的原理,即约束参数矩阵的诱导范数下得到新的更新方向。 在Muon对参数更新方向$-\boldsymbol{U}\boldsymbol{V}^\top$的计算中用到了Newton-Schulz迭代方法,本质上是在寻找这样一个多项式函数 $$ f(x)=ax+bx^3+cx^5+\ldots $$ 使其满足对任意$x\in(0, 1]$,对$x$应用多次$f(\cdot)$,都能收敛到1附近。这里我们尝试设计一个能work的参数组合。 我的一个简单的想法是,设计一个多项式函数,使$x=1$是它的一个吸引不动点: 定义1(不动点)当$x_0$被函数$f(\cdot)$映射到自身,即$f(x_0)=x_0$时,称$x_0$是函数$f(\cdot)$的一个不动点。 定义2(吸引不动点)$f$的吸引不动点是$f$的不动点$x_0$使得,对在足够接近$x_0$的定义域中的任何$x$值而言,迭代函数序列$x,f(x),f(f(x)),f(f(f(x))),\ldots$收敛于$x_0$。 要令$x=1$是$f(x)$的一个吸引不动点,要满足如下的必要条件: $f(1)=1$ $|f'(1)|<1$ 使用这两个条件是无法确定具体的参数值$a,b,\ldots$的,但是对于三阶(参数包括$a,b$两个)或者五阶(参数包括$a,b,c$三个)的Newton-Schulz迭代,可以大大缩小搜索的空间。下面展开看下。 三阶迭代 先讨论三阶迭代的形式 $$ f(x)=ax+bx^3 $$ 代入上面的两个必要条件: $$ \begin{split} f(1)=a+b = 1\\ -1 < f'(1)=a+3b < 1\\ \end{split} $$ 根据第一个条件,可以把$b$用$1-a$重参数化,然后就有可行的条件 $$ 1 < a < 2 $$ 我们记五次迭代后的函数$\phi(x)=f(f(f(f(f(x)))))$,可视化看一下不同$a$取值下对应的情况(理想情况下,对于$(0,1]$区间内的$x$,曲线要尽可能接近$y=1$) 三阶迭代下,a取不同取值时对应的φ(x) 注意到在$a$接近1的时候,$\phi(x)$收敛到1附近的邻域是比较窄的,随着$a\to 2$,收敛到1附近的「邻域」范围逐渐拓宽,但在$a=2$附近,曲线开始出现一定的抖动。对于优化器而言,这样的局部近似的方差是可以容忍的,因此我们可以选取一个比较接近2的值作为$a$的参数,例如$a=1.99,b=-0.99$。 作为对比,在Bernstein & Newhouse 2024.中,作者给出的参数是$a=3/2,b=-1/2$。可以在下图中对照两种设定下的$\phi(x)$. 两种φ(x)对比 可以看到Bernstein给出的参数虽然更平滑地收敛于1,但是对于在0附近的初始$x$,普遍无法收敛到1。也就是说对于较小的奇异值对应的$\boldsymbol{u}_i, \boldsymbol{v}_i$,倾向于在更新中被忽略。 在$x=0$附近$\phi(x)$能否快速接近1,主要取决于参数$a$的大小,这是因为$\phi'(0)=a^5$。所以应该在尽可能保证$\phi(x)\approx 1,\forall x\in(0,1]$的同时,让$a$尽可能大。 五阶迭代 现在来考虑五阶迭代的形式 $$ f(x)=ax+bx^3+cx^5 $$ 代入上面的两个必要条件: $$ \begin{split} f(1)=a+b+c = 1\\ -1 < f'(1)=a+3b+5c < 1\\ \end{split} $$ ...

2025-03-08 · Tianyang Lin

从约束视角看深度学习优化若干新进展

在深度学习中最常用的优化方法是梯度下降方法及其变体。在过去很长一段时间中,Adam优化器是NLP社区的默认选择,在ViT出现之后,CV方面的工作也逐渐开始使用Adam和Adam的变体(在ViT之前,一种常见的观点是Adam不适用于Vision任务)。 最近Muon优化器在Kimi的新工作带动下又火了一把,相较于Adam优化器需要同时维护一阶、二阶矩估计量,Muon只需要维护一份梯度的动量估计,因此在大规模训练中有很大优势。最近笔者顺着Muon的reference看了Jeremy Bernstein在优化的一些文章,觉得很有意思,因此写这篇文章梳理一下这一系列工作的部分精要。本文的核心论点是:使用诱导范数来约束梯度更新,可以推导出最近的一些新出现的优化方法,这也可能是未来深度学习优化的一个有潜力的探索方向。 梯度下降(Recap) 当前深度学习优化算法的基石是梯度下降。之前笔者写过一篇拙文(自然梯度(二):黎曼距离下的最速下降)整理过梯度下降的推导,核心的结论是:当我们假设参数空间是一个欧几里得空间、参数的距离可以用欧几里得距离来衡量时,我们在某个点约束$\|\Delta\theta\|_2\le\epsilon(\epsilon>0)$时,$\Delta\theta$取梯度的反方向时可以让目标函数下降最多(具体的证明请参阅上述引文)。 使用梯度下降最大的问题是,它实际上忽略了模型的结构。换句话说,梯度下降相当于将模型所有参数展平为1维向量,并且用向量2范数来衡量每次更新的「步长」。这种抽象是实用的,但是也存在一定的问题。两组参数有可能在欧几里得空间中距离很近,但是诱导的模型输出空间距离很远。造成的结果就是更新的方向实际上不是目标函数下降最快的方向。 这个问题要如何解决呢?在自然梯度(二):黎曼距离下的最速下降中,我们介绍了自然梯度方法,即使用Fisher信息矩阵的逆作为梯度的pre-conditioner来矫正梯度下降的方向,从原理上是使用参数更新前后引导的概率分布的KL散度作为每次更新的步长约束。但是对于常见的深度神经网络来说,这样做仍然是不切实际的,因为FIM是一个$N\times N$的大矩阵(其中$N$是参数量),对于这么大的矩阵存储或求逆都是很难做到的。 诱导范数作为步长约束 是否有一种更「廉价」的方法,可以考虑模型的参数结构,同时将参数的变化对于输出的影响作为约束呢? 幸运的是,对于当下最流行的神经网络(e.g., Transformer)而言,模型往往可以拆解为很多小模块,其中最常见的是Linear模块(线性映射,这里忽略bias term) $$ f(\boldsymbol{x};\boldsymbol{W})=\boldsymbol{Wx},\ \boldsymbol{W}\in\mathbb{R}^{n\times m},\boldsymbol{x}\in \mathbb{R}^{m}\\ $$ 在标准的Transformer中,Attention、FFN、LM分类器都是由Linear模块组成的,Embedding从数学原理上也是输入为one-hot encoding的线性映射。假设现在对于某个Linear模块的参数$\boldsymbol{W}$做$\Delta\boldsymbol{W}$的更新($\boldsymbol{W}'\leftarrow \boldsymbol{W}+\Delta\boldsymbol{W}$),我们需要衡量这个更新对于最终输出的影响是多少(从而可以约束这个影响)。由于神经网络比较复杂,衡量$\Delta\boldsymbol{W}$对于最终目标函数的影响是相对繁琐的,但我们可以退而求其次,衡量$\Delta\boldsymbol{W}$对于这个Linear模块的输出$\boldsymbol{Wx}$的影响。 考虑线性模块的输入与输出空间的距离都使用欧几里得范数$\|\cdot\|_{\ell_2}$衡量,那么这个约束可以通过如下不等式实现 $$ \|\Delta\boldsymbol{W}\boldsymbol{x}\|_{\ell_2} \le {\color[rgb]{0, 0.5, 0.8}{\|\Delta\boldsymbol{W}\|_{\ell_2\to\ell_2}}}\|\boldsymbol{x}\|_{\ell_2} $$ 这里的${\color[rgb]{0, 0.5, 0.8}{\|\Delta\boldsymbol{W}\|_{\ell_2\to\ell_2}}}$是矩阵2范数。这个不等式告诉我们,如果约束了参数更新量的谱范数(不等式右侧),也就约束了更新前后这个线性模块输出的变化量。 假设现在需要优化的神经网络是由一系列的线性模块堆叠组成(e.g., MLP),我们可以参照梯度下降的推导构造如下的更新1 $$ \underset{\Delta\boldsymbol{W}_1,\ldots,\Delta\boldsymbol{W}_L}{\text{arg min}}\left[ \sum_{l=1}^L {\langle \boldsymbol{G}_l, \Delta\boldsymbol{W}_l \rangle}_F + \frac{\lambda}{2}\max_{l=1}^L{\color[rgb]{0, 0.5, 0.8}{\|\Delta\boldsymbol{W}_l\|^2_{\ell_2\to\ell_2}}} \right]\\ $$ 这里$\boldsymbol{G}_l$表示$\boldsymbol{W}_l$对应的梯度矩阵(布局与原参数矩阵相同),${\langle \cdot, \cdot \rangle}_F$表示Frobenius内积(对矩阵而言,逐元素相乘求和)。这里之所以使用$\max_{l=1}^L$(而不是直接求和),是因为我们引入这个约束时希望目标函数在$\Delta\boldsymbol{W}_l$变化下,能够保持平滑的性质2,因此需要bound所有参数矩阵更新量的谱范数的最大值。 我们来逐步推导这个最小值成立时的$\Delta\boldsymbol{W}_1,\ldots,\Delta\boldsymbol{W}_L$取值3。为了方便,把每个$\Delta\boldsymbol{W}_l$拆解成大小和方向两部分:$\Delta\boldsymbol{W}_l=c_l\boldsymbol{T}_l(c_l\triangleq\|\Delta\boldsymbol{W}_l\|_{\ell_2\to\ell_2})$ (为了可读性,下面的$\|\cdots\|$均表示谱范数$\|\cdot\|_{\ell_2\to\ell_2}$) $$ \begin{align} &\underset{\Delta\boldsymbol{W}_1,\ldots,\Delta\boldsymbol{W}_L}{\text{min}}\left[ \sum_{l=1}^L {\langle \boldsymbol{G}_l, \Delta\boldsymbol{W}_l \rangle}_F + \frac{\lambda}{2}\max_{l=1}^L\|\Delta\boldsymbol{W}_l\|^2 \right]\\ &=\underset{c_1,\ldots,c_L\ge 0}{\text{min}}\left[ \sum_{l=1}^L c_l\min_{\|\boldsymbol{T}_l\|=1}{\langle \boldsymbol{G}_l, \boldsymbol{T}_l \rangle}_F + \frac{\lambda}{2}\max_{l=1}^Lc_l^2 \right]\\ &=\underset{c_1,\ldots,c_L\ge 0}{\text{min}}\left[ -\sum_{l=1}^L c_l\max_{\|\boldsymbol{T}_l\|=1}{\langle \boldsymbol{G}_l, \boldsymbol{T}_l \rangle}_F + \frac{\lambda}{2}\max_{l=1}^Lc_l^2 \right]\\ &=\underset{c_1,\ldots,c_L\ge 0}{\text{min}}\left[ -\sum_{l=1}^L c_l \|\boldsymbol{G}_l\|_* + \frac{\lambda}{2}\max_{l=1}^Lc_l^2 \right]\quad\triangleright\|\cdot\|_*\text{表示核范数}\\ &=\underset{\eta\ge 0}{\text{min}}\left[ -\sum_{l=1}^L \eta\|\boldsymbol{G}_l\|_* + \frac{\lambda}{2}\max_{l=1}^L \eta^2 \right]\tag{1}\\ \end{align} $$ ...

2025-03-05 · Tianyang Lin

为什么LLM一般使用较大的权重衰减系数?

最近在阅读Muon is Scalable for LLM Training这篇文章的时候注意到他们使用无权重衰减(weight decay)版本的Muon优化LLM的时候,优化器的收敛优势会随着训练过程逐渐消失,又看到@小明同学在评论区提到的一个细节,很多开源的LLM在技术报告中都提到了使用0.1作为权重衰减的系数,觉得是个比较有意思的发现。结合Kimi的文章中关于bf16的简单陈述,笔者在本文中稍微展开讲下,权重衰减对于LLM的低精度训练中有什么作用。 首先把结论放在前面:除了一般认知中的正则化作用,权重衰减也可能降低精度损失的风险——对于计算机的浮点数而言,绝对值越大,精度越低。对于低精度/混合精度训练而言,使用权重衰减可以控制参数的绝对值范围,从而保证模型参数不落入低精度的数值区间。 浮点数的存储与精度 上述结论主要与浮点数在计算机内的存储形式有关。学过计算机的一些基本课程的读者可能有印象,浮点数的存储是二进制的形式,分为符号位、指数位和尾数位三段。深度学习中常见的浮点数协议(fp32、fp16、bf16、tf32)的区别在于指数位和尾数位的比特数量不同。由于浮点数是一个「指数」的形式,因此它在实数空间的分布是不均匀的。 这里我们考虑规范数的情形(指数位非全0),做一点分析。假设符号位、指数位、尾数位(mantissa)的二进制编码分别是$S$、$E$、$M$,那么对应的浮点数为: $$ \text{value}=(-1)^S\times 1.M\times 2^{E-\text{bias}} $$ 在单精度fp32标准中,$\text{bias}$取${01111111}_{2}=127_{10}$ . fp32浮点数的一个例子 例如在图中的例子中,$S=0$,$E=\underbrace{00\cdots0}_{7\ 0's}1$,$M=\underbrace{00\cdots0}_{22\ 0's}1$,相应的值为 $$ \begin{align} &{-1}^0\times 1.\underbrace{00\cdots0}_{22\ 0's}1_2\times 2^{00000001_2-{01111111}_{2}}\\ &\approx [1.175494490952134\times 10^{-38}]_{10} \end{align} $$ 现在我们来考虑不同的数值范围内的浮点数精度。对于区间范围$[2^{x}, 2^{x+1}],\forall -126\le x\le127$(这里的x已经是经过-bias之后得到的最终指数),我们希望在给定任意浮点数$y\in[2^{x}, 2^{x+1}]$的基础上增加一个最小量$\varepsilon$(即区间内两个浮点数的最小间隔),这个增加的过程是通过操纵二进制编码实现的,那么最小间隔只能是通过在$y$的尾数部分加上 $$ 0.\underbrace{00\cdots0}_{22\ 0's}1 $$ 来实现,对应的最小间隔是 $$ \begin{align} \varepsilon &= 0.\underbrace{00\cdots0}_{22\text{ 0's}}1_2\times 2^{x}\\ &=2^{-23}\times2^{x}=2^{x-23}\\ \end{align} $$ 如果你将上面的公式带入不同的指数位,可以验证与Wikipedia中给出的这张表的Gap是吻合的: 不同指数位下的最小精度,来源:Wikipedia 这个计算可以拓展到其他的精度格式: 对于半精度格式fp16而言,其尾数位有10位,对应的最小间隔是$2^{x-10}$; 对于半精度格式bf16而言,其尾数位有7位,对应的最小间隔是$2^{x-7}$; 从这里也可以看到,bf16相比fp16虽然拓宽了表示范围,但是减少了精度(同样数值范围内的最小间隔更宽了)。 从这里我们得到了一个结论:计算机存储的浮点数之间的最小间隔随着浮点数绝对值数值增加,指数级地增大,换言之,浮点数(绝对值)数值越大,精度越低。并且这个问题对于fp16或bf16格式的浮点数,问题要更加显著。 这个结论的另一个引申的问题是舍入误差,假如一个较大的浮点数和一个较小的浮点数相加,由于浮点数的加法(减法过程相当于取补码后相加,结论是类似的)过程需要先将两个数的指数位对齐,因此绝对值较小的数字的尾数的最后几位数字可能会在加法中丢失。这里我们举一个极端的例子来说明。 假设浮点数存储为fp32格式(8位指数、23位尾数)。 $$ \begin{align} x=(-1)^0\times 1.0_2\times 2^{-1}\\ y=(-1)^0\times 1.0_2\times 2^{-25}\\ \end{align} $$ ...

2025-02-26 · Tianyang Lin