9.7.2 解码器

训练时解码器使用目标句子作为输入,这样即使某一个时间步预测错了也不要紧,我们输入的目标句子一定是对的

repeat这个函数就是广播张量,但是具体机制好像很复杂,只解释书上那一句代码。现在X的形状是(num_steps,batch_size,embed_size),而在广播之前,context的形状是(batch_size,num_hiddens),由于context的形状跟X的后两维一样,所以repeat的后两个参数1,1就可以保证这两维不动,而第一个参数为X.shape[0],也就是将context在第零维广播,最后context的形状是(X.shape[0],batch_size,num_hiddens)

后面的self.densenn.Linear,但是现在却传入了一个三维的张量,会怎么样?
在 PyTorch 中,nn.Linear(num_hiddens, vocab_size) 的输入张量形状 可以灵活处理,即使输入形状是 (num_steps, batch_size, num_hiddens),也能正常工作。以下是详细解释:


关键机制

nn.Linear 的底层实现会自动处理输入张量的维度,它会将除最后一个维度外的所有维度视为“批量维度”,只对最后一个维度进行线性变换。

1. 输入形状为 (num_steps, batch_size, num_hiddens) 时:

  • 实际计算时,nn.Linear 会将输入隐式展平为 (num_steps * batch_size, num_hiddens)
  • 进行矩阵乘法:(num_steps * batch_size, num_hiddens) @ (num_hiddens, vocab_size) → (num_steps * batch_size, vocab_size)
  • 最终输出的形状会自动恢复为 (num_steps, batch_size, vocab_size)

2. 输出形状:

  • 输出张量形状为 (num_steps, batch_size, vocab_size),与输入的前两个维度保持一致。

代码示例

import torch
import torch.nn as nn

# 定义线性层
num_hiddens = 64
vocab_size = 10000
linear = nn.Linear(num_hiddens, vocab_size)

# 定义输入张量 (num_steps=5, batch_size=32, num_hiddens=64)
X = torch.randn(5, 32, 64)  # 3D 输入

# 通过线性层
output = linear(X)

# 检查输出形状
print(output.shape)  # 输出: torch.Size([5, 32, 10000])

应用场景

这种设计在处理序列数据时非常有用,例如:

  • 语言模型:输入是 (sequence_length, batch_size, hidden_size),需要将其映射到词汇表维度 (sequence_length, batch_size, vocab_size)
  • 时间序列预测:每个时间步的特征需要映射到目标维度。

为什么能自动处理多维输入?

PyTorch 的 nn.Linear 底层通过 torch.matmul 实现,其规则是:

  • 输入张量形状(*, in_features),其中 * 表示任意多个额外维度。
  • 输出张量形状(*, out_features),保持前面的维度不变,只变换最后一个维度。

注意事项

  1. 参数不匹配时会报错:如果输入张量的最后一个维度与 nn.Linearin_features 不一致(例如 X.shape[-1] != num_hiddens),会抛出维度错误。
  2. 手动展平:如果某些场景需要强制展平输入(例如合并时间步和批次),可以用 X.view(-1, num_hiddens),但大多数情况下无需手动操作。

总结

  • 输入形状(num_steps, batch_size, num_hiddens) → 合法。
  • 输出形状(num_steps, batch_size, vocab_size)
  • 底层逻辑:PyTorch 自动处理批量维度,仅对最后一个维度进行线性变换。

iinit_state这个函数返回的就是enc_outputs[1],因为enc_outputs接收的是一个元组(outputs,state),所以enc_outputs[1]就是state

posted @   最爱丁珰  阅读(4)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
· 阿里巴巴 QwQ-32B真的超越了 DeepSeek R-1吗?
· 【译】Visual Studio 中新的强大生产力特性
· 10年+ .NET Coder 心语 ── 封装的思维:从隐藏、稳定开始理解其本质意义
· 【设计模式】告别冗长if-else语句:使用策略模式优化代码结构
历史上的今天:
2024-02-13 连续攻击游戏
2024-02-13 捉迷藏
2024-02-13 路径覆盖与二分图匹配一一对应
2024-02-13 会议
点击右上角即可分享
微信分享提示