返回列表

3rd place solution

600. HMS - Harmful Brain Activity Classification | hms-harmful-brain-activity-classification

开始: 2024-01-09 结束: 2024-04-08 临床决策支持 数据算法赛
第三名方案

第三名方案

感谢 Kaggle 和所有参与举办这场有趣比赛的人员。我们学到了很多关于 EEG 数据以及如何为之构建强模型的知识。特别感谢 @darraghdog,在比赛的最后几周承担了大部分工作,当时我正忙于训练我的第一个真实神经网络。

摘要

我们的方案是将来自两种不同建模思路的多个模型进行集成。第一种思路是把预训练的 2D‑CNN 网络应用于数据的 MelSpectrogram 变换;第二种思路是先用 1D 卷积对原始 EEG 进行编码,再使用 Squeezeformer 块进行建模。我们方案的关键要素包括基于数据筛选的稳健交叉验证以及创意十足的数据增强。

交叉验证与数据过滤

拥有可靠的交叉验证显然是我们在这场比赛中取得好成绩的关键!我们投入了大量时间设计合适的验证策略。总体上我们采用 4 折交叉验证,以患者 id 为划分依据。一个关键技巧是对验证集仅保留投票数大于 9 的样本。SPaRCNet 论文可能解释了数据的生成方式。

image

测试数据不仅过滤掉投票数不足 9 的样本,而且原始数据通过平移 EEG 并使用相同标签进行了增强。这意味着给定的 106k 训练行高度冗余,应该被过滤。

我们通过添加同一 EEG 稳定期的额外片段,扩展了高质量和低质量的 EEG 片段集合。

简而言之,作者对原始 EEG 进行了类似平移的增强来生成更多数据,并使用相同标签。这解释了大量相同 EEG_id 拥有相同标签的行。我们相信将原始数据与增强数据分离可以使数据更干净,并提升交叉验证。于是我们投入精力逆向还原这一过程,寻找给定标签对应的原始“真实”数据点。最终我们得到一个仅包含 6350 行的过滤数据集,不仅用于 4 折验证,还主要用于训练,而放弃了其余约 10 万行。最后我们的 CV 与 LB 相关性相当好。

image

数据来源

对于 2D 模型,因为我们拥有预训练权重,观察到高质量数据非常关键,使用伪标签以及投票数小于 8 的数据没有收益。所有模型都使用高质量数据进行训练。

我们的 1D 模型是从零开始训练的,因此加入低投票数的数据也会有帮助。关于如何在训练中使用低投票数样本,详见后面的训练流程。我们尝试了很多伪标签方案,但没有取得成效。

最终我们没有使用 10 分钟的频谱图。它们在部分 2D 模型实验中略有帮助,但对整体影响不大。

数据预处理

我们没有对数据进行磁盘层面的预处理,这大大加速了实验并提升了方法的灵活性。我们使用 torchaudio 的 GPU 实现直接在运行时生成 MelSpectrogram 并对信号进行归一化。

整个方案都使用了 @cdeotte 解释的双香蕉(double‑banana) Montage。如其他人所说,将 16 个信号并排堆叠会导致每个节点在模型中与其相邻节点的交互更强。我们尝试了不同的排序方式,发现以下顺序效果最好,它实际上把大脑左右两侧的同一节点配在一起:
Fp1>F7 Fp2>F8 F7>T3 F8>T4 T3>T5 T4>T6 T5>O1 T6>O2 Fp1>F3 Fp2>F4 F3>C3 F4>C4 C3>P3 C4>P4 P3>O1 P4>O2

使用 Scipy 的 butter 带通滤波器,阶数为 2,低通在 0~1.5 Hz 之间(不同模型略有不同),高通在 20~30 Hz 之间。没有使用陷波滤波器。

我们发现不对数据进行样本级或批次级归一化非常重要。这一做法来源于 @medali1992 的 ResNet 1D GRU 实现——不确定这是否是最初的想法。于是,在 butter 滤波后,我们采用 x = x.clip(-1024, 1024) / 32. 的方式进行处理。

数据增强

对于 2D 模型,因为我们有预训练权重,发现高质量数据更重要——并非所有增强都有帮助。我们的 1D 模型是从零开始训练的,因此使用了较重的增强手段。两种模型的增强都极其有效。

两种模型共用的增强方式如下:

  • 在 50% 的样本中,将 16 个 mel‑spec 节点中的 1~8 个节点用零覆盖。
  • 在 20% 的样本中,随机选取不同的更窄的 butter 带通范围,覆盖 1~8 个节点。
  • 在 50% 的样本中,以中心点为基准随机平移 50 秒窗口,最多平移 20 秒。

仅在 1D 模型中使用的增强:

  • 在 50% 的样本中,对信号进行左右翻转(时间维度)。
  • 在 50% 的样本中,交换大脑左右侧。

仅在 2D 模型中使用的增强:

  • 在 50% 的样本中,针对放大的中心区域(使用固定的 50 秒窗口),将 10 秒中心点在窗口内随机平移最多 5 秒。

模型

MelSpectrogram + 2D‑CNN 主干网络

与大多数竞争对手相同,我们使用预训练的 2D‑CNN 主干网络并将 MelSpectrogram 变换后的数据输入模型。然而我们对公开方法做了几项改进。第一个主要提升来源于不把 EEG 电极的数据合并为区域(LL、LP、RP、RL),而是为双香蕉 Montage 各自生成 16 张独立的梅尔频谱图,然后把 16 张图片拼接成一张大图。在最后一周,我们发现了另一个重要改进——“放大”到中心 10 秒。放大时使用不同的窗口和步长,以增加多样性。Mel‑Spectrogram 变换直接在 GPU 上使用 torchaudio 的实现,作为模型的一部分完成。

image

1D‑CNN + Squeezeformer 块

我们的 1D‑CNN 极大受到了 Yuri Sun(@sunyuri)的工作启发,特别是他用于癫痫检测的轻量级 CNN 实现(代码与论文)。如架构所示,该模型关注时间维度的分组卷积,然后是通道维度的卷积。前两个块专注于时间维度,保持通道信息独立并使用分组卷积。在对时间维度进行池化后,进行一次跨脑区节点的通道卷积,随后再进行一次时间卷积。实际上我们使用 conv2d,但效果等同于 conv1d,并且可以并行卷积所有通道。原始实现使用了 CBAM 注意力块,我们将其替换为 3 层 Squeezeformer。此外,在深度卷积前后加入了逐点卷积,这有助于网络稳定并学习更好的表示。
由于在相对较小的数据上从零训练,保持参数量低非常重要。在 ASL 竞赛期间我们在这方面投入了大量时间,因此直接复用了 那里的 Squeezeformer 实现。

image

训练流程

MelSpectrogram + 2D‑CNN 主干网络

训练参数一般为 12~16 轮,使用余弦学习率衰减,初始学习率 0.0012,批大小 32。Drop path 效果显著,设为 0.2。训练的主要区别在于放大 10 秒窗口的粒度,不同模型使用不同的窗口长度训练,并且配合 mixnet_l 与 mixnet_xl 主干网络。

1D‑CNN + SqueezeFormer 块

训练参数约为 32 轮,余弦学习率衰减,初始学习率 0.001,批大小 64。当使用低投票数样本时,批大小提升至 256。隐藏维度为 128,dropout(注意力、前馈、头、卷积)均为 0.1,增大隐藏维度或增加层数并没有帮助。
我们针对数据训练了 3 种不同版本的 1D 模型:

  • 仅使用投票数 ≥8 的样本。
  • 使用投票数 3~8 与 ≥8 的样本。每一轮对 3~8 投票数据集进行子采样,并为每个集合使用不同的损失。随着训练进行,3~8 批的权重从 0.8 线性或余弦衰减至约 0.2。该加权损失加到 ≥8 投票数据集的损失上。
  • 进一步加入投票数 1~2 的样本,作为一个单独的衰减权重损失,与 3~8 投票集合和 ≥8 投票集合一起训练。1~2 投票集合噪声很大,权重从 0.5 开始并在结束时衰减至 0。

为了增加多样性,我们还训练了一些使用 butter 带通阶数为 0 或 1 的 1D 模型。虽然表现稍弱,但能为融合提供多样性。

集成与后处理

截至最后一天,我们一直使用平均加权的方式,2D 模型的权重更高。在距离结束 15 小时时,借助 GPT4 我们构建了一个简易神经网络来学习 23 个基模型的最好融合权重。其中 5 个模型是不同训练方式的 1D‑CNN,18 个模型是不同主干网络、不同 10 秒放大窗口的 Mixnet 2D‑CNN。
在 CV 上学习的融合权重提升了分数,但对 Public LB 的影响并不显著。
此前我们曾考虑预测分布与训练集中 >8 投票部分的分布不同。因此在距离结束约 10 小时时,我们为每个类别(共 6 类)添加了一个偏置项,加入到 logits 中。CV 变化很小,但 Public LB 分数公布后,我们从第 9 名跃升至第 4 名,令人振奋。该网络结构非常简单,如下所示。在接下来几天的大多数融合方案(包括平均权重)下,我们的私有 LB 分数都保持在 0.27 左右的获奖区间。

class Net(nn.Module):
    def __init__(self, n_models = 23):
        super(Net, self).__init__()
        self.fc = nn.Linear(n_models, 1, bias = False)
        self.fc_c = torch.nn.Parameter(torch.zeros(6)[None,:,None])
    def forward(self, x):
        return self.fc(x) + self.fc_c

补充数据

没有使用任何补充数据。

消融实验(约略)

  • 16 个节点信号的排序 -0.02
  • 在 1D 模型中加入低投票样本 -0.01
  • 数据增强 -0.03 或更多
  • 放大标注窗口 -0.006;不同视角的混合还能进一步提升
  • 去除样本级和批次级归一化 -0.02
  • 使用全连接层学习融合权重 -0.003
  • 使用偏置项后处理 -0.001
模型折数种子数CV / Public / Private
最佳单模型(1D)4 折2 种子/折0.257 / 0.25 / 0.30
最佳单模型(2D)4 折2 种子/折0.232 / 0.24 / 0.29
最终融合4 折2 种子,23 模型/折0.207 / 0.22 / 0.27

使用的工具 / 仓库

PyTorch
Timm、Huggingface、Albumentations。
@darraghdog 使用了从 runpod.io 租用的 4090 GPU 实例。
Neptune.ai 是我们的 MLOps 平台,用于追踪、比较和共享模型。它被大量使用,使得 4 折分组视图和所有折验证分数的平均值变得非常方便,极大地简化了模型追踪工作。

image

感谢阅读,欢迎提问。

编辑:
推理代码:https://www.kaggle.com/code/darraghdog/3rd-place-solution
GitHub:https://github.com/darraghdog/kaggle-hms-3rd-place-solution

同比赛其他方案