生成扩散模型漫谈(二十八):分步理解一致性模型
By 苏剑林 | 2024-12-18 | 25126位读者 |书接上文,在《生成扩散模型漫谈(二十七):将步长作为条件输入》中,我们介绍了加速采样的Shortcut模型,其对比的模型之一就是“一致性模型(Consistency Models)”。事实上,早在《生成扩散模型漫谈(十七):构建ODE的一般步骤(下)》介绍ReFlow时,就有读者提到了一致性模型,但笔者总感觉它更像是实践上的Trick,理论方面略显单薄,所以兴趣寥寥。
不过,既然我们开始关注扩散模型加速采样方面的进展,那么一致性模型就是一个绕不开的工作。因此,趁着这个机会,笔者在这里分享一下自己对一致性模型的理解。
熟悉配方 #
还是熟悉的配方,我们的出发点依旧是ReFlow,因为它大概是ODE式扩散最简单的理解方式。设x0∼p0(x0)是目标分布的真实样本,x1∼p1(x1)是先验分布的随机噪声,xt=(1−t)x0+tx1是加噪样本,那么ReFlow的训练目标是:
θ∗=argminθEt∼U[0,1],x0∼p0(x0),x1∼p1(x1)[w(t)‖vθ(xt,t)−(x1−x0)‖2]
其中w(t)是可调的权重。训练完成后可以通过求解dxt/dt=vθ∗(xt,t)来实现x1到x0的变换,从而完成采样。
需要指出的是,一致性模型的Noise Schedule是xt=x0+tx1(当t足够大时xt同样接近于纯噪声),跟ReFlow略有不同。不过本文的主要目的,是尝试一步步引导出跟一致性模型相同的训练思想和训练目标,笔者认为ReFlow的更好理解一些,所以还是按照ReFlow的来介绍,至于具体的训练细节大家按需自行调整就好。
利用xt=(1−t)x0+tx1,我们可以消去目标(1)中的x1:
θ∗=argminθEt∼U[0,1],x0∼p0(x0),x1∼p1(x1)[˜w(t)‖xt−tvθ(xt,t)⏟fθ(xt,t)−x0‖2]
其中˜w(t)=w(t)/t2。注意x0是真实样本,xt是加噪样本,所以ReFlow的训练目标实际上也是在去噪。预测干净样本的模型为fθ(xt,t)=xt−tvθ(xt,t),这个函数有一个重要特性是恒成立fθ(x0,0)=x0,这正是一致性模型的关键约束之一。
分步理解 #
接下来让我们一步步解构ReFlow的训练过程,试图从中找到更好的训练目标。首先我们将[0,1]等分为n份,每份大小为1/n,记tk=k/n,那么t就只需从有限集合{0,t1,t2,⋯,tn}均匀采样。当然我们也可以选择非均匀的离散化方式,这些都是非关键的细节问题。
由于t0=0是平凡的,我们从t1开始,第一步的训练目标是
θ∗1=argminθEx0∼p0(x0),x1∼p1(x1)[˜w(t1)‖fθ(xt1,t1)−x0‖2]
接着,考虑第二步的训练目标,还是按照(2)的话,那么应该是‖fθ(xt2,t2)−x0‖2的期望,但现在我们评估一个新目标:
θ∗2=argminθEx0∼p0(x0),x1∼p1(x1)[˜w(t2)‖fθ(xt2,t2)−fθ∗1(xt1,t1)‖2]
也就是说预测对象改为fθ∗1(xt1,t1)而不是x0。为什么要这样改呢?我们分可行性和必要性两方面来讨论。可行性方面,xt2相比xt1加了更多噪声,所以它去噪会更困难,换言之fθ∗2(xt2,t2)的复原程度是不如fθ∗1(xt1,t1)的,所以用fθ∗1(xt1,t1)替换掉x0作为第二步的训练目标完全是可行的。
可即便如此,那又有什么换的必要呢?答案是减少“轨迹交叉”。由于xtk=(1−tk)x0+tkx1,因此随着k的增大,xtk对x0的依赖会越来越弱,以至于两个不同的x0,它们对应的xtk会很接近,这时候还是以x0为预测目标的话,就会出现“一个输入,多个目标”的困境,这就是“轨迹交叉”。
面对这个困境,ReFlow的策略是事后蒸馏,因为预训练完后求解dxt/dt=vθ∗(xt,t)就可以得到很多(x0,x1)对,用这些配对的x0,x1去构建xt就能避免交叉。一致性模型的想法是把预测目标换成fθ∗k−1(xtk−1,tk−1),因为对于“同一x1、不同x0”,fθ∗k−1(xtk−1,tk−1)间的差异会比x0间的差异要小,所以也能减少交叉风险。
简单来说,就是fθ∗2(xt2,t2)预测fθ∗1(xt1,t1)比预测x0更容易,并且该达到的效果也能达到,所以调整了预测目标。类似地,我们可以写出
θ∗3=argminθEx0∼p0(x0),x1∼p1(x1)[˜w(t3)‖fθ(xt3,t3)−fθ∗2(xt2,t2)‖2]θ∗4=argminθEx0∼p0(x0),x1∼p1(x1)[˜w(t4)‖fθ(xt4,t4)−fθ∗3(xt3,t3)‖2]⋮θ∗n=argminθEx0∼p0(x0),x1∼p1(x1)[˜w(tn)‖fθ(xtn,tn)−fθ∗n−1(xtn−1,tn−1)‖2]
一致训练 #
现在我们已经完成了ReFlow模型的解构,并且得到了一个新的自认为更合理的训练目标,但代价是得到了n套参数θ∗1,θ∗2,⋯,θ∗n,这当然不是我们想要的,我们只想要一个模型。于是我们认为所有的θ∗i可以共用同一套参数,于是我们可以写出训练目标
θ∗=argminθEk∼[n],x0∼p0(x0),x1∼p1(x1)[˜w(tk)‖fθ(xtk,tk)−fθ∗(xtk−1,tk−1)‖2]
这里k∼[n]是指k从{1,2,⋯,n}中均匀采样。上式的问题是,θ∗是我们要求的参数,但它又出现在目标函数中,这显然是不科学的(知道θ∗了我还训练干嘛),因此必须修改上述目标使得它更为可行。
θ∗的意义是理论最优解,考虑到随着训练的推进,θ会慢慢逼近θ∗,所以在目标函数中我们可以将这个条件放宽为“超前解”,即它只要比当前的θ更好就行了。怎么构建“超前解”呢?一致性模型的做法是对历史权重进行EMA(Exponential Moving Average,指数滑动平均),这往往能得到一个更优秀的解,早些年我们在打比赛时就经常用到这个技巧。
因此,一致性模型最终的训练目标是:
θ∗=argminθEk∼[n],x0∼p0(x0),x1∼p1(x1)[˜w(tk)‖fθ(xtk,tk)−fˉθ(xtk−1,tk−1)‖2]
其中ˉθ是θ的EMA。这就是原论文中的“一致性训练(Consistency Training,CT)”。从实践上来看,我们也可以将‖⋅−⋅‖2换成更一般的度量d(⋅,⋅),以更贴合数据特性。
采样分析 #
由于我们是从ReFlow出发一步步“等价变换”过来的,所以训练完成后一种基本的采样方式就是跟ReFlow一样求解ODE
dxt/dt=vθ∗(xt,t)=xt−fθ∗(xt,t)t
当然,如果费那么大劲得到的是跟ReFlow一样的结果,那么就纯粹是瞎折腾了。幸运的是,一致性训练所得的模型,有一个重要的优势是可以使用更大的采样步长——甚至等于1的步长,这就可以实现单步生成:
x0=x1−vθ∗(x1,1)×1=fθ∗(x1,1)
理由是
‖fθ∗(x1,1)−x0‖=‖n∑k=1[fθ∗(xtk,tk)−fθ∗(xtk−1,tk−1)]‖≤n∑k=1‖fθ∗(xtk,tk)−fθ∗(xtk−1,tk−1)‖
可以看到,一致性训练相当于在优化‖fθ∗(x1,1)−x0‖的上界,当损失足够小时,意味着‖fθ∗(x1,1)−x0‖也足够小,因此可以一步生成。
可‖fθ∗(x1,1)−x0‖是原本ReFlow的训练目标,为什么直接优化它会不如优化它的上界呢?这又回到了“轨迹交叉”的问题了,直接训练的话,x0,x1都是随机采样的,没有一一配对关系,所以无法直接训练出一步生成模型。但训练上界的话,通过多个fθ∗(xtk,tk),fθ∗(xtk−1,tk−1)的传递性,隐含地实现了x0,x1的配对。
如果单步生成的效果不能让我们满意,我们也可以增加采样步数来提高生成质量,这里边又有两种思路:1、用更小的步长来数值求解(8);2、转化为类似SDE的随机迭代。前者比较常规,我们主要讨论后者。
首先注意到式(10)中的fθ∗(x1,1)换成任意fθ∗(xt,t),也可以得到类似的不等关系,这意味着任意的fθ∗(xt,t)预测的都是x0,这样一来,我们从x1出发,通过fθ∗(x1,1)就得到一个初步的x0,但可能不够完美,于是我们通过加噪来“掩饰”这种不完美,得到一个xtn−1,代入fθ∗(xtn−1,tn−1)得到一个更好一点的结果,依此类推:
x1∼N(0,I)x0←fθ∗(x1,1)for k=n−1,n−2,⋯,1:z∼N(0,I)xtk←(1−tk)x0+tkzx0←fθ∗(xtk,tk)
用于蒸馏 #
一致性模型的训练思想同样可以用于现成扩散模型的蒸馏,结果称为“一致性蒸馏(Consistency Distillation,CD)”,方法是把式(7)中fθ(xtk,tk)的学习目标由fˉθ(xtk−1,tk−1)换成fˉθ(ˆxφ∗tk−1,tk−1):
θ∗=argminθEk∼[n],x0∼p0(x0),x1∼p1(x1)[˜w(tk)‖fθ(xtk,tk)−fˉθ(ˆxφ∗tk−1,tk−1)‖2]
其中ˆxφ∗tk−1是由教师扩散模型以xtk为初值所预测的xtk−1,比如最简单的欧拉求解器,我们有
ˆxφ∗tk−1≈xtk−(tk−tk−1)vφ∗(xtk,tk)
这样做的理由也很简单,如果有了预训练好的扩散模型,那么我们就没必要在直线xt=(1−t)x0+tx1上找学习目标了,因为这是人为定义的,终究有交叉的风险,而是改为由预训练好扩散模型来预测轨迹,这样找出来的学习目标可能并不一定是“最直”的,但肯定不会有交叉。
如果不计成本,我们也可以从随机采样的x1出发,加上预训练扩散模型解出的x0,通过成对的(x0,x1)来构建学习目标,这差不多就是ReFlow的蒸馏思路,缺点是必须对教师模型运行完整的采样过程,费时费力。相比之下,一致性蒸馏只需要运行单步教师模型,计算成本更低。
不过,一致性蒸馏在蒸馏过程中还需要真实样本,这在某些场景下也是一个缺点。如果蒸馏过程既不想运行完整的教师模型采样,又不想提供真实数据,那么有一个选择就是我们之前介绍过的SiD,代价是模型的推导更加复杂了。
文章小结 #
本文通过逐步解构和优化ReFLow训练流程的方式,提供了一个从ReFlow逐渐过渡到一致性模型(Consistency Models)的直观理解路径。
转载到请包括本文地址:https://kexue.fm/archives/10633
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Dec. 18, 2024). 《生成扩散模型漫谈(二十八):分步理解一致性模型 》[Blog post]. Retrieved from https://kexue.fm/archives/10633
@online{kexuefm-10633,
title={生成扩散模型漫谈(二十八):分步理解一致性模型},
author={苏剑林},
year={2024},
month={Dec},
url={\url{https://kexue.fm/archives/10633}},
}
December 26th, 2024
“轨迹交叉”这个概念,总有种雾里看花之感,懂了又好像没完全懂。请问这是一种感性的解释嘛,还是有严谨数学定义的理论,如果是的话,有没有什么推荐参考资料,谢谢:)
这里边涉及到一些背景知识,你可以先看看reflow、flow matching等的原始论文。
举个例子:假如真实样本有两个(0,0)、(0,1),噪声样本也有两个(1,0),(1,1),那么随机采样的情况下,就有四种组合,两点确定一条直线,所以确定了四条直线,其中存在交叉的直线,这就是轨迹交叉。理想情况下,应该是(0,0)只跟(1,0)配对、(0,1)只跟(1,1)配对,这样就只有两条不相交直线。
好像sCM说,在reFlow, OT可以等效转换为其他噪声调度。 其更好的经验性能的原因本质上是由于训练期间不同的加权以及采样期间缺乏 DPM-Solver 系列等高级扩散采样器,而不是 “直线路径” 本身。
直线路径本身可能不会太本质,因为不同noise schedule之间其实可以相互变换的。
January 9th, 2025
如果是从t_k预测x_0更难的话 可不可以通过用更小的weight来解决这个问题呢
更小的weight是指什么weight?
比如随着时间递减的weight 比如1/t这样的
你是说loss的权重?更难为什么要用更小的权重?另外个人认为这个权重其实不大好说能够根据预测难度的大小来调,参考@苏剑林|comment-26263
January 13th, 2025
请问公式(2)中为什么t越接近0,权重越大,此时不应该基本上接近干净样本loss非常小了吗。
你说˜w(t)=w(t)/t2?那还有w(t)呢。
而且损失函数小,就算加大权重不也是合理的选择?还有,预测x0准确,那么预测x1必然没那么准确,因此从预测x1的角度看,加大权重也是合理的。
我的主观感觉就是x1这个噪声无论在训练还是推理都是已知的,因此x1-x0的预测等价于预测x0,这和重参数化似乎一致。从你提到的离x0越近预测x1变难的角度来说,如果把x1作为条件输入到网络是否能提高模型的效果呢。
denoise模型的输入是xt,不论是x1还是x0都是未知的。
February 14th, 2025
感谢大佬,这个分析直击核心,consistency model的论文我看了好几遍,不得要领。
February 22nd, 2025
请问大佬,如果在损失项后面加上一个一致性模型预测的结果与真实数据之间的损失是合理的嘛?
意思是既回归fθ∗、又回归x0?不大好说,从采样步数来说应该不利。
February 27th, 2025
请问下,在使用k-Rectified Flow蒸馏时,需要事先获得大量噪声-图像对,请问在生成图像时需要用classifer free guidance吗?
细节我不清楚,但直觉上这不是取决于你训练时有没有加入class?ifer free guidance
March 3rd, 2025
苏神,一致性蒸馏,是不是相当于把原来x_t的加噪路线更换了,因为预测x_0的target改变了,而更换之后的加噪路线,不容易和别的加噪路线相交?
是的