返回列表

2nd place solution overview

386. Tweet Sentiment Extraction | tweet-sentiment-extraction

开始: 2020-03-23 结束: 2020-06-16 自然语言处理 数据算法赛
第二名解决方案概述

第二名解决方案概述

作者:hiromu | 排名:第2名

首先,我们要感谢 Kaggle 团队举办这场有趣的比赛。同时,我也要感谢我的队友(@yuooka@futureboykid)的辛勤付出。

这篇文章是我们解决方案的概述。我们的解决方案由四个主要部分组成(预处理、基础模型训练、重排序模型训练、后处理)。

预处理

我们简单地将后处理方法应用于预处理,参考了这篇帖子。使用这种方法,我们可以训练噪声更少的模型。

例如:selected_text "e fun" -> "fun"

基础模型训练

我们的训练方法与这个优秀的 Kernel 基本相同。

模型架构

我们尝试了很多模型架构……但是,我们最终使用了这两个在 SQuAD2 上预训练的 RoBERTa 模型。

  1. 使用第 11 层或第 23 层隐藏层作为输出。(这篇帖子

  2. 使用可训练向量并在其上应用 softmax 及多重 dropout。(Google Quest 第一名方案)(下文简称 MDO。)

损失函数

我们尝试了很多,但最终选择了简单的 CrossEntropyLoss。

class CROSS_ENTROPY:
    def __init__(self):
        self.CELoss = nn.CrossEntropyLoss(reduction='none')

    def __call__(self, pred, target):
        s_pre, e_pre = pred.split(1, dim=-1)
        s_tar, e_tar = target.split(1, dim=-1)

        s_loss = self.CELoss(s_pre.squeeze(-1), s_tar.squeeze(-1))
        e_loss = self.CELoss(e_pre.squeeze(-1), e_tar.squeeze(-1))

        loss = (s_loss + e_loss) / 2
        return loss

训练策略

我们遇到了学习不稳定的问题。有两个主要思路可以获得稳定的结果。

SentimentSampler

大多数公共 Kernel 使用按情感分层的 KFold。但是,我们认为这还不够。看这篇帖子,情感之间存在巨大差异。

因此,我们采用 SentimentSampler 来平衡批次内的不平衡。

SentimentSampler 示意图

SWA

我们发现训练过程中验证分数不稳定(仅 1 次迭代就会产生约 ±0.001 的波动!)

因此,我们采用 SWA 来稳定结果。

这使得通过 10~50 次迭代监控获得稳定结果成为可能,从而我们可以获得未对验证集过拟合的结果(此外验证时间也更高效。)

集成

使用两个不同的种子(种子平均)

RoBERTa Base 第 11 层 + RoBERTa Large 第 23 层 + RoBERTa Base MDO + RoBERTa Large MDO

4 个模型 * 2 个种子 = 总共 8 个模型

重排序模型训练 (Public LB +0.002, Private LB +0.003~5)

方法

我们认为这部分是我们独特之处,也是与其他团队最大的不同点。该方法基于以下思路:“创建多个候选答案并选择最佳的一个。”

Step1 基于基础模型的起始和结束值(应用 softmax)计算前 n 个分数。

Step2 根据 Step1 的分数创建候选。候选包括 selected_text、Jaccard_score 和 Step1_score。

Step3 训练 RoBERTa Base 模型。(目标是候选的 jaccard_score)

Step4 根据 Step3 预测值 + Step1 分数 * 0.5 对候选进行排序。并选择最好的一个作为最终答案。

SequenceBucketing

到目前为止,我们已经构建了超过 50 个模型(基础模型和重排序模型)。为了在有限的时间内完成推理,我们决定选择 SequenceBucketing,即使有些模型没有使用它进行训练。

在这种情况下,每个批次包含相同的文本和略有不同的候选。

因此,推理时间加快了 2 倍,令人惊讶的是,结果比不使用时更好。我们需要找出原因……

后处理 (Public LB +0.01, Private LB +0.012)

就像其他团队发现的魔法一样,我们关注了额外的空格。

def pp(filtered_output, real_tweet):
    filtered_output = ' '.join(filtered_output.split())
    if len(real_tweet.split()) <