返回列表

1st Place Solution Writeup for Open Problems – Single-Cell Perturbations

584. Open Problems – Single-Cell Perturbations | open-problems-single-cell-perturbations

开始: 2023-09-12 结束: 2023-11-30 基因组学与生物信息 数据算法赛
作者:JK-Piece
发布时间:2023-12-04
竞赛排名:第1名

Open Problems – Single-Cell Perturbations 第一名解决方案详解

我要感谢组织者和Kaggle举办这场激动人心的比赛。同时感谢分享入门笔记本、数据集和富有洞察力想法的参赛者们。以下是我的解决方案更详细的说明,包括后期发现。

比赛页面

https://www.kaggle.com/competitions/open-problems-single-cell-perturbations/overview

https://openproblems.bio/

1. 生物学知识的整合

由于输入特征仅由短关键词对组成(即细胞类型和小分子名称),且目标变量规模庞大,我很快意识到需要丰富输入特征空间。因此我在比赛初期专注于此任务。首先我在文献中搜索生物学词/术语嵌入,找到了Zhang等人发表的论文'BioWordVec,通过子词信息和MeSH改进生物医学词嵌入'[1]。论文引导我找到GitHub上的代码库,其中包含预训练的生物学术语嵌入。这样做的原因是:1)我能在这些嵌入中找到大多数细胞类型和小分子名称;2)这些嵌入能编码每个术语的丰富语义信息。利用这些嵌入,我创建了更大的输入特征并训练回归模型,在公共排行榜上获得0.767的分数。通过更好的超参数搜索和特征工程,我将分数提升到0.614。

由于这个方向看起来很有前景,我决定进一步丰富输入特征。这次我在维基百科上搜索每种细胞类型和小分子名称的定义,使用Python库wikipedia。然后用描述句子表示每种细胞类型和小分子,并从描述中引导生成嵌入。例如,NK细胞被描述为:'自然杀伤细胞(NK细胞)是细胞毒性淋巴细胞的一种,对先天免疫系统至关重要,属于已知先天淋巴样细胞的快速扩张家族,占人类所有循环淋巴细胞的5-20%。NK细胞的作用类似于脊椎动物适应性免疫反应中的细胞毒性T细胞。NK细胞在感染后约3天对病毒感染细胞和其他胞内病原体产生快速反应,并响应肿瘤形成。'

虽然这在生物学角度很有趣,但并未提升排行榜分数,实际上分数变得更差(0.656 vs 之前的0.614)。这可以解释为自然语言描述带来了噪声,而预训练嵌入可能并非为处理这种情况而设计。在生物学术语的自然语言描述上微调嵌入也未能奏效。

由于输入特征丰富的初始想法未达预期,我决定寻找替代方案。通过论坛讨论,我了解到一个使用SMILES编码小分子化学结构的笔记本。我立即决定使用SMILES编码的ChemBERTa嵌入,并在验证数据分割上观察到评估指标MRRMSE的显著提升(我在整个比赛中使用5折交叉验证)。在此基础上,我开发了额外的数据增强技术,包括训练数据中每种细胞类型和小分子的差异表达均值、标准差以及(25%、50%、75%)百分位数。

2. 问题探索

如前所述,我比赛初期尝试为输入对(细胞类型,小分子名称)构建丰富特征。最终,使用小分子SMILES的ChemBERTa特征似乎是实现这一目标的重要步骤。结合每种细胞类型和小分子的均值、标准差和百分位数,我实现了最优的输入特征表示。

在我的实验中,我使用固定种子(42)的5折交叉验证。第2折和第4折的验证集上很难获得良好分数,这些折的验证集MRRMSE分别约为1.19和1.15。第1、3、5折的平均分数分别为0.86、0.86和0.90。分数是在不同模型架构(LSTM、1D-CNN、GRU)和不同输入特征组合('initial'、'light'、'heavy')上平均的。三种输入特征表示如下:

  • "initial":ChemBERTa嵌入,细胞类型/小分子名称对独热编码,每种细胞类型和小分子的目标均值、标准差、百分位数
  • "light":ChemBERTa嵌入,细胞类型/小分子名称对独热编码,每种细胞类型和小分子的目标均值
  • "heavy":ChemBERTa嵌入,细胞类型/小分子名称对独热编码,每种细胞类型和小分子的均值、25%、50%、75%百分位数

下图显示了每折在所有三种模型架构上的训练曲线(MRRMSE)。

MRRMSE训练曲线

上述图中验证MRRMSE的差异促使我仔细查看验证集,发现细胞类型分布不同。下图显示了每折的主要细胞类型和相应的平均验证MRRMSE。

困难细胞类型

在第1、3、5折的验证集中,主要细胞类型(按百分比)分别是'T调节细胞'、'B细胞'和'NK细胞'。在第2折和第4折上,'CD8+ T细胞'和'髓系细胞'分别是验证集中最具代表性的细胞类型。百分比计算为验证集中某细胞类型的出现次数除以训练集中该细胞类型的出现次数。

从上述条形图可以看出,'T调节细胞'、'B细胞'和'NK细胞'是较易预测的细胞类型,而'CD8+ T细胞'和'髓系细胞'是最难预测的。基于此观察,理想的训练集应包含更多'CD8+ T细胞'和'髓系细胞',而不是其他细胞类型。这样训练的ML模型才能泛化到其他细胞类型。

3. 模型设计

模型架构

我尝试了不同的模型架构,包括梯度提升模型、MLP和2D CNN,但效果不佳。最终我选择LSTM、GRU和1D CNN架构,因为它们在验证集上表现更好。以下是GRU模型的粗略实现:

dims_dict = {
    'conv': {'heavy': 13400, 'light': 4576, 'initial': 8992},
    'rnn': {
        'linear': {'heavy': 99968, 'light': 24192, 'initial': 29568},
        'input_shape': {'heavy': [779,142], 'light': [187,202], 'initial': [229,324]}
    }
}

class GRU(nn.Module):
    def __init__(self, scheme):
        super(GRU, self).__init__()
        self.name = 'GRU'
        self.scheme = scheme
        self.gru = nn.GRU(dims_dict['rnn']['input_shape'][self.scheme][1], 128, num_layers=2, batch_first=True)
        self.linear = nn.Sequential(
            nn.Linear(dims_dict['rnn']['linear'][self.scheme], 1024),
            nn.Dropout(0.3),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.Dropout(0.3),
            nn.ReLU())
        self.head = nn.Linear(512, 18211)
        
        self.loss1 = nn.MSELoss()
        self.loss2 = LogCoshLoss()
        self.loss3 = nn.L1Loss()
        self.loss4 = nn.BCELoss()
        
    def forward(self, x, y=None):
        shape1, shape2 = dims_dict['rnn']['input_shape'][self.scheme]
        x = x.reshape(x.shape[0],shape1,shape2)
        if y is None:
            out, hn = self.gru(x)
            out = out.reshape(out.shape[0],-1)
            out = torch.cat([out, hn.reshape(hn.shape[1], -1)], dim=1)
            out = self.head(self.linear(out))
            return out
        else:
            out, hn = self.gru(x)
            out = out.reshape(out.shape[0],-1)
            out = torch.cat([out, hn.reshape(hn.shape[1], -1)], dim=1)
            out = self.head(self.linear(out))
            loss1 = 0.4*self.loss1(out, y) + 0.3*self.loss2(out, y) + 0.3*self.loss3(out, y)
            yhat = torch.sigmoid(out)
            yy = torch.sigmoid(y)
            loss2 = self.loss4(yhat, yy)
            return 0.8*loss1 + 0.2*loss2

在后期实验中,我发现1D-CNN和GRU实际上是最佳架构,它们单独就能获得最好分数(GRU在私有LB上为0.733,1D-CNN为0.745)。LSTM单独在私有LB上达到0.839。组合0.25xLSTM + 0.65xCNN在私有LB上为0.725,而0.25xLSTM + 0.65xGRU为0.723。

损失函数和优化器

我通过加权平均同时优化4个损失函数:MSE、MAE、LogCosh和BCE,权重分别为0.32、0.24、0.24和0.2。这被发现能提升模型的预测性能。使用Adam优化器,LSTM和CNN的学习率为0.001,GRU为0.0003。LogCosh定义如下:

class LogCoshLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, y_prime_t, y_t):
        ey_t = (y_t - y_prime_t)/3  # 除以3避免cosh数值溢出
        return torch.mean(torch.log(torch.cosh(ey_t + 1e-12)))

LogCosh与MAE类似,但它是更柔和的版本,能实现更平滑的收敛。改编自https://github.com/tuantle/regression-losses-pytorch

BCE损失确实很特殊,因为它常用于分类任务。但我认为当目标值接近零时,它能向模型和优化器发送更好的信号。为证明这一点,考虑以下两段代码:

m1 = nn.Sigmoid()
loss = nn.BCELoss()
input = torch.tensor([0.05], requires_grad=True).unsqueeze(0)
target = torch.sigmoid(torch.tensor([-0.05], requires_grad=False).unsqueeze(0))
output1 = loss(m1(input), target)
print(output1.item())  # 0.694

m2 = nn.Identity()
loss = nn.MSELoss()
input = torch.tensor([0.05], requires_grad=True).unsqueeze(0)
target = torch.tensor([-0.05], requires_grad=False).unsqueeze(0)
output2 = loss(m2(input), target)
print(output2.item())  # 0.010

通过这个例子可以看出,MSELoss告诉模型和优化器"没问题,这里没有错误"。显然存在错误,而BCELoss能识别到,它返回较高的损失值(0.694 vs MSELoss的0.010)。我选择BCELoss是因为大多数目标值来自均值为0的高斯分布,如下图所示。

高斯分布

超参数

  • 训练轮数:250轮
  • 学习率:LSTM和CNN为0.001,GRU为0.0003
  • 梯度裁剪值:三种方案'initial'、'light'、'heavy'分别为[5.0, 1.0, 1.0]

4. 鲁棒性

我使用不同子集的训练数据进行了4个实验,并监控私有排行榜分数。我考虑初始训练数据(de_train)的25%、50%、75%和100%子集。低于25%时,即使使用sm_name分层分割,也无法覆盖测试集(id_map)中的所有小分子,因此独热编码算法无法运行。25%数据获得0.946分,50%获得0.815分,75%为0.769分,完整数据私有排行榜为0.719分(比我获胜提交更好,因为我移除了ChemBERTa模型中的填充)。下图显示了我的方法的鲁棒性,即MRRMSE随训练数据量增加而改善的递减曲线。

鲁棒性分析

我的第二种数据增强技术可视为添加噪声。我随机将30%的输入特征条目替换为零,并将结果输入特征与正确目标一起作为新的训练数据点。这被证明能提高模型的预测性能。因此,我的模型对噪声具有鲁棒性,性能不仅不受损反而提升。生物学动机在于:我们可能不需要知道分子的完整化学结构(假设被丢弃的输入特征来自sm_name)就能知道其对细胞的影响。类似地,细胞可能存在生物学紊乱,但我们仍期望该细胞对分子(药物)的反应与正常细胞相同。

以下是数据增强函数:

def augment_data(x_, y_):
    copy_x = x_.copy()
    new_x = []
    new_y = y_.copy()
    dim = x_.shape[2]
    k = int(0.3*dim)
    for i in range(x_.shape[0]):
        idx = random.sample(range(dim), k=k)
        copy_x[i,:,idx] = 0
        new_x.append(copy_x[i])
    return np.stack(new_x, axis=0), new_y

5. 文档和代码风格

文档和软件依赖在GitHub上可用:https://github.com/Jean-KOUAGOU/1st-place-solution-single-cell-pbs

6. 可复现性

代码在GitHub上可用且文档完善:https://github.com/Jean-KOUAGOU/1st-place-solution-single-cell-pbs,并添加了复现脚本。

同比赛其他方案