590. UBC Ovarian Cancer Subtype Classification and Outlier Detection (UBC-OCEAN) | UBC-OCEAN
业务背景:UBC卵巢癌亚型分类与异常检测
数据背景:本竞赛的挑战是从活检样本的显微扫描图像中对卵巢癌进行亚型分类。数据描述链接
寻找更多公开外部数据是关键。由于样本数量少,过拟合问题严重。最初希望使用CLAM或多实例学习(MIL)方法缓解,因为图像巨大可分割成数万切片。但模型仍严重过拟合。可能因为同一患者的切片存在相似性,模型利用这些捷径而非泛化特征。
利用提供的分割数据创建合成组织微阵列(TMA)图像。从大图像的分割区域裁剪微小图像作为癌症组织,同时裁剪健康或基质区域生成"其他"类别合成图像。
使用Lunit-DINO预训练模型提取特征。参考论文"组织病理学弱监督学习:好的特征提取器足矣",使用16位精度加速特征提取,对特征质量影响甚微。
使用PyVips处理图像切片。曾尝试重写CLAM特征提取代码和使用large_image库,但受Kaggle资源限制困扰。最终使用PyVips和PyTorch异步加载实现。
在提取特征上训练CLAM模型。该模型类似MIL但计算注意力矩阵加权切片。针对"其他"标签修改了实例级损失函数,因为癌变切片中的健康组织区域应标记为"其他"。
下图展示了哈佛/BWH & MGH Mahmood实验室的CLAM模型[1]。模型输入是全切片图像中所有组织切片的特征向量。上半部分计算注意力得分A(每个切片一个值),下半部分计算A加权特征和的类别分类。

示意图来自Paul Pham[2]
PyTorch实现的适配CLAM模型:
class Attn_Net_Gated(nn.Module):
def __init__(self, L = 1024, D = 256, dropout = 0, n_classes = 1):
super(Attn_Net_Gated, self).__init__()
self.attention_a = [
nn.Linear(L, D),
nn.Tanh()
]
self.attention_b = [
nn.Linear(L, D),
nn.Sigmoid()
]
if dropout > 0:
self.attention_a.append(nn.Dropout(dropout))
self.attention_b.append(nn.Dropout(dropout))
self.attention_a = nn.Sequential(*self.attention_a)
self.attention_b = nn.Sequential(*self.attention_b)
self.attention_c = nn.Linear(D, n_classes)
def forward(self, x):
a = self.attention_a(x)
b = self.attention_b(x)
A = a.mul(b)
A = self.attention_c(A) # N x n_classes
return A, x
class CLAM_SB(nn.Module):
def __init__(self, gate = True, size_arg = "small", n_classes=2, dropout = 0, k_sample=8,
instance_loss_fn=None, subtyping=False, feature_dim=1024, use_inst_predictions=True,
label_mapping=None, class_weights=None, inst_class_depth=None, inst_dropout=None):
super().__init__()
self.size_dict = {
"very small": [feature_dim, 256, 128],
"small": [feature_dim, 512, 256],
"big": [feature_dim, 1024, 512],
"xl": [feature_dim, 2048, 1024]
}
size = self.size_dict[size_arg]
fc = [nn.Linear(size[0], size[1]), nn.ReLU()]
if dropout > 0:
fc.append(nn.Dropout(dropout))
if gate:
attention_net = Attn_Net_Gated(L = size[1], D = size[2], dropout = dropout, n_classes = 1)
else:
attention_net = Attn_Net(L = size[1], D = size[2], dropout = dropout, n_classes = 1)
fc.append(attention_net)
self.attention_net = nn.Sequential(*fc)
self.classifiers = nn.Linear(size[1], n_classes)
instance_classifiers = []
for class_idx in range(n_classes):
layers = []
for depth_idx in range(inst_class_depth-1):
divisor = 2 ** depth_idx
layers.append(nn.Linear(size[1] // divisor, size[1] // (divisor * 2)))
layers.append(nn.ReLU())
if inst_dropout is not None:
layers.append(nn.Dropout(inst_dropout))
layers.append(nn.Linear(size[1] // 2**(inst_class_depth-1), 1))
instance_classifiers.append(nn.Sequential(*layers))
self.instance_classifiers = nn.ModuleList(instance_classifiers)
self.k_sample = k_sample
self.instance_loss_fn = instance_loss_fn
self.n_classes = n_classes
self.subtyping = subtyping
self.use_inst_predictions = use_inst_predictions
self.other_idx = label_mapping['Other']
self.class_weights = class_weights
initialize_weights(self)
self.to('cuda')
@staticmethod
def create_positive_targets(length, device):
return torch.full((length, ), 1, device=device).float()
@staticmethod
def create_negative_targets(length, device):
return torch.full((length, ), 0, device=device).float()
# 类内注意力分支的实例级评估
def inst_eval(self, A, h, classifier, is_tma, is_other_class):
device=h.device
if len(A.shape) == 1:
A = A.view(1, -1)
if is_tma:
k_sample = self.k_sample // 2
else:
k_sample = self.k_sample
if k_sample <= math.ceil(A.shape[1] / 2):
top_p_ids = torch.topk(A, k_sample)[1][-1] # [1][-1]选择最后一个索引
else:
top_p_ids = torch.topk(A, math.ceil(A.shape[1] / 2))[1][-1]
top_p_ids = top_p_ids.repeat(k_sample)[:k_sample]
top_p = torch.index_select(h, dim=0, index=top_p_ids) # dim = k_sample x self.size_dict[1]
if k_sample <= math.ceil(A.shape[1] / 2):
top_n_ids = torch.topk(-A, k_sample, dim=1)[1][-1]
else:
top_n_ids = torch.topk(-A, math.ceil(A.shape[1] / 2))[1][-1]
top_n_ids = top_n_ids.repeat(k_sample)[:k_sample]
top_n = torch.index_select(h, dim=0, index=top_n_ids)
p_targets = self.create_positive_targets(k_sample, device)
n_targets = self.create_negative_targets(k_sample, device)
# 错误评估使用正负标签,以约束负切片的低注意力
p_logits = classifier(top_p) # dim = k_sample
n_logits = classifier(top_n)
inst_preds = (p_logits.squeeze() > 0).long()
p_loss = self.instance_loss_fn(p_logits.squeeze(), p_targets) * (self.n_classes -1)
n_loss = self.instance_loss_fn(n_logits.squeeze(), n_targets)
if not is_tma and not is_other_class:
loss = p_loss + n_loss
else: loss = p_loss
return loss, inst_preds, p_targets, p_logits
# 类外注意力分支的实例级评估
def inst_eval_out(self, A, h, classifier, is_tma):
device=h.device
if len(A.shape) == 1:
A = A.view(1, -1)
if is_tma:
k_sample = self.k_sample // 2
else:
k_sample = self.k_sample
if k_sample <= math.ceil(A.shape[1] / 2):
top_ids = torch.topk(A, k_sample)[1][-1]
else:
top_ids = torch.topk(A, math.ceil(A.shape[1] / 2))[1][-1]
top_ids = top_ids.repeat(k_sample)[:k_sample]
top_inst = torch.index_select(h, dim=0, index=top_ids)
top_targets = self.create_negative_targets(k_sample, device)
logits = classifier(top_inst)
inst_preds = (logits.squeeze() > 0).long()
instance_loss = self.instance_loss_fn(logits.squeeze(), top_targets)
return instance_loss, inst_preds, top_targets, logits
def forward(self, h, bag_pred_weight:float, is_tma:bool, label=None, attention_only=False):
A, h = self.attention_net(h) # NxK
A = torch.transpose(A, 1, 0) # KxN
if attention_only:
return A
A_raw = A
A = F.softmax(A, dim=1) # 在N上softmax
M = torch.mm(A, h) # shape 1 x self.size_dict[1]
logits = self.classifiers(M)
bag_Y_prob = F.softmax(logits.squeeze(), dim=0)
if is_tma:
k_sample = self.k_sample // 2
else:
k_sample = self.k_sample
all_inst_logits = []
top_p_ids = None
if bag_pred_weight < 1 and label is not None:
total_inst_loss = 0.0
all_inst_preds = []
all_targets = []
for i in range(len(self.instance_classifiers)):
classifier = self.instance_classifiers[i]
if i == label.item(): # 类内
is_other_class = (label.item() == self.other_idx)
instance_loss, inst_preds, targets, inst_logits = self.inst_eval(A, h, classifier, is_tma, is_other_class)
all_inst_preds.extend(inst_preds.cpu().numpy())
all_targets.extend(targets.cpu().numpy())
all_inst_logits.append(inst_logits)
if self.class_weights is not None:
instance_loss *= self.class_weights[i]
else: # 类外
if self.subtyping:
instance_loss, inst_preds, targets, inst_logits = self.inst_eval_out(A, h, classifier, is_tma)
all_inst_preds.extend(inst_preds.cpu().numpy())
all_targets.extend(targets.cpu().numpy())
all_inst_logits.append(inst_logits)
else:
continue
total_inst_loss += instance_loss
if self.subtyping:
total_inst_loss /= 2 * len(self.instance_classifiers)
else:
if self.k_sample <= math.ceil(A.shape[1] / 2):
top_p_ids = torch.topk(A, k_sample)[1][-1] # [1][-1]选择最后一个索引
else:
top_p_ids = torch.topk(A, math.ceil(A.shape[1] / 2))[1][-1]
top_p_ids = top_p_ids.repeat(k_sample)[:k_sample]
top_p = torch.index_select(h, dim=0, index=top_p_ids)
for classifier in self.instance_classifiers:
class_logits = classifier(top_p)
all_inst_logits.append(class_logits)
if self.use_inst_predictions:
all_inst_logits = torch.concatenate(all_inst_logits, axis=1) # dim k_sample x n_classes
if self.k_sample <= math.ceil(A.shape[1] / 2):
top_p_ids = torch.topk(A, k_sample)[1][-1] # [1][-1]选择最后一个索引
else:
top_p_ids = torch.topk(A, math.ceil(A.shape[1] / 2))[1][-1]
top_p_ids = top_p_ids.repeat(k_sample)[:k_sample]
all_inst_logits =A_raw[0, top_p_ids].reshape(-1, 1) * all_inst_logits
softmax_inst_probs = torch.softmax(all_inst_logits, dim=1)
agg_inst_probs = softmax_inst_probs
agg_inst_probs = torch.mean(agg_inst_probs, dim=0) # 结果维度 k_sample
Y_probs = bag_Y_prob * bag_pred_weight + agg_inst_probs * (1 - bag_pred_weight)
Y_hat = torch.topk(Y_probs, 1, dim=0)[1]
results_dict = {}
if bag_pred_weight < 1:
results_dict.update({
'all_inst_logits': all_inst_logits.detach().cpu().numpy(),
'agg_inst_probs': agg_inst_probs.detach().cpu().numpy()
})
if self.use_inst_predictions:
results_dict.update({
'softmax_inst_probs': softmax_inst_probs.detach().cpu().numpy()
})
if label is not None:
results_dict.update({
'inst_labels': np.array(all_targets),
'inst_preds': np.array(all_inst_preds).flatten(),
'instance_loss': total_inst_loss
})
return logits, Y_probs, Y_hat, A_raw, results_dict
从癌症影像档案使用了以下数据:
标签与竞赛不完全对应(如"Papillary Serous Carcinoma"需判断为HGSC或LGSC),使用其他数据训练的模型辅助选择。
还使用了:
曾将所有数据混合进行5折交叉验证(确保同一患者图像在同一折),但验证分数虚高。后完全排除Harmanreh实验室数据验证,获得更可靠的交叉验证结果。
在本地台式机(GTX 4090显卡)训练模型。特征提取约需6小时,模型训练约1小时。
感谢各位对解决方案的关注,可通过Twitter联系作者