Fork me on GitHub

Coding Poineer

Coding Poineer

Coding Poineer

Coding Poineer

Coding Poineer

Coding Poineer

Coding Poineer

Coding Poineer

Coding Poineer

Coding Poineer

Coding Poineer

手写Conformer网络结构

import torch
from torch import nn
x = torch.randint(0, 10, size=(5, 280,80))
length = torch.tensor([10,9,9,9,9])
x.size(),x.shape,x[0].shape,length

# def mad_mask(length):
batch = length.size(0)
max_len = length.max().item()
# seq_t = torch.arange(0, max_len).unsqueeze(0).expand(batch, max_len)
seq_t = torch.arange(0, max_len).unsqueeze(0).repeat(5,1)     # 维度必须对应起来
length, seq_t

seq_mask = seq_t >= length.unsqueeze(1)
mask = ~seq_mask
mask

odim = 256
idim = 80
# def embedding
conv1 = torch.nn.Conv2d(1, 256, 3, 2)
relu1 = torch.nn.ReLU()
conv2 = torch.nn.Conv2d(256, 256, 3, 2)

# x = x.unsqueeze(1).long()
x = x.unsqueeze(1).to(dtype=torch.float32)

x1 = conv1(x)
x1 = relu1(x1)
x2 = conv2(x1)
x2 = relu1(x2)

linear1 = torch.nn.Linear(odim * (((idim-1)//2 -1)//2), odim)
b, c, t, f = x2.size()
x = linear1(x2.transpose(1,2).contiguous().view(b, t, c*f))

conv1, conv2, x1.shape,x2.shape , linear1, x2.shape , x.shape


# encoder
layernorm1 = nn.LayerNorm((256), eps=1e-12)
layernorm2 = nn.LayerNorm((256), eps=1e-12)
dropout1 = nn.Dropout(0.1)
## ffn1 
drop1 = nn.Dropout(0.1)
linear1 = nn.Linear(256, 2048)
linear2 = nn.Linear(2048, 256)
activation1 = nn.ReLU()

residual = x
x1 = layernorm1(x)
x2 = residual + 0.5* linear2( drop1( activation1( linear1(x))))
x = x2

# muti_head_attn
x3 = layernorm2(x)

def generate_qkv(query, key, value):
    linear_q = nn.Linear(256, 256)
    linear_k = nn.Linear(256, 256)
    linear_v = nn.Linear(256, 256)
    
    n_head = 4
    n_batch = query.size(0)
    n_feat = query.size(-1)
    d_k = n_feat//n_head
    
    q = linear_q(query).view(n_batch, -1 , n_head, d_k)
    k = linear_k(key).view(n_batch, -1 , n_head, d_k)
    v = linear_v(value).view(n_batch, -1 , n_head, d_k)

    
    return q,k,v   
# convolution module
channels = odim
resdual = x
print(x.shape)
x = nn.LayerNorm((256), eps=1e-12)(x)
pointwise1_conv = nn.Conv1d(channels, 2*channels, kernel_size=1, stride=1, padding=0, bias=True)
x = x.transpose(-1, -2)
x = pointwise1_conv(x)
# x = nn.functional.glu(x, dim=1) #nn.functional中的应该都是阔约直接调用的函数
# depthwise_conv = nn.Conv1d(
#             channels,
#             channels,
#             15,
#             stride=1,
#             padding=7,
#             groups=channels,
#             bias=True
#         )
# x = depthwise_conv(x)
# x = nn.ReLU(x)
# pointwise_conv2 = nn.Conv1d(
#             channels,
#             channels,
#             kernel_size=1,
#             stride=1,
#             padding=0,
#             bias=True,
#         )
# x = pointwise_conv2(x)

查看参数


import torch
from torch import nn
x = torch.randint(0, 10, size=(5, 280,80))
length = torch.tensor([5,4,4,4,4])
x.size(),x.shape,x[0].shape,length

# def mad_mask(length):
batch = length.size(0)
max_len = length.max().item()
# seq_t = torch.arange(0, max_len).unsqueeze(0).expand(batch, max_len)
seq_t = torch.arange(0, max_len).unsqueeze(0).repeat(5,1)     # 维度必须对应起来
length, seq_t

seq_mask = seq_t >= length.unsqueeze(1)
mask = ~seq_mask
mask

odim = 256
idim = 80
# def embedding
conv1 = torch.nn.Conv2d(1, 256, 3, 2)
relu1 = torch.nn.ReLU()
conv2 = torch.nn.Conv2d(256, 256, 3, 2)

# x = x.unsqueeze(1).long()
x = x.unsqueeze(1).to(dtype=torch.float32)

x1 = conv1(x)
x1 = relu1(x1)
x2 = conv2(x1)
x2 = relu1(x2)

linear1 = torch.nn.Linear(odim * (((idim-1)//2 -1)//2), odim)
b, c, t, f = x2.size()
x = linear1(x2.transpose(1,2).contiguous().view(b, t, c*f))

conv1, conv2, x1.shape,x2.shape , linear1, x2.shape , x.shape

# encoder
layernorm1 = nn.LayerNorm((256), eps=1e-12)
layernorm2 = nn.LayerNorm((256), eps=1e-12)
dropout1 = nn.Dropout(0.1)
## ffn1 
drop1 = nn.Dropout(0.1)
linear1 = nn.Linear(256, 2048)
linear2 = nn.Linear(2048, 256)
activation1 = nn.ReLU()

residual = x
x1 = layernorm1(x)
x2 = residual + 0.5* linear2( drop1( activation1( linear1(x))))
x = x2

# muti_head_attn
x3 = layernorm2(x)

def generate_qkv(query, key, value):
    linear_q = nn.Linear(256, 256)
    linear_k = nn.Linear(256, 256)
    linear_v = nn.Linear(256, 256)
    
    n_head = 4
    n_batch = query.size(0)
    n_feat = query.size(-1)
    d_k = n_feat//n_head
    
    q = linear_q(query).view(n_batch, -1 , n_head, d_k)
    k = linear_k(key).view(n_batch, -1 , n_head, d_k)
    v = linear_v(value).view(n_batch, -1 , n_head, d_k)

    return q,k,v



# convolution module
channels = odim
resdual = x
x = nn.LayerNorm((256), eps=1e-12)(x)
pointwise1_conv = nn.Conv1d(channels, 2*channels, kernel_size=1, stride=1, padding=0, bias=True)
x = x.transpose(-1, -2)
x = pointwise1_conv(x)
x = nn.functional.glu(x, dim=1) #nn.functional中的应该都是阔约直接调用的函数
depthwise_conv = nn.Conv1d(
            channels,
            channels,
            15,
            stride=1,
            padding=7,
            groups=channels,
            bias=True
        )
x = depthwise_conv(x)
x = nn.ReLU()(x)
pointwise_conv2 = nn.Conv1d(
            channels,
            channels,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=True,
        )
x = pointwise_conv2(x)
x.shape

# print(x)
# 解码
# CTC
linear_ctc = nn.Linear(x.size(-1), 108)

x_ctc = linear_ctc(nn.functional.dropout(x, p=0.1))

x_ctc = x_ctc.transpose(0, 1)
# logsoftmax = nn.functional.softmax()
x_ctc = nn.functional.log_softmax(x_ctc, dim=-1)

ctc_loss_fn = nn.CTCLoss(blank=0, reduction="sum")
y_ref = torch.tensor([[12, 13, 1, 15, 16], 
                      [12, 13, 1, 15, -1], 
                      [12, 13, 1, 15, -1],
                      [12, 13, 1, 15, -1],
                      [12, 13, 1, 15, -1]])
x_true_len = torch.tensor([69, 60, 60 ,60 ,60])
print(x_ctc.shape, y_ref.shape, x_true_len.shape, length.shape)
ctc_loss = ctc_loss_fn(x_ctc, y_ref, x_true_len, length)
loss = ctc_loss / x_ctc.size(1)
x_ctc = x_ctc.transpose(0, 1)
print(loss)



for name, p in conv1.named_parameters():
    print(name, p.shape, p.numel())
for name, p in conv2.named_parameters():
    print(name, p.shape, p.numel())
posted @   365/24/60  阅读(112)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· winform 绘制太阳,地球,月球 运作规律
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人
历史上的今天:
2019-12-10 构建调试Linux内核网络的环境Menuos系统
点击右上角即可分享
微信分享提示