510. Mayo Clinic - STRIP AI | mayo-clinic-strip-ai
两周前我刚开始参加这场比赛时,我想起了 RSNA-MICCAI 脑肿瘤放射基因组分类比赛的往事。这两场比赛在信号弱方面非常相似,但这场比赛稍微好一点,随机性也小一些。我在那场比赛中有一些不错的经验,并在这里应用了这些方法。
为了更高效地训练,我创建了一个预计算的实例数据集。预计算数据集每张图像包含 16 个大小为 1024x1024 的 3 通道(RGB)实例。数据集创建过程可简化为:
我使用多实例学习(MIL)模型来处理实例。我的 MIL 模型与本次比赛中分享的模型略有不同。我认为这个笔记本中的模型在高度维度上拼接特征图,这让我觉得很奇怪,所以我尝试了不同的池化和拼接方法。
class ConvolutionalMultiInstanceLearningModel(nn.Module):
def __init__(self, n_instances, model_name, pretrained, freeze_parameters, aggregation, head_class, head_args):
super(ConvolutionalMultiInstanceLearningModel, self).__init__()
self.backbone = timm.create_model(
model_name=model_name,
pretrained=pretrained,
num_classes=head_args['n_classes']
)
if freeze_parameters is not None:
# Freeze all parameters in backbone
if freeze_parameters == 'all':
for parameter in self.backbone.parameters():
parameter.requires_grad = False
else:
# Freeze specified parameters in backbone
for group in freeze_parameters:
if isinstance(self.backbone, timm.models.DenseNet):
for parameter in self.backbone.features[group].parameters():
parameter.requires_grad = False
elif isinstance(self.backbone, timm.models.EfficientNet):
for parameter in self.backbone.blocks[group].parameters():
parameter.requires_grad = False
self.aggregation = aggregation
n_classifier_features = self.backbone.get_classifier().in_features
input_features = (n_classifier_features * n_instances) if self.aggregation == 'concat' else n_classifier_features
self.classification_head = eval(head_class)(input_features=input_features, **head_args)
def forward(self, x):
# Stack instances on batch dimension before passing input to feature extractor
input_batch_size, input_instance, input_channel, input_height, input_width = x.shape
x = x.view(input_batch_size * input_instance, input_channel, input_height, input_width)
x = self.backbone.forward_features(x)
feature_batch_size, feature_channel, feature_height, feature_width = x.shape
if self.aggregation == 'avg':
# Average feature maps of multiple instances
x = x.contiguous().view(input_batch_size, input_instance, feature_channel, feature_height, feature_width)
x = torch.mean(x, dim=1)
elif self.aggregation == 'max':
# Max feature maps of multiple instances
x = x.contiguous().view(input_batch_size, input_instance, feature_channel, feature_height, feature_width)
x = torch.max(x, dim=1)[0]
elif self.aggregation == 'logsumexp':
# LogSumExp feature maps of multiple instances
x = x.contiguous().view(input_batch_size, input_instance, feature_channel, feature_height, feature_width)
x = torch.logsumexp(x, dim=1)
elif self.aggregation == 'concat':
# Stack feature maps on channel dimension
x = x.contiguous().view(input_batch_size, input_instance * feature_channel, feature_height, feature_width)
output = self.classification_head(x)
return output
聚合参数调整模型的 MIL 池化或聚合部分。