返回列表

10th Place Solution – Boey – End-to-End JAX RL

645. NeurIPS 2024 - Lux AI Season 3 | lux-ai-season-3

开始: 2024-12-09 结束: 2025-03-24 游戏AI AI大模型赛
第 10 名解决方案 – Boey – 端到端 JAX RL
作者: Boey
发布日期: 2025-03-26
竞赛排名: 第 10 名

第 10 名解决方案 – Boey – 端到端 JAX RL

首先,我要感谢组织者和 Lux AI 团队设计了如此高效且引人入胜的 Lux S3 环境。这既是我第一次参加 Kaggle 竞赛,也是我第一次参加强化学习竞赛,我非常感激有机会在这样的竞争环境中应用和测试我的 RL 知识。

方法概述

我的解决方案建立在 PureJaxRL(一个类似 CleanRL 的 JAX 最小化 PPO 实现)之上,我将其扩展为一个完全端到端的 JAX 强化学习 pipeline。每个组件,从观察预处理到训练循环,都是完全可 JIT 编译的,从而实现高速训练性能。

关键组件:

  1. 端到端 JAX RL – 完全可 JIT 编译的 pipeline。
  2. 带有时间反向传播 (BPTT) 的 PPO
  3. 基于优先虚构自我博弈 (PFSP) 的多智能体学习

性能基准:

  • 训练吞吐量:80,000 步/秒
  • 模型大小:~180 万参数,(+140 万用于独立的 Critic 网络)

输入特征

单位特征 (每 32 个单位):

  • 盟友/敌人标志
  • 可见性
  • 最近 5 回合能量
  • 自从上次被发现以来的回合数
  • 当前位置/最后已知位置

空间特征 (24×24 网格):

  • 能量场
  • 星云 & 小行星场
  • 遗迹节点 & 点数
  • 传感器掩码
  • 访问过的瓦片

*空间信息在多个步骤中持久存在。
*星云/小行星瓦片根据漂移速度动态移动。
*大多数空间特征在地图上是对称的(传感器掩码除外)。

标量特征:

  • 二进制编码的团队点数和胜利次数
  • 二进制编码的比赛步数
  • 隐藏的游戏参数(仅在可推导时包含)

动作空间

组件 描述 掩码
动作类型 noop, move_up, move_down, move_left, move_right, sap 能量不足,地图边界,小行星碰撞
SAP 头 15×15 = 225 个可能的 sap 目标(仅在选择 sap 动作时使用) 仅允许 targeting 围绕可见的敌人单位和已知遗迹点

奖励工程

阶段 奖励
早期到中期训练 每获得一点 +0.002,每赢得一场比赛 +1
后期训练 (稀疏) 仅每赢得一场比赛 +1

所有奖励均为零和。我仅在比赛的最后一周切换到稀疏奖励,因此两个提交版本都只使用了稀疏奖励训练了总训练步数的最后 10%。尽管使用有限,但这一切换仍然带来了显著的性能提升。


网络架构

网络架构图

这是我提交使用的最终网络架构。它与我的初始设计相似,相对轻量,约有 180 万可训练参数。虽然我尝试过更大的模型,但它们消耗了更多的显存,减少了并行环境的数量并减慢了训练速度,而在中期训练期间并没有带来显著的性能提升。

尝试的实验

  • 用 ResBlocks 替换 Conv2d。
  • 为 Sap 添加 3x ResBlocks 上采样以创建 15x15 Sap 地图。
  • 使用 Pointer Network 选择 SAP 目标

集中式 Critic (全观测性)

对于我的第二个提交,我使用了具有全局信息访问权限的集中式 Critic。策略和 Critic 使用单独的输入 pipeline,并通过两个独立的网络进行训练(+140 万参数)。

  • 优点:
    • 改进的价值估计
    • 样本效率大约提高 2 倍
  • 缺点:
    • 训练较慢(降至 ~40k SPS)
    • 实际耗时性能增益为中性

优先虚构自我博弈 (PFSP)

我实施了优先虚构自我博弈 (PFSP),灵感来自 AlphaStar 的方法,以增强我强化学习代理的训练效率和鲁棒性。与标准的虚构自我博弈 (FSP) 不同(后者从过去版本中均匀采样对手),PFSP 根据对手相对于当前代理的胜率分配更高的选择概率。这种针对性采样确保代理更多地关注具有挑战性的对手,从而促进持续改进并避免停滞。

  • 75% 游戏:自我博弈
  • 25% 游戏:冻结的过去版本
  • PFSP 更频繁地与代理难以战胜的对手匹配(通过胜率跟踪)

训练设置

我大部分实验是在本地机器上运行的,配置为 AMD Ryzen 7950x, 64gb RAM 和 Nvidia RTX 4090。对于超参数调整,我在大约 10 天内租用了 2-4 个云 GPU (RTX 4090)。

挑战

虽然构建完全端到端的 JAX pipeline 提供了出色的训练性能,但它显著降低了代码可读性,使得发现 bug 变得更加困难。直到比赛进行到三分之二时,我才发现了一些严重阻碍代理性能的关键 bug。事后看来,我应该从一开始就包含广泛的单元测试,因为它们对于确保正确性和稳定性至关重要。

最终提交

提交 设置 训练环境步数 实际耗时 最终排行榜评分
#1 共享网络 ~580 亿 8 天 1771.5
#2 集中式 Critic ~230 亿 7 天 1884.5

结束语

我非常自豪能在我的第一次 Kaggle 竞赛中进入前 10 名。

回顾过去,我相信选择更高容量的模型(1000 万 + 参数)而不是最大化训练吞吐量会带来更好的结果。我的最终提交使用了具有约 320 万参数的集中式 Critic,其评分显著高于早期的 180 万模型。

同比赛其他方案