返回列表

10th Place Solution (ST-GCN + Transformer)

680. MABe Challenge - Social Action Recognition in Mice | MABe-mouse-behavior-detection

开始: 2025-09-18 结束: 2025-12-15 智慧养殖 数据算法赛
第 10 名解决方案 (ST-GCN + Transformer)

第 10 名解决方案 (ST-GCN + Transformer)

作者: Ebi (Master)

发布时间: 2025-12-16

竞赛排名: 第 10 名

概述

  • ST-GCN + Transformer
    • 联合建模所有 16 个(代理鼠,目标鼠)对
    • 直接优化竞赛指标(Macro Soft F1 Loss)
    • 集成不同关键点(4-7 个)和序列长度的模型

预处理

FPS 标准化

  • 将所有视频重采样至 30 FPS(线性插值)
  • 原始 FPS 作为额外特征传入模型

坐标标准化(按实验室)

  • 按实验室标准化 (x, y) 坐标:x_norm = (x - mean) / std

关键点映射

  • 不同实验室有关键点名称差异 → 映射为统一名称
    • head → nose (鼻子)
    • spine_1 → neck (颈部)
    • hip_left/right → lateral_left/right (左侧/右侧)

特征工程 (每个关键点 24 维)

从标准化后的 (x, y) 坐标计算 24 个特征:

每个关键点特征 (8 维):
  - x, y         (2): 标准化坐标
  - vx, vy       (2): 速度 (Δt=2 帧)
  - ax, ay       (2): 加速度 (Δt=2 帧)

鼠标间特征 (16 维,相对于其他 3 只鼠标):
  - rel_x, rel_y (6): 相对于鼠标 i 的位置 (i=1,2,3)
  - rel_vx, rel_vy (6): 相对于鼠标 i 的速度
  - dist         (3): 到鼠标 i 的欧几里得距离
  - approach_vel (3): 接近速度 (d(dist)/dt)

总计:8 + 16 = 24 特征

交叉验证

  • 4 折 StratifiedGroupKFold
    • 分组:video_id
    • 分层:lab_id
    • CalMS21, CRIM13 实验室始终在训练集中
  • OOF 评估: 在 3 折上训练,在 1 折上验证
  • 最终提交: 在所有数据上训练(无验证集分割)

模型架构

ST-GCN + Transformer

输入:           [B, 4, K, 24, F]   # 批次,鼠标,关键点,特征,帧
    ↓
ST-GCN (×4) + 时间池化:
                 [B, 4, K, 128, F']  # 空间图卷积 + 时间卷积 + 池化 (F'=F/p)
    ↓
关键点注意力池化:
                 [B, 4, 128, F']    # 关键点上的可学习注意力
    ↓
成对特征提取:
  连接 (agent, target, agent-target, target-agent) 对于 4×4 对
                 [B, 16, 512, F']
    ↓
连接嵌入:
  + 实验室 emb (16) + FPS emb (16) + 动作 emb (32)
                 [B, 16, 576, F']
    ↓
特征压缩:
                 [B, 16, 192, F']
    ↓
带 RoPE 的 Transformer (每对):
                 [B, 16, F', 192]
    ↓
时间上采样:
                 [B, 16, F, 192]    # 恢复到原始帧长度
    ↓
跨对注意力:
  16 对相互关注
                 [B, 4, 4, F, 192]
    ↓
分类器:      [B, 4, 4, 38, F]   # 代理,目标,类别,帧

注意:集成中也使用了非成对模型(without Pairwise Feature Extraction / Cross-Pair Attention)。

损失函数

logits: [B, 4, 4, 38, F] → 展平 → [N, 152] # N = B × 4 代理 × F

三元组掩码 (来自 behaviors_labeled):

  • 在 softmax 之前设置 logits[invalid] = -inf
  • 只有 behaviors_labeled 中的 (agent, target, action) 是有效的

Loss = α × CE + (1-α) × MacroSoftF1

  • CE: 每个代理的所有 (目标 × 动作:152) 上的交叉熵
  • MacroSoftF1: 竞赛指标的可微近似(软 TP/FP/FN,排除背景)
  • 调度:α 在 32 个 epoch 内从 0.2 过渡到 0(最后纯 F1)

数据增强

训练

  • 仿射变换:旋转、缩放、平移、剪切、水平/垂直翻转
  • 鼠标洗牌:排列鼠标 ID
  • CutMix:混合时间片段(相同实验室和鼠标数量)
  • 关键点 dropout

TTA (测试时增强)

  • 水平翻转平均

集成

  • 不同的关键点和序列长度 (160-224)
    • 4kp: ear×2, nose, tail_base
    • 5kp: + neck
    • 7kp: + lateral×2
  • 成对模型 (LB: ~0.525, CV: ~0.510) + 非成对模型 (LB: ~0.500, CV: ~0.500)
  • 跨模型概率平均

无效尝试

  • 测试数据上的伪标签
  • 使用 MABe22 未标记数据的半监督学习
  • 训练中合并每个实验室的稀有类别 (例如,AdaptableSnail: avoid → escape)
同比赛其他方案