9.7.2 解码器
训练时解码器使用目标句子作为输入,这样即使某一个时间步预测错了也不要紧,我们输入的目标句子一定是对的
repeat
这个函数就是广播张量,但是具体机制好像很复杂,只解释书上那一句代码。现在X
的形状是(num_steps,batch_size,embed_size)
,而在广播之前,context
的形状是(batch_size,num_hiddens)
,由于context
的形状跟X
的后两维一样,所以repeat
的后两个参数1,1
就可以保证这两维不动,而第一个参数为X.shape[0]
,也就是将context
在第零维广播,最后context
的形状是(X.shape[0],batch_size,num_hiddens)
后面的self.dense
是nn.Linear
,但是现在却传入了一个三维的张量,会怎么样?
在 PyTorch 中,nn.Linear(num_hiddens, vocab_size)
的输入张量形状 可以灵活处理,即使输入形状是 (num_steps, batch_size, num_hiddens)
,也能正常工作。以下是详细解释:
关键机制
nn.Linear
的底层实现会自动处理输入张量的维度,它会将除最后一个维度外的所有维度视为“批量维度”,只对最后一个维度进行线性变换。
1. 输入形状为 (num_steps, batch_size, num_hiddens)
时:
- 实际计算时,
nn.Linear
会将输入隐式展平为(num_steps * batch_size, num_hiddens)
。 - 进行矩阵乘法:
(num_steps * batch_size, num_hiddens) @ (num_hiddens, vocab_size) → (num_steps * batch_size, vocab_size)
。 - 最终输出的形状会自动恢复为
(num_steps, batch_size, vocab_size)
。
2. 输出形状:
- 输出张量形状为
(num_steps, batch_size, vocab_size)
,与输入的前两个维度保持一致。
代码示例
import torch
import torch.nn as nn
# 定义线性层
num_hiddens = 64
vocab_size = 10000
linear = nn.Linear(num_hiddens, vocab_size)
# 定义输入张量 (num_steps=5, batch_size=32, num_hiddens=64)
X = torch.randn(5, 32, 64) # 3D 输入
# 通过线性层
output = linear(X)
# 检查输出形状
print(output.shape) # 输出: torch.Size([5, 32, 10000])
应用场景
这种设计在处理序列数据时非常有用,例如:
- 语言模型:输入是
(sequence_length, batch_size, hidden_size)
,需要将其映射到词汇表维度(sequence_length, batch_size, vocab_size)
。 - 时间序列预测:每个时间步的特征需要映射到目标维度。
为什么能自动处理多维输入?
PyTorch 的 nn.Linear
底层通过 torch.matmul
实现,其规则是:
- 输入张量形状:
(*, in_features)
,其中*
表示任意多个额外维度。 - 输出张量形状:
(*, out_features)
,保持前面的维度不变,只变换最后一个维度。
注意事项
- 参数不匹配时会报错:如果输入张量的最后一个维度与
nn.Linear
的in_features
不一致(例如X.shape[-1] != num_hiddens
),会抛出维度错误。 - 手动展平:如果某些场景需要强制展平输入(例如合并时间步和批次),可以用
X.view(-1, num_hiddens)
,但大多数情况下无需手动操作。
总结
- 输入形状:
(num_steps, batch_size, num_hiddens)
→ 合法。 - 输出形状:
(num_steps, batch_size, vocab_size)
。 - 底层逻辑:PyTorch 自动处理批量维度,仅对最后一个维度进行线性变换。
iinit_state
这个函数返回的就是enc_outputs[1]
,因为enc_outputs
接收的是一个元组(outputs,state)
,所以enc_outputs[1]
就是state
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
· 阿里巴巴 QwQ-32B真的超越了 DeepSeek R-1吗?
· 【译】Visual Studio 中新的强大生产力特性
· 10年+ .NET Coder 心语 ── 封装的思维:从隐藏、稳定开始理解其本质意义
· 【设计模式】告别冗长if-else语句:使用策略模式优化代码结构
2024-02-13 连续攻击游戏
2024-02-13 捉迷藏
2024-02-13 路径覆盖与二分图匹配一一对应
2024-02-13 会议