最近在阅读Muon is Scalable for LLM Training这篇文章的时候注意到他们使用无权重衰减(weight decay)版本的Muon优化LLM的时候,优化器的收敛优势会随着训练过程逐渐消失,又看到@小明同学在评论区提到的一个细节,很多开源的LLM在技术报告中都提到了使用0.1作为权重衰减的系数,觉得是个比较有意思的发现。结合Kimi的文章中关于bf16的简单陈述,笔者在本文中稍微展开讲下,权重衰减对于LLM的低精度训练中有什么作用。
首先把结论放在前面:除了一般认知中的正则化作用,权重衰减也可能降低精度损失的风险——对于计算机的浮点数而言,绝对值越大,精度越低。对于低精度/混合精度训练而言,使用权重衰减可以控制参数的绝对值范围,从而保证模型参数不落入低精度的数值区间。
浮点数的存储与精度
上述结论主要与浮点数在计算机内的存储形式有关。学过计算机的一些基本课程的读者可能有印象,浮点数的存储是二进制的形式,分为符号位、指数位和尾数位三段。深度学习中常见的浮点数协议(fp32、fp16、bf16、tf32)的区别在于指数位和尾数位的比特数量不同。由于浮点数是一个「指数」的形式,因此它在实数空间的分布是不均匀的。
这里我们考虑规范数的情形(指数位非全0),做一点分析。假设符号位、指数位、尾数位(mantissa)的二进制编码分别是$S$
、$E$
、$M$
,那么对应的浮点数为:
$$ \text{value}=(-1)^S\times 1.M\times 2^{E-\text{bias}} $$
在单精度fp32标准中,$\text{bias}$
取${01111111}_{2}=127_{10}$
.
例如在图中的例子中,$S=0$
,$E=\underbrace{00\cdots0}_{7\ 0's}1$
,$M=\underbrace{00\cdots0}_{22\ 0's}1$
,相应的值为
$$ \begin{align} &{-1}^0\times 1.\underbrace{00\cdots0}_{22\ 0's}1_2\times 2^{00000001_2-{01111111}_{2}}\\ &\approx [1.175494490952134\times 10^{-38}]_{10} \end{align} $$
现在我们来考虑不同的数值范围内的浮点数精度。对于区间范围$[2^{x}, 2^{x+1}],\forall -126\le x\le127$
(这里的x已经是经过-bias之后得到的最终指数),我们希望在给定任意浮点数$y\in[2^{x}, 2^{x+1}]$
的基础上增加一个最小量$\varepsilon$
(即区间内两个浮点数的最小间隔),这个增加的过程是通过操纵二进制编码实现的,那么最小间隔只能是通过在$y$
的尾数部分加上
$$ 0.\underbrace{00\cdots0}_{22\ 0's}1 $$
来实现,对应的最小间隔是
$$ \begin{align} \varepsilon &= 0.\underbrace{00\cdots0}_{22\text{ 0's}}1_2\times 2^{x}\\ &=2^{-23}\times2^{x}=2^{x-23}\\ \end{align} $$
如果你将上面的公式带入不同的指数位,可以验证与Wikipedia中给出的这张表的Gap是吻合的:
这个计算可以拓展到其他的精度格式:
- 对于半精度格式fp16而言,其尾数位有10位,对应的最小间隔是
$2^{x-10}$
; - 对于半精度格式bf16而言,其尾数位有7位,对应的最小间隔是
$2^{x-7}$
;
从这里也可以看到,bf16相比fp16虽然拓宽了表示范围,但是减少了精度(同样数值范围内的最小间隔更宽了)。
从这里我们得到了一个结论:计算机存储的浮点数之间的最小间隔随着浮点数绝对值数值增加,指数级地增大,换言之,浮点数(绝对值)数值越大,精度越低。并且这个问题对于fp16或bf16格式的浮点数,问题要更加显著。
这个结论的另一个引申的问题是舍入误差,假如一个较大的浮点数和一个较小的浮点数相加,由于浮点数的加法(减法过程相当于取补码后相加,结论是类似的)过程需要先将两个数的指数位对齐,因此绝对值较小的数字的尾数的最后几位数字可能会在加法中丢失。这里我们举一个极端的例子来说明。
假设浮点数存储为fp32格式(8位指数、23位尾数)。
$$ \begin{align} x=(-1)^0\times 1.0_2\times 2^{-1}\\ y=(-1)^0\times 1.0_2\times 2^{-25}\\ \end{align} $$
在执行浮点数加法的时候,指数位首先按照大的一个operand的指数对齐,也就是$y$
的指数变为$2^{-25}\to 2^{-1}$
,相应地需要将尾数(包括前面隐含的1)向右移动$-1-(-25)=24$
位。由于fp32的尾数位只有23位,因此在移动之后$y$
变成了$0.000\cdots0\times2^{-1}$
。也就是说$y$
在对齐的过程中直接变成了0!
读者可以运行如下的代码验证这一结论:
import numpy as np
x = np.float32(2**-1) # 0.5
y = np.float32(2**-25) # 0.0000001192092896
result = x + y
print(f"x = {x}, y = {y}")
print(f"x+y = {result}")
print(f"x == x+y? {x==result}")
# x = 0.5, y = 2.9802322387695312e-08
# x+y = 0.5
# x == x+y? True
对深度模型训练的影响
上面我们从浮点数的存储格式建立了「计算机浮点数的数值绝对值越大,则精度越低」的结论,并且引申到舍入误差的问题。接下来我们把这个现象带入到一个神经网络的训练过程中来看可能引发什么样的问题。
在神经网络的训练过程中,一个经典的训练过程是(以监督学习为例):
- (forward)给定当前参数
$\boldsymbol{\theta}$
和小批量数据$\boldsymbol{x},\boldsymbol{y}$
,计算损失函数$\mathcal{L}(\boldsymbol{\theta};\boldsymbol{x},\boldsymbol{y})$
; - (backward)反向传播,得到每个参数的梯度;
- (update)更新优化器状态(梯度的统计量,例如梯度动量,Adam中的一、二阶矩统计量),更新模型参数
$\boldsymbol{\theta}$
.
对于这个过程而言,前一节中的结论会造成两方面的结果:
- 使用低精度浮点数保存和更新模型参数时,如果模型参数绝对值比较大,而更新的步幅比较小,那么更新会由于舍入误差而失效;
- 从一个高精度的模型转化为低精度模型的时候,参数的绝对值越大,则丢失的精度越多。
如果读者有训练一些LM或者其他神经网络的经验,可能会发现Transformer这类深度模型在训练过程中,参数的范数会随着训练过程中逐渐增大。Merrill 2020指出对于T5模型而言,其参数范数的增长正比于$\sqrt{t}$
($t$
是更新次数). 因此,对于训练后期,随着参数的量级逐渐变大,精度变差的风险也会增加。
值得指出的是,当下的混合精度训练范式一般会在低精度的权重之外,维护一份fp32的权重,优化器的states一般也会使用高精度版本,防止累加的过程中出现严重的舍入误差。而且在低精度的GEMM中,也会使用高精度的accumulator来存储分块内的内积。但是这些操作仍然不能完全杜绝由浮点数存储带来的精度问题。
例如,在模型更新了fp32的备份之后,还需要将fp32的权重转化为低精度的版本,参与后续的forward过程。由于浮点数的精度随着绝对值的增加而降低,因此参数的绝对值越大,在精度的转化中损失的精度也越多。此外,在前向和反向计算的过程中,激活值夜会存在类似的精度损失问题。
如果我们在训练过程中引入权重衰减:
$$ \theta^+\leftarrow \theta - \eta(\tilde{\Delta}+{\color[rgb]{0, 0.5, 0.8}{\lambda\theta}}) $$
其中$\tilde{\Delta}$
是优化器计算出来的权重更新量,$\lambda$
是权重衰减的系数,那么模型的权重的绝对值就可以得到一定的控制。除了提供一定的正则化效应之外,也能够降低由于模型的参数范数增长而导致的精度损失的风险。
总结
本文从浮点数的存储原理出发,建立了「数值越大,精度越低」的结论,从而(ad-hoc地)解释了LLM的训练对权重衰减的依赖。但是也需要指出的是,权重的范数增长与模型的结构是有一定关系的,这个规律不一定对所有模型成立。