返回列表

26th Solution - 3D UNet + CCL with Morphological Post Processing + Custom Loss Function

639. CZII - CryoET Object Identification | czii-cryo-et-object-identification

开始: 2024-11-06 结束: 2025-02-05 医学影像分析 数据算法赛
第 26 名解决方案 - 3D UNet + CCL 结合形态学后处理 + 自定义损失函数
作者: Andrei Zamfir
发布日期: 2025-02-06
竞赛排名: 第 26 名

第 26 名解决方案 - 3D UNet + CCL 结合形态学后处理 + 自定义损失函数

大家好,

首先,我要感谢竞赛主办方提出了这个非常有趣的问题。我对数据的同质性感到惊叹,我们训练的 7 个断层扫描图像足以很好地代表数据。私有/公有数据集的划分也非常棒,排行榜的变化很小。

其次,我要一如既往地感谢社区,感谢论坛上所有富有洞察力的讨论以及 Kaggle 分享 ideas 的精神。如果没有其他参与者的集体努力,我不可能完成这一切。特别感谢 @hengck23, @fnands, @davidlist@sacuscreed。我从 Hengck 那里学习了解决计算机视觉问题的基础知识,并通过剖析他的代码学到了很多。

解决方案总结

我的解决方案使用了 MONAI 的完整 3D UNet 实现,窗口大小为 184x184x184。我生成了多类硬标签,使用估计粒子半径 0.5 倍的球体。

我最终使用了 50% 的重叠,并为体积重建生成了高斯权重。对于数据增强管道,我使用了 XY 平面上的随机旋转、所有空间维度上的随机翻转、随机高斯噪声和随机缩放 (0.95-1.05)。
我也花了一些时间研究如何使用仿射变换复制缺失楔形伪影 (missing wedge artifact),但最终还是放弃了这个想法。

我尝试使用多种损失函数进行训练,最终选择了具有多种 alpha/beta 组合的 Tversky 损失,获得了不错的结果,最佳单个模型在公开排行榜上的得分为 0.70x-0.73x。

对于测试时增强 (TTA),我使用了 XY 平面上的 90-180-270 度旋转和全空间维度翻转。我使用了两种特定方法将 TTA 和模型集成在一起:

  1. 原始 Logit 平均:取 TTA 的平均值,将其与其他模型预测的 TTA 平均值相加并进行再次平均,重建体积,然后应用最终的 softmax。
  2. 每个 TTA 堆栈的 Softmax:然后将概率的最大值与其他模型对该子体积的预测相加,在模型之间执行平均并重建体积。

这两种方法的得分非常相似,尽管显然需要不同的 CCL 阈值,其中 TTA 的最大值对预测过于自信。回顾过去,我想如果我将这两种方法结合起来,可能会获得更多的集成多样性,但我还没有这样做。

在阈值处理将输出概率转换为硬二值掩码之后,我尝试了以下形态学操作:膨胀 (dilation)、腐蚀 (erosion)、开运算 (opening) 和闭运算 (closing)。在执行 CCL 之前完全膨胀所有的二值掩码,使排行榜得分提高了约 0.010-0.015。

随后,我使用 Optuna HPO 研究来搜索每个粒子的最佳阈值/形态学操作,以最大化该粒子的像素 FBeta(beta 为 1 或 1.25 表现最好,再高会牺牲太多精确率)得分(注意不要与竞赛 FBeta 混淆)。

在整个竞赛过程中,我只在 TS_99_9 上进行了验证,并在其他 6 个样本上进行了训练,我试图找到任何局部指标与排行榜之间的相关性。我最终使用了加权像素 FBeta,因为我发现排行榜 FBeta 非常不稳定,在排行榜上相关性不太好。

额外思考

一月份的大部分时间我都花在分析这个问题上,即如何找到这种相关性。一方面,我们并不是在做纯粹的分割问题,所以比较两个模型之间的 dice 指标并不能说明全部情况;另一方面,dice/IoU 或你想使用的任何其他指标与它如何转化为排行榜得分之间确实存在某种相关性。

假设你已经在正确的位置有一簇像素预测,无论你使用什么峰值检测方法都已经为你“得分”了该粒子。进一步改进这些概率(使它们更接近 1,或覆盖更多的真实值)可以提高 dice,但不会提高竞赛 FBeta。

这让我想到一个损失函数,可以优先考虑识别尽可能多的对象(以鼓励对象级召回率,直接转化为竞赛 FBeta)。

def CustomLoss(self, predictions, labels, metadata):
    
    pred_mask = torch.softmax(predictions, dim=0)
    threshold_sigmoid = 0.05
    sharpness_sigmoid = 10
    eps = 1e-6
    n_objects = 0
    total_object_loss = 0
    objects_found = 0
    object_counter = 0
    for class_label, obj_id, obj_mask in metadata:
            
        pred_probs = pred_mask[class_label]
        if class_label == 2:
            continue
            
        tp = torch.sum(pred_probs * obj_mask)
        fn = torch.sum((1 - pred_probs) * obj_mask)  
        weight = class_weights[class_label]
            
        recall = tp/(tp+fn+eps)
        sigmoid_recall = torch.sigmoid(sharpness_sigmoid * (recall - threshold_sigmoid))
            
        if recall > threshold_sigmoid:
            objects_found += 1
            
        total_object_loss += (-torch.log(sigmoid_recall + eps)) * weight
                
        object_counter += 1
        n_objects += weight
            
    loss = total_object_loss / (n_objects + eps)
    print(f'自定义损失等于 {loss}!')
    print(f'在召回率超过 {threshold_sigmoid} 时找到的对象:{objects_found} / {object_counter}!')
        
    return loss

上述函数遍历特定批次的元数据(在批次外部生成以保持可微性),迭代类别和单个对象并计算召回率(假设此批次中有 3 个核糖体和 5 个脱铁铁蛋白,元数据将创建 8 个真实掩码,每个粒子球体一个)。

然后我使用 sigmoid 进行软阈值处理,并将损失计算为召回率 sigmoid 的负对数。背后的想法是鼓励网络识别尽可能多的对象(由于元数据拆分对象的方式,只能计算召回率,因为预测的精确率会非常不准,考虑到它会考虑该类的所有预测,从而大幅降低精确率)。

该损失与具有较高 alpha 的 Tversky 损失相结合,以优先处理精确率,试图抵消自定义损失函数为达到每个对象 5% 召回率而生成的所有正预测。

这种组合需要大量调整以获得正确的召回率阈值、正确的类别权重以及自定义损失和 Tversky 损失混合的权重(-log of epsilon 将主导有界于 [0,1] 的 Tversky)。

不幸的是,我只用它得到了不错的结果,并没有真正提高我的排行榜得分,我不得不放弃这个想法,因为我缺乏 GPU 来进一步深入研究,加上竞赛截止日期临近。

最后,我使用了用这种组合训练的模型,这至少有助于集成混合的多样性。

经验教训

这是我第一次 dedicate 3 个月的时间参加竞赛,很有趣看到排行榜是如何演变的,论坛上的讨论以及共享的代码片段。我现在对如何管理时间、GPU 有了更好的预期,我认为最大的收获是始终重新审视代码管道的早期部分,并更加注意它们如何以及为何协同工作。

我倾向于对眼前的任务产生隧道视野,而较少考虑大局。

有时候,少即是多,暂时放下你无法解决的问题,会让你带着更新的视角回归。

同比赛其他方案