LSTM的隐藏状态和细胞状态

class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super(LSTMModel, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)  # 全连接层用于输出预测

    def forward(self, x):
        # 初始化隐藏状态和细胞状态
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        
        # 隐藏状态和细胞状态还可以是随机的
        # h0 = torch.randn(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        # c0 = torch.randn(self.num_layers, x.size(0), self.hidden_size).to(x.device)        

        # 假设输入x的形状是 (batch_size, seq_length, input_size)
        lstm_out = self.lstm(x, (h0, c0))
        
        # lstm_out 形状为 (batch_size, seq_length, hidden_size)
        # 取最后一个时间步的输出作为全连接层的输入
        out = lstm_out[:, -1, :]
        out = self.fc(out)  # 形状变为 (batch_size, output_size)
        return out

forward 方法内部初始化了两个张量 h0c0 分别代表隐藏状态和细胞状态。这两个状态都设置为零张量,并且它们的维度是根据 num_layersbatch_sizehidden_size 来确定的。这确保了对于每一个新序列,LSTM 都会从零状态开始处理,避免了不同序列之间潜在的状态混淆。

具体来说:

  • h0 是最顶层(即最后一层)LSTM 的初始隐藏状态,形状为 (num_layers * num_directions, batch_size, hidden_size)。
  • c0 是最顶层(即最后一层)LSTM 的初始细胞状态,形状与 h0 相同。

当调用 self.lstm(x, hidden) 时,hidden 就是这个元组 (h0, c0),它告诉 LSTM 使用什么作为序列开始时的内部状态。如果未提供 hidden(即传递 None),那么默认情况下所有隐藏状态和细胞状态都会被初始化为零张量。

    def forward(self, x, hidden=None):
        # 如果没有提供hidden,则初始化为全零张量
        if hidden is None:
            h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
            c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
            hidden = (h0, c0)

        # 假设输入x的形状是 (batch_size, seq_length, input_size)
        lstm_out, (hidden_state, cell_state) = self.lstm(x, hidden)
        
        # lstm_out 形状为 (batch_size, seq_length, hidden_size)
        # 取最后一个时间步的输出作为全连接层的输入
        out = lstm_out[:, -1, :]
        out = self.fc(out)  # 形状变为 (batch_size, output_size)
        return out, (hidden_state, cell_state)

如果在连续的序列间保留隐藏状态(例如,在生成任务或某些类型的序列预测任务中),则不应该每次都重新初始化这些状态,而是应该将前一个序列的最终状态传递给下一个序列作为初始状态。

    def forward(self, x):
        # 假设输入x的形状是 (batch_size, seq_length, input_size)
        lstm_out= self.lstm(x)
        
        # lstm_out 形状为 (batch_size, seq_length, hidden_size)
        # 取最后一个时间步的输出作为全连接层的输入
        out = lstm_out[:, -1, :]
        out = self.fc(out)  # 形状变为 (batch_size, output_size)
        return out
posted @   华小电  阅读(64)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· winform 绘制太阳,地球,月球 运作规律
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人
历史上的今天:
2023-01-19 ProcessPoolExecutor in concurrent
2023-01-19 多进程 multiprocessing in Python
2023-01-19 多线程threading in Python
点击右上角即可分享
微信分享提示