返回列表

2nd Place Solution - UBC-OCEAN

590. UBC Ovarian Cancer Subtype Classification and Outlier Detection (UBC-OCEAN) | UBC-OCEAN

开始: 2023-10-06 结束: 2024-01-03 医学影像分析 数据算法赛

第二名解决方案 - UBC-OCEAN

前言

全切片图像(WSI)分类中最主要的困难在于极高的分辨率,这应该是所有参赛者都经历过的。虽然竞赛组织者提供了处理难度较高的数据类型,但幸运的是,其分辨率远低于典型的WSI。在本文中,我们将详细介绍我们的方法。

概述

遵循学术界常用的方法,我们采取了以下步骤:

  1. 裁剪整个WSI为数千个图像块
  2. 使用特征提取器提取特征
  3. 训练MIL模型。

外部数据

我们使用了两个带有标签的外部数据集,所有参赛者均可免费下载。我们发现,尽管使用了更多外部数据和"其他"类别进行训练,但分数并未显著提高。我们认为这是由于外部数据的质量问题或与竞赛数据存在显著差异。正如一些参赛者不使用外部数据也能取得高分,我们认为在这个竞赛中外部数据并非必要。

裁剪图像块并提取特征

我们为每个WSI创建一个Dataset,代码如下:

class SingleWSIDataset(Dataset):
    def __init__(self, data_path: str, wsi_name: str, patch_size: int, mode: str):
        super().__init__()
        self.data_path = data_path
        self.wsi_name = wsi_name
        self.ratio = ratio
        assert mode in ['train', 'test']
        self.mode = mode
        self.wsi = pyvips.Image.new_from_file(os.path.join(data_path, f'{mode}_images', wsi_name + '.png'))
        self.is_tma = self.wsi.height < 5000 and self.wsi.width < 5000
        self.patch_size = patch_size
        self.transform = T.Compose([T.ToTensor(), T.Resize((224, 224), antialias=True), T.Normalize(mean=[0.2585, 0.2556, 0.2506], std=[0.229, 0.224, 0.225])])
        self.cor_list = self.get_patch()

    def get_patch(self):
        cor_list = []
        if self.is_tma:
            thumbnail = self.wsi
        else:
            thumbnail = pyvips.Image.new_from_file(os.path.join(self.data_path, f'{self.mode}_thumbnails', self.wsi_name + '_thumbnail.png'))
        wsi_width, wsi_height = self.wsi.width, self.wsi.height
        thu_width, thu_height = thumbnail.width, thumbnail.height
        h_r, w_r = wsi_height / thu_height, wsi_width / thu_width
        down_h, down_w = int(self.patch_size / h_r), int(self.patch_size / w_r)
        cors = [(x, y) for y in range(0, thu_height, down_h) for x in range(0, thu_width, down_w)]
        for x, y in cors:
            tile = thumbnail.crop(x, y, min(down_w, thu_width - x), min(down_h, thu_height - y)).numpy()[..., :3]
            black_bg = np.mean(tile, axis=2) < 20
            tile[black_bg, :] = 255
            mask_bg = np.mean(tile, axis=2) > 235
            if np.sum(mask_bg) < min(down_h, thu_height - y) * min(down_w, thu_width - x) * 0.7 or len(cor_list) == 0 or self.is_tma:
                cor_list.append((int(x * w_r), int(y * h_r)))
        if self.is_tma:
            return cor_list
        if self.wsi.height < 40000 and self.wsi.width < 40000:
            R_ratio = 0.8
        elif self.wsi.height < 80000 and self.wsi.width < 80000:
            R_ratio = 0.6
        else:
            R_ratio = 0.5
        random.shuffle(cor_list)
        cor_list = cor_list[:max(int(len(cor_list) * R_ratio), 1)]
        return cor_list

    def __len__(self):
        return len(self.cor_list)

    def __getitem__(self, idx):
        x, y = self.cor_list[idx]
        tile = self.wsi.crop(x, y, min(self.patch_size, self.wsi.width - x), min(self.patch_size, self.wsi.height - y)).numpy()[..., :3]
        tile = self.transform(tile)
        return tile

特征提取模型

我们使用了dino_vit_small_patch16_200ep.torchdino_vit_small_patch8_200ep.torch

MIL模型

  • ABMIL
  • DSMIL
  • TransMIL

代码

简化版本

最终版本

特征提取代码

同比赛其他方案