返回列表

7th place solution

590. UBC Ovarian Cancer Subtype Classification and Outlier Detection (UBC-OCEAN) | UBC-OCEAN

开始: 2023-10-06 结束: 2024-01-03 医学影像分析 数据算法赛
第7名解决方案 - UBC-OCEAN
排名: 第7名
作者: m1dsolo
发布日期: 2024-01-05

感谢Kaggle和UBC举办这场精彩的比赛,也祝贺所有获奖者所付出的努力!同时也要感谢我的队友以及论坛中的每一位朋友提供的帮助!

方法

概述

我们的最终方案基于多实例学习(MIL)进行卵巢癌亚型分类,并使用Sigmoid和阈值法进行异常值检测。在最终提交中,我们没有使用掩码标注和额外的数据集。

1. 预处理

1. 使用 pyvips 来加速PNG图像的读取速度。(感谢 GUNES EVITAN 的 pyvips 笔记本。)

image = pyvips.Image.new_from_file(image_id, access='sequential').numpy()
is_tma = image.shape[0] <= 5000 and image.shape[1] <= 5000

2. 将WSI和TMA分别从20倍和40倍下采样至10倍。(也许20倍的效果会更好,但由于资源限制,我无法提交。)

if is_tma:
    resize = A.Resize(image.shape[0] // 4, image.shape[1] // 4)
else:
    resize = A.Resize(image.shape[0] // 2, image.shape[1] // 2)
image = resize(image=image)['image']

3. 对WSI中相同的组织区域进行去重。(我不确定这是否对结果有贡献,但它节省了大量本地内存。)

def rgb2gray(image: np.ndarray):
    image = image.astype(np.float16)
    image = (image[..., 0] * 299 + image[..., 1] * 587 + image[..., 2] * 114) / 1000
    return image.astype(np.uint8)

if not is_tma:
    resize = A.Resize(image.shape[0] // 16, image.shape[1] // 16)
    thumbnail = resize(image=image)['image'].astype(np.float16)
    mask = rgb2gray(thumbnail) > 0
    x0, y0, x1, y1 = get_biggest_component_box(mask)

    scale_h = image.shape[0] / thumbnail.shape[0]
    scale_w = image.shape[1] / thumbnail.shape[1]

    x0 = max(0, math.floor(x0 * scale_w))
    y0 = max(0, math.floor(y0 * scale_h))
    x1 = min(image.shape[1] - 1, math.ceil(x1 * scale_w))
    y1 = min(image.shape[0] - 1, math.ceil(y1 * scale_h))
    image = image[y0: y1 + 1, x0: x1 + 1]

4. 使用非重叠滑动窗口方法将组织区域切割为256x256的图像块。(对于TMA我使用了重叠窗口,但不确定这是否会影响结果。)

def image2patches(image: np.ndarray, patch_size: int, step: int, ratio: float, transform, is_tma: bool):
    patches = []
    for i in range(0, image.shape[0], step):
        for j in range(0, image.shape[1], step):
            patch = image[i: i + patch_size, j: j + patch_size, :]
            if patch.shape != (patch_size, patch_size, 3):
                patch = np.pad(patch, ((0, patch_size - patch.shape[0]), (0, patch_size - patch.shape[1]), (0, 0)))

            if is_tma:
                patch = transform(image=patch)['image']
                patches.append(patch)
            else:
                patch_gray = rgb2gray(patch)  # (patch_size, patch_size)
                patch_binary = (patch_gray <= 220) & (patch_gray > 0)

                if np.count_nonzero(patch_binary) / patch_binary.size >= ratio:
                    patch = transform(image=patch)['image']
                    patches.append(patch)

    if len(patches) != 0:
        patches = torch.stack(patches, dim=0)
    else:
        patches = torch.zeros(0, dtype=torch.uint8)

    return patches

image2patches(image, 256, [256, 128][is_tma], 0.25, transform, is_tma)

2. 亚型分类

癌症亚型分类方法主要基于多实例学习(MIL)。在尝试了多种骨干网络和MIL方法后,最终选择了 CTransPathLunitDINO 作为骨干网络,DSMILPerceiver 作为MIL分类器。具体信息请参考:

  1. CTransPath, MIA2022
  2. LunitDINO, CVPR2023
  3. DSMIL, CVPR2021
  4. Perceiver, BMVA2023

本地交叉验证结果:

实验 CC EC HGSC LGSC MC 平均
CTransPath + DSMIL 0.9300 0.7657 0.8909 0.7822 0.7911 0.8320
CTransPath + Perceiver 0.9695 0.8147 0.8818 0.8044 0.9156 0.8772
LunitDINO + DSMIL 0.9400 0.7240 0.8864 0.8244 0.9356 0.8621
LunitDINO + Perceiver 0.9300 0.7983 0.8591 0.8711 0.8933 0.8704

排行榜结果:

实验 公共榜 私有榜
CTransPath + LunitDINO + DSMIL 0.57 0.54
CTransPath + LunitDINO + Perceiver 0.58 0.57
CTransPath + LunitDINO + DSMIL + Perceiver 0.6 0.58

我几乎没有调整MIL的超参数,因为发现交叉验证分数较高的模型在公共榜上的表现反而较低。

  1. 对于 DSMIL,我们使用 nn.CrossEntropyLoss 作为损失函数。
  2. 对于 Perceiver,我们使用 nn.BCEWithLogitsLoss 作为损失函数,并采用 mixuplabel smoothing 来缓解过拟合。

3. 异常值检测

我们尝试了许多方法,其中两种方法在私有榜上可以达到0.6的分数。(如果不使用异常值检测,私有榜分数为0.58。)

1. BCE + 阈值法

分数:公共榜 0.6,私有榜 0.6。

这个方法非常简单。使用 nn.BCEWithLogitsLoss 作为损失函数训练模型,然后对最大预测概率进行判断,如果小于0.4,则视为异常值。

logits = self.model(x)
probs = F.sigmoid(logits)  # (C,)
pred = probs.argmax(dim=0).item()
if max(probs) < PROB_THRESH:  # 根据验证集选择阈值
    pred = 5  # 'Other' 类别

2. 概率熵

分数:公共榜 0.54,私有榜 0.6。

这个方法也很简单。与设置概率阈值相比,该方法通过计算概率的熵来检测异常值。

logits = self.model(x)
probs = F.sigmoid(logits)  # (C,)
pred = probs.argmax(dim=0).item()
entropy = (probs * torch.log2(probs)).mean(dim=0)
if entropy > ENTROPY_THRESH:   # 根据验证集选择阈值
    pred = 5  # 'Other' 类别

总结

未成功的方法

  1. 额外数据集:ATEC、PTRC-HGSOC、CPTAC-OV、TCGA-OV、Bevacizumab。
  2. 通过注意力或掩码选择癌变区域,并端到端地微调骨干网络和MIL。
  3. 仅选择癌变区域的图像块用于MIL。
  4. 基于图像块预测概率熵检测异常值。(MIA2023
  5. 基于KNN分类器检测异常值。(Arxiv2023

补充材料

所有 PyTorch 代码(包括提交笔记本)均基于 一个简单的 PyTorch 深度学习框架 开发。
该框架仅有几百行代码,非常适合初学者学习。

  1. 提交笔记本
  2. 训练代码
同比赛其他方案