返回列表

[1st place solution] Transformer model with Dynamic positional encoding + CNN for BPPM features

587. Stanford Ribonanza RNA Folding | stanford-ribonanza-rna-folding

开始: 2023-09-07 结束: 2023-12-07 基因组学与生物信息 数据算法赛
使用动态位置编码和CNN处理BPPM特征的Transformer模型 - Ribonanza RNA折叠竞赛第一名解决方案

使用动态位置编码和CNN处理BPPM特征的Transformer模型 - Ribonanza RNA折叠竞赛第一名解决方案

作者:vyaltsevvaleriy(团队成员:Penzar Dmitry、Elizaveta Aristova、Artemy Bakulin、polymerase)
竞赛排名:第1名
发布时间:2023年12月8日
投票数:147

Ribonanza RNA折叠竞赛是一次非常棒的机会,我们非常高兴能够参与其中。向主办方在保持高水平组织和为参赛者提供大量宝贵见解与支持方面所做的巨大努力表示感谢。我们还要感谢社区富有成效的公开讨论和编码活动。祝贺所有参与者!🎉

代码

开源代码可在GitHub上获取。

概要

我们的解决方案基于Transformer架构。我们使用EternaFold预测每个RNA序列的碱基对概率矩阵(BPPM),并在架构中添加卷积块来处理BPPM特征。对于某些模型,我们在卷积块中还添加了挤压-Excitation(Squeeze-and-Excitation)层。此外,我们将这些特征与自注意力块中的注意力值在softmax操作之前相加。为了更好地泛化到更长的输入序列,我们实现了动态位置偏置(Dynamic Positional Bias)。最后,我们融合了架构和训练过程略有差异的多个模型。

数据预处理

首先,我们使用EternaFold为训练和测试数据集中的每个序列计算碱基对概率矩阵(BPPM)。在训练和推理阶段,我们将RNA序列通过可学习嵌入层编码后输入模型。每个核苷酸被视为一个token,并在两端添加特殊的<start><end> token。为了向模型提供序列是否来自训练数据集中"干净"子集的信息(由SN_filter值表示 - 1对应高信噪比的"干净"序列,0表示其他情况),我们使用可学习嵌入层对SN_filter值进行编码,并将相应的嵌入向量与序列嵌入向量相加。嵌入维度选择为192。BPPM在边缘用零填充,以考虑添加<start><end> token。下图总结了数据预处理部分。

数据预处理示意图

模型

该模型的想法部分受到@shujun717在Open Vaccine挑战赛中解决方案的启发,链接

模型接收token序列和BPPM作为输入,并为每个输入核苷酸输出DMS_MaP和2A3_MaP反应性。其架构包含12个连续的Transformer编码器层和输出投影线性层。每个Transformer块接收来自前一层的token序列和BPPM特征,并输出更新的特征图,如下图所示。Transformer编码器块采用常见的Transformer编码器架构,但我们修改了自注意力块以确保BPPM特征与序列特征之间的交互。

模型总体架构图

自注意力块

在自注意力(SA)块中,我们实现了具有以下修改的注意力机制:在计算每个头的注意力值后,我们添加由"卷积块"更新的BPPM特征,该卷积块输出通道数对应于SA块中头数的BPPM特征。我们将SA中的头数和相应的BPPM特征通道数都设为6。因此,Q、K、V矩阵的隐藏维度大小为32。SA块的整体结构如下图所示。

自注意力块结构图

动态位置偏置

训练和测试数据集中的序列长度分布不同(测试序列通常更长)。为了更好地泛化到更长的输入,我们实现了一种添加到注意力值上的位置编码。我们发现动态位置偏置相比我们尝试过的其他相对位置编码方法(如xPos(旋转位置编码)和ALiBi)更有用。动态位置偏置为每个头计算一个相对位置偏置图,该偏置是可学习的,并且取决于序列长度。相对位置偏置无法利用序列起始和结束的距离信息,因此我们添加<start><end> token来解决这个问题。

动态位置偏置示意图

卷积块和SE块

最终集成中的模型有两个版本,它们在卷积块的结构上略有不同。基本卷积块由2D卷积层、批归一化层、激活函数和可学习的gamma参数组成,这些参数对输出特征通道进行缩放;而该块的修改版本(SE-卷积块)还包含挤压-Excitation层,因此得名。SE层沿通道方向应用输入相关的值重缩放,如下图所示。因此,集成中模型之间的唯一区别在于SE层的存在与否。

卷积块和SE块结构图

训练过程

训练过程已根据@IAFOSS的notebook进行了调整
我们使用了单周期学习率调度(pct_start=0.05, lr_max=2.5e-3)配合AdamW优化器(wd=0.05),批大小设为128。
模型训练轮数根据数据集大小确定。
对于在几乎整个数据集上训练的最终模型,我们使用了270轮。
每轮处理1791个批次(由于历史原因我们保留了这个数字),每个批次的元素根据权重 = 0.5 * torch.clamp_min(torch.log(sn + 1.01),0.01) 从数据集中采样

我们还发现,使用简单的SGD优化器训练模型约15轮(每轮500个批次)可以进一步提升模型性能(轮数会变化,因此我们使用一个小的验证集来确定确切的轮数)

此外,我们训练了一个模型,在给定DMS、RNA序列和BPPM的情况下预测2A3(dms-to-2a3模型)。

推理

在推理阶段,对于所有输入,我们将SN filter值设为1,就像它们来自"干净"数据集一样。

模型集成

我们集成了15个使用SE-卷积块的模型、10个使用普通卷积块的模型,以及2个在按序列长度分割数据上训练的模型(使用普通卷积块,其中一个模型接受括号特征)。我们对它们的预测结果取平均值。然后我们使用dms-to-2a3模型基于平均后的DMS反应性预测2a3反应性,并以以下方式将这些预测添加到平均后的2a3反应性中:(27/28)*averaged_2a3 + (1/28)*predicted_2a3。

其他数据分割方式

简单KFold分割的一个明显问题是训练数据集中存在高度序列相似性,且测试序列与训练数据差异很大。这可能会阻碍模型开发,因为验证集上的质量提升可能是由于过拟合而非实际改进所致。

我们计算了训练集中所有序列的汉明距离矩阵,并执行了修改后的DBSCAN聚类过程,距离阈值设为0.2。在下图中,我们展示了聚类映射到各自聚类标识符的结果(聚类ID被分配为训练数据集中该聚类内最小序列的编号)。

序列聚类结果图

我们测试了多种基于序列同一性的分割方式(最简单的方法是不打乱数据直接分割),发现虽然随着我们选择越来越严格的距离阈值,模型的最终验证性能会下降,但其相对值与简单KFold分割的表现方式相同。因此,我们决定在训练最终模型时使用简单的KFold分割。

此外,我们还对基于长度的分割进行了模型性能测试。为此,我们在短序列(长度<206)上训练模型,并在长度为206的序列上进行验证。验证指标的行为略为嘈杂,但仍与简单KFold分割的指标高度相关。

公共数据泄露

大约13%的公共测试序列与训练数据集中的序列完全相同(按序列)。为避免选择过度记忆这些序列而非学习RNA相关内容的模型,在大多数提交中我们将这些序列的预测结果设为零,有时也会提交未设为零的结果以与其他参与者进行比较。

我们还尝试过的特征

capR

对于该特征我们没有明确的结论。似乎它对模型平均而言并无益处,有时会产生更好的模型,有时则更差。我们决定不使用该特征。

括号

我们尝试使用EternaFold、ContraFold、ViennaRNA等生成的括号,以及伪结预测程序(如IPknot)。然而,在将EternaFold BPPM添加到模型后,添加其他特征并未带来模型性能的显著提升。对于某些模型,我们仅使用括号来进行模型增强。

不同的BPPM

所有其他BPPM(ContraFold、ViennaRNA、RNAsoft、RNAstructure)都会导致次优模型。
对BPPM进行平均并不能带来更好的性能。
RFold的BPPM与ViennaRNA BPPM质量相同。
RNA-FM模型既产生每个核苷酸的嵌入,也产生类似BPPM的矩阵。然而,这些对模型完全没有帮助,与使用仅序列模型相比,仅带来轻微的性能提升。

SQUARNA矩阵

SQUARNA(github)输出的矩阵不同于BPPM,但可以以相同方式使用。不幸的是,该特征也没有带来额外的性能提升。

我们还尝试过的方法

全卷积架构

我们参加竞赛的最初原因是测试我们的LegNet模型链接,该模型在DNA序列上展现了SOTA结果,并在一些RNA相关任务上表现良好(尚未发表)。
然而,该架构的任何修改在与适当调优的Transformer模型相比时,都表现出次优性能。这可以解释为预测RNA二级结构需要关注长程接触,而Transformer架构更适合此类情况。

数据子集

对数据进行子集划分(通过SN比的不同阈值进行过滤)为所有模型带来了性能提升。然而,这项技术被权重采样所取代,后者被证明更加有效。

在公共数据集上微调

我们尝试通过在组织者收集的公共数据集(链接)上微调模型来提升性能,方法是同时训练模型预测公共实验的结果(排除样本数较少的实验)和Ribonanza数据。不幸的是,这也未能提升模型性能。

使用3D数据

我们尝试使用训练数据集中10万条序列的预测3D结构数据,但在对它们进行可视化分析后放弃了:

3D结构数据示例

绝对位置嵌入

使用绝对位置嵌入在泛化到更长序列时会导致无法解决的问题。

相对位置嵌入

最简单的方法是在训练阶段增强绝对位置编码,使其从0到(Lmax - seqlen)位置随机偏移。这确实解决了外推问题,但效果不如其他方法。
旋转位置编码不幸地无法帮助模型泛化到更大的长度。
ALiBi位置编码解决了外推问题,但即使仅对部分头保留(如https://github.com/lucidrains/x-transformers中所建议),其表现仍不如动态位置偏置。

数据增强

首先,我们尝试使用反向增强。这可以通过三种方式完成:
在任何修改之前反转序列,
在填充前但添加 token后反转序列,
在填充后反转序列。

前两种方式对我们测试的所有模型变体都没有增益。然而,第三种方式(结合额外的微调)为我们带来了使用xpos位置编码模型的良好结果,该模型的单模型公共排行榜性能为0.13937。不幸的是,xpos在长序列上表现相当差,因此我们放弃在最终提交中使用该模型。

我们还尝试了偏移增强和不同的序列填充方法。这也没有改善我们的模型性能。

滑动窗口

泛化到更长序列的另一种可能方式是使用滑动窗口预测反应性。然而,这个想法在生物学意义上有些错误,并且在测试训练数据集序列时会导致性能下降。

伪标签

一旦我们获得了表现最佳的模型集成,就尝试使用它们对测试数据集进行伪标签标注,并使用置信度最高的预测来训练新模型。虽然这确实带来了更好的单模型性能,但将这样的模型添加到集成中并不能改善集成性能。

修改损失函数

我们没有过滤低SN序列,而是尝试像@nullrecurrent在Open Vaccine挑战中所做的那样,对高反应性误差的位置进行掩码(链接
这导致了较差性能。
我们尝试根据每个序列的SN来加权损失——这没有带来任何改进。

工具

arnie
EternaFold

参考文献

挤压-Excitation块

Hu, J., Shen, L., & Sun, G. (2018). Squeeze-and-excitation networks. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 7132-7141).

动态位置偏置

Wang, W., Chen, W., Qiu, Q., Chen, L., Wu, B., Lin, B., ... & Liu, W. (2023). Crossformer++: A versatile vision transformer hinging on cross-scale attention. arXiv preprint arXiv:2303.06908.

同比赛其他方案