返回列表

3rd Place Solution for the UBC-OCEAN UBC Ovarian Cancer Subtype Classification and Outlier Detection (UBC-OCEAN)

590. UBC Ovarian Cancer Subtype Classification and Outlier Detection (UBC-OCEAN) | UBC-OCEAN

开始: 2023-10-06 结束: 2024-01-03 医学影像分析 数据算法赛
UBC-OCEAN卵巢癌亚型分类与异常检测季军方案

UBC-OCEAN卵巢癌亚型分类与异常检测季军方案

作者:DanielT

竞赛排名:第3名

发布时间:2024年1月4日

背景

业务背景:UBC卵巢癌亚型分类与异常检测

数据背景:本竞赛的挑战是从活检样本的显微扫描图像中对卵巢癌进行亚型分类。数据描述链接

方法概述

  • 寻找更多公开外部数据是关键。由于样本数量少,过拟合问题严重。最初希望使用CLAM或多实例学习(MIL)方法缓解,因为图像巨大可分割成数万切片。但模型仍严重过拟合。可能因为同一患者的切片存在相似性,模型利用这些捷径而非泛化特征。

  • 利用提供的分割数据创建合成组织微阵列(TMA)图像。从大图像的分割区域裁剪微小图像作为癌症组织,同时裁剪健康或基质区域生成"其他"类别合成图像。

  • 使用Lunit-DINO预训练模型提取特征。参考论文"组织病理学弱监督学习:好的特征提取器足矣",使用16位精度加速特征提取,对特征质量影响甚微。

  • 使用PyVips处理图像切片。曾尝试重写CLAM特征提取代码和使用large_image库,但受Kaggle资源限制困扰。最终使用PyVips和PyTorch异步加载实现。

  • 在提取特征上训练CLAM模型。该模型类似MIL但计算注意力矩阵加权切片。针对"其他"标签修改了实例级损失函数,因为癌变切片中的健康组织区域应标记为"其他"。

提交详情

CLAM模型

下图展示了哈佛/BWH & MGH Mahmood实验室的CLAM模型[1]。模型输入是全切片图像中所有组织切片的特征向量。上半部分计算注意力得分A(每个切片一个值),下半部分计算A加权特征和的类别分类。

CLAM示意图

示意图来自Paul Pham[2]

PyTorch实现的适配CLAM模型:

class Attn_Net_Gated(nn.Module):
    def __init__(self, L = 1024, D = 256, dropout = 0, n_classes = 1):
        super(Attn_Net_Gated, self).__init__()
        self.attention_a = [
            nn.Linear(L, D),
            nn.Tanh()
        ]
        self.attention_b = [
            nn.Linear(L, D),
            nn.Sigmoid()
        ]
        if dropout > 0:
            self.attention_a.append(nn.Dropout(dropout))
            self.attention_b.append(nn.Dropout(dropout))

        self.attention_a = nn.Sequential(*self.attention_a)
        self.attention_b = nn.Sequential(*self.attention_b)
        
        self.attention_c = nn.Linear(D, n_classes)

    def forward(self, x):
        a = self.attention_a(x)
        b = self.attention_b(x)
        A = a.mul(b)
        A = self.attention_c(A)  # N x n_classes
        return A, x

class CLAM_SB(nn.Module):
    def __init__(self, gate = True, size_arg = "small", n_classes=2, dropout = 0, k_sample=8,
            instance_loss_fn=None, subtyping=False, feature_dim=1024, use_inst_predictions=True,
            label_mapping=None, class_weights=None, inst_class_depth=None, inst_dropout=None):
        super().__init__()
        self.size_dict = {
            "very small": [feature_dim, 256, 128],
            "small": [feature_dim, 512, 256],
            "big": [feature_dim, 1024, 512],
            "xl": [feature_dim, 2048, 1024]
        }
        size = self.size_dict[size_arg]
        fc = [nn.Linear(size[0], size[1]), nn.ReLU()]
        if dropout > 0:
            fc.append(nn.Dropout(dropout))
        if gate:
            attention_net = Attn_Net_Gated(L = size[1], D = size[2], dropout = dropout, n_classes = 1)
        else:
            attention_net = Attn_Net(L = size[1], D = size[2], dropout = dropout, n_classes = 1)
        fc.append(attention_net)
        self.attention_net = nn.Sequential(*fc)
        self.classifiers = nn.Linear(size[1], n_classes)
        instance_classifiers = []  
        for class_idx in range(n_classes):
            layers = []
            for depth_idx in range(inst_class_depth-1):
                divisor = 2 ** depth_idx        
                layers.append(nn.Linear(size[1] // divisor, size[1] // (divisor * 2)))
                layers.append(nn.ReLU())
                if inst_dropout is not None:
                    layers.append(nn.Dropout(inst_dropout))
            layers.append(nn.Linear(size[1] // 2**(inst_class_depth-1), 1))
            instance_classifiers.append(nn.Sequential(*layers))  
        self.instance_classifiers = nn.ModuleList(instance_classifiers)
        self.k_sample = k_sample
        self.instance_loss_fn = instance_loss_fn
        self.n_classes = n_classes
        self.subtyping = subtyping
        self.use_inst_predictions = use_inst_predictions
        self.other_idx = label_mapping['Other']
        self.class_weights = class_weights
        initialize_weights(self)
        self.to('cuda')

    @staticmethod
    def create_positive_targets(length, device):
        return torch.full((length, ), 1, device=device).float()
    @staticmethod
    def create_negative_targets(length, device):
        return torch.full((length, ), 0, device=device).float()
    
    # 类内注意力分支的实例级评估
    def inst_eval(self, A, h, classifier, is_tma, is_other_class): 
        device=h.device
        if len(A.shape) == 1:
            A = A.view(1, -1)
        
        if is_tma:
            k_sample = self.k_sample // 2
        else:
            k_sample = self.k_sample

        if k_sample <= math.ceil(A.shape[1] / 2):
            top_p_ids = torch.topk(A, k_sample)[1][-1] # [1][-1]选择最后一个索引
        else:
            top_p_ids = torch.topk(A, math.ceil(A.shape[1] / 2))[1][-1]
            top_p_ids = top_p_ids.repeat(k_sample)[:k_sample]
        top_p = torch.index_select(h, dim=0, index=top_p_ids) # dim = k_sample x self.size_dict[1]
        if k_sample <= math.ceil(A.shape[1] / 2):
            top_n_ids = torch.topk(-A, k_sample, dim=1)[1][-1]
        else:
            top_n_ids = torch.topk(-A, math.ceil(A.shape[1] / 2))[1][-1]
            top_n_ids = top_n_ids.repeat(k_sample)[:k_sample]
        top_n = torch.index_select(h, dim=0, index=top_n_ids)
        p_targets = self.create_positive_targets(k_sample, device)
        n_targets = self.create_negative_targets(k_sample, device)

        # 错误评估使用正负标签,以约束负切片的低注意力
        p_logits = classifier(top_p) # dim = k_sample
        n_logits = classifier(top_n)
        inst_preds = (p_logits.squeeze() > 0).long()
        p_loss = self.instance_loss_fn(p_logits.squeeze(), p_targets) * (self.n_classes -1)
        n_loss = self.instance_loss_fn(n_logits.squeeze(), n_targets)
        if not is_tma and not is_other_class:
            loss = p_loss + n_loss
        else: loss = p_loss
        return loss, inst_preds, p_targets, p_logits
    
    # 类外注意力分支的实例级评估
    def inst_eval_out(self, A, h, classifier, is_tma):
        device=h.device
        if len(A.shape) == 1:
            A = A.view(1, -1)

        if is_tma:
            k_sample = self.k_sample // 2
        else:
            k_sample = self.k_sample

        if k_sample <= math.ceil(A.shape[1] / 2):
            top_ids = torch.topk(A, k_sample)[1][-1]
        else:
            top_ids = torch.topk(A, math.ceil(A.shape[1] / 2))[1][-1]
            top_ids = top_ids.repeat(k_sample)[:k_sample]
        top_inst = torch.index_select(h, dim=0, index=top_ids)
        top_targets = self.create_negative_targets(k_sample, device)
        
        logits = classifier(top_inst)
        inst_preds = (logits.squeeze() > 0).long()
        instance_loss = self.instance_loss_fn(logits.squeeze(), top_targets)
        return instance_loss, inst_preds, top_targets, logits
    
    def forward(self, h, bag_pred_weight:float, is_tma:bool, label=None, attention_only=False):
        A, h = self.attention_net(h)  # NxK        
        A = torch.transpose(A, 1, 0)  # KxN
        if attention_only:
            return A
        A_raw = A
        A = F.softmax(A, dim=1)  # 在N上softmax
        M = torch.mm(A, h) # shape 1 x self.size_dict[1]
        logits = self.classifiers(M)
        bag_Y_prob = F.softmax(logits.squeeze(), dim=0)

        if is_tma:
            k_sample = self.k_sample // 2
        else:
            k_sample = self.k_sample

        all_inst_logits = []
        top_p_ids = None
        if bag_pred_weight < 1 and label is not None:        
            total_inst_loss = 0.0
            all_inst_preds = []
            all_targets = []
            for i in range(len(self.instance_classifiers)):
                classifier = self.instance_classifiers[i]
                if i == label.item(): # 类内
                    is_other_class = (label.item() == self.other_idx)
                    instance_loss, inst_preds, targets, inst_logits = self.inst_eval(A, h, classifier, is_tma, is_other_class)
                    all_inst_preds.extend(inst_preds.cpu().numpy())                 
                    all_targets.extend(targets.cpu().numpy())
                    all_inst_logits.append(inst_logits)
                    if self.class_weights is not None:
                        instance_loss *= self.class_weights[i]
                else: # 类外
                    if self.subtyping:
                        instance_loss, inst_preds, targets, inst_logits = self.inst_eval_out(A, h, classifier, is_tma)
                        all_inst_preds.extend(inst_preds.cpu().numpy())
                        all_targets.extend(targets.cpu().numpy())              
                        all_inst_logits.append(inst_logits)
                    else:
                        continue
                
                total_inst_loss += instance_loss 

            if self.subtyping:
                total_inst_loss /= 2 * len(self.instance_classifiers)
        else:
            if self.k_sample <= math.ceil(A.shape[1] / 2):
                top_p_ids = torch.topk(A, k_sample)[1][-1] # [1][-1]选择最后一个索引
            else:
                top_p_ids = torch.topk(A, math.ceil(A.shape[1] / 2))[1][-1]
                top_p_ids = top_p_ids.repeat(k_sample)[:k_sample]
            top_p = torch.index_select(h, dim=0, index=top_p_ids)
            for classifier in self.instance_classifiers:
                class_logits = classifier(top_p)
                all_inst_logits.append(class_logits)

        if self.use_inst_predictions: 
            all_inst_logits = torch.concatenate(all_inst_logits, axis=1)  # dim k_sample x n_classes
            if self.k_sample <= math.ceil(A.shape[1] / 2):
                top_p_ids = torch.topk(A, k_sample)[1][-1] # [1][-1]选择最后一个索引
            else:
                top_p_ids = torch.topk(A, math.ceil(A.shape[1] / 2))[1][-1]
                top_p_ids = top_p_ids.repeat(k_sample)[:k_sample]

            all_inst_logits =A_raw[0, top_p_ids].reshape(-1, 1) * all_inst_logits
            softmax_inst_probs = torch.softmax(all_inst_logits, dim=1)
            agg_inst_probs = softmax_inst_probs 
            agg_inst_probs = torch.mean(agg_inst_probs, dim=0) # 结果维度 k_sample
            Y_probs = bag_Y_prob * bag_pred_weight + agg_inst_probs * (1 - bag_pred_weight)
        Y_hat = torch.topk(Y_probs, 1, dim=0)[1]
        
        results_dict = {}
        if bag_pred_weight < 1:
            results_dict.update({
                'all_inst_logits': all_inst_logits.detach().cpu().numpy(),
                'agg_inst_probs': agg_inst_probs.detach().cpu().numpy()
            })
        if self.use_inst_predictions: 
            results_dict.update({
                'softmax_inst_probs': softmax_inst_probs.detach().cpu().numpy()
            })
        if label is not None:
            results_dict.update({
                'inst_labels': np.array(all_targets),
                'inst_preds': np.array(all_inst_preds).flatten(),
                'instance_loss': total_inst_loss
            })

        return logits, Y_probs, Y_hat, A_raw, results_dict

数据描述

癌症影像档案使用了以下数据:

标签与竞赛不完全对应(如"Papillary Serous Carcinoma"需判断为HGSC或LGSC),使用其他数据训练的模型辅助选择。

还使用了:

验证方案

曾将所有数据混合进行5折交叉验证(确保同一患者图像在同一折),但验证分数虚高。后完全排除Harmanreh实验室数据验证,获得更可靠的交叉验证结果。

技术配置

在本地台式机(GTX 4090显卡)训练模型。特征提取约需6小时,模型训练约1小时。

感谢各位对解决方案的关注,可通过Twitter联系作者

同比赛其他方案