587. Stanford Ribonanza RNA Folding | stanford-ribonanza-rna-folding
非常感谢组织者举办这场比赛,并致力于解决RNA结构预测等难题。在整个比赛过程中,这句话一直萦绕在我心头:"...如果不能理解RNA分子如何折叠,我们就无法深入了解自然如何运作,生命如何起源,以及我们如何能够设计..."
许多人会略过这段话,但这却是一个巨大的谜团,有时甚至让我夜不能寐。我的意思是,这太奇怪了...进化/创造了4种具有特定物理化学性质的核苷酸,它们能够编码并形成不同结构以在体内完成特定功能,而我们仍然不了解这些分子的起源/目的...希望我们正朝着发现真相和背后原因的方向前进,甚至是发现"谁"创造了它们。
解决方案并非基于数十个集成模型,而是2个强大的独立模型。之所以有2个模型,是因为我不确定是否能在规定时间内产生可行的AlphaFold风格解决方案,所以我先专注于一个较小的"安全"版本模型,然后在比赛后期再开发一个更大的"风险更高"的模型。但这2个模型的集成对获得第三名起到了关键作用。
两个独立模型的融合。较小"更安全"版本模型基于增强的Squeezeformer架构,包含相对多头自注意力、卷积和前馈模块。通过2D卷积将可学习的碱基对概率(BPP)添加到注意力分数中,除了干净训练数据外,还额外使用增强的低信噪比数据进行训练。较大的双塔模型基于增强的AlphaFold风格架构,包含MSA堆栈表示和配对堆栈表示,通过外积均值和配对表示偏置以交叉方式进行通信。
训练数据按照此笔记本的方式分为4折。本地CV与LB之间存在稳定的相关性。BPP被处理并缓存为.npz文件。大多数时候,模型仅在fold0上进行评估。
双塔模型仅在干净训练数据上进行训练,它能够预测其置信度/误差估计,类似于AlphaFold中的pLDDT。在双塔模型从干净数据上获得良好的泛化能力后,通过将模型置信度/误差估计与噪声低信噪比数据进行增强,创建了新的训练数据。对于低信噪比数据集中的特定核苷酸位置,其思想是结合模型预测的置信度与实验中该位置的反应性误差,从而"修复"噪声数据。这在较小"更安全"的模型上带来了显著改进,该模型在干净数据集+改进的低信噪比数据集上进行了训练。即使这个有用的新的数据对较大的双塔模型可用,但由于时间限制,它从未在此额外数据上进行训练,最终提交仅使用了干净数据集。双塔实验目前正在完整数据集上进行。
较小的"更安全"模型基于Squeezeformer架构,该架构曾在之前的Google美国手语竞赛中使用。模型输入为标记化的RNA序列和BPP矩阵。模型由14个Squeezeformer块和一个输出投影层组成,该投影层预测每个位置的化学反应性。一个Squeezeformer块包含三个模块:相对多头自注意力模块、卷积模块和前馈模块。
为了在更长序列上实现泛化,使用相对位置编码而非绝对位置编码。除了相对位置分数外,注意力分数还受到碱基对概率矩阵的影响,尽管起初我不愿使用它们,因为它们是由无法检测长距离伪结的软件创建的,这种偏差不幸地被引入模型中。Transformer中的注意力分数按此方式计算:(内容分数 + 相对位置分数)/sqrt(头维度) + bpp偏置分数。BPP偏置分数通过将BPP矩阵通过2D卷积块获得。
该模型灵感来自Google的AlphaFold及其衍生模型OpenComplex/RhoFold。原始AlphaFold架构依赖两种不同的输入表示进行预测。它联合使用MSA(多序列比对)和配对表示特征。MSA表示使用行方向注意力查找序列内特征,而列方向注意力用于从MSA堆栈中获取序列间进化信号。由于本次竞赛的MSA在提取进化信息方面没有帮助,因为RNA序列是合成的,因此仅使用了标记化的输入序列。由于未使用MSA,轴向自注意力被替换为相对多头注意力和卷积。
(提交时,仅使用了单表示分支进行预测,由于时间限制,配对特征被完全忽略)
1) 输入序列被标记化,生成序列和配对掩码
2) 标记化序列通过嵌入网络进行嵌入:MSANet和PairNet
3) 嵌入的MSA和配对特征通过网络的主干,该主干由上面图示的8个Chemformer块组成。一些关键要点:
4) 处理后的表示MSA特征和配对特征随后通过各种头部传递,以提取有意义的信息,如置信度/误差估计、化学反应性、碱基对概率等。
首先,双塔模型在干净数据集上以1e-3的学习率训练60个epoch。每个GPU的批大小为8,梯度累积为4,在2个GPU上实现有效批大小64。使用的优化器为AdamW,采用余弦调度,权重衰减参数为0.05,预热比例为0.5。模型在2张4090上总共训练30小时。该模型未进行广泛的参数调优。
这带来了0.13746公共/0.14398私有单模型分数。此后,通过将模型预测与0.35<数据<1信噪比范围内的噪声数据子集混合,创建了合成数据集。选择较低的信噪比阈值增加了训练样本但降低了质量。当前合成数据集采用简单加权创建,由于时间限制,尽管模型已训练并输出有效的plddt/误差估计,但它们未被充分利用。我没有时间尝试在逐个核苷酸基础上结合plddt分数与实验反应性误差的公式,这必定能产生更好的合成数据。
较小的"更安全"模型以7e-4的学习率、64的批大小训练了200个epoch。优化器和学习率调度与之前相同,但未使用梯度累积。该模型在干净数据集和合成创建的数据集上均进行了训练,使训练样本翻倍。该单模型达到了0.13865公共/0.14256私有分数。
竞赛结束前一天,我决定在双塔模型上运行更长的epoch,因此我从上一次迭代的第30个epoch开始权重,继续相同的训练过程。仅通过更长时间运行模型就带来了改进,使该单模型的最终分数达到0.13706公共/0.14366私有。最终分数是双塔模型与挤压模型的融合,前者仅在输入序列上训练并尝试自行学习所有交互,后者通过BPP和合成数据对最终预测做出贡献。
所有这些调整很可能为模型带来好处。部分测试目前正在进行中。
双塔模型是一项庞大任务,部分原因是我独自参赛。提交时模型仅在干净数据集上训练,预测仅使用了MSA特征通路,另一个配对特征通路被完全忽略,但它可以用来尝试重建BPP,这应该能提高分数(我正在写此文时测试正在进行中)。不使用BPP、环类型或任何预处理/后处理的模型最佳提交成绩为0.13709公共和0.14366私有。然而,与较小"更安全"的Squeezeformer模型相比,双塔模型在泛化能力上存在较大差距,后者得分为0.13865公共但0.14265私有,因此目前正在研究其背后的原因。
我是机器学习新手。我今年五月开始学习该领域,因此我相信代码和方法中会有很多错误,如果我的解释杂乱无章,请见谅,这一切对我来说都是新的,我仍在学习中。
开源代码:
https://github.com/GosUxD/OpenChemFold