突破单token预测局限!南洋理工首次将多token预测引入微调
推理和生成能力增强的同时,也不会额外增加多余成本
CAFT团队 投稿
量子位 | 公众号 QbitAI
告别Next-token,现在模型微调阶段就能直接多token预测!
从GPT到Claude,当前主流LLM都依赖next-token prediction(下一token预测)进行训练,但它却让AI很难真正理解跨越多token的完整概念。
于是南洋理工大学最近提出了一项新技术——概念感知微调(CAFT),首次实现将multi-token prediction(多token预测)引入微调阶段,让模型能够像人类一样理解和学习完整概念。

原来LLM只能碎片化理解每个token,现在CAFT可以为模型添加额外的辅助头,在主模型学习下一个词的同时,帮助学习后续token,并通过动态调整权重,确保模型始终优先优化主要任务的损失。
最终LLM可以兼顾多token概念学习,形成更为完整的认知,在推理和生成能力增强的同时,既不会影响模型本身,也不会额外增加多余成本。

另外研究人员通过实验发现,CAFT在编程、数学、生物医学等多个领域都能显著提升模型性能,或许未来将会让AI训练范式迎来根本性转变。
下面是有关CAFT的更多详细内容。
Next-token预测:AI的“基因密码”
首先,next-token prediction的基本思想是在已知上下文的基础上,预测最有可能的下一个token。
举个例子,针对句子“人工智能将改变_”,你可能会直接预测出“世界”、“未来”或“社会”,但是next-token prediction的预测流程则分为以下三步:
- 分词:例如将“人工智能”拆分为“人工”和“智能”。
- 序列建模:让模型逐个学习每个token与其前文的关系。
- 概率预测:为所有候选token分配概率,并选择最高者作为输出。
Next-token将会在预训练里的大规模语料上学习语言统计规律与通识知识,然后在微调中通过特定任务数据学习具体行为模式,决定模型实际表现。
但无论是预训练还是微调,next-token prediction都只会在每一步中只预测下一个token,再依次进行。
与此同时,这也带来了一个根本性缺陷,即它将完整概念拆解为碎片,阻碍模型形成整体认知。
例如“ribonucleic acid”(核糖核酸),Llama 3分词器就会将其拆解为:“rib”→“on”→“ucle”→“ic”→“acid”,当模型预测“rib”时,无法预见“onucleic acid”,因此无法理解这是一个生物学分子概念。
又比如说将“北京大学”拆成“北”、“京”、“大”、“学”分开记忆,这严重破坏了语义完整性。
所以next-token prediction存在前瞻能力差、不擅长处理跨概念的复杂推理、学习效率低、结果高度依赖具体分词器等问题。

Meta等机构对此提出可以在预训练阶段尝试multi-token prediction,但同样也面临以下限制:
- 预训练成本过大,是微调阶段的上千倍。
- 仅能提升通用语言能力,对具体概念理解帮助有限。
- 直接应用于微调时会造成分布偏移,从而导致性能下降。
这让multi-token prediction只适用于预训练阶段,难以普及,所以研究团队提出了新技术CAFT,将multi-token prediction引入微调。
CAFT:打破瓶颈的概念感知微调方法
CAFT在架构上主要包括辅助头、损失函数两部分,辅助头含独立隐藏层,且共享输出层,以降低参数成本,损失函数为:

其中L₁指原始next-token损失,β是控制辅助损失的权重(设为0.01,确保主任务优先),γ是反射正弦动态调整因子,训练初期高,后期低,α是几何衰减因子,越远的token权重越小,t指token位置。
在微调结束后,还可以直接丢弃辅助头,让推理开销为零。

CAFT采取分阶段训练策略,可分为两个阶段:
- 辅助头预训练
在原模型上添加n-1个辅助预测头,然后使用通用指令数据集训练辅助头,分别预测第2、3、4…个未来token。
其中需要使用原模型自己生成的回答作为“伪标签”,避免分布偏移,且辅助头训练一次即可,多任务可通用复用。
- 概念感知微调
在特定任务上同时优化原始预测头和辅助头,然后用特殊设计的损失函数确保主目标仍是第一个token。
利用动态权重调整策略,训练初期关注多token概念学习,后期聚焦任务表现。

最终CAFT可实现极低的使用门槛,只需要几行代码,就能结合任意预训练模型,在成本上远低于重新预训练,只略高于传统微调。
CAFT的全面验证:从代码到生命科学
研究团队在五个不同领域任务上测试了CAFT,将其与传统的next-token微调(包括全量微调与LoRA微调)进行对比。
所有结果均为5次独立评估的平均值及95%置信区间,部分任务在微调前会对辅助头进行1个epoch的预训练。

在编程任务中,由于存在大量跨token的语义单元,例如Python中的“_name_”会被分为“_”、“name”、“_”三个token,但需整体理解,所以借助HumanEval数据集,判断CAFT能否让模型能够整体理解这类编程概念。

实验结果表明,LoRA CAFT在准确率上从40.9%提升至45.1%,Full CAFT则从40.5%提升到49.3%。

然后将题目按概念密度分类,发现CAFT在高概念密集题目上提升更显著(+11.67%vs+7.59%),证实了概念学习的有效性。
在数学推理上,LoRA CAFT在MATH-500数据集里性能提升了1.7%(22.9%到24.6%),Full CAFT则是1.5%(23.7%到25.2%)。

而当CAFT置于临床文本中,由于医学文本充满复杂专业术语,被拆分后往往失去意义,此时让CAFT完成医学术语整体理解极具挑战性。
但CAFT仍然在MIMIC-IV-BHC数据集上表现良好,在ROUGE等指标上全面优于传统方法,其中ROUGE-1从44.57提高到45.93,ROUGE-2从22.94提高到24.44,ROUGE-L从32.17提高到33.76,说明其能更好地捕捉长文本中的概念。

在官能团结构理解上,由于化学分子包含功能性“官能团”,如苯环、酰胺基团等,而SMILES序列中的官能团是典型的多token概念,传统方法很难整体学习。
CAFT可以很好地弥补这一点,准确匹配率从原来的0.14%,提升了4倍,到0.54%,有效分子比例从92.38%改进到97.14%,结构相似性也得到了显著改善。

进一步进行官能团学习验证,发现苯环识别中F1分数大幅提升、酰胺识别中准确率和召回率双重改善、羧酸识别中复杂分子的识别能力增强。

另外为考验CAFT泛化能力,让CAFT根据功能设计蛋白质序列,由于蛋白质使用氨基酸编码,与自然语言差异极大,测试环境相当极限。
实验结果显示,序列同一性从20.32%提升到22.14%,序列对比分数也从原来的负值(-16.01)提升到正值(3.18),结构置信度从52.60变为54.30,结构相似性从33.07%变为35.12%。

其中,25.0%的生成序列具有高结构置信度(>70),比传统方法的20.0%有了显著提升。
最终,研究团队通过在广泛领域中实验,验证了CAFT实现multi-token prediction在微调阶段的可行性,其易用性和低成本也展示了其可能替代现有next-token prediction的巨大潜力,为理解模型内部机制提供了新视角。
论文链接:https://www.arxiv.org/abs/2506.07833
项目链接: https://github.com/michaelchen-lab/caft-llm
