549. Stable Diffusion - Image to Prompts | stable-diffusion-image-to-prompts
我们感谢组织者举办这场激动人心的竞赛。
祝贺所有完成竞赛的选手和获奖者。
我们的方法包括训练多个图像模型,使其能够直接从图像预测提示词嵌入(由句子转换器生成),然后对这些模型的预测结果进行集成。最终提交的是四个模型的集成:
竞赛初期,我们尝试了直接预测图像提示词的方法。但后来发现直接预测提示词嵌入能获得更好的分数。因此从竞赛中期开始,我们专注于改进直接嵌入预测的方法。
我们生成了约500万对提示词-图像组合(部分图像由相同提示词多次生成),详情如下表:
| 数据集名称 | 提示词数量 | 每个提示词的图像数 | 总图像数 |
|---|---|---|---|
| cc3m | 249,593 | 3 | 748,781 |
| lexica | 67,101 | 3 | 201,303 |
| diffusiondb | 326,035 | 3 | 978,105 |
| cc3m part2 | 1,166,852 | 1 | 1,166,852 |
| diffusiondb part2 | 1,371,480 | 1 | 1,371,480 |
| mscoco | 197,527 | 3 | 592,583 |
| 总计: | 3,378,588 | 1或3 | 5,059,104 |
训练过程中,对于具有三个对应图像的提示词样本,我们随机选择一个图像,确保同一epoch内不使用相同提示词的重复图像。
数据集参考:
为划分生成的图像-提示词对用于训练和评估,我们首先使用句子转换器计算提示词嵌入。然后将具有相似提示词嵌入的样本分组。在嵌入余弦相似度≥0.7的样本分组后,使用GroupKFold将数据集划分为10折。模型训练在除fold0外的所有折上进行,fold0保留用于评估。
此外,在创建cc3m和diffusiondb数据集时,我们预先计算了提示词嵌入的余弦相似度。对于相似度≥0.9的提示词,仅使用其中一个生成图像。竞赛中期开始,我们也利用之前未使用的剩余提示词,创建了cc3m_part2和diffusiondb_part2数据集。
我们的模型架构简单直接,由一个骨干网络连接384维全连接层(无偏置)组成,与句子转换器的输出维度匹配。
部分模型具有更大的输入图像分辨率,如下表所示:
| 模型 | 预训练骨干网络 | 输入尺寸 |
|---|---|---|
| eva02_large | eva02_large_patch14_448.mim_m38m_ft_in22k (timm) | 448 (预训练@448) |
| convnext_xxlarge | laion2b_s34b_b82k_augreg_rewind (open_clip) | 384 (预训练@256) |
| convnext_large | laion2b_s29b_b131k_ft_soup (open_clip) | 384 (预训练@320) |
| vit_large | laion2b_s32b_b82k (open_clip) | 336 (预训练@224) |
open_clip的ViT模型可按如下方式更改输入图像尺寸:
backbone = open_clip_model.visual
hgrid, wgrid = backbone.grid_size
hpatch = input_height // hgrid
wpatch = input_width // wgrid
backbone.patch_size = (hpatch, wpatch)
backbone.conv1.kernel_size = (hpatch, wpatch)
backbone.conv1.stride = (hpatch, wpatch)
我们使用PyTorch框架训练模型,配置如下:
数据增强配置相对较轻:
竞赛中期前,我们仅使用三个数据集训练模型:cc3m、lexica和diffusiondb。
| 模型 | epoch数 | 热身 | 初始学习率 | 最终学习率 | 数据集 |
|---|---|---|---|---|---|
| eva02_large | 10 | 1 | 1e-4 | 1e-6 | cc3m,lexica,diffusiondb |
| convnext_xxlarge | 10 | 1 | 1e-4 | 1e-6 | cc3m,lexica,diffusiondb |
| convnext_large | 15 | 3 | 1e-4 | 1e-6 | cc3m,lexica,diffusiondb |
| vit_large | 15 | 3 | 1e-4 | 1e-6 | cc3m,lexica,diffusiondb |
中期开始后,我们加入cc3m_part2、diffusiondb_part2和mscoco数据集进行模型微调。
| 模型 | epoch数 | 热身 | 初始学习率 | 最终学习率 | 数据集 |
|---|---|---|---|---|---|
| eva02_large | 3 | 0 | 1e-6 | 1e-7 | 全部 |
| convnext_xxlarge | 3 | 0 | 1e-6 | 1e-7 | 全部 |
| convnext_large | 5 | 0 | 1e-5 | 1e-7 | 全部 |
| vit_large | 5 | 0 | 1e-5 | 1e-7 | 全部 |
最终提交使用了四个模型。每个模型使用4-5个不同的交叉验证折,因此我们在四个模型上集成了18个折。
| 模型 | 折 |
|---|---|
| eva02_large | 0,2,2*,5,9 |
| convnext_xxlarge | 0,1,6,6*,7 |
| convnext_large | 0,1,2,3 |
| vit_large | 0,3,4,8 |
标有*的折表示仅微调1个epoch且无数据增强。
我们为每个模型准备了fold0的验证模型,并使用这些fold0验证数据调整集成权重。
我们还尝试了不仅按模型确定权重,还沿着每个模型的384个输出维度确定权重(共384×4个权重),这些权重也通过训练确定。但训练结果在特定维度上权重变化不明显,与使用单模型权重几乎相同。
提交两种方法并比较分数后,使用每模型384个权重的方法仅提升了+0.0002分。因此我们采用后一种方法作为最终提交,但需注意该改进可能在误差范围内。
使用CosineEmbeddingLoss时,损失计算针对每个图像-提示词对,不考虑不同对之间的关系。为此我们尝试实现一种损失,使模型输出中所有对的余弦相似度接近小批量内对应目标的所有对余弦相似度,考虑所有样本组合。
但我们未采用此方法,因为无论单独使用还是与CosineEmbeddingLoss结合,它都未提升性能。