返回列表

[LB Pub 29 Pvt 28] Enhancing the efficiency of a reasoner using SFT and GRPO [Fast-Math-R1-14B]

647. Al Mathematical Olympiad - Progress Prize 2 | ai-mathematical-olympiad-progress-prize-2

开始: 2024-10-17 结束: 2025-04-01 数学与计算 AI大模型赛
[LB Pub 29 Pvt 28] 使用 SFT 和 GRPO 提升推理器效率 [Fast-Math-R1-14B]

[公有榜 29 分 私有榜 28 分] 使用 SFT 和 GRPO 提升推理器效率 [Fast-Math-R1-14B]

作者: RabotniKuma (analokamus) | 发布时间: 2025-04-02 | 竞赛排名: 第 9 名

首先,恭喜所有参与者完成了这场漫长而充满挑战的比赛。我们还要向组织者表示诚挚的感谢,他们设计了如此有趣的比赛并提供了高质量的数据集,同时也感谢 Kaggle 团队提供了使这一切成为可能的 extensive 计算资源。

我们团队认为本次比赛的主要挑战在于 DeepSeek R1 系列模型推理过程中的冗余——具体来说,它们的输出 token 长度过长,使得很难在 5 小时的时间限制内解决 50 个问题。为了解决这个问题,我们旨在微调一个蒸馏版的 R1 模型,使其推理更高效。

我们的最终模型在公有榜得分为 29 分,私有榜得分为 28 分

第一阶段:使用高难度数据集进行 intensive SFT

数据集

  • OpenR1 Math: 我们随机采样了 3000 个 R1 轨迹超过 12800 tokens 且准确率超过 50% 的样本,以及另外 3000 个准确率在 50% 到 75% 之间的样本。
  • openr1_hard: "~2.5k 来自 open-r1-math-220k 的困难样本。被认为是困难的样本是 r1-distill-32b 尝试 4 次后无法解决的。” (感谢数据集提供者 @andy2709)
  • Light-R1-SFTData: 我们使用了 Light-R1-SFTData 的第 2 阶段数据。(非常感谢 LightR1 团队,@zouhaosheng)

我们合并了上述所有数据集,去重,并选择了 token 长度最短的正确生成结果。对于 Light-R1 数据集中未提供 ground truth 答案的样本,我们从 R1 轨迹中提取并替换了答案。结果,我们构建了一个包含 7900 个问题 - R1 轨迹 - 答案集合的高难度数据集

训练

根据我们之前的实验,我们观察到 14B 模型表现出更稳定的性能。因此,我们选择 DeepSeek-R1-Distill-Qwen-14B 作为起点。我们在配备 8 张 H200 GPU 的机器上进行了全参数监督微调训练,使用了 trl 库中的 SFTTrainer。

关键参数

  • per_device_train_batch_size = 1
  • gradient_accumulation_steps = 8
  • num_train_epochs = 20
    • 这个数字异常大,但我们发现只有在长时间训练后才会出现有意义的性能提升
  • max_seq_length = 24000
  • packing = True
  • learning_rate = 1e-5
  • lr_scheduler_type = cosine
  • system_prompt = "Please reason step by step, and put your final answer within \\boxed{{}}."

训练时间:约 10 小时 (8× H200 GPUs)

评估

我们使用包含 40 个问题的数据集评估了模型性能,其中包括10 个参考问题和 30 个来自 AIME 2025 的问题
最初,我们使用 16k tokens × 32 个提示生成答案。然后我们应用基于 token 长度和提示数量的后处理过滤,以评估各种推理条件下的性能。

结果

实验 Token 预算 准确率 (majority@32) 准确率 (pass@32) 收集的答案数量 平均生成长度 公有榜 (量化模型)
DeepSeek-R1-Distill-Qwen-14B 16384 0.700 0.775 21.775 9684 25
DeepSeek-R1-Distill-Qwen-14B 12800 0.675 0.775 16.775 8331
DeepSeek-R1-Distill-Qwen-14B 9000 0.525 0.600 12.500 4725
SFT 16384 0.750 0.825 20.725 10396 23
SFT 12800 0.725 0.725 15.700 7024
SFT 9000 0.550 0.550 11.600 4387

第一阶段 SFT 后,我们的本地验证分数显著提高。然而,我们观察到公有榜分数往往略低。
我们认为这是由于SFT 引入了更多的推理冗余,导致许多示例未能在时间限制内得出结论
为了解决这个问题,我们的下一个目标是应用强化学习,鼓励模型在使用更少 token 的同时达到准确的结论,同时保持性能。

第二阶段:使用 GRPO 进行更高效的推理

数据集

Light-R1-SFTData: 我们使用了 Light-R1-SFTData 的第 2 阶段数据。

训练

我们使用了由 @andy2709 创建的 更快版本的 trl GRPOTrainer (再次非常感谢!)。

我们使用了以下奖励函数:

  1. 格式奖励
    在我们的提交中,生成在 </think> 标签处停止以节省时间,因此我们设计的奖励匹配模式 r"^.*?oxed{(.*?)}.*?</think>.*?$"
  2. 余弦奖励 (正确 [1.0, 0.1], 错误 [-0.1, -1.0], max_len=30000)
    与普通的基于准确率的奖励相比,余弦奖励对较长的正确推理轨迹和较短的错误推理轨迹施加连续惩罚。
  3. 长度奖励
    基于长度的奖励,以 discourage 过度思考并促进 token 效率。
    论文:https://arxiv.org/abs/2501.12599

关键参数

  • num_generations = 8
  • beta = 0.04
  • per_device_train_batch_size = 2
  • gradient_accumulation_steps = 8
  • num_train_epochs = 1
  • max_completion_length = 16384
  • learning_rate = 4e-6
  • lr_scheduler_type = cosine
  • system_prompt = ( 'You are a helpful and harmless assistant. You are Qwen developed by Alibaba. ' 'You should think step-by-step. Return final answer within \\boxed{{}}.' )

训练时间:约 10 小时 (8× H200 GPUs)

评估

与第一阶段相同。

结果

GRPO Training Reward Chart

奖励在整个训练过程中稳步优化,但在第 60 步之后发生了灾难性转变,导致性能大幅下降。因此我们决定使用较早步骤的检查点进行评估。

实验 Token 预算 准确率 (majority@32) 准确率 (pass@32) 收集的答案数量 平均生成长度 公有榜 (量化模型)
DeepSeek-R1-Distill-Qwen-14B 12800 0.675 0.775 16.775 8331 25
DeepSeek-R1-Distill-Qwen-14B 9000 0.525 0.600 12.5 4725
SFT 12800 0.725 0.725 15.7 7024 23
SFT 9000 0.550 0.550 11.6 4387
SFT + GRPO (最佳检查点) 12800 0.725 0.775 18.5 6817 29
SFT + GRPO (最佳检查点) 9000 0.625 0.700 15.25 4759

GRPO 使我们能够训练出一个模型,在保持准确率的同时,通过更短的 token 长度显著提高了推理效率。

提交推理设置

量化

我们使用 AutoAWQ 应用了 4-bit 量化。
我们观察到量化后验证性能有所下降,并尝试通过在数学特定数据集上进行校准来恢复它。不幸的是,这种方法未能成功保留原始性能。

推理时间调度

我们训练了一个 ModernBERT 模型来预测 OpenR1 Math 数据集中每个问题的正确 R1 轨迹的最短 token 长度,从而量化问题难度。
在验证集上,我们观察到预测难度与实际生成的 token 数量之间存在中等相关性,如下图所示。

Token Length Scatter Plot

使用该模型,我们在推理过程中动态缩放输出 token 长度。这种方法稳定了公有榜分数,并带来了约 +1 分的提升(尽管可能是心理作用)。

杂项

  • 提示词:prompt_config0 * 8 + prompt_config1 * 2 (共 10 个提示)
prompt_config0 = dict(
    system=(
        'You are a helpful and harmless assistant. You are Qwen developed by Alibaba. '
        'You should think step-by-step. Return final answer within \\boxed{{}}, after taking modulo 1000.'
    ),
    prompt='{question}'
)

prompt_config1 = dict(
    system=(
        'You are a helpful and harmless assistant. You are Qwen developed by Alibaba. '
        'You should think step-by-step. After you get your final answer, take modulo 1000, and return the final answer within \\boxed{{}}.'
    ),
    prompt='{question}'
)
  • 输出 token 长度:10500 - 13300 (上述动态缩放)
  • vLLM 版本 0.7.3 with V1

未成功的方法

约束解码 (Constrained decoding)

许多研究探讨了通过约束解码方法控制 LLM 的推理过程,例如:

我们广泛实验了这些方法并调整了各种参数;然而,它们在我們的验证数据集上均未证明有效。我们的假设是,这些方法可能仅适用于相对简单的问题。

重写推理过程并重新训练

我们还尝试遵循 SkyThought 的方法,将原始 R1-Qwen 模型的推理过程重写为更紧凑的形式,并通过 SFT 和 DPO 重新训练模型。虽然这种方法显著减少了到达答案之前生成的 token 数量,但也导致准确率大幅下降。

对于功能强大的 LLM, externally 强行改变其自然推理过程可能不是有效的策略,特别是在处理像我们这样的复杂问题解决场景时。

同比赛其他方案