653. BYU - Locating Bacterial Flagellar Motors 2025 | byu-locating-bacterial-flagellar-motors-2025
感谢 Kaggle 及所有参与举办这场激动人心比赛的人。这是一次很棒的学习经历,很有趣看到我们获得第一名的 CryoET 方法论有多少可以应用到这里。感谢 @bloodaxe 带来的伟大团队体验。我还要感谢乌克兰武装部队为我的队友提供安全和保障,使其能够参加这次比赛。
该解决方案是一个简单的 3D-ResNet18 分类器和来自 MONAI 的目标检测模型的集成。我们还使用 MONAI 进行数据增强,并通过 jit 或 TensorRT 导出模型,这显著提高了速度,使我们能够拥有稍大一点的集成模型。我们使用了由 @brendanartley 分享的额外数据。
本文涵盖了基于 ResNet18 分类的方法。关于目标检测部分,请参阅 @bloodaxe 的撰写:第四名解决方案 [目标检测部分]
我按体素大小(Voxel Size)分割原始训练数据,按数据集 ID 分割外部数据,以某种程度模拟训练/测试的差异。使用了 4 折交叉验证。与 Leaderboard (LB) 的相关性不是很好,所以我主要依赖 LB 分数作为反馈。
3D 图像被缩放到固定的体素大小 15.6,并使用 int8 保存到磁盘。
由于模型是从头开始训练的,增强对于防止过拟合至关重要。
我在 torch dataloader 中使用了 RandomCrop(大小 96x160x160),每个轴上的翻转,以及在 GPU 上的缩放 + 旋转(全部来自 MONAI)。此外,我使用了一个自定义实现的 MixUp,这对于延长训练时间和防止过拟合非常有效。我实现了一个版本,确保混合 patch 中不超过 1 个马达。此外,正样本(即包含马达的 crops)被过采样,总比例为 12.5%。
建模在这次比赛中相当有趣。我从 CryoET 竞赛第一名解决方案中的 3D UNET 开始,效果已经相当不错。在了解到竞赛指标对于定位的宽容度后,我尝试进一步简化模型,去掉任何解码器,因为 32 倍下采样的模型输出应该已经足够了。令人惊讶的是,一个简单的 ResNet3D 编码器就奏效了。我的方法如下:
分类模型的输入是 96x160x160 的图像 patch,这在离开 ResNet 骨干网络后产生了 512x3x5x5 的特征图。我展平 3x5x5 的输出“像素”,并使用一个简单的全连接 512->1 层来对每个像素进行二元预测,这基本上是 75 个类,用于确定马达位置。我还添加了一个额外的类来反映 crop 中没有马达的情况。所以总共有 76 个类的简单 3D-ResNet18 分类器,使用 CrossEntropy Loss 进行训练。
对于推理,我使用了重叠率为 0.5 的滑动窗口方法。对于马达定位,直接取 75 个类中预测值最大的 patch。然后使用 patch 位置 + 来自 3x5x5 网格的偏移量来确定最终定位。这个非常简单且快速的模型单独得分就在 public LB 上达到 0.875(第 5 名)!
通过代码可能更容易理解架构:
import monai.networks.nets as mnn
def downscale(y, scale=32):
bs, c, d, h, w = y.shape
idxs = torch.where(y>0)
for item in idxs[2:]:
item //= scale
y2 = torch.zeros((bs,c,d//scale,h//scale,w//scale), dtype=y.dtype, layout=y.layout, device=y.device)
y2[idxs] += 1
return y2
cfg.backbone_args = dict(model_name='resnet18',
spatial_dims=3,
pretrained=False,
in_channels=1)
class Net(nn.Module):
def __init__(self, cfg):
super(Net, self).__init__()
self.backbone = mnn.ResNetFeatures(**cfg.backbone_args)
self.global_pool = nn.AdaptiveAvgPool3d(1)
self.reg_head = nn.Conv3d(512,1,kernel_size=1,stride=1)
self.cls_head = torch.nn.Linear(512,1)
def forward(self, batch):
x = batch['input']
out = self.backbone(x)[-1]
loc_logits = self.reg_head(out)
cls_logits = self.cls_head(self.global_pool(out).flatten(1))
loss = self.custom_loss(y,loc_logits,cls_logits)
outputs = {'loss':loss,'logits':loc_logits}
return outputs
def custom_loss(self, target, logits, cls_logits):
y2 = downscale(target,scale=32)
l = logits.flatten(1)
y3 = y2.flatten(1)
y3 = torch.cat([y3,1-y3.max(1)[0][:,None]],dim=-1)
l2 = torch.cat([l,cls_logits],dim=-1)
l_cls = DenseCrossEntropy1D()(l2,y3)
return l_cls
对于阈值处理,我使用了基于分位数的方法,因为在比较不同模型时这更稳定。
模型使用 bf16 训练,每个折在单个 A100 上训练大约需要 17 小时。虽然模型很小很简单,但我尝试了很多其他架构,替代方案的表现要差得多。所以我坚持使用了这个。
代码库非常接近 CryoET 第一名解决方案,所以我就不麻烦发布这个了。
再见。欢迎提问。