返回列表

22nd place solution: CTC Loss, Strong augmentations, CNN+MHSA

570. Google - American Sign Language Fingerspelling Recognition | asl-fingerspelling

开始: 2023-05-10 结束: 2023-08-24 计算机视觉 数据算法赛
第22名解决方案:CTC损失、强数据增强与CNN+MHSA

第22名解决方案:CTC损失、强数据增强与CNN+MHSA

排名:第22名

作者:Samrat Thapa

发布时间:2023-08-26

我的最佳模型配置文件)基于上一届比赛第一名解决方案(作者@hoyso48)构建。我使用标准的CTC损失进行训练,推理时采用贪婪解码。

有效改进方案:

  • 更长的输入帧序列:384帧的输入长度比256帧表现更好
  • 更深更大的模型:最佳模型包含7个模块,隐藏维度256,共950万参数
  • 姿态关键点:加入姿态信息很有帮助,最佳模型使用了姿态+手部+嘴唇+眼睛的关键点数据
  • CNN+MHSA:CNN+多头自注意力模型 > 纯CNN模型 > 纯MHSA模型
  • 强数据增强:我的模型在本地CV和公开榜都有良好表现,但用网络摄像头测试时无法识别我的手势,因此引入了强数据增强。时间掩码特别有效,因为我的手势比专业选手慢。这些增强还将公开榜分数提升了+0.006:
    • 水平翻转概率=0.5
    • 随机仿射变换概率=0.75
    • 冻结概率=0.5
    • 时间掩码概率=0.75
    • 时间掩码范围=(0.2,0.4)
  • 拼接增强:随机拼接两个短的关键点序列及其标签,将公开榜分数提升了+0.008。该增强应用于40%的训练样本

尝试但未采用的方法:

  • Transformer编码器+解码器:使用交叉熵损失的端到端Transformer模型
  • Transformer解码器:CNN+MHSA模型配合类Transformer解码器及交叉熵损失
  • 自注意力因果掩码:移除因果掩码比使用因果掩码效果更好
  • 自注意力跨度:原以为该任务不存在帧间长距离依赖,尝试减少注意力跨度,虽无性能下降但也无提升
  • Focal损失:基于CTC Focal损失的尝试未带来提升
  • 关键点擦除增强:随机擦除除手部外的关键点,提高模型对MediaPipe检测错误的鲁棒性

最初使用PyTorch,但因PyTorch转tfLite遇到困难而转用TensorFlow。然而TensorFlow训练CTC损失比PyTorch慢10倍。回顾来看,本应更努力解决模型转换问题以进行更多实验。

祝贺获奖团队!感谢Google举办此次比赛,这是一次宝贵的学习机会。也感谢所有分享笔记和独特思路的参与者。

同比赛其他方案