2 Jan

为什么梯度裁剪的默认模长是1?

我们知道,梯度裁剪(Gradient Clipping)是让模型训练更加平稳的常用技巧。常用的梯度裁剪是根据所有参数的梯度总模长来对梯度进行裁剪,其运算可以表示为
\begin{equation}\text{clip}(\boldsymbol{g},\tau)=\left\{\begin{aligned}&\boldsymbol{g}, &\Vert\boldsymbol{g}\Vert\leq \tau \\
&\frac{\tau}{\Vert\boldsymbol{g}\Vert}\boldsymbol{g},&\Vert\boldsymbol{g}\Vert > \tau
\end{aligned}\right.\end{equation}
这样一来,$\text{clip}(\boldsymbol{g},\tau)$保持跟$\boldsymbol{g}$相同的方向,但模长不超过$\tau$。注意这里的$\Vert\boldsymbol{g}\Vert$是整个模型所有的参数梯度放在一起视为单个向量所算的模长,也就是所谓的Global Gradient Norm。

不知道大家有没有留意到一个细节:不管是数百万参数还是数百亿参数的模型,$\tau$的取值在很多时候都是1。这意味着什么呢?是单纯地复用默认值,还是背后隐含着什么深刻的原理呢?

点击阅读全文...

25 Dec

从谱范数梯度到新式权重衰减的思考

在文章《Muon优化器赏析:从向量到矩阵的本质跨越》中,我们介绍了一个名为“Muon”的新优化器,其中一个理解视角是作为谱范数正则下的最速梯度下降,这似乎揭示了矩阵参数的更本质的优化方向。众所周知,对于矩阵参数我们经常也会加权重衰减(Weight Decay),它可以理解为$F$范数平方的梯度,那么从Muon的视角看,通过谱范数平方的梯度来构建新的权重衰减,会不会能起到更好的效果呢?

那么问题来了,谱范数的梯度或者说导数长啥样呢?用它来设计的新权重衰减又是什么样的?接下来我们围绕这些问题展开。

基础回顾

谱范数(Spectral Norm),又称“$2$范数”,是最常用的矩阵范数之一,相比更简单的$F$范数(Frobenius Norm),它往往能揭示一些与矩阵乘法相关的更本质的信号,这是因为它定义上就跟矩阵乘法相关:对于矩阵参数$\boldsymbol{W}\in\mathbb{R}^{n\times m}$,它的谱范数定义为

点击阅读全文...

10 Dec

Muon优化器赏析:从向量到矩阵的本质跨越

随着LLM时代的到来,学术界对于优化器的研究热情似乎有所减退。这主要是因为目前主流的AdamW已经能够满足大多数需求,而如果对优化器“大动干戈”,那么需要巨大的验证成本。因此,当前优化器的变化,多数都只是工业界根据自己的训练经验来对AdamW打的一些小补丁。

不过,最近推特上一个名为“Muon”的优化器颇为热闹,它声称比AdamW更为高效,且并不只是在Adam基础上的“小打小闹”,而是体现了关于向量与矩阵差异的一些值得深思的原理。本文让我们一起赏析一番。

Muon与AdamW效果对比(来源:推特@Yuchenj_UW)

Muon与AdamW效果对比(来源:推特@Yuchenj_UW)

点击阅读全文...

29 Nov

从Hessian近似看自适应学习率优化器

这几天在重温去年的Meta的一篇论文《A Theory on Adam Instability in Large-Scale Machine Learning》,里边给出了看待Adam等自适应学习率优化器的新视角:它指出梯度平方的滑动平均某种程度上近似于在估计Hessian矩阵的平方,从而Adam、RMSprop等优化器实际上近似于二阶的Newton法。

这个角度颇为新颖,而且表面上跟以往的一些Hessian近似有明显的差异,因此值得我们去学习和思考一番。

牛顿下降

设损失函数为$\mathcal{L}(\boldsymbol{\theta})$,其中待优化参数为$\boldsymbol{\theta}$,我们的优化目标是
\begin{equation}\boldsymbol{\theta}^* = \mathop{\text{argmin}}_{\boldsymbol{\theta}} \mathcal{L}(\boldsymbol{\theta})\label{eq:loss}\end{equation}
假设$\boldsymbol{\theta}$的当前值是$\boldsymbol{\theta}_t$,Newton法通过将损失函数展开到二阶来寻求$\boldsymbol{\theta}_{t+1}$:
\begin{equation}\mathcal{L}(\boldsymbol{\theta})\approx \mathcal{L}(\boldsymbol{\theta}_t) + \boldsymbol{g}_t^{\top}(\boldsymbol{\theta} - \boldsymbol{\theta}_t) + \frac{1}{2}(\boldsymbol{\theta} - \boldsymbol{\theta}_t)^{\top}\boldsymbol{\mathcal{H}}_t(\boldsymbol{\theta} - \boldsymbol{\theta}_t)\end{equation}

点击阅读全文...

14 Nov

当Batch Size增大时,学习率该如何随之变化?

随着算力的飞速进步,有越多越多的场景希望能够实现“算力换时间”,即通过堆砌算力来缩短模型训练时间。理想情况下,我们希望投入$n$倍的算力,那么达到同样效果的时间则缩短为$1/n$,此时总的算力成本是一致的。这个“希望”看上去很合理和自然,但实际上并不平凡,即便我们不考虑通信之类的瓶颈,当算力超过一定规模或者模型小于一定规模时,增加算力往往只能增大Batch Size。然而,增大Batch Size一定可以缩短训练时间并保持效果不变吗?

这就是接下来我们要讨论的话题:当Batch Size增大时,各种超参数尤其是学习率该如何调整,才能保持原本的训练效果并最大化训练效率?我们也可以称之为Batch Size与学习率之间的Scaling Law。

方差视角

直觉上,当Batch Size增大时,每个Batch的梯度将会更准,所以步子就可以迈大一点,也就是增大学习率,以求更快达到终点,缩短训练时间,这一点大体上都能想到。问题就是,增大多少才是最合适的呢?

点击阅读全文...

6 Aug

通向最优分布之路:概率空间的最小化

当要求函数的最小值时,我们通常会先求导函数然后寻找其零点,比较幸运的情况下,这些零点之一正好是原函数的最小值点。如果是向量函数,则将导数改为梯度并求其零点。当梯度零点不易求得时,我们可以使用梯度下降来逐渐逼近最小值点。

以上这些都是无约束优化的基础结果,相信不少读者都有所了解。然而,本文的主题是概率空间中的优化,即目标函数的输入是一个概率分布,这类目标的优化更为复杂,因为它的搜索空间不再是无约束的,如果我们依旧去求解梯度零点或者执行梯度下降,所得结果未必能保证是一个概率分布。因此,我们需要寻找一种新的分析和计算方法,以确保优化结果能够符合概率分布的特性。

对此,笔者一直以来也感到颇为头疼,所以近来决定”痛定思痛“,针对概率分布的优化问题系统学习了一番,最后将学习所得整理在此,供大家参考。

点击阅读全文...

13 May

缓存与效果的极限拉扯:从MHA、MQA、GQA到MLA

前几天,幻方发布的DeepSeek-V2引起了大家的热烈讨论。首先,最让人哗然的是1块钱100万token的价格,普遍比现有的各种竞品API便宜了两个数量级,以至于有人调侃“这个价格哪怕它输出乱码,我也会认为这个乱码是一种艺术”;其次,从模型的技术报告看,如此便宜的价格背后的关键技术之一是它新提出的MLA(Multi-head Latent Attention),这是对GQA的改进,据说能比GQA更省更好,也引起了读者的广泛关注。

接下来,本文将跟大家一起梳理一下从MHA、MQA、GQA到MLA的演变历程,并着重介绍一下MLA的设计思路。

MHA

MHA(Multi-Head Attention),也就是多头注意力,是开山之作《Attention is all you need》所提出的一种Attention形式,可以说它是当前主流LLM的基础工作。在数学上,多头注意力MHA等价于多个独立的单头注意力的拼接,假设输入的(行)向量序列为$\boldsymbol{x}_1,\boldsymbol{x}_2,\cdots,\boldsymbol{x}_l$,其中$\boldsymbol{x}_i\in\mathbb{R}^d$,那么MHA可以形式地记为

点击阅读全文...

14 Jan

旁门左道之如何让Python的重试代码更加优雅

这篇文章我们讨论一个编程题:如何更优雅地在Python中实现重试。

在文章《新年快乐!记录一下 Cool Papers 的开发体验》中,笔者分享了开发Cool Papers的一些经验,其中就提到了Cool Papers所需要的一些网络通信步骤。但凡涉及到网络通信,就有失败的风险(谁也无法保证网络不会间歇性抽风),所以重试是网络通信的基本操作。此外,当涉及到多进程、数据库、硬件交互等操作时,通常也需要引入重试机制。

在Python中,实现重试并不难,但如何更加简单而又不失可读性地实现重试,还是有一定技巧的。接下来笔者分享一下自己的尝试。

循环重试

完整的重试流程大致上包含循环重试、异常处理、延时等待、后续操作等部分,其标准写法就是用for循环,用“try ... except ...”来捕捉异常,一个参考代码是:

点击阅读全文...