Wenet模型流程梳理
text2token.py
:prepare data
tools/text2token.py -s 1 -n 1 data/train/text | cut -f 2- -d" " \
| tr " " "\n" | sort | uniq | grep -a -v -e '^\s*$' | \
awk '{print $0 " " NR+1}' >> ${dict}
asr_model
-
encoder
input: speech(16,80,183)# 183属于batch中最大元素决定 speech_length text (16,6)# 6由batch最大值决定 text_length
-
make_pad_mask
mask :(16,183)
-
subsampling
input(speech,mask)
-
conv(speech)
torch.nn.Conv2d(1, odim, 3, 2), torch.nn.ReLU(), torch.nn.Conv2d(odim, odim, 3, 2), torch.nn.ReLU() # output (16,256,45,19)
-
self.out: linear
torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)) # output (16,45,256)
-
self.pos_enc
pos_emb (1,45,256) # output # speech = torch.nn.Dropout(speech) (16,45,256) # pos_emb = torch.nn.Dropout(pos_emb) (1,45,256)
-
-
$subsampling
#output: speech , pos_emb , mask(16,1,45) #x_mask[:, :, :-2:2][:, :, :-2:2]
-
add_optional_chunk_mask
add_optional_chunk_mask
-
feed_forward_macaron:Feed_Forward
xs = xs + 0.5 * Dropout( feed_forward_macaron( LN(xs) ) ) # feed_forward_macaron: # PositionwiseFeedForward( # (w_1): Linear(in_features=256, out_features=2048, bias=True) # (activation): SiLU() # (dropout): Dropout(p=0.1, inplace=False) # (w_2): Linear(in_features=2048, out_features=256, bias=True) # ) xs = LN(xs)
-
self_attn:
q, k, v = self.forward_qkv(query, key, value)
-
QKV
q, k, v = self.forward_qkv(query, key, value) # query(bs , 71 , 256 ) # q (batch, time1, head, d_k)
-
-