社区供稿 | 使用 Firefly 在单卡V100 上对 Qwen1.5 进行 SFT 和 DPO,显著超越官方模型

762次阅读
没有评论

01

简介

Firefly 是开源的大模型一站式训练框架,支持对各种大模型进行预训练、指令微调、DPO,支持全量参数、LoRA、QLoRA 等训练方式。支持包括但不限于 Gemma、Qwen1.5、MiniCPM、Mixtral-8x7B、Mistral、Llama 等绝大多数主流的大模型。 


项目链接:https://github.com/yangjianxin1/Firefly


模型权重:

https://hf.co/YeungNLP/firefly-qwen1.5-en-7b

https://hf.co/YeungNLP/firefly-qwen1.5-en-7b-dpo-v0.1

本文将分享我们使用 Firefly 项目对 Qwen1.5-7B 进行训练的实验。我们对训练数据进行 精细化筛选,然后 在单张 V100 上进行 SFT 和 DPO。经过两阶段的训练,我们的模型 在 Open LLM Leaderboard 上的表现显著优于官方的 Qwen1.5-7B-Chat、Gemma-7B-it、Vicuna-13B 等模型。比 Qwen1.5-7B-Chat 高 7.12 分,比 Gemma-7B-it 高 8.8 分

社区供稿 | 使用 Firefly 在单卡V100 上对 Qwen1.5 进行 SFT 和 DPO,显著超越官方模型

通义千问 Qwen1.5 是阿里巴巴在春节前开源的大模型,支持 32K 的上下文长度,该模型本质上是 Qwen2 的 beta 版本,按照官方的说法,后续将会有 Qwen2 的正式版本。从评测结果来看,Qwen1.5 各个尺寸的模型都显著优于同量级的 Llama2。

社区供稿 | 使用 Firefly 在单卡V100 上对 Qwen1.5 进行 SFT 和 DPO,显著超越官方模型

在 2 月份的 SuperCLUE 大模型榜单中,Qwen1.5 也有非常优秀的表现,在开源模型中处于引领者的地位。

社区供稿 | 使用 Firefly 在单卡V100 上对 Qwen1.5 进行 SFT 和 DPO,显著超越官方模型

02

DPO 简介

大模型训练主要可以分为以下三大阶段:

  1. 预训练: 使用超大规模文本对模型进行训练,训练任务为“预测下一个 token”,训练的数据量往往需要几万亿 token。

  2. SFT (指令微调): 使用指令数据,让模型的输出格式与人类对齐,使其具备 chat 的能力。

  3. RLHF: 使用人类反馈或者偏好数据来训练模型,使模型的输出更加符合人类的价值观或者预期行为。


在 RLHF 阶段,以往的许多大模型,例如 Llama2、InstructGPT 等,大多采用 PPO 来对模型进行价值观对齐训练。但是采用 PPO 进行 RLHF 存在流程繁琐、显存需求多(需要将策略网络、参考网络、critic 网络、奖励模型同时加载到显存中)等问题,这导致大部分普通玩家对其敬而远之。


使用 PPO 进行 RLHF 的主要流程大致如下:

  1. 构建奖励模型的训练数据: 对于同一个 prompt 产生多个生成结果,对这些生成结果进行人工排序,两两一组,形成 chosen 和 rejected 的 pair。每条训练数据包含三个字段,prompt、chosen、rejected。

  2. 训练奖励模型: 使用上述数据训练奖励模型,对于每条训练数据,训练目标为最大化 chosen 与 rejected 的奖励的差值。

  3. PPO 训练: 使用奖励模型的反馈对语言模型进行训练。


上面描述的 PPO 流程复杂且冗长,而 DPO 则绕过了奖励模型的构建,可直接使用人类偏好数据对模型进行训练,且在训练时仅需加载策略网络和参考网络,极大地节省了显存占用。训练数据包含三个字段,prompt、chosen、rejected。


DPO 损失函数的计算过程也极具对称性,其公式如下所示:

社区供稿 | 使用 Firefly 在单卡V100 上对 Qwen1.5 进行 SFT 和 DPO,显著超越官方模型

对于上述公式,根据对数运算法则进行变换,在代码实现中,其计算过程大致如下:

  1. 计算对数概率:将 prompt 分别与 chosen 和 rejected 进行拼接,然后分别输入策略网络和参考网络,得到 4 个对数概率。

  2. 计算策略网络的 diff:策略网络的 chosen 对数概率 – rejected 对数概率。

  3. 计算参考网络的 diff:参考网络的 chosen 对数概率 – rejected 对数概率。

  4. 计算损失函数:策略网络的 diff – 参考网络的 diff。

03

训练设置

在 Qwen1.5-7B 的基础上,我们进行了 SFT 和 DPO 两阶段的训练,整个训练流程仅使用一张 V100 GPU,采用 QLoRA 技术,在所有 Linear 层都添加 adapter 以提升训练效果。两阶段均使用英文数据进行训练。我们与 Qwen1.5 官方的对话模板保持一致:

<|im_start|>systemYou are a helpful assistant.<|im_end|><|im_start|>userhello, who are you?<|im_end|><|im_start|>assistantI am a AI program developed by Firefly<|im_end|>

使用 Firefly 对 Qwen1.5 进行 SFT 的启动命令:

python train.py --train_args_file train_args/sft/qlora/qwen1.5-7b-sft-qlora.json

在 SFT 阶段,实验参数设置如下:

num_epochs: 1learning_rate: 2e-4total_train_batch_size: 32max_seq_length: 2048optimizer: paged_adamw_32bitlr_scheduler_type: constant_with_warmupwarmup_steps: 700lora_rank: 64lora_alpha: 16lora_dropout: 0.05gradient_checkpointing: truefp16: true

使用 Firefly 对 Qwen1.5 进行 DPO 的启动命令:

python train.py --train_args_file train_args/dpo/qlora/qwen1.5-7b-dpo-qlora.json

在 DPO 阶段,我们采用 ultrafeedback 数据集,实验设置如下:

num_epochs: 1learning_rate: 2e-4total_train_batch_size: 32max_seq_length: 1600max_prompt_length: 500optimizer: paged_adamw_32bitlr_scheduler_type: constant_with_warmupwarmup_steps: 200lora_rank: 64lora_alpha: 16lora_dropout: 0.05gradient_checkpointing: truefp16: true

04

模型评测 & 训练指标

我们在 Open LLM Leaderboard 上对模型进行评测,我们的模型的表现显著优于官方的 Qwen1.5-7B-Chat、Gemma-7B-it 等模型。经过 DPO 之后,模型的平均分也有接近 1 分左右的提升。

社区供稿 | 使用 Firefly 在单卡V100 上对 Qwen1.5 进行 SFT 和 DPO,显著超越官方模型

DPO 训练过程中的训练指标的变化如下图所示。在训练过程中, Rewards/accuracies 和 Rewards/margins 均处于上升趋势。

社区供稿 | 使用 Firefly 在单卡V100 上对 Qwen1.5 进行 SFT 和 DPO,显著超越官方模型

DPO 训练 loss 变化趋势如下:

社区供稿 | 使用 Firefly 在单卡V100 上对 Qwen1.5 进行 SFT 和 DPO,显著超越官方模型

DPO 训练的 Rewards/accuracies 的变化趋势如下,该指标表示较优回答的奖励大于较劣回答的奖励的频率的均值:

社区供稿 | 使用 Firefly 在单卡V100 上对 Qwen1.5 进行 SFT 和 DPO,显著超越官方模型

DPO 训练的 Rewards/margins 变化趋势如下,该指标表示较优回答的奖励与较劣回答的奖励二者之差的均值:

社区供稿 | 使用 Firefly 在单卡V100 上对 Qwen1.5 进行 SFT 和 DPO,显著超越官方模型

 

Read More 

正文完
可以使用微信扫码关注公众号(ID:xzluomor)
post-qrcode
 
评论(没有评论)
Generated by Feedzy