返回列表

6th Place Solution

644. WSDM Cup - Multilingual Chatbot Arena | wsdm-cup-multilingual-chatbot-arena

开始: 2024-11-18 结束: 2025-03-10 自然语言处理 数据算法赛
第六名解决方案

第六名解决方案

作者: Zhecheng Li
发布时间: 2025-03-11
竞赛: WSDM Cup 多语言聊天机器人竞技场 (WSDM Cup Multilingual Chatbot Arena)

首先,我衷心感谢 Kaggle 和 Chatbot Arena 组织了这场有意义的竞赛。我还要向我的队友们致以最深的谢意:@pingfan, @xuanmingzhang777, @xiaoqinglong1996, @tonyarobertson, 我与他们一起在 Kaggle 上竞争了将近一年。不幸的是,我们在之前的比赛中遇到了意外情况,导致排名下滑或错失金牌。然而,今天五位 Kaggle 竞赛专家一起成为了竞赛大师。这确实是我们旅程中最不可思议的时刻之一,能够与我的队友们分享这份成功!


TL;DR (太长不看版)

  1. 伪标签 (Pseudo Labeling):

    • 利用了 LMSYS 第 3 名团队生成的数据。
    • 从 100 万数据集中采样提示词,并使用 API 生成回复。
    • 结合了开源 DPO 数据(例如 RLHFlow),混合生成伪标签。
  2. 蒸馏 (Distillation):

    • 将 Llama3.3-70B 和 Qwen2.5-72B 模型蒸馏到 Gemma2-9B 和 Qwen2.5-14B。
    • 最大序列长度训练为 2500 tokens,使用 4-bit 量化。
  3. 多语言策略:

    • 多语言性能不是主要关注点,因为 Gemma 和 Qwen 已经是最强大的多语言模型之一。
    • 优先考虑前五种主要语言,尤其是英语,因为我们发现英语准确率欠佳。

伪标签 (Pseudo Labeling)

伪标签在我们的方法中起到了至关重要的作用。通过有效利用伪标签数据,即使不直接在竞赛数据集上训练,我们也实现了超过 0.693 的榜单 (LB) 分数。

数据集

我们聚合了多个数据源,过滤掉短回复后获得了约 56 万样本:

  1. LMSYS 第 3 名团队 生成的数据(特别感谢 @conjuring92)。
  2. 从 100 万数据集中采样的提示词,以及来自各种模型的 API 生成回复。
  3. 大约 10 个开源 DPO 数据集(例如 RLHFlow)。

我们处理了这些数据集以生成高质量的伪标签。

伪标签方法

为了确保标签准确性并最小化数据泄露,我们对评判模型尝试了两种方法:

  1. 基线方法: 在竞赛数据上微调 Gemma2-9B。
  2. 增强方法: 在竞赛数据上微调 Llama3.3-70B 和 Qwen2.5-72B。

虽然增强方法显示出轻微改进,但需要显著更长的推理时间。最终,使用伪标签数据重新训练 Gemma2-9B 在 KL 散度损失和交叉熵损失方面产生了可比的结果。


知识蒸馏 (Knowledge Distillation)

  • 教师模型: Qwen2.5-72B, Llama3.3-70B
  • 训练数据: WSDM + LMSYS
  • 损失函数: KL 散度损失,交叉熵损失,以及两种损失的等权重平均

我们进行了广泛的蒸馏后训练。虽然蒸馏过程对 Qwen2.5-14B 模型没有显著影响,但它证明了在 Gemma2-9B 上有可测量的改进。


最终训练阶段

  1. 直接 LoRA 训练 针对 4-bit 量化的 Qwen2.5-14B 和 Gemma2-9B 模型。
  2. 最大序列长度:2500 — 将其扩展到 3072 没有产生明显好处,且训练时间限制阻碍了进一步增加。

推理策略

  • 主要模型: Qwen2.5-14B
  • 辅助模型: Gemma2-9B
  • 推理机制:
    • 使用 Qwen2.5-14B 的 logits 进行主要分类。
    • 选择性部署 Gemma2-9B,优先考虑 Qwen2.5-14B 分类困难的情况。
    • 有效管理推理时间以充分利用 12 小时限制。

无效的方法 (What Didn't Work)

  1. TTA (测试时增强) 对序列分类没有显示出可测量的影响。
  2. LoRA 合并 导致性能显著下降,尽管进行了调试,该问题仍未解决。
  3. 多 LoRA 集成 方法未能提高性能。
  4. 基于回复长度的动态 token 分配 导致异常的长度分布。
  5. 模型选择: Gemma2-9B 和 Qwen2.5-14B 优于其他模型。
  6. 来自 DPO 数据集的原始标签 对于后训练无效,但对伪标签有用。
  7. 思维链 (Chain-of-Thought) 提示 策略没有产生有意义的改进。
  8. 推理期间的动态截断 最初提供了 +0.003 的榜单提升,但在更新训练代码和截断方法后,其有效性减弱。

我个人的“无效方法”列表

  1. 我使用 AutoModelForMultipleChoice 微调了一个 mDeBERTa 模型用于多项选择任务。然而,正如预期的那样,模型的参数限制导致性能不佳。
  2. 我尝试使用 GPT-3.5 作为伪标签的评判模型,通过在 1 万样本上微调它。然而,OpenAI 训练过程缺乏透明度以及对参数调整的敏感性导致结果不理想。
  3. 与研究建议 few-shot 提示能提高准确率相反,这在我们的竞赛 setting 中并不成立。我尝试了 2-shot 到 32-shot 的 GPT-3.5 配置,但没有显著的准确率提升 — 可能是由于过度的人工干预。

改进机会

  1. 教师模型的交叉验证和榜单性能欠佳 — 大参数模型训练仍然是一个主要挑战;例如,我们在训练 Gemma2-27B 时无法获得令人满意的结果。
  2. LoRA 合并和后训练量化问题 导致意外的性能下降,这仍然是一个关键瓶颈。

结论

再次感谢我的四位了不起的队友:@pingfan, @xuanmingzhang777, @xiaoqinglong1996, @tonyarobertson

同比赛其他方案