返回列表

25th Solution - Custom MIL with timm backbones

510. Mayo Clinic - STRIP AI | mayo-clinic-strip-ai

开始: 2022-07-06 结束: 2022-10-05 医学影像分析 数据算法赛
第25名方案 - 基于 timm 骨干网络的自定义 MIL

第25名方案 - 基于 timm 骨干网络的自定义 MIL

作者: Gunes Evitan | 发布时间: 2022-10-06

两周前我刚开始参加这场比赛时,我想起了 RSNA-MICCAI 脑肿瘤放射基因组分类比赛的往事。这两场比赛在信号弱方面非常相似,但这场比赛稍微好一点,随机性也小一些。我在那场比赛中有一些不错的经验,并在这里应用了这些方法。

数据集

为了更高效地训练,我创建了一个预计算的实例数据集。预计算数据集每张图像包含 16 个大小为 1024x1024 的 3 通道(RGB)实例。数据集创建过程可简化为:

  • 使用 JPEG 压缩图像(100% JPEG 质量)
  • 将最长边调整为 20,000 像素
  • 提取不重叠的实例并用白色背景填充它们
  • 按总和降序对实例进行排序,并取前 16 个实例

模型

我使用多实例学习(MIL)模型来处理实例。我的 MIL 模型与本次比赛中分享的模型略有不同。我认为这个笔记本中的模型在高度维度上拼接特征图,这让我觉得很奇怪,所以我尝试了不同的池化和拼接方法。

class ConvolutionalMultiInstanceLearningModel(nn.Module):

    def __init__(self, n_instances, model_name, pretrained, freeze_parameters, aggregation, head_class, head_args):

        super(ConvolutionalMultiInstanceLearningModel, self).__init__()

        self.backbone = timm.create_model(
            model_name=model_name,
            pretrained=pretrained,
            num_classes=head_args['n_classes']
        )

        if freeze_parameters is not None:
            # Freeze all parameters in backbone
            if freeze_parameters == 'all':
                for parameter in self.backbone.parameters():
                    parameter.requires_grad = False
            else:
                # Freeze specified parameters in backbone
                for group in freeze_parameters:
                    if isinstance(self.backbone, timm.models.DenseNet):
                        for parameter in self.backbone.features[group].parameters():
                            parameter.requires_grad = False
                    elif isinstance(self.backbone, timm.models.EfficientNet):
                        for parameter in self.backbone.blocks[group].parameters():
                            parameter.requires_grad = False

        self.aggregation = aggregation
        n_classifier_features = self.backbone.get_classifier().in_features
        input_features = (n_classifier_features * n_instances) if self.aggregation == 'concat' else n_classifier_features
        self.classification_head = eval(head_class)(input_features=input_features, **head_args)

    def forward(self, x):

        # Stack instances on batch dimension before passing input to feature extractor
        input_batch_size, input_instance, input_channel, input_height, input_width = x.shape
        x = x.view(input_batch_size * input_instance, input_channel, input_height, input_width)
        x = self.backbone.forward_features(x)
        feature_batch_size, feature_channel, feature_height, feature_width = x.shape

        if self.aggregation == 'avg':
            # Average feature maps of multiple instances
            x = x.contiguous().view(input_batch_size, input_instance, feature_channel, feature_height, feature_width)
            x = torch.mean(x, dim=1)
        elif self.aggregation == 'max':
            # Max feature maps of multiple instances
            x = x.contiguous().view(input_batch_size, input_instance, feature_channel, feature_height, feature_width)
            x = torch.max(x, dim=1)[0]
        elif self.aggregation == 'logsumexp':
            # LogSumExp feature maps of multiple instances
            x = x.contiguous().view(input_batch_size, input_instance, feature_channel, feature_height, feature_width)
            x = torch.logsumexp(x, dim=1)
        elif self.aggregation == 'concat':
            # Stack feature maps on channel dimension
            x = x.contiguous().view(input_batch_size, input_instance * feature_channel, feature_height, feature_width)

        output = self.classification_head(x)
        return output

聚合参数调整模型的 MIL 池化或聚合部分。

同比赛其他方案