手写Conformer网络结构
import torch
from torch import nn
x = torch.randint(0, 10, size=(5, 280,80))
length = torch.tensor([10,9,9,9,9])
x.size(),x.shape,x[0].shape,length
# def mad_mask(length):
batch = length.size(0)
max_len = length.max().item()
# seq_t = torch.arange(0, max_len).unsqueeze(0).expand(batch, max_len)
seq_t = torch.arange(0, max_len).unsqueeze(0).repeat(5,1) # 维度必须对应起来
length, seq_t
seq_mask = seq_t >= length.unsqueeze(1)
mask = ~seq_mask
mask
odim = 256
idim = 80
# def embedding
conv1 = torch.nn.Conv2d(1, 256, 3, 2)
relu1 = torch.nn.ReLU()
conv2 = torch.nn.Conv2d(256, 256, 3, 2)
# x = x.unsqueeze(1).long()
x = x.unsqueeze(1).to(dtype=torch.float32)
x1 = conv1(x)
x1 = relu1(x1)
x2 = conv2(x1)
x2 = relu1(x2)
linear1 = torch.nn.Linear(odim * (((idim-1)//2 -1)//2), odim)
b, c, t, f = x2.size()
x = linear1(x2.transpose(1,2).contiguous().view(b, t, c*f))
conv1, conv2, x1.shape,x2.shape , linear1, x2.shape , x.shape
# encoder
layernorm1 = nn.LayerNorm((256), eps=1e-12)
layernorm2 = nn.LayerNorm((256), eps=1e-12)
dropout1 = nn.Dropout(0.1)
## ffn1
drop1 = nn.Dropout(0.1)
linear1 = nn.Linear(256, 2048)
linear2 = nn.Linear(2048, 256)
activation1 = nn.ReLU()
residual = x
x1 = layernorm1(x)
x2 = residual + 0.5* linear2( drop1( activation1( linear1(x))))
x = x2
# muti_head_attn
x3 = layernorm2(x)
def generate_qkv(query, key, value):
linear_q = nn.Linear(256, 256)
linear_k = nn.Linear(256, 256)
linear_v = nn.Linear(256, 256)
n_head = 4
n_batch = query.size(0)
n_feat = query.size(-1)
d_k = n_feat//n_head
q = linear_q(query).view(n_batch, -1 , n_head, d_k)
k = linear_k(key).view(n_batch, -1 , n_head, d_k)
v = linear_v(value).view(n_batch, -1 , n_head, d_k)
return q,k,v
# convolution module
channels = odim
resdual = x
print(x.shape)
x = nn.LayerNorm((256), eps=1e-12)(x)
pointwise1_conv = nn.Conv1d(channels, 2*channels, kernel_size=1, stride=1, padding=0, bias=True)
x = x.transpose(-1, -2)
x = pointwise1_conv(x)
# x = nn.functional.glu(x, dim=1) #nn.functional中的应该都是阔约直接调用的函数
# depthwise_conv = nn.Conv1d(
# channels,
# channels,
# 15,
# stride=1,
# padding=7,
# groups=channels,
# bias=True
# )
# x = depthwise_conv(x)
# x = nn.ReLU(x)
# pointwise_conv2 = nn.Conv1d(
# channels,
# channels,
# kernel_size=1,
# stride=1,
# padding=0,
# bias=True,
# )
# x = pointwise_conv2(x)
查看参数
import torch
from torch import nn
x = torch.randint(0, 10, size=(5, 280,80))
length = torch.tensor([5,4,4,4,4])
x.size(),x.shape,x[0].shape,length
# def mad_mask(length):
batch = length.size(0)
max_len = length.max().item()
# seq_t = torch.arange(0, max_len).unsqueeze(0).expand(batch, max_len)
seq_t = torch.arange(0, max_len).unsqueeze(0).repeat(5,1) # 维度必须对应起来
length, seq_t
seq_mask = seq_t >= length.unsqueeze(1)
mask = ~seq_mask
mask
odim = 256
idim = 80
# def embedding
conv1 = torch.nn.Conv2d(1, 256, 3, 2)
relu1 = torch.nn.ReLU()
conv2 = torch.nn.Conv2d(256, 256, 3, 2)
# x = x.unsqueeze(1).long()
x = x.unsqueeze(1).to(dtype=torch.float32)
x1 = conv1(x)
x1 = relu1(x1)
x2 = conv2(x1)
x2 = relu1(x2)
linear1 = torch.nn.Linear(odim * (((idim-1)//2 -1)//2), odim)
b, c, t, f = x2.size()
x = linear1(x2.transpose(1,2).contiguous().view(b, t, c*f))
conv1, conv2, x1.shape,x2.shape , linear1, x2.shape , x.shape
# encoder
layernorm1 = nn.LayerNorm((256), eps=1e-12)
layernorm2 = nn.LayerNorm((256), eps=1e-12)
dropout1 = nn.Dropout(0.1)
## ffn1
drop1 = nn.Dropout(0.1)
linear1 = nn.Linear(256, 2048)
linear2 = nn.Linear(2048, 256)
activation1 = nn.ReLU()
residual = x
x1 = layernorm1(x)
x2 = residual + 0.5* linear2( drop1( activation1( linear1(x))))
x = x2
# muti_head_attn
x3 = layernorm2(x)
def generate_qkv(query, key, value):
linear_q = nn.Linear(256, 256)
linear_k = nn.Linear(256, 256)
linear_v = nn.Linear(256, 256)
n_head = 4
n_batch = query.size(0)
n_feat = query.size(-1)
d_k = n_feat//n_head
q = linear_q(query).view(n_batch, -1 , n_head, d_k)
k = linear_k(key).view(n_batch, -1 , n_head, d_k)
v = linear_v(value).view(n_batch, -1 , n_head, d_k)
return q,k,v
# convolution module
channels = odim
resdual = x
x = nn.LayerNorm((256), eps=1e-12)(x)
pointwise1_conv = nn.Conv1d(channels, 2*channels, kernel_size=1, stride=1, padding=0, bias=True)
x = x.transpose(-1, -2)
x = pointwise1_conv(x)
x = nn.functional.glu(x, dim=1) #nn.functional中的应该都是阔约直接调用的函数
depthwise_conv = nn.Conv1d(
channels,
channels,
15,
stride=1,
padding=7,
groups=channels,
bias=True
)
x = depthwise_conv(x)
x = nn.ReLU()(x)
pointwise_conv2 = nn.Conv1d(
channels,
channels,
kernel_size=1,
stride=1,
padding=0,
bias=True,
)
x = pointwise_conv2(x)
x.shape
# print(x)
# 解码
# CTC
linear_ctc = nn.Linear(x.size(-1), 108)
x_ctc = linear_ctc(nn.functional.dropout(x, p=0.1))
x_ctc = x_ctc.transpose(0, 1)
# logsoftmax = nn.functional.softmax()
x_ctc = nn.functional.log_softmax(x_ctc, dim=-1)
ctc_loss_fn = nn.CTCLoss(blank=0, reduction="sum")
y_ref = torch.tensor([[12, 13, 1, 15, 16],
[12, 13, 1, 15, -1],
[12, 13, 1, 15, -1],
[12, 13, 1, 15, -1],
[12, 13, 1, 15, -1]])
x_true_len = torch.tensor([69, 60, 60 ,60 ,60])
print(x_ctc.shape, y_ref.shape, x_true_len.shape, length.shape)
ctc_loss = ctc_loss_fn(x_ctc, y_ref, x_true_len, length)
loss = ctc_loss / x_ctc.size(1)
x_ctc = x_ctc.transpose(0, 1)
print(loss)
for name, p in conv1.named_parameters():
print(name, p.shape, p.numel())
for name, p in conv2.named_parameters():
print(name, p.shape, p.numel())