LSTM的隐藏状态和细胞状态
class LSTMModel(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, output_size):
super(LSTMModel, self).__init__()
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, output_size) # 全连接层用于输出预测
def forward(self, x):
# 初始化隐藏状态和细胞状态
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
# 隐藏状态和细胞状态还可以是随机的
# h0 = torch.randn(self.num_layers, x.size(0), self.hidden_size).to(x.device)
# c0 = torch.randn(self.num_layers, x.size(0), self.hidden_size).to(x.device)
# 假设输入x的形状是 (batch_size, seq_length, input_size)
lstm_out = self.lstm(x, (h0, c0))
# lstm_out 形状为 (batch_size, seq_length, hidden_size)
# 取最后一个时间步的输出作为全连接层的输入
out = lstm_out[:, -1, :]
out = self.fc(out) # 形状变为 (batch_size, output_size)
return out
forward
方法内部初始化了两个张量 h0
和 c0
分别代表隐藏状态和细胞状态。这两个状态都设置为零张量,并且它们的维度是根据 num_layers
、batch_size
和 hidden_size
来确定的。这确保了对于每一个新序列,LSTM 都会从零状态开始处理,避免了不同序列之间潜在的状态混淆。
具体来说:
- h0 是最顶层(即最后一层)LSTM 的初始隐藏状态,形状为 (num_layers * num_directions, batch_size, hidden_size)。
- c0 是最顶层(即最后一层)LSTM 的初始细胞状态,形状与 h0 相同。
当调用 self.lstm(x, hidden)
时,hidden
就是这个元组 (h0, c0)
,它告诉 LSTM 使用什么作为序列开始时的内部状态。如果未提供 hidden
(即传递 None
),那么默认情况下所有隐藏状态和细胞状态都会被初始化为零张量。
def forward(self, x, hidden=None):
# 如果没有提供hidden,则初始化为全零张量
if hidden is None:
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
hidden = (h0, c0)
# 假设输入x的形状是 (batch_size, seq_length, input_size)
lstm_out, (hidden_state, cell_state) = self.lstm(x, hidden)
# lstm_out 形状为 (batch_size, seq_length, hidden_size)
# 取最后一个时间步的输出作为全连接层的输入
out = lstm_out[:, -1, :]
out = self.fc(out) # 形状变为 (batch_size, output_size)
return out, (hidden_state, cell_state)
如果在连续的序列间保留隐藏状态(例如,在生成任务或某些类型的序列预测任务中),则不应该每次都重新初始化这些状态,而是应该将前一个序列的最终状态传递给下一个序列作为初始状态。
def forward(self, x):
# 假设输入x的形状是 (batch_size, seq_length, input_size)
lstm_out= self.lstm(x)
# lstm_out 形状为 (batch_size, seq_length, hidden_size)
# 取最后一个时间步的输出作为全连接层的输入
out = lstm_out[:, -1, :]
out = self.fc(out) # 形状变为 (batch_size, output_size)
return out
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· winform 绘制太阳,地球,月球 运作规律
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人
2023-01-19 ProcessPoolExecutor in concurrent
2023-01-19 多进程 multiprocessing in Python
2023-01-19 多线程threading in Python