Fork me on GitHub

Coding Poineer

Coding Poineer

Coding Poineer

Coding Poineer

Coding Poineer

Coding Poineer

Coding Poineer

Coding Poineer

Coding Poineer

Coding Poineer

Coding Poineer

手写Conformer网络结构

import torch
from torch import nn
x = torch.randint(0, 10, size=(5, 280,80))
length = torch.tensor([10,9,9,9,9])
x.size(),x.shape,x[0].shape,length

# def mad_mask(length):
batch = length.size(0)
max_len = length.max().item()
# seq_t = torch.arange(0, max_len).unsqueeze(0).expand(batch, max_len)
seq_t = torch.arange(0, max_len).unsqueeze(0).repeat(5,1)     # 维度必须对应起来
length, seq_t

seq_mask = seq_t >= length.unsqueeze(1)
mask = ~seq_mask
mask

odim = 256
idim = 80
# def embedding
conv1 = torch.nn.Conv2d(1, 256, 3, 2)
relu1 = torch.nn.ReLU()
conv2 = torch.nn.Conv2d(256, 256, 3, 2)

# x = x.unsqueeze(1).long()
x = x.unsqueeze(1).to(dtype=torch.float32)

x1 = conv1(x)
x1 = relu1(x1)
x2 = conv2(x1)
x2 = relu1(x2)

linear1 = torch.nn.Linear(odim * (((idim-1)//2 -1)//2), odim)
b, c, t, f = x2.size()
x = linear1(x2.transpose(1,2).contiguous().view(b, t, c*f))

conv1, conv2, x1.shape,x2.shape , linear1, x2.shape , x.shape


# encoder
layernorm1 = nn.LayerNorm((256), eps=1e-12)
layernorm2 = nn.LayerNorm((256), eps=1e-12)
dropout1 = nn.Dropout(0.1)
## ffn1 
drop1 = nn.Dropout(0.1)
linear1 = nn.Linear(256, 2048)
linear2 = nn.Linear(2048, 256)
activation1 = nn.ReLU()

residual = x
x1 = layernorm1(x)
x2 = residual + 0.5* linear2( drop1( activation1( linear1(x))))
x = x2

# muti_head_attn
x3 = layernorm2(x)

def generate_qkv(query, key, value):
    linear_q = nn.Linear(256, 256)
    linear_k = nn.Linear(256, 256)
    linear_v = nn.Linear(256, 256)
    
    n_head = 4
    n_batch = query.size(0)
    n_feat = query.size(-1)
    d_k = n_feat//n_head
    
    q = linear_q(query).view(n_batch, -1 , n_head, d_k)
    k = linear_k(key).view(n_batch, -1 , n_head, d_k)
    v = linear_v(value).view(n_batch, -1 , n_head, d_k)

    
    return q,k,v   
# convolution module
channels = odim
resdual = x
print(x.shape)
x = nn.LayerNorm((256), eps=1e-12)(x)
pointwise1_conv = nn.Conv1d(channels, 2*channels, kernel_size=1, stride=1, padding=0, bias=True)
x = x.transpose(-1, -2)
x = pointwise1_conv(x)
# x = nn.functional.glu(x, dim=1) #nn.functional中的应该都是阔约直接调用的函数
# depthwise_conv = nn.Conv1d(
#             channels,
#             channels,
#             15,
#             stride=1,
#             padding=7,
#             groups=channels,
#             bias=True
#         )
# x = depthwise_conv(x)
# x = nn.ReLU(x)
# pointwise_conv2 = nn.Conv1d(
#             channels,
#             channels,
#             kernel_size=1,
#             stride=1,
#             padding=0,
#             bias=True,
#         )
# x = pointwise_conv2(x)

查看参数


import torch
from torch import nn
x = torch.randint(0, 10, size=(5, 280,80))
length = torch.tensor([5,4,4,4,4])
x.size(),x.shape,x[0].shape,length

# def mad_mask(length):
batch = length.size(0)
max_len = length.max().item()
# seq_t = torch.arange(0, max_len).unsqueeze(0).expand(batch, max_len)
seq_t = torch.arange(0, max_len).unsqueeze(0).repeat(5,1)     # 维度必须对应起来
length, seq_t

seq_mask = seq_t >= length.unsqueeze(1)
mask = ~seq_mask
mask

odim = 256
idim = 80
# def embedding
conv1 = torch.nn.Conv2d(1, 256, 3, 2)
relu1 = torch.nn.ReLU()
conv2 = torch.nn.Conv2d(256, 256, 3, 2)

# x = x.unsqueeze(1).long()
x = x.unsqueeze(1).to(dtype=torch.float32)

x1 = conv1(x)
x1 = relu1(x1)
x2 = conv2(x1)
x2 = relu1(x2)

linear1 = torch.nn.Linear(odim * (((idim-1)//2 -1)//2), odim)
b, c, t, f = x2.size()
x = linear1(x2.transpose(1,2).contiguous().view(b, t, c*f))

conv1, conv2, x1.shape,x2.shape , linear1, x2.shape , x.shape

# encoder
layernorm1 = nn.LayerNorm((256), eps=1e-12)
layernorm2 = nn.LayerNorm((256), eps=1e-12)
dropout1 = nn.Dropout(0.1)
## ffn1 
drop1 = nn.Dropout(0.1)
linear1 = nn.Linear(256, 2048)
linear2 = nn.Linear(2048, 256)
activation1 = nn.ReLU()

residual = x
x1 = layernorm1(x)
x2 = residual + 0.5* linear2( drop1( activation1( linear1(x))))
x = x2

# muti_head_attn
x3 = layernorm2(x)

def generate_qkv(query, key, value):
    linear_q = nn.Linear(256, 256)
    linear_k = nn.Linear(256, 256)
    linear_v = nn.Linear(256, 256)
    
    n_head = 4
    n_batch = query.size(0)
    n_feat = query.size(-1)
    d_k = n_feat//n_head
    
    q = linear_q(query).view(n_batch, -1 , n_head, d_k)
    k = linear_k(key).view(n_batch, -1 , n_head, d_k)
    v = linear_v(value).view(n_batch, -1 , n_head, d_k)

    return q,k,v



# convolution module
channels = odim
resdual = x
x = nn.LayerNorm((256), eps=1e-12)(x)
pointwise1_conv = nn.Conv1d(channels, 2*channels, kernel_size=1, stride=1, padding=0, bias=True)
x = x.transpose(-1, -2)
x = pointwise1_conv(x)
x = nn.functional.glu(x, dim=1) #nn.functional中的应该都是阔约直接调用的函数
depthwise_conv = nn.Conv1d(
            channels,
            channels,
            15,
            stride=1,
            padding=7,
            groups=channels,
            bias=True
        )
x = depthwise_conv(x)
x = nn.ReLU()(x)
pointwise_conv2 = nn.Conv1d(
            channels,
            channels,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=True,
        )
x = pointwise_conv2(x)
x.shape

# print(x)
# 解码
# CTC
linear_ctc = nn.Linear(x.size(-1), 108)

x_ctc = linear_ctc(nn.functional.dropout(x, p=0.1))

x_ctc = x_ctc.transpose(0, 1)
# logsoftmax = nn.functional.softmax()
x_ctc = nn.functional.log_softmax(x_ctc, dim=-1)

ctc_loss_fn = nn.CTCLoss(blank=0, reduction="sum")
y_ref = torch.tensor([[12, 13, 1, 15, 16], 
                      [12, 13, 1, 15, -1], 
                      [12, 13, 1, 15, -1],
                      [12, 13, 1, 15, -1],
                      [12, 13, 1, 15, -1]])
x_true_len = torch.tensor([69, 60, 60 ,60 ,60])
print(x_ctc.shape, y_ref.shape, x_true_len.shape, length.shape)
ctc_loss = ctc_loss_fn(x_ctc, y_ref, x_true_len, length)
loss = ctc_loss / x_ctc.size(1)
x_ctc = x_ctc.transpose(0, 1)
print(loss)



for name, p in conv1.named_parameters():
    print(name, p.shape, p.numel())
for name, p in conv2.named_parameters():
    print(name, p.shape, p.numel())
posted @ 2023-12-10 00:37  365/24/60  阅读(84)  评论(0编辑  收藏  举报