FastCorrect&Fairseq学习笔记
一 工作说明:
FastCorrect,字面意思就是快速纠错;这项主要是对asr的识别结果进行纠错,提升识别率;
目前大部分的纠错模型采用了基于注意力机制的端到端自回归模型(seq2seq model to correct an ASR output sentence autoregressively)结构。这种结构延迟较大;为此,微软亚洲研究院机器学习组与微软 Azure 语音团队合作,推出了 FastCorrect 系列工作,提出了低延迟的纠错模型;相关研究论文已被 NeurIPS 2021 和 EMNLP 2021 收录;
Fairseq,是一个开源的序列建模工具,由Facebook AI Research于2017年9月推出,Fairseq基于python&pytorch,更加简单,人性化;主要应用场景是nlp任务,支持多种常用模型;
FastCorrect就是基于Fairseq工具进行训练的模型;
本文记录FastCorrect的学习过程,中间对机器学习,Fairseq的学习,理解和记录;初入此道,欢迎讨论;
FastCorrect git:https://github.com/microsoft/NeuralSpeech/tree/master/FastCorrect
Fairseq git:https://github.com/facebookresearch/fairseq
Fairseq 文档:https://fairseq.readthedocs.io/en/latest/command_line_tools.html
二 Fairseq训练流程&核心代码阅读:
1 训练命令解读:
训练命令:
fairseq-train $DATA_PATH --task fastcorrect \
--arch fastcorrect --lr 5e-4 --lr-scheduler inverse_sqrt \
--length-loss-factor 0.5 \
--noise full_mask \
--src-with-werdur \
--dur-predictor-type "v2" \
--dropout 0.3 --weight-decay 0.0001 \
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
--criterion fc_loss --label-smoothing 0.1 \
--max-tokens 9000 \
--werdur-max-predict 3 \
--assist-edit-loss \
--save-dir $SAVE_DIR \
--user-dir $EXP_HOME/FastCorrect \
--left-pad-target False --left-pad-source False \
--encoder-layers 6 --decoder-layers 6 \
--max-epoch 30 --update-freq 4 --fp16 --num-workers 8 \
--share-all-embeddings --encoder-embed-dim=512 --decoder-embed-dim=512
这个命令的含义,可以自己去查Fairseq的文档,这里罗列出来:
--task,意思是任务类型,默认是“translation”(翻译);包括translation_from_pretrained_bart等Fairseq自带的任务类型,这里设置为本任务fastcorrect ;
--arch,意思是model architecture,模型结构,可选项包括,transformer,lstm等;这里采用自带的结构fastcorrect
--lr 学习率,为初始学习率,后续可能被--lr-scheduler修改;--lr-scheduler 是lr更新计划,这里采用inverse_sqrt 方法;
--length-loss-factor 0.5
--noise full_mask
--src-with-werdur
--dur-predictor-type "v2"
--dropout 字面意思就是丢弃,这里指的是在训练模型时,丢弃一部分数据来防止过拟合;
--weight-decay
--optimizer adam --adam-betas '(0.9, 0.98)' 参数优化策略;
--clip-norm 0.0
--criterion fc_loss 训练准则;
--label-smoothing 0.1
--max-tokens 9000 一个batch最大的token数量;
--werdur-max-predict 3
--assist-edit-loss
--save-dir $SAVE_DIR 存储checkpoints的路径,checkpoint即模型;
--user-dir $EXP_HOME/FastCorrect 一个包含扩展的python模块,这里的扩展是指模型结构或者任务,和task是相对的,一般不适用官方规定的arch时需要手动设置这个路径;
--left-pad-target False --left-pad-source False
--encoder-layers 6 --decoder-layers 6
--max-epoch 30 当达到这个30个epoch的时候,停止训练;
--update-freq 4 参数更新频率,每4个batch更新参数;
--fp16 使用FP16
--num-workers 8 8个子线程用于load数据;
--share-all-embeddings
--encoder-embed-dim=512 --decoder-embed-dim=512