返回列表

5th place solution

549. Stable Diffusion - Image to Prompts | stable-diffusion-image-to-prompts

开始: 2023-02-13 结束: 2023-05-15 AIGC与多模态 数据算法赛
第五名解决方案

第五名解决方案

作者:kenji(团队成员:datelier, Yuya Otsuka, tarokd, ryoshih)
发布日期:2023-05-16

我们感谢组织者举办这场激动人心的竞赛。
祝贺所有完成竞赛的选手和获奖者。

解决方案概述

我们的方法包括训练多个图像模型,使其能够直接从图像预测提示词嵌入(由句子转换器生成),然后对这些模型的预测结果进行集成。最终提交的是四个模型的集成:

  • eva02_large
  • convnext_xxlarge
  • convnext_large
  • vit_large

竞赛初期,我们尝试了直接预测图像提示词的方法。但后来发现直接预测提示词嵌入能获得更好的分数。因此从竞赛中期开始,我们专注于改进直接嵌入预测的方法。

数据集生成

我们生成了约500万对提示词-图像组合(部分图像由相同提示词多次生成),详情如下表:

数据集名称 提示词数量 每个提示词的图像数 总图像数
cc3m 249,593 3 748,781
lexica 67,101 3 201,303
diffusiondb 326,035 3 978,105
cc3m part2 1,166,852 1 1,166,852
diffusiondb part2 1,371,480 1 1,371,480
mscoco 197,527 3 592,583
总计: 3,378,588 1或3 5,059,104

训练过程中,对于具有三个对应图像的提示词样本,我们随机选择一个图像,确保同一epoch内不使用相同提示词的重复图像。

数据集参考:

验证策略

为划分生成的图像-提示词对用于训练和评估,我们首先使用句子转换器计算提示词嵌入。然后将具有相似提示词嵌入的样本分组。在嵌入余弦相似度≥0.7的样本分组后,使用GroupKFold将数据集划分为10折。模型训练在除fold0外的所有折上进行,fold0保留用于评估。

此外,在创建cc3m和diffusiondb数据集时,我们预先计算了提示词嵌入的余弦相似度。对于相似度≥0.9的提示词,仅使用其中一个生成图像。竞赛中期开始,我们也利用之前未使用的剩余提示词,创建了cc3m_part2和diffusiondb_part2数据集。

模型架构

我们的模型架构简单直接,由一个骨干网络连接384维全连接层(无偏置)组成,与句子转换器的输出维度匹配。

部分模型具有更大的输入图像分辨率,如下表所示:

模型 预训练骨干网络 输入尺寸
eva02_large eva02_large_patch14_448.mim_m38m_ft_in22k (timm) 448 (预训练@448)
convnext_xxlarge laion2b_s34b_b82k_augreg_rewind (open_clip) 384 (预训练@256)
convnext_large laion2b_s29b_b131k_ft_soup (open_clip) 384 (预训练@320)
vit_large laion2b_s32b_b82k (open_clip) 336 (预训练@224)

open_clip的ViT模型可按如下方式更改输入图像尺寸:

backbone = open_clip_model.visual
hgrid, wgrid = backbone.grid_size
hpatch = input_height // hgrid
wpatch = input_width // wgrid
backbone.patch_size = (hpatch, wpatch)
backbone.conv1.kernel_size = (hpatch, wpatch)
backbone.conv1.stride = (hpatch, wpatch)

训练过程

我们使用PyTorch框架训练模型,配置如下:

  • 10或15个epoch,使用分布式数据并行(DDP)、自动混合精度(AMP)和梯度检查点
  • 使用CosineEmbeddingLoss作为损失函数
  • 使用MADGRAD优化器
  • 热身1或3个epoch,仅训练FC层,骨干网络冻结(学习率:从1e-2到1e-4)
  • 余弦学习率调度器,学习率范围1e-4到1e-7
  • 使用完整数据集微调3或5个epoch

数据增强配置相对较轻:

  • 无旋转
  • 无水平翻转
  • RandomResizedCrop,比例0.5到1.0
  • ColorJitter(亮度=0.05,对比度=0.05,饱和度=0.05,色调=0.05)

竞赛中期前,我们仅使用三个数据集训练模型:cc3m、lexica和diffusiondb。

模型 epoch数 热身 初始学习率 最终学习率 数据集
eva02_large 10 1 1e-4 1e-6 cc3m,lexica,diffusiondb
convnext_xxlarge 10 1 1e-4 1e-6 cc3m,lexica,diffusiondb
convnext_large 15 3 1e-4 1e-6 cc3m,lexica,diffusiondb
vit_large 15 3 1e-4 1e-6 cc3m,lexica,diffusiondb

中期开始后,我们加入cc3m_part2、diffusiondb_part2和mscoco数据集进行模型微调。

模型 epoch数 热身 初始学习率 最终学习率 数据集
eva02_large 3 0 1e-6 1e-7 全部
convnext_xxlarge 3 0 1e-6 1e-7 全部
convnext_large 5 0 1e-5 1e-7 全部
vit_large 5 0 1e-5 1e-7 全部

集成策略

最终提交使用了四个模型。每个模型使用4-5个不同的交叉验证折,因此我们在四个模型上集成了18个折。

模型
eva02_large 0,2,2*,5,9
convnext_xxlarge 0,1,6,6*,7
convnext_large 0,1,2,3
vit_large 0,3,4,8

标有*的折表示仅微调1个epoch且无数据增强。

我们为每个模型准备了fold0的验证模型,并使用这些fold0验证数据调整集成权重。

我们还尝试了不仅按模型确定权重,还沿着每个模型的384个输出维度确定权重(共384×4个权重),这些权重也通过训练确定。但训练结果在特定维度上权重变化不明显,与使用单模型权重几乎相同。

提交两种方法并比较分数后,使用每模型384个权重的方法仅提升了+0.0002分。因此我们采用后一种方法作为最终提交,但需注意该改进可能在误差范围内。

未奏效的方法

  • Mixup
  • GeM(广义平均)池化
  • GGeM池化(视觉变换器的分组广义平均池化
  • 三元注意力(卷积三元注意力模块
  • BLIP、BLIP2、CoCa等图像字幕模型
    • 直接预测提示词嵌入在计算成本和分数上都更有效
  • 基于模型OOF预测分数加权损失(考虑简单或困难样本的损失)
  • 使用预训练clip模型的kNN方法
    • 与单模型预测结果集成时观察到改进,但加入最终提交的4模型集成后无效果
  • 所有样本组合的余弦相似度损失(详情如下)

所有样本组合的余弦相似度损失

使用CosineEmbeddingLoss时,损失计算针对每个图像-提示词对,不考虑不同对之间的关系。为此我们尝试实现一种损失,使模型输出中所有对的余弦相似度接近小批量内对应目标的所有对余弦相似度,考虑所有样本组合。

但我们未采用此方法,因为无论单独使用还是与CosineEmbeddingLoss结合,它都未提升性能。

同比赛其他方案