645. NeurIPS 2024 - Lux AI Season 3 | lux-ai-season-3
大家好,首先我要感谢组织者以及 Discord 上所有提供帮助的人,让这次比赛变得非常愉快!此外,我还要感谢我的大学和研究所,允许我使用他们的一些计算资源!这是我第一次认真参加比赛,所以我的方法和代码非常杂乱,策略在很多方面也肯定有所欠缺。尽管如此,这里还是分享我的最终解决方案。代码可以在 GitHub 上找到。
总的来说,我开发了一种基于 xLSTM [1] 的方法,这似乎非常适合这种部分可观测的决策问题。我将此与强化学习(RL)和带有自博弈(self-play)的循环 PPO 方法相结合。我利用这次机会学习了 JAX,并确保整个训练循环都是可 JIT 编译的(jittable)。由于我们处理的是一个部分可观测的环境,我使用了独立的 Actor(演员)和(全知视角的)Critic(评论家)。该模型基于 JointPPO [2],它使用单个网络来控制所有单位。
在这次比赛中我很快学到的一点是:可视化。因为这是我第一次使用 JAX,而且我从零开始编写了很多框架代码(大部分灵感来自 PureJaxRL [3]),我的实现有很多(现在可能仍然有)bug。如果不调整可视化器以提供有关模型输入和输出的更多信息,我永远无法发现其中许多错误。以下是最终可视化器更改的示例:
我在地图下方显示全局特征,在其下方显示地图输入通道,在单位列表中显示友方单位特征。我还可以用蓝色色调可视化“有效汲取位置”掩码。
在这里,我想解释一下我是如何编码输入和模型输出的。
输入编码: 我将观测值分为三种类型:
我还向地图特征添加了 x 和 y 坐标作为位置编码,并翻转所有特征(位置、地图等),以便从模型的角度来看,玩家始终位于左上角。
遗物碎片的位置是通过保留单位位置和点数增益的历史记录来计算的。每当我们获得新信息以排除/重新包含位置时,遗物概率就会迭代更新。算法并不完美,但我认为循环网络能够弄清楚其余部分。
输出: 模型为每个单位输出一个动作类型和一个汲取位置。无效动作(如移出地图或汲取 unreachable 区域)会被掩码处理。单位只能汲取可见对手周围 3x3 区域或不可见遗物碎片位置。
我开发了一种基于 xLSTM 进行时间序列建模和基于 Transformer 编码当前状态的架构。模型和训练的概念灵感来自 JointPPO,它将多智能体决策问题视为序列建模问题,通过迭代预测每个单位的动作。虽然 Actor 和 Critic 共享许多架构,但它们是完全独立的模型,Critic 获得更详细的信息(如对手位置、遗物碎片位置、星云速度等)。模型大致如下*:
*这里还有很多 LayerNorm、跳跃连接和其他小细节我没有画出来。
编码器(Encoder): 编码器由 a) 用于编码地图的 ConvNeXt [4],b) 友方单位编码器,c) 敌方单位编码器,以及 d) 全局状态编码器组成。
Transformer 元编码器: 编码所有特征后,我得到一个地图 token、一个(可学习的)循环 token、32 个单位 token 和一个全局 token。我将这些输入到四个带有自注意力和门控 MLP 的 Transformer 层中,同时掩码掉已死亡或尚未生成的单位。
xLSTM 核心: 为了通过循环神经网络处理游戏的时间序列特性,我使用了一个由 mLSTM 层(用于记忆容量)和 sLSTM 层组成的 xLSTM。xLSTM 的表现比简单的 LSTM 好得多,同时需要更少的参数。我在纯 JAX 中重新实现了 xLSTM,这使得整个基于 JAX 的流水线成为可能,但该实现在效率、速度和内存使用方面存在一些限制。我将循环 token 传递通过 xLSTM,然后在将其用于值头(value head)之前或将其添加到每个友方单位向量之前使用它。
输出头(Heads): xLSTM 模型随后辅以 Actor 和 Critic 头以允许 PPO 训练。Critic 头只是一个带有谱归一化(spectral norm)的 MLP。Actor 头是一个 Transformer 解码器,它使用最终的单位嵌入作为查询(queries),预测的动作作为键(keys)和值(values)进行交叉注意力。它还包括一个从地图特征到预测汲取位置的跳跃连接。
最终的 Actor 和 Critic 各有约 200 万参数。但在 Kaggle 服务器上只部署了 Actor。对于每个单位,模型预测动作类型的概率分布和每个可能汲取位置的概率分布。在 Kaggle 上,动作和汲取位置是通过使用概率最大的元素来确定的。动作被重新编码并反馈到 Transformer 中。
对于训练,我最初训练了一个小得多的模型(约 50 万参数),只有 1 个 Transformer 层,一个 mLSTM 层,并且同时预测所有单位动作而不是按顺序预测。小模型分 3 个阶段训练:
在最后阶段之后,模型停止改进。虽然这足以获得金牌(至少在当时),但我想要更进一步,所以我决定训练一个更大的模型。这个模型分 2 个阶段训练:
我发现切换到稀疏奖励时模型改进较慢。即使在我提交最终检查点之前,模型仍在不断改进。我本可以训练一个更大的模型,因为我在提交服务器上从未遇到时间管理问题。然而,我不知道在比赛最后一周切换到更大的模型是否值得。
自博弈是通过在 25% 的游戏中对抗最后 128 个检查点,在 75% 的游戏中对抗最新检查点来完成的。因为 JAX 允许你并行玩所有这些游戏,我可以将所有权重保留在 GPU 上并对它们进行 vmap 操作。
我使用了以下超参数来训练最终模型:
| 参数 | 值 |
|---|---|
| LR (学习率) | 3e-4 |
| NUM_ENVS (环境数) | 1024 |
| NUM_STEPS_BETWEEN_UPDATE (更新间隔步数) | 128 |
| BPTT_HORIZON (BPTT 视界) | 16 |
| OPPONENT_UPDATE_STEPS (对手更新步数) | 2^20 |
| OPPONENT_BUFFER_SIZE (对手缓冲区大小) | 128 |
| LATEST_VARIABLES_ENVS (最新变量环境数) | 768 |
| UPDATE_EPOCHS (更新轮数) | 2 |
| MINIBATCH_SIZE (小批量大小) | 64 |
| GAMMA (折扣因子) | 0.997 |
| GAE_LAMBDA | 0.9 |
| CLIP_EPS | 0.05 |
| ENT_COEF (熵系数) | 0.001 |
| VF_COEF (价值函数系数) | 0.5 |
| MAX_GRAD_NORM (最大梯度范数) | 5 |
所有训练均使用 bfloat16 进行。
总的来说,我玩得很开心,学习了 JAX(并成为了它的绝对粉丝),并在 Discord 上与非常乐于助人和酷的人进行了交流。
关于比赛本身,我真的很喜欢游戏的循环方面,并认为引入稍后生成的遗物的平衡补丁是本次比赛中最好的决定之一,让我想尝试循环模型。我也非常欣赏小地图尺寸,这使得像我这样的人可以在只有 8GB 内存的小型 GPU 上开始在家训练!