返回列表

2nd place solution

620. LMSYS - Chatbot Arena Human Preference Predictions | lmsys-chatbot-arena

开始: 2024-05-02 结束: 2024-08-12 自然语言处理 数据算法赛
第二名解决方案 - LMSys Chatbot Arena

第二名解决方案

作者: tascj (Grandmaster)
发布日期: 2024-08-13
队友: @liushuzhi, @kapenon

首先,感谢组织者。我们非常享受这次比赛的内容。虽然数据泄露令人沮丧,但我们坚持下来并获得了第二名。

恭喜我的队友 @liushuzhi@kapenon 成为 Grandmaster。

解决方案

基线 (Baseline)

  • 我们使用了基于 prompt 的 StratifiedGroupKFold,预留 20% 作为验证集。
  • 21k 来自去重后 33k 数据集的数据被添加到训练数据中,感谢 @abdullahmeda
  • 使用 2306.05685 中的指令来格式化输入。我们尝试了 prompt-res_a-res_b 和 prompt-res_a-prompt-res_b 两种格式。对于 1.5B 模型,后者似乎更好,而对于 7B 及以上的模型,差别不大。考虑到 token 效率,我们主要使用 PAB 格式。gemma2-9b 的最大序列长度约为 4340。
  • 使用自定义头进行分类。XXXForSequenceClassification 初始化头会在早期迭代中产生高 loss,因此在模型初始化后重新初始化头。
    model.score = torch.nn.Sequential(
        torch.nn.Dropout(0.1),
        torch.nn.Linear(hdim, hdim // 2),
        torch.nn.Dropout(0.1),
        torch.nn.GELU(),
        torch.nn.Linear(hdim // 2, 3),
    )

根据我之前参加比赛的经验,我没有尝试 LoRA,只使用了全参数训练。使用 BF16 和支持 kahan summation 的优化器,可以使用单个 A100 80G 训练 7B 模型,9B 模型则需要两个 A100。

在本次比赛的最后 10 天,我使用了 A100 80G x4 进行所有实验。

全交换 (Full Swap)

在早期实验中,我执行了 response_a 和 response_b 的随机交换作为增强,这稍微提高了验证 log_loss。@kapenon 发现同时包含原始样本及其交换样本更好。为了避免过拟合,原始样本及其交换样本的梯度必须为同一个 optimizer.step 累积。虽然训练时间翻倍,但与随机交换相比,全交换为 gemma2-9b 带来了稳定的 0.003 提升。
此外,我尝试以同样的方式添加不同的输入格式(PAB 和 PAPB)作为增强,带来了微小 (0.001) 的提升。

训练最终模型的步骤

阶段 1 (Stage 1)

我们微调了 google/gemma-2-9b-itgoogle/gemma-2-27b-itRLHFlow/ArmoRM-Llama3-8B-v0.1。 without TTA 的验证 log_loss 分别为 0.891、0.883 和 0.899。平均集成后,log_loss 为 0.876。

在完成 gemma-2-9b 训练后,我花了很多时间尝试 gemma-2-27b 但没有得到好的结果。通过与 @kapenon 的代码比较,我将 batch_size 调整为 80 并关闭了 grad_clip,最终成功训练了模型。

阶段 2 (伪标签 Pseudo Labeling)

我们使用阶段 1 获得的集成模型为 240k 数据生成了伪标签。其中,110k 数据来自 lmsys-1m(由 @kapenon 准备),130k 来自其他数据集(@liushuzhi 使用 1.5B 模型测试了许多外部数据集)。
在这个数据集上,我们微调了 gemma-2-9bRLHFlow/ArmoRM-Llama3-8B-v0.1
从这个阶段开始,我关闭了 gemma-2-9b 的窗口注意力 (window attention),因为我不确定支持 sm75 的高效注意力实现是否可以做窗口注意力。最长输入长度为 4340(包括指令),所以这对分数的影响应该很小。

生成伪标签时未应用 TTA(这是我的失误)。

阶段 3 (Stage 3)

基于阶段 2 获得的检查点,我们使用 55k+21k 数据进行了微调。由于时间限制,禁用了多输入格式增强。
在 20% 验证集上,两个模型分别达到了 0.884 和 0.890,平均集成 log_loss 为 0.876~0.877。
提交时,RLHFlow/ArmoRM-Llama3-8B-v0.1 的输入进行了 AB 交换。该模型在旧 LB 上得分为 0.873。用所有数据训练后,该模型在旧 LB 上达到 0.869。调整集成比例为 2:1 后,得分为 0.868。

更快的训练 (Faster Training)

我们使用 flash-attn==2.6.2 以支持 logit_softcapping

当使用 flash_attn_varlen_func 时,attention_mask 和 padding 是不必要的。为了避免浪费计算在 pad token 上,我:

  1. 实现了一个自定义 collator 来对样本进行序列连接并准备 cu_seqlens
  2. 基于 huggingface 的实现修改了代码,因此模型只接受 input_idscu_seqlens。模型从头到尾不涉及 padding。

此外,使用了 transformer_engine 中的 RMSNormFusedRoPEFunc 来进一步加速训练。

更快的推理 (Faster Inference)

T4x2 足以以 fp16 运行 7b-9b 模型。Transformers 可以几乎均匀地分布在 2 个 GPU 上,只需要 slight 代码修改使两个 GPU 上的执行流水线化。

在 T4 (sm75) 上使用最新的高效算子似乎不容易。经过一些尝试,我使用了以下 triton 算子进行推理:

  1. context_attention_fwd 来自 ModelTC/lightllm,带有一些优化和 logit_softcapping 支持。
  2. rms_normfused_rotary_emb 来自 InternLM/lmdeploy
  3. gelu_and_mul_fwdsilu_and_mul_fwd 来自 ModelTC/lightllm

Llama3 使用了 xformers 中的 memory_efficient_attention

与训练一样,整个推理过程也基于序列 collate,不需要 padding。

同比赛其他方案