TL;DR
笔者通过自定义$XX^\top$
算子,加速Muon优化器中的核心操作——Newton-Schulz迭代,算子单测在8192维度上相比原生实现计算时间降低约一半,端到端的Newton-Schulz迭代运行时间降低至原来的0.71倍。相关代码已经发布在:
https://github.com/nil0x9/flash-muon
读者可以通过如下方式来试用优化版本的Muon实现或者核心的CUDA算子。
git clone --recurse-submodules https://github.com/nil0x9/flash-muon.git
pip install -e ./
具体做法
Muon优化器的核心机制是通过Newton-Schulz迭代法来代替SVD分解,实现GPU友好的$\boldsymbol{U}\boldsymbol{V}^T$
计算(之前笔者写过这篇blog介绍Muon背后的数学原理)。
Newton-Schulz迭代法的核心公式如下所示,通过指定合适的多项式系数$a,b,c$
,将SVD分解的 $\boldsymbol{U}\boldsymbol{\Sigma}\boldsymbol{V}^T$
中的$\boldsymbol{\Sigma}$
在迭代中消除
$$ \boldsymbol{X}\leftarrow a\boldsymbol{X} + b(\boldsymbol{X}\boldsymbol{X}^\top)\boldsymbol{X} + c(\boldsymbol{X}\boldsymbol{X}^\top)(\boldsymbol{X}\boldsymbol{X}^\top)\boldsymbol{X} $$
其中$\boldsymbol{X}$
是梯度矩阵或梯度的动量矩阵。在Jordan的实现 中这个迭代过程由如下的PyTorch代码实现:
for _ in range(steps):
A = X @ X.T
B = b * A + c * A @ A
X = a * X + B @ X
注意到这里提前计算了$\boldsymbol{A}=\boldsymbol{X}\boldsymbol{X}^\top$
来减少重复计算。由于$\boldsymbol{A}$
是一个对称矩阵,则循环内第二行的$\boldsymbol{AA} = \boldsymbol{AA}^\top$
。这里我们可以看到一个被使用了两次的运算lambda x: matmul(x, x.T)
,如果这个算子可以有比matmul
更高效的实现,就可以实现更快的Newton-Schulz迭代。
幸运的是$\boldsymbol{X}\boldsymbol{X}^\top$
是一个对称矩阵,理论上我们可以只计算上三角部分,下三角部分可以直接复用下三角部分的结果。
这一点在Laker Newhouse的这篇文章中有提及,但是可惜的是,他们并没有实现一个可用的kernel版本。
笔者沿着这个思路实现了一个可用的CUDA kernel。它的设计思路其实非常直观(也可能比较clumsy,欢迎批评指正)——在一般的GEMM算子中,每个线程会从全局内存load计算一定分块的矩阵算元到自己的寄存器内,在计算完对应分块的结果矩阵后,将这部分的结果store到输出算元对应的全局内存空间。GEMM的相关分析教程实在是太多了,笔者就不赘述了。——而对于$\boldsymbol{X}\boldsymbol{X}^\top$
这种对称运算,笔者的让下三角对应的线程块在kernel一开始直接退出,上三角部分的线程在计算结束后做完一般的store操作后,根据线程和线程块的id,找到对应的下三角部分的全局内存空间,将当前线程对应的结果分块转置后store到下三角的分块。
用图像表达如下:
为了防止warp divergence,笔者这里采取了一种偷懒的做法,直接以线程块为粒度做提早退出,这样总能保证同一个warp内执行的指令尽量对齐。
CUDA kernel的具体实现就不赘述了,相关代码可以在https://github.com/nil0x9/flash-muon中的csrc/matmul_transpose.cu
文件中找到,笔者是基于@水木皇工仔的gemm实现的基础上,修正了predicate问题之后修改增加了早退和转置复制的逻辑得到的。
效果
笔者在RTX 4090设备上使用$\{2^i\}_{i=9}^{13}$
的不同维度的方阵上测试了这个自定义kernel的单测速度,以及对应的插入到Newton-Schulz迭代后的整体速度。如下图所示:
可以看到在8192的维度下,单测运行时间降低了一半左右,插入到NS迭代(5步)时,运行时间约为基线版本的0.71倍。
当然当前这种做法还有一定局限性,目前以线程块为最小粒度的早退策略,在矩阵维度较低的时候,由于SM调度不会打满,而每个kernel又比基线GEMM多了一系列转置拷贝的指令,因此速度反而达不到基线版本,因此这个实现建议在较大的梯度矩阵上使用才能获得应有的效率提升。
结语
本文介绍了一种使用自定义CUDA kernel的方式来提升Muon优化器的运行速度的方法,相关代码开源在nil0x9/flash-muon,欢迎读者Star、试用,并给笔者一些反馈(笔者水平有限,实现还存在诸多问题)!
值得一提的是,本文的这种计算[email protected]
的kernel,其实在cuBLAS中有类似的API,叫做SYRK
,但可惜的是,他们并没有提供半精度的版本,因此笔者才自己实现了这样一个功能。(番外:如何合理设计kernel(以及thread-data布局)实现cublas的syrk函数(XX’)?)