581. Google - Fast or Slow? Predict AI Model Runtime | predict-ai-model-runtime
首先,感谢Kaggle和Google主办这场精彩的竞赛。同时感谢我的队友和朋友 @tomirol 和 @thanhhau097a,这段旅程因你们而更加特别。
InstanceNorm → SAGEConv → SelfChannelAttention → CrossConfigAttention → +残差 → GELU
PairwiseHingeLoss损失函数我们加入竞赛时距离结束不到一个月,因此首要解决的问题是训练任务效率低下。
对于布局,我们发现只有Convolution、Dot和Reshape是可配置节点。此外,大多数节点在不同配置间是相同的。因此我们采用简单的剪枝策略:对每个图,只保留可配置节点本身或与可配置节点相连的节点(输入/输出)。这样将单个大图转换为多个(可能不连通的)子图,由于网络最后有全局池化层来融合子图信息,这不是问题。此方法将vRAM使用量减少4倍,训练速度提升5倍。
布局的配置集包含大量重复项,但重复配置的运行时可能差异很大,影响训练稳定性。因此我们移除了所有重复配置。
即使经过剪枝和去重,NLP集合加载所有配置到内存的RAM使用量仍然很高。我们预先压缩node_config_feat,在数据加载器中按需解压。这样可以在训练开始时加载所有数据到内存,显著减少IO/CPU瓶颈。
压缩原理:每个node_config_feat的6维向量(输入、输出和内核)只有7个可能值(-1,0,1,2,3,4,5),可用7进制整数(0到7^6)表示。
我们注意到node_feat使用0填充。虽然多数特征没问题,但对layout_minor_to_major_*等特征会产生歧义(0是有效轴索引)。而node_config_feat使用-1填充。因此我们将node_feat重新生成为-1填充,从而对node_feat[134:]和node_config_feat使用单个嵌入矩阵。
对于布局,我们将node_feat分为node_feat[:134]和node_feat[134:](layout_minor_to_major_*)。前者使用StandardScaler标准化,后者与node_config_feat一起输入学习嵌入矩阵(4通道)。我们发现标准化至关重要,因为node_feat包含*_sum和*_product等可能值很大的特征,会破坏优化过程。
对于node_opcode,我们使用独立的16通道嵌入层。网络输入是上述所有特征的拼接,每个图实时采样64(默认)或128(随机)个配置组成批次。对于tile,我们采用后期融合方式集成config_feat。
我们的网络结构非常简单:首先输入特征通过线性层映射到256维嵌入,然后是2个卷积块、全局图平均池化和最终线性层。
对于图卷积层,我们尝试了多种类型,但SAGEConv效果最好。过去我在其他应用中成功使用GAT变体,但这次效果不佳。我推测原因是:其他应用中图结构噪声较大,注意力有助于"忽略"不重要的连接;但TPU图的连接都是"真实"且重要的,因此图注意力帮助不大。不过我们发现两种有用的注意力:自通道注意力和跨配置注意力。
我们借鉴Squeeze-and-Excitation思想创建通道注意力层:先通过线性层压缩通道维度(8倍缩减)并应用ReLU,再用第二个线性层恢复原始通道数并应用sigmoid,最后与原始输入逐元素相乘。
其目的是捕捉通道间相关性,抑制不重要的通道,增强有用的通道。
我们还可以利用批处理维度(跨配置)的注意力。这个简单模块让模型在网络中显式"比较"每个配置与其他配置,比仅通过损失函数隐式比较更有效。代码如下:
class CrossConfigAttention(nn.Module):
def __init__(self):
super().__init__()
self.temperature = nn.Parameter(torch.tensor(0.5))
def forward(self, x):
# x形状:(配置数, 节点数, 特征数)
scores = (x / self.temperature).softmax(dim=0)
x = x * scores
return x
在每个块的通道注意力后应用此层,显著提升了默认集合的效果。推理时使用128的批次大小,由于预测依赖于批次,我们可通过TTA生成10种配置排列并排序后平均结果来进一步提升。
我们遵循计算机视觉的最佳实践:先用InstanceNorm归一化输入特征图,然后执行Linear/SAGEConv层、SelfChannelAttention和CrossConfigAttention(拼接输出与输入以保留每个样本的个体性),加上残差连接后使用GELU和dropout。
我们的最佳单模型在私有(公开)排行榜上得分为0.714(0.748)。但由于某些集合的测试样本较少,我们采用集成学习来提升结果并防止波动。通过对每个集合5-10个模型预测取平均,最佳结果达到0.736(0.757),但遗憾的是我们未提交此结果。
我制作了网络简图,虽然有点简陋但能说明整体思路 😅
