返回列表

4th place solution

576. Bengali.AI Speech Recognition | bengaliai-speech

开始: 2023-07-17 结束: 2023-10-17 音视频处理 数据算法赛

第4名解决方案

作者:hanx

发布日期:2023年10月18日

竞赛排名:第4名

非常感谢 organizers 举办这次有趣的比赛。语音识别是一个非常有趣的方向,我从许多参赛选手的讨论和公开代码中学到了很多。过去的三个月充满压力但也收获颇丰。我将尝试用我蹩脚的英语清晰地阐述我的解决方案。

摘要

我的解决方案相对简单,使用 wav2vec2 1b 模型作为预训练模型,并训练一个 Wav2Vec2ForCTC 模型。在后处理阶段,使用 KenLM 训练了一个 6-gram 语言模型,并对输出结果进行归一化和 dari 的进一步后处理。

Wav2Vec2ForCTC 训练

具体来说,我使用 facebook/wav2vec2-xls-r-1b 作为预训练模型。该模型的训练分为三个阶段,每个阶段使用不同的随机种子,并采用一致的数据增强和参数:

  • 优化器:AdamW(weight_decay: 0.05; betas: (0.9, 0.999))
  • 学习率调度器:改进的 linear_warmup_cosine 调度器(init_lr: 1e-5; min_lr: 5e-6; warmup_start_lr: 1e-6; warmup_steps: 1000; max_epoch: 120; iters_per_epoch: 1000)
class LinearWarmupCosine3LongTailLRScheduler:
    def __init__(
            self,
            optimizer,
            max_epoch,
            min_lr,
            init_lr,
            iters_per_epoch,
            warmup_steps=0,
            warmup_start_lr=-1,
            **kwargs
    ):
        self.optimizer = optimizer

        self.max_epoch = max_epoch
        self.min_lr = min_lr

        self.init_lr = init_lr
        self.warmup_steps = warmup_steps
        self.iters_per_epoch = iters_per_epoch
        self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
        self.max_iters = max_epoch * iters_per_epoch

    def step(self, cur_epoch, cur_step):
        total_steps = cur_epoch * self.iters_per_epoch + cur_step
        if total_steps < self.warmup_steps:
            warmup_lr_schedule(
                step=cur_step,
                optimizer=self.optimizer,
                max_step=self.warmup_steps,
                init_lr=self.warmup_start_lr,
                max_lr=self.init_lr,
            )
        elif total_steps <= self.max_iters // 4:
            cosine_lr_schedule(
                epoch=total_steps,
                optimizer=self.optimizer,
                max_epoch=self.max_iters // 4,
                init_lr=self.init_lr,
                min_lr=self.min_lr,
            )
        elif total_steps <= self.max_iters // 2:
            cosine_lr_schedule(
                epoch=self.max_iters // 4,
                optimizer=self.optimizer,
                max_epoch=self.max_iters // 4,
                init_lr=self.init_lr,
                min_lr=self.min_lr,
            )
        else:
            cosine_lr_schedule(
                epoch=total_steps - self.max_iters // 2,
                optimizer=self.optimizer,
                max_epoch=self.max_iters // 2,
                init_lr=self.min_lr,
                min_lr=0,
            )

def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
    lr = (init_lr - min_lr) * 0.5 * (
            1.0 + math.cos(math.pi * epoch / max_epoch)
    ) + min_lr
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr
  • 基础数据增强(记为 `base_aug`):
def get_transform(musan_dir):
    trans = Compose(
        [
            TimeStretch(min_rate=0.9, max_rate=1.1, p=0.2, leave_length_unchanged=False),
            Gain(min_gain_in_db=-6, max_gain_in_db=6, p=0.1),
            PitchShift(min_semitones=-4, max_semitones=4, p=0.2),
            OneOf(
                [
                    AddBackgroundNoise(sounds_path=musan_dir, min_snr_in_db=3.0, max_snr_in_db=30.0,
                                       noise_transform=PolarityInversion(), p=1.0),
                    AddGaussianNoise(min_amplitude=0.005, max_amplitude=0.015, p=1.0),
                ] if musan_dir is not None else [
                    AddGaussianNoise(min_amplitude=0.005, max_amplitude=0.015, p=1.0), ],
                p=0.5,
            ),
        ]
    )
    return trans
  • 复合数据增强(记为 `comp_aug`)。我使用了三种复合增强方法:
    1. 将音频波形均匀分割为3段,对每段执行 base_aug
    2. 从数据集中随机选择两段语音,分别执行 base_aug 后拼接在一起
    3. 结合上述两种数据增强方法
  • 数据集
    对于 Wav2Vec2ForCTC 训练,我没有使用外部数据,通过以下步骤筛选竞赛数据集:
    1. 基于 arijitx/wav2vec2-xls-r-300m-bengali 训练模型
    2. 使用上述模型对整个数据集进行推理,将所有样本得分从小到大排序,保留前70%的数据

KenLM 训练

  • 数据集:使用 IndicCorpv1 和 IndicCorpv2 作为语料库。清洗后将两个语料库去重合并
  • 语料清洗:使用以下代码清洗每个句子:
chars_to_ignore = re.compile(r'[^\u0980-\u09FF\s]')
long_space_to_ignore = re.compile(r'\s+')
bnorm = Normalizer()

def fix_text(text: str):
    text = re.sub(chars_to_ignore, ' ', text)
    text = re.sub(long_space_to_ignore, ' ', text).strip()
    return text

def norm_sentence(sentence):
    sentence = normalize(sentence)
    sentence = fix_text(sentence)
    words = sentence.split()
    try:
        all_words = [bnorm(word)["normalized"] for word in words]
        all_words = [_ for _ in all_words if _]
        if len(all_words) < 2:
            return ""
        return " ".join(all_words).strip()
    except TypeError:
        return None
  • 训练了 6gram 语言模型

尝试但未奏效的方法

  1. 使用外部数据(如 openslr, shrutilipi)
  2. 使用 deepfilternet 对音频去噪
  3. 使用更大的模型(wav2vec2-xls-r-2b)
  4. 使用 whisper-small
  5. 使用更多语料训练语言模型(BanglaLM)
  6. 训练拼写纠错模型
  7. 训练第四阶段模型
  8. ……
    还有很多其他尝试

更新说明

同比赛其他方案