10.5. 多头注意力 — 动手学深度学习 2.0.0 documentation (d2l.ai)
Multi-Head Attention | 算法 + 代码_哔哩哔哩_bilibili
代码实现
x[1,4,2] 1几个样本(句子) 4 预测步长(4个单词) 2每个单词的编码后特征长度
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 | from math import sqrt import torch import torch.nn as nn class MultiHeadSelfAttention(nn.Module): def __init__( self , dim_in, d_model, num_heads = 3 ): super (MultiHeadSelfAttention, self ).__init__() self .dim_in = dim_in self .d_model = d_model self .num_heads = num_heads # 维度必须能被num_head 整除 assert d_model % num_heads = = 0 , "d_model must be multiple of num_heads" # 定义线性变换矩阵 self .linear_q = nn.Linear(dim_in, d_model) self .linear_k = nn.Linear(dim_in, d_model) self .linear_v = nn.Linear(dim_in, d_model) self .scale = 1 / sqrt(d_model / / num_heads) # 最后的线性层 self .fc = nn.Linear(d_model, d_model) def forward( self , x): # x: tensor of shape (batch, n, dim_in) batch, n, dim_in = x.shape assert dim_in = = self .dim_in nh = self .num_heads dk = self .d_model / / nh # dim_k of each head q = self .linear_q(x).reshape(batch, n, nh, dk).transpose( 1 , 2 ) # (batch, nh, n, dk) k = self .linear_k(x).reshape(batch, n, nh, dk).transpose( 1 , 2 ) # (batch, nh, n, dk) v = self .linear_v(x).reshape(batch, n, nh, dk).transpose( 1 , 2 ) # (batch, nh, n, dk) dist = torch.matmul(q, k.transpose( 2 , 3 )) * self .scale # batch, nh, n, n dist = torch.softmax(dist, dim = - 1 ) # batch, nh, n, n att = torch.matmul(dist, v) # batch, nh, n, dv att = att.transpose( 1 , 2 ).reshape(batch, n, self .d_model) # batch, n, dim_v # 最后通过一个线性层进行变换 output = self .fc(att) return output x = torch.rand(( 1 , 4 , 2 )) multi_head_att = MultiHeadSelfAttention(x.shape[ 2 ], 6 , 3 ) # (6, 3) output = multi_head_att(x) |
101. 101 - 101 Multi-head的作用_哔哩哔哩_bilibili
最后一层处理下 压缩下维度
分类:
1_4pytorch
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 单元测试从入门到精通
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)
2018-10-23 收藏
2018-10-23 當 Alexa 遇上 ESP8266 (一)
2018-10-23 modbus与rs485的关系_modbus与rs485的区别和联系
2018-10-23 UART\RS232与RS485的关系
2018-10-23 RS-485总线通信协议
2017-10-23 YOLO2 (2) 测试自己的数据
2017-10-23 Ubuntu 14.04服务器配置 (1) 安装和配置