TL;DR
笔者通过自定义$XX^\top$
算子,加速Muon优化器中的核心操作——Newton-Schulz迭代,算子单测在8192维度上相比原生实现计算时间降低约一半,端到端的Newton-Schulz迭代运行时间降低至原来的0.71倍。相关代码已经发布在:
https://github.com/nil0x9/flash-muon
读者可以通过如下方式来试用优化版本的Muon实现或者核心算子。
git clone --recurse-submodules https://github.com/nil0x9/flash-muon.git
pip install -e ./
具体做法
Muon优化器的核心机制是通过Newton-Schulz迭代法来代替SVD分解,实现GPU友好的$\boldsymbol{U}\boldsymbol{V}^T$
计算(之前笔者写过这篇blog介绍Muon背后的数学原理)。
Newton-Schulz迭代法的核心公式如下所示,通过指定合适的多项式系数$a,b,c$
,将SVD分解的 $\boldsymbol{U}\boldsymbol{\Sigma}\boldsymbol{V}^T$
中的$\boldsymbol{\Sigma}$
在迭代中消除
$$ \boldsymbol{X}\leftarrow a\boldsymbol{X} + b(\boldsymbol{X}\boldsymbol{X}^\top)\boldsymbol{X} + c(\boldsymbol{X}\boldsymbol{X}^\top)(\boldsymbol{X}\boldsymbol{X}^\top)\boldsymbol{X} $$
其中$\boldsymbol{X}$
是梯度矩阵或梯度的动量矩阵。在Jordan的实现 中这个迭代过程由如下的PyTorch代码实现:
for _ in range(steps):
A = X @ X.T
B = b * A + c * A @ A
X = a * X + B @ X
注意到这里提前计算了$\boldsymbol{A}=\boldsymbol{X}\boldsymbol{X}^\top$
来减少重复计算。由于$\boldsymbol{A}$
是一个对称矩阵,则循环内第二行的$\boldsymbol{AA} = \boldsymbol{AA}^\top$
。这里我们可以看到一个被使用了两次的运算lambda x: matmul(x, x.T)
,如果这个算子可以有比matmul
更高效的实现,就可以实现更快的Newton-Schulz迭代。
幸运的是$\boldsymbol{X}\boldsymbol{X}^\top$
是一个对称矩阵,理论上我们可以只计算上三角部分,下三角部分可以直接复用下三角部分的结果。
这一点在Laker Newhouse的这篇文章中有提及,但是可惜的是,他们并没有实现一个可用的kernel版本。
笔者沿着这个思路实现了一个可用的CUDA kernel。它的设计思路其实非常直观(也可能比较clumsy,欢迎批评指正)——在一般的GEMM算子中,每个线程会从全局内存load计算一定分块的矩阵算元到自己的寄存器内,在计算完对应分块的结果矩阵后,将这部分的结果store到输出算元对应的全局内存空间。GEMM的相关分析教程实在是太多了,笔者就不赘述了。——而对于$\boldsymbol{X}\boldsymbol{X}^\top$
这种对称运算,笔者让下三角对应的线程块在kernel一开始直接退出,上三角部分的线程在计算结束后做完一般的store操作后,根据线程和线程块的id,找到对应的下三角部分的全局内存空间,将当前线程对应的结果分块转置后store到下三角的分块。
用图像表达如下:
为了防止warp divergence,笔者这里采取了一种偷懒的做法,直接以线程块为粒度做提早退出,这样总能保证同一个warp内执行的指令尽量对齐。
CUDA kernel的具体实现就不赘述了,相关代码可以在flash-muon中找到,CUDA版本在这里,triton版本在这里。
效果
笔者在RTX 4090设备上使用$\{2^i\}_{i=9}^{13}$
的不同维度的方阵上测试了这个自定义kernel的单测速度,以及对应的插入到Newton-Schulz迭代后的整体速度。如下图所示:
可以看到在8192的维度下,单测运行时间降低了一半左右,插入到NS迭代(5步)时,运行时间约为基线版本的0.71倍。
笔者也在不同的设备上测试了相应的执行速度(triton实现):
device | dim | flash | torch | compiled |
---|---|---|---|---|
H800 | 1024 | 0.0124 | 0.0112 | 0.0107 |
H800 | 2048 | 0.0322 | 0.0384 | 0.0384 |
H800 | 4096 | 0.1838 | 0.2955 | 0.3000 |
H800 | 8192 | 1.4528 | 2.2643 | 2.2804 |
H20 | 1024 | 0.0164 | 0.0275 | 0.0275 |
H20 | 2048 | 0.0746 | 0.1588 | 0.1587 |
H20 | 4096 | 0.5068 | 1.0431 | 1.0431 |
H20 | 8192 | 3.9265 | 7.9691 | 7.9508 |
A100 | 1024 | 0.0191 | 0.0228 | 0.0232 |
A100 | 2048 | 0.0689 | 0.1166 | 0.1164 |
A100 | 4096 | 0.3733 | 0.6644 | 0.6649 |
A100 | 8192 | 2.9815 | 5.1604 | 5.2858 |
4090 | 1024 | 0.0208 | 0.0213 | 0.0208 |
4090 | 2048 | 0.0823 | 0.1098 | 0.1095 |
4090 | 4096 | 0.5249 | 0.8535 | 0.8546 |
4090 | 8192 | 3.5689 | 6.7631 | 6.7869 |
结语
本文介绍了一种使用自定义CUDA kernel的方式来提升Muon优化器的运行速度的方法,相关代码开源在nil0x9/flash-muon,欢迎读者Star、试用,并给笔者一些反馈(笔者水平有限,实现还存在诸多问题)!
值得一提的是,本文的这种计算[email protected]
的kernel,其实在cuBLAS中有类似的API,叫做SYRK
,但可惜的是,他们并没有提供半精度的版本,因此笔者才自己实现了这样一个功能。(番外:如何合理设计kernel(以及thread-data布局)实现cublas的syrk函数(XX’)?)