620. LMSYS - Chatbot Arena Human Preference Predictions | lmsys-chatbot-arena
首先,感谢组织者。我们非常享受这次比赛的内容。虽然数据泄露令人沮丧,但我们坚持下来并获得了第二名。
恭喜我的队友 @liushuzhi 和 @kapenon 成为 Grandmaster。
StratifiedGroupKFold,预留 20% 作为验证集。XXXForSequenceClassification 初始化头会在早期迭代中产生高 loss,因此在模型初始化后重新初始化头。 model.score = torch.nn.Sequential(
torch.nn.Dropout(0.1),
torch.nn.Linear(hdim, hdim // 2),
torch.nn.Dropout(0.1),
torch.nn.GELU(),
torch.nn.Linear(hdim // 2, 3),
)
根据我之前参加比赛的经验,我没有尝试 LoRA,只使用了全参数训练。使用 BF16 和支持 kahan summation 的优化器,可以使用单个 A100 80G 训练 7B 模型,9B 模型则需要两个 A100。
在本次比赛的最后 10 天,我使用了 A100 80G x4 进行所有实验。
在早期实验中,我执行了 response_a 和 response_b 的随机交换作为增强,这稍微提高了验证 log_loss。@kapenon 发现同时包含原始样本及其交换样本更好。为了避免过拟合,原始样本及其交换样本的梯度必须为同一个 optimizer.step 累积。虽然训练时间翻倍,但与随机交换相比,全交换为 gemma2-9b 带来了稳定的 0.003 提升。
此外,我尝试以同样的方式添加不同的输入格式(PAB 和 PAPB)作为增强,带来了微小 (0.001) 的提升。
我们微调了 google/gemma-2-9b-it、google/gemma-2-27b-it 和 RLHFlow/ArmoRM-Llama3-8B-v0.1。 without TTA 的验证 log_loss 分别为 0.891、0.883 和 0.899。平均集成后,log_loss 为 0.876。
在完成 gemma-2-9b 训练后,我花了很多时间尝试 gemma-2-27b 但没有得到好的结果。通过与 @kapenon 的代码比较,我将 batch_size 调整为 80 并关闭了 grad_clip,最终成功训练了模型。
我们使用阶段 1 获得的集成模型为 240k 数据生成了伪标签。其中,110k 数据来自 lmsys-1m(由 @kapenon 准备),130k 来自其他数据集(@liushuzhi 使用 1.5B 模型测试了许多外部数据集)。
在这个数据集上,我们微调了 gemma-2-9b 和 RLHFlow/ArmoRM-Llama3-8B-v0.1。
从这个阶段开始,我关闭了 gemma-2-9b 的窗口注意力 (window attention),因为我不确定支持 sm75 的高效注意力实现是否可以做窗口注意力。最长输入长度为 4340(包括指令),所以这对分数的影响应该很小。
生成伪标签时未应用 TTA(这是我的失误)。
基于阶段 2 获得的检查点,我们使用 55k+21k 数据进行了微调。由于时间限制,禁用了多输入格式增强。
在 20% 验证集上,两个模型分别达到了 0.884 和 0.890,平均集成 log_loss 为 0.876~0.877。
提交时,RLHFlow/ArmoRM-Llama3-8B-v0.1 的输入进行了 AB 交换。该模型在旧 LB 上得分为 0.873。用所有数据训练后,该模型在旧 LB 上达到 0.869。调整集成比例为 2:1 后,得分为 0.868。
我们使用 flash-attn==2.6.2 以支持 logit_softcapping。
当使用 flash_attn_varlen_func 时,attention_mask 和 padding 是不必要的。为了避免浪费计算在 pad token 上,我:
cu_seqlensinput_ids 和 cu_seqlens。模型从头到尾不涉及 padding。此外,使用了 transformer_engine 中的 RMSNorm 和 FusedRoPEFunc 来进一步加速训练。
T4x2 足以以 fp16 运行 7b-9b 模型。Transformers 可以几乎均匀地分布在 2 个 GPU 上,只需要 slight 代码修改使两个 GPU 上的执行流水线化。
在 T4 (sm75) 上使用最新的高效算子似乎不容易。经过一些尝试,我使用了以下 triton 算子进行推理:
context_attention_fwd 来自 ModelTC/lightllm,带有一些优化和 logit_softcapping 支持。rms_norm 和 fused_rotary_emb 来自 InternLM/lmdeploygelu_and_mul_fwd 和 silu_and_mul_fwd 来自 ModelTC/lightllmLlama3 使用了 xformers 中的 memory_efficient_attention。
与训练一样,整个推理过程也基于序列 collate,不需要 padding。