李飞飞团队提出架构设计新思路!无需从头训练,直接“嫁接”预训练模型关键组件
所使用的算力不到预训练阶段的2%
预训练模型能否作为探索新架构设计的“底座” ?
最新答案是:yes!
简单来说,按照研究人员设计好的架构方案从头训练模型,往往是检验一个架构是否有效的重要手段。
但问题在于,从头训练模型的成本也太高了!

对此,包括李飞飞团队在内的研究人员提出了一种被称为“Grafting(嫁接)”的新思路——
直接将预训练好的模型作为“底座”,通过修改其组件来研究新架构。
这就好比软件开发中,程序员常基于现有代码修改而非重写,以此省时省力。
基于这一思路,他们重点关注了DiTs这一广泛用于图像和视频生成的Transformer模型。
具体而言,这群人先是构建了一个基于DiT-XL/2设计的测试平台,以方便后续研究“嫁接”对模型质量的影响,然后实际使用“嫁接”技术开发了一系列混合设计。
结果发现,许多混合设计在使用不到2%的预训练计算量的情况下,获得了和原来大差不差的模型性能。
将这一方法应用于文生图模型PixArt-Σ,其生成速度提高了1.43倍,但生成图像的质量只下降了不到2%。
以上说明,“嫁接”确实能成为一种轻量级、高效的架构探索工具,可以让研究者在缺少计算资源的情况下测试新想法。

下面详细揭秘团队提出的这种新方法——
两阶段架构编辑法
众所周知,模型架构设计通常涉及测试不同的组件(如注意力机制、卷积层)和配置(如模型深度、宽度)。
而作为一种架构编辑方法,“嫁接”主要通过修改预训练DiTs的计算图来实现新架构的验证,具体则主要通过激活蒸馏和轻量级微调这两个关键阶段来实现。

所谓计算图,是指模型内部的运算逻辑结构——由多层Transformer块组成,每个块包含自注意力(MHA)、多层感知器(MLP)等算子(Operator),这些算子按特定顺序连接,形成执行生成任务的“数据流路径”。
要实现这种修改替换,关键要解决两个问题:
问题1:在将新算子整合到计算图之前,应该如何初始化新算子?
如果简单地把新算子的权重随机初始化,它可能一开始就会和模型的其他部分不协调,导致模型性能下降。
问题2:如何减轻因替换多个算子而导致的错误累积?
当替换多个算子时,每个替换都可能引入一些误差。一旦误差逐渐积累,最终可能会导致模型性能大幅下降。
对此,新方法采用了以下两阶段架构编辑法:
- 激活蒸馏(Activation Distillation):新算子(如卷积)初始化时,通过回归任务学习原算子的“行为”,即用少量数据训练新算子,使其输出与原算子的激活值尽可能接近。
- 轻量级微调(Lightweight Fine-tuning):替换多个组件后,用有限数据进行端到端微调,减少误差累积,恢复模型性能。

并且,为了评估“嫁接”本身的效果,研究正式开始前还引入了自嫁接(self-grafting)作为对照实验。
所谓自嫁接,是指将现有MHA、MLP等替换为相同类型但权重随机初始化的算子 。
其作用主要有三个:
- 评估在不改变架构的情况下,“嫁接”过程本身对模型的影响;
- 为后续比较不同的替换方案提供一个基准性能,便于判断新方案的优劣;
- 研究影响模型性能的各种因素,比如数据规模大小、回归目标的选择以及超参数设置等。
结果发现,在实际操作中,仅需8k样本就能实现较好的初始化。
此外,即便替换DiT-XL/2中所有的多头注意力(MHA)或多层感知器(MLP)层,仅使用10%的训练数据进行微调,模型也能正常恢复。
实验结果
研究人员进行了三项实验,并得出以下主要结论:
实验1:混合架构实验,验证替换的可行性。
通过将DiT-XL/2中的注意力层MHA替换为滑动窗口注意力(SWA)或门控卷积(Hyena-X),在50%替换比例下,FID仅比基线高0.4(FID值越低,说明越接近原始性能)。
而100%全替换会导致FID骤降(数值>75),生成质量崩溃,这说明并非所有层都能被局部算子替代,即模型中存在“必须依赖全局信息” 的层,而另一部分层可接受局部计算。
团队还尝试将DiT-XL/2中的感知器层MLP也进行了替换,结果在将MLP的扩展比改成r=3或r=6的情况下,就算全换掉,模型效果也挺好,这说明MLP宽度改起来不容易出问题。
一言以蔽之,多种混合设计的生成质量均接近原模型,且计算成本不到预训练的2%。

实验2:文本到图像生成实验,验证新架构的有效性。
接下来,研究人员对文生图模型PixArt-Σ进行了“嫁接”,将MHA替换为Hyena-X,结果使用12k合成数据微调后,实现了1.43倍速度提升(从235ms→164ms),GenEval分数从49.75→47.78(下降小于2%)。

实验3:并行化改造实验,验证架构重组的有效性。
通过将DiT-XL/2的28层顺序块转为14层并行块(每对顺序块并行执行),在深度减半的情况下,模型生成质量优于同类深度模型。
这验证了,并行架构在减少深度的同时可提升质量,可用作模型轻量化的思路。

不过最后,团队也提到了研究的局限性。一是仅在DiT-XL/2模型上进行了验证,二是仅测试了替换成Hyena-X和SWA的效果,结论的普适性受限。
但不管怎样,团队认为“嫁接”这种方法在探索新的模型架构方面显示出很大的潜力,尤其是在需要高效利用计算资源的场景中。
BTW,目前研究所涉及的22种“嫁接”模型均已开源。
论文:
https://grafting.stanford.edu/
博客:
https://www.liquid.ai/research/exploring-diffusion-transformer-designs-via-grafting
开源地址:
https://huggingface.co/grafting
https://github.com/keshik6/grafting
