返回列表

4th place: Simple ResNet18 classification

653. BYU - Locating Bacterial Flagellar Motors 2025 | byu-locating-bacterial-flagellar-motors-2025

开始: 2025-03-05 结束: 2025-06-04 医学影像分析 数据算法赛
第四名:简单的 ResNet18 分类

第四名:简单的 ResNet18 分类

作者: Dieter (ChristofHenkel)
发布时间: 2025-06-11
竞赛排名: 第 4 名

感谢 Kaggle 及所有参与举办这场激动人心比赛的人。这是一次很棒的学习经历,很有趣看到我们获得第一名的 CryoET 方法论有多少可以应用到这里。感谢 @bloodaxe 带来的伟大团队体验。我还要感谢乌克兰武装部队为我的队友提供安全和保障,使其能够参加这次比赛。

简而言之 (TLDR)

该解决方案是一个简单的 3D-ResNet18 分类器和来自 MONAI 的目标检测模型的集成。我们还使用 MONAI 进行数据增强,并通过 jit 或 TensorRT 导出模型,这显著提高了速度,使我们能够拥有稍大一点的集成模型。我们使用了由 @brendanartley 分享的额外数据。

本文涵盖了基于 ResNet18 分类的方法。关于目标检测部分,请参阅 @bloodaxe 的撰写:第四名解决方案 [目标检测部分]

交叉验证

我按体素大小(Voxel Size)分割原始训练数据,按数据集 ID 分割外部数据,以某种程度模拟训练/测试的差异。使用了 4 折交叉验证。与 Leaderboard (LB) 的相关性不是很好,所以我主要依赖 LB 分数作为反馈。

数据预处理/增强

3D 图像被缩放到固定的体素大小 15.6,并使用 int8 保存到磁盘。
由于模型是从头开始训练的,增强对于防止过拟合至关重要。
我在 torch dataloader 中使用了 RandomCrop(大小 96x160x160),每个轴上的翻转,以及在 GPU 上的缩放 + 旋转(全部来自 MONAI)。此外,我使用了一个自定义实现的 MixUp,这对于延长训练时间和防止过拟合非常有效。我实现了一个版本,确保混合 patch 中不超过 1 个马达。此外,正样本(即包含马达的 crops)被过采样,总比例为 12.5%。

模型

建模在这次比赛中相当有趣。我从 CryoET 竞赛第一名解决方案中的 3D UNET 开始,效果已经相当不错。在了解到竞赛指标对于定位的宽容度后,我尝试进一步简化模型,去掉任何解码器,因为 32 倍下采样的模型输出应该已经足够了。令人惊讶的是,一个简单的 ResNet3D 编码器就奏效了。我的方法如下:

分类模型的输入是 96x160x160 的图像 patch,这在离开 ResNet 骨干网络后产生了 512x3x5x5 的特征图。我展平 3x5x5 的输出“像素”,并使用一个简单的全连接 512->1 层来对每个像素进行二元预测,这基本上是 75 个类,用于确定马达位置。我还添加了一个额外的类来反映 crop 中没有马达的情况。所以总共有 76 个类的简单 3D-ResNet18 分类器,使用 CrossEntropy Loss 进行训练。
对于推理,我使用了重叠率为 0.5 的滑动窗口方法。对于马达定位,直接取 75 个类中预测值最大的 patch。然后使用 patch 位置 + 来自 3x5x5 网格的偏移量来确定最终定位。这个非常简单且快速的模型单独得分就在 public LB 上达到 0.875(第 5 名)!

模型架构示意图

通过代码可能更容易理解架构:

import monai.networks.nets as mnn

def downscale(y, scale=32):
    bs, c, d, h, w = y.shape
    idxs = torch.where(y>0)
    for item in idxs[2:]:
        item //= scale
    y2 = torch.zeros((bs,c,d//scale,h//scale,w//scale), dtype=y.dtype, layout=y.layout, device=y.device)
    y2[idxs] += 1
    return y2

cfg.backbone_args = dict(model_name='resnet18',
                         spatial_dims=3,    
                         pretrained=False, 
in_channels=1)

class Net(nn.Module):

    def __init__(self, cfg):
        super(Net, self).__init__()
        
        self.backbone = mnn.ResNetFeatures(**cfg.backbone_args)
        self.global_pool = nn.AdaptiveAvgPool3d(1)
        self.reg_head = nn.Conv3d(512,1,kernel_size=1,stride=1)
        self.cls_head = torch.nn.Linear(512,1)        
           
    def forward(self, batch):

        x = batch['input']
        
        out = self.backbone(x)[-1]
        loc_logits = self.reg_head(out)
        cls_logits = self.cls_head(self.global_pool(out).flatten(1))

        loss = self.custom_loss(y,loc_logits,cls_logits)
        outputs = {'loss':loss,'logits':loc_logits}
        return outputs

    def custom_loss(self, target, logits, cls_logits):

        y2 = downscale(target,scale=32)
        l = logits.flatten(1)
        y3 = y2.flatten(1)
        y3 = torch.cat([y3,1-y3.max(1)[0][:,None]],dim=-1)
        l2 = torch.cat([l,cls_logits],dim=-1)
        l_cls = DenseCrossEntropy1D()(l2,y3)
        return l_cls

对于阈值处理,我使用了基于分位数的方法,因为在比较不同模型时这更稳定。

模型使用 bf16 训练,每个折在单个 A100 上训练大约需要 17 小时。虽然模型很小很简单,但我尝试了很多其他架构,替代方案的表现要差得多。所以我坚持使用了这个。

代码库非常接近 CryoET 第一名解决方案,所以我就不麻烦发布这个了。
再见。欢迎提问。

同比赛其他方案