返回列表

40th place solution + code

560. Benetech - Making Graphs Accessible | benetech-making-graphs-accessible

开始: 2023-03-21 结束: 2023-06-19 教育内容生成 数据算法赛
第40名解决方案 + 代码

第40名解决方案 + 代码

作者: moth (Kaggle Grandmaster)
团队成员: Cody_Null
发布日期: 2023-06-20

代码

我希望你们觉得训练和推理代码整洁有序,我已经尽力不制造混乱!

简要说明

以下是我们与 @cody11null 团队合作解决方案的简要说明。

我们的解决方案由对象检测器 (EfficientDet) 和图像编码器/文本解码器 (Matcha 和 Donut) 组成:

  1. 使用 Donut 模型作为 OCR (对象字符识别) 模型来提取图表的坐标轴标签(红色)
  2. 对象检测器 (EfficientDet) 检测数据系列(绿色)。最终我们仅将其用于散点图
  3. 第二个对象检测器 (EfficientDet) 检测图表的绘图区域(数据点所在的区域)(紫色)
  4. 了解图表的限制范围(来自步骤3)、数据系列的边界框(来自步骤2)和坐标轴数值(来自步骤1),您可以通过简单的交叉相乘得到数据系列的值。这个方法基本上如我在 讨论 中所述
  5. 对于其他图表类型 (vertical_barhorizontal_barlinedot),我们使用了 Matcha 模型
解决方案流程图

详细解决方案

1. Donut:坐标轴标签检测

我们训练了不同的图像编码器/文本解码器架构来检测坐标轴标签。我训练了一个 Donut 模型(令人惊讶地获得了最佳结果)和另外两个 Matcha 模型,但表现较差。尽管使用不同的 max_patches 进行训练,所有模型的结果都相当相似。大多数复杂情况与坐标轴标签的共享原点、长序列、具有多位小数的浮点数等有关。在这一步中,最重要的是正确获取每个轴的最小值和最大值,以便正确进行交叉相乘。

2. EfficientDet:数据系列对象检测器

这是我耗时最多的部分之一。在比赛初期,直到 Nicholas 分享了他的 Donut 模型之前,我不知道图像编码器/文本解码器模型能在比赛中表现如此出色,所以我专注于一个能够良好检测数据系列的对象检测器模型。为此,我缺乏边界框训练数据,于是开始手动标记数千张图像。这花费了我很多天时间。我的 EfficientDet 对象检测器在生成图像上表现极好,但在提取图像上表现不佳。我尽可能多地标记了图像。我的对象检测器在散点图上的精确匹配准确率约为50%。它在其他图表类型上表现更好,但由于散点图可能包含*大量*点,因此这是一项更艰巨的任务。

以下是一些我的对象检测器的预测结果:

对象检测器预测示例

3. EfficientDet:图表边界框对象检测器

作为管道的一部分,我必须检测图表的边界框,以便将像素坐标映射到坐标轴标签的真实数值。这一步相当简单,基本上是将第(2)部分相同的代码应用于注释中的 plot-bb 边界框。我们在 JSON 文件中拥有这些训练数据,这是一项简单的任务,因此对象检测器达到了高准确率。

4. Matcha 和 Donut 模型

在比赛初期,我开始对 Nicholas 的 Donut 模型进行多次更改,很快发现通过执行一些基本的后处理可以获得更高的性能。然而,即使训练更多轮次、执行数据增强和其他技巧,我意识到无法获得更高的性能。然后我尝试实现 Matcha 模型,并通过很多努力才使其工作(再次感谢 Nicholas 提出了 GitHub 讨论的问题)。一旦 Matcha 模型开始工作,我尝试了各种方法,直到能够尽可能提高性能。以下是在公开排行榜上达到 0.74 的路线图:

  • 0.20: 仅使用对象检测和 Donut 进行坐标轴标签的模型
  • 0.47: Nicholas 模型,但只是修改后处理,例如使用 xy 系列的最大长度而不是最小长度
  • 0.48: 进行额外的后处理,例如用数据系列平均值的平均值填充值
  • 0.49: 使用 Donut 和对象检测器的混合模型
  • 0.50: 将 Donut 训练10轮而不是5轮
  • 0.56: 爆发!通过使 Matcha 模型工作获得大幅提升。非常简单,没有花哨的东西
  • 0.61: 将 vanilla Matcha 与我的散点图对象检测器结合
  • 0.64: 将 Matcha 训练更多轮次(10轮)
  • 0.69: 使用 Bartley 从代码生成的图像训练 Matcha。这些改进大大提高了除散点图外所有图表类型的性能,每类图表仅增加了相对较少的额外图像(+5k)
  • 0.71: 将 Matcha 的 max_patches 从512增加到1024
  • 0.74: 在100%的提取图像上训练 Matcha,而不是我通常使用的75%,这样我就可以获得25%的验证集

无效尝试

太多无效尝试!我将写下我能记住的:

  1. 在比赛初期,我尝试为每种图表类型训练一个不同的图像编码器/文本解码器模型,而不是训练一个通用模型。我认为这没有起效的原因有两个:1. 我认为我为这个训练了一个 donut 模型 2. 我没有足够的图像(那时我还没有使用 Bartley 生成的图像)
  2. 我尝试了 Matcha 模型的不同预训练权重,但所有结果都更差。这些权重包括:statistachartqaplotqa。最佳性能来自 matcha-base
  3. 尝试为每个轴训练两个 Matcha 模型。我真的很认为这个想法可行,直到今天我也不知道为什么它不起作用!我的理由是:预测更长的序列更难,那么为什么不训练一个模型预测 x 轴系列,另一个预测 y 轴系列呢?但最终它的表现比组合模型更差
  4. 尝试使用大量生成的图像进行训练。我尝试为每类图表增加5k、10k和25k张额外图像。因此基本上我用比赛的60k张加上额外生成的图像训练模型。所以我训练了一些总图像数达18万张的模型!我很快意识到,如果额外生成的图像没有太多差异,无论增加多少图像,模型都会达到其容量。我在这里讨论了这个问题,这与收益递减规律有关
  5. 使用平衡采样器/过采样进行训练。我使用平衡采样器进行训练,这确保在每批中至少有一个少数类图像。我的少数类图像是提取的图像。这有两个效果:它过采样少数类,并且(理论上)使模型收敛更快。我没有看到过采样有任何好处(这并不奇怪)
  6. 数据增强。我尝试通过应用颜色相关的增强来增强数据,例如 RGBShiftRandomBrightnessColorJitter
  7. max_patches 参数增加到1536(介于1024和2048之间)。因此,当我看到将 max_patches 从512增加到1024的改进后,我想,为什么不进一步增加一点看看会发生什么?嗯,它似乎没有提高分数,当然消耗了更多的显存和计算时间。增加 max_patches 会大大增加显存消耗,因此必须降低批量大小,训练时间也更长
  8. 使用不同的调度器进行训练。我最终使用 OneCycleLR 进行训练,但之前尝试过其他调度器,甚至是恒定学习率,从未通过修改学习率调度器看到性能提升
  9. 对于对象检测器,我尝试使用加权框融合来提高数据系列的精确匹配。这场比赛最大的挑战之一就是获得正确数量的数据系列点。我的对象检测器有时会产生过多低置信度的边界框。我想也许如果我能将它们与高置信度的边界框融合,就能获得更高的准确率,但没能做到。最终我调整了保留/丢弃边界框的概率阈值。最终,保留 p>0.22 的边界框。如果有人知道怎么做,我会很高兴!
  10. 不同的对象检测器骨干网络。我使用 tf_efficientnetv2_s 作为我的对象检测器的骨干网络,但尝试过更大的模型,如 tf_efficientnetv2_l,但没有成功
  11. 自动混合精度。并不是说它不起作用,而是使用脑浮点张量 (bfloat16) 进行混合精度训练没有提供任何性能提升。然而,由于与 float32 相比你使用了更少的显存,你可以使用更大的批量大小,这是一个推荐的做法(我听说 Karpathy 说最好总是使用你能使用的最大批量进行训练,我可能在这里错了)
  12. 许多其他我不记得的事情

硬件

使用 Kaggle 提供的 P100 训练像 Matcha 和 Donut 这样的模型几乎是不可能的。我们最终购买了 Google Colab Pro+,它提供带有40 GB显存的 A100 GPU,因为我们俩都没有像 RTX 3090/4090 这样的深度学习 GPU。使用 A100,我们能够训练大型模型并更快地进行更多实验!我训练了超过30个不同的模型!

结论

尽管我本希望进入金牌区(总是希望做到最好),但我对我们获得的结果和在整个比赛中吸取的所有教训心怀感激。我坚信,如果你想学习某些东西,就去改变它!从头开始编写代码!这是最好的学习方式。我真的很期待阅读其他团队的解决方案。我还想知道是否有人尝试更改 Pix2Struct 的视觉模型和文本模型。我不知道是否可行或如何操作,所以如果你知道,请在评论中留言。

谢谢大家!

同比赛其他方案