返回列表

1st Place - 3D U-Net + Quantile Thresholding

653. BYU - Locating Bacterial Flagellar Motors 2025 | byu-locating-bacterial-flagellar-motors-2025

开始: 2025-03-05 结束: 2025-06-04 医学影像分析 数据算法赛
第一名 - 3D U-Net + 分位数阈值法

第一名 - 3D U-Net + 分位数阈值法

作者: Bartley (Grandmaster)
竞赛排名: 第 1 名
发布时间: 2025-06-05

感谢 BYU 和 Kaggle 举办此次竞赛。很高兴能再次参与一场组织完善的断层扫描竞赛,主办方非常棒。我简直不敢相信这个结果!

简而言之 (TLDR)

我的解决方案使用了一个经过大量数据增强和辅助损失函数训练的 3D U-Net。在推理过程中,我根据最大预测像素值对每个断层扫描图进行排名,并使用分位数阈值法来确定是否存在马达。

交叉验证 (Cross Validation)

为了验证模型,竞赛数据被分为 4 折。本地交叉验证 (CV) 与排行榜 (LB) 高度相关,直到大约 0.93 分。超过这个分数后,我使用公共排行榜进行验证。使用分位数阈值法对于从排行榜获得可靠反馈非常重要。更多细节见后处理部分。

预处理 (Preprocessing)

来自竞赛数据和 CryoET 数据门户的断层扫描图被用于创建训练集。每个断层扫描图使用 scipy.ndimage.zoom() 调整大小为 (128, 704, 704),并且没有马达的断层扫描图被丢弃。正如其他人指出的那样,竞赛数据相当嘈杂,因此使用 Napari 手动添加了缺失的马达。我将在这里添加更新后的数据。

对于标签,我使用以每个马达为中心的高斯热力图。类似于 CZII 竞赛中 @bloodaxe@christofhenkel 的解决方案,热力图的分辨率降低了 8 倍。这对本次竞赛特别有效,因为指标对距离误差有很高的容忍度。这意味着预测确切的像素不如预测马达的存在重要。如果你不信,下图展示了当体素间距等于 10 时,每个马达周围允许的大致误差范围。

断层扫描图误差容忍示意

模型 (Model)

该模型是一个 3D U-Net(sort of)。编码器是来自 Kenoshara 仓库的预训练 ResNet200,地址在这里。对于大多数实验,我使用了 ResNet101 变体,但增加编码器的容量会产生更好的性能。此外,应用随机 dropout 进行正则化,并使用梯度检查点来减少训练期间的显存 (vRAM) 使用。解码器在分割头之前使用单个反卷积块。

模型架构图

损失函数 (Loss)

模型使用 SmoothBCE 损失进行训练,包含 3 个贡献部分。主分割头预测输出 logits,深度监督头应用于倒数第二个特征图,最大池化损失(内核大小和步长为 4)应用于主分割头。此外,池化损失鼓励在马达区域周围具有高概率,同时减少对小定位误差的惩罚。

损失函数示意图

数据增强 (Augmentations)

大量的数据增强使得模型能够训练 400 个 epoch 而不过拟合。虽然,我可能本可以训练更长时间,但在 250 个 epoch 之后公共排行榜分数没有变化。

  • Mixup (100%)
  • Rescale/Zoom (100%)
  • Rotate90/180/270 (100%)
  • Axis Flips (100%)
  • Axis Swap (100%)
  • Coarse Dropout (50%)
  • Color inversion (25%)
  • Simple Cutmix (15%)

从磁盘加载断层扫描图很慢,这限制了 CPU 上进行数据增强的时间。为了解决这个问题,除了调整大小外,所有增强都在 GPU 上应用。为了尽可能保持调整大小的速度,使用了 scipy.ndimage.zoom(..., order=0)

推理 (Inference)

最初,在推理期间应用了相同的预处理管道。这效果很好,但匹配补丁的高度和宽度,并仅在深度上滑动要快 4 倍。这允许更多的时间进行测试时增强 (TTA) 和非常高的重叠率 (0.875)。两种方法的得分差不多,但我最终的解决方案使用了后者。

所有边缘预测都使用 roi_weight_map 参数进行降权。在聚合滑动窗口时,中间 40% 的 logits 权重为 1.0,其他 logits 权重为 0.001。

模型集成 (Ensembling)

最终提交使用了 8 种子集成。对每个模型输出应用 Sigmoid,并将 logits 相加。推理耗时约 10 小时。

后处理 (Postprocessing)

像许多人一样,我发现固定阈值不稳定。相反,我使用分位数阈值法来确定马达的存在。

为了应用这一点,所有断层扫描图都根据其最大预测像素值进行排名。然后,移除最低分位数的预测。我在公共排行榜上调整了分位数,然后向 Kaggle 之神祈祷私有排行榜也是类似的。在公共排行榜上,最佳阈值是 0.565,在私有排行榜上是 0.560。

排行榜分数示意图

最后说明 (Final Note)

感谢阅读,也感谢 everyone 对外部数据集表示赞赏。

Kaggle 快乐!

同比赛其他方案