返回列表

1st Place Solution ➡️ Distill is all you need🔥🔥🔥

620. LMSYS - Chatbot Arena Human Preference Predictions | lmsys-chatbot-arena

开始: 2024-05-02 结束: 2024-08-12 自然语言处理 数据算法赛
第一名解决方案 ➡️ 蒸馏就是你所需要的🔥🔥🔥

第一名解决方案 ➡️ 蒸馏就是你所需要的🔥🔥🔥

作者: sayoulala (Grandmaster)
发布时间: 2024-08-19

介绍

我很激动能在这次比赛中获得第一名,这标志着我的第一枚个人金牌🏅。
我要感谢 Kaggle 和组织者举办这次比赛。
尽管数据泄露问题困扰了许多参与者,但我 appreciate Kaggle 为挽救比赛所做的努力。
在我看来,这次比赛的 CV 和 LB 非常一致,使其成为一场罕见且优秀的竞赛,尤其是在当前 LLM 的背景下。🔥
接下来,我将总结我的解决方案。

解决方案

数据集

Kaggle 训练数据
ut 数据 (https://www.kaggle.com/competitions/lmsys-chatbot-arena/discussion/499756)
33k 数据 (https://www.kaggle.com/competitions/lmsys-chatbot-arena/discussion/500973)

基础模型

llama3 70b, qwen2 72b, gemma2-9b

基础模型架构

AutoModelForSequenceClassification
lora(9b)
qlora(llama3 和 qwen2)
lora 应用于所有线性层
r=64, a=128
max_len=1024
epoch = 2
global batch_size = 64

后预训练 (Post-pretrain)

首先,使用 ut 数据集在三个模型上训练一个 epoch。(lr=1e-5)

获取 logits 分布

加载后预训练的权重,将数据集分为 5 折进行训练
(例如:train➡️4/5 Kaggle 训练数据 + 33k 数据,dev➡️1/5 Kaggle 训练数据)来训练 llama3 70b 和 qwen2 72b。
然后推断训练集的概率分布。

使用 logits 蒸馏到 9b 模型

获得 logits 分布后,加载 9b 模型进行微调,并在微调过程中加入蒸馏损失。(训练时至少使用三种损失,lr=5e-5)。

模型集成

直接平均 5 折的 LoRA 层。

获取 8bit 模型

使用 GPTQ 量化为 8-bit,并在提交期间使用 TTA(长度 2000)。

CV/LB

(在这里,我只提供最终结果。之前有过太多实验,但这些结果是最重要的。)

  • qwen72b
    5 折 CV: 0.875, 0.881, 0.869, 0.880, 0.875
  • llama3 70b
    5 折 CV: 0.874, 0.877, 0.877, 0.873, 0.873
  • 蒸馏 gemma 9b
    5 折 CV: 0.862, 0.876, 0.858, 0.872, 0.868
  • 合并 lora 并量化为 8bit
    LB: 0.882 (使用 TTA 0.876) 最终 PB:0.96898
    (在最终提交中,我有一个提交也运行失败了,因为我删除了另一个已上传的模型。)

总结

在我的解决方案中,最重要的方面是使用较大模型进行蒸馏。还有一些其他细节,如果您感兴趣可以自行探索。我相信蒸馏是一个非常有前景的方法,尤其是在当前的 Kaggle 比赛中,推理限制是一个 limiting factor。

一个小推荐

BlackPearl 也参加了今年的 KDD Cup 2024 OAG-Challenge 并横扫了该赛道的所有冠军。该赛道包括三个挑战:AQA、PST 和 WhoIsWho-IND。在我们的解决方案中,我们采用 LLM 来解决分类和向量召回问题,显著优于传统的特征提取和基于 BERT 的方法。我们也开源了我们的代码,欢迎您 star (https://github.com/BlackPearl-Lab/KddCup-2024-OAG-Challenge-1st-Solutions)。

同比赛其他方案