卷积lstm论文
论文题目:Convolutional LSTM Network: A Machine LearningApproach for Precipitation Nowcasting
参考的学习文章:https://blog.csdn.net/m0_64557752/article/details/125882525
文章出处:https://arxiv.org/abs/1506.04214v1
代码:https://github.com/xinxuann/ConvLSTM_pytorch
点击查看代码
import torch import torch.nn as nn from torch.autograd import Variable class ConvLSTMCell(nn.Module): def __init__(self, input_channels, hidden_channels, kernel_size): super(ConvLSTMCell, self).__init__() assert hidden_channels % 2 == 0 self.input_channels = input_channels self.hidden_channels = hidden_channels self.kernel_size = kernel_size self.num_features = 4 self.padding = int((kernel_size - 1) / 2) self.Wxi = nn.Conv2d(self.input_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=True) self.Whi = nn.Conv2d(self.hidden_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=False) self.Wxf = nn.Conv2d(self.input_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=True) self.Whf = nn.Conv2d(self.hidden_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=False) self.Wxc = nn.Conv2d(self.input_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=True) self.Whc = nn.Conv2d(self.hidden_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=False) self.Wxo = nn.Conv2d(self.input_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=True) self.Who = nn.Conv2d(self.hidden_channels, self.hidden_channels, self.kernel_size, 1, self.padding, bias=False) self.Wci = None self.Wcf = None self.Wco = None def forward(self, x, h, c): ci = torch.sigmoid(self.Wxi(x) + self.Whi(h) + c * self.Wci) cf = torch.sigmoid(self.Wxf(x) + self.Whf(h) + c * self.Wcf) cc = cf * c + ci * torch.tanh(self.Wxc(x) + self.Whc(h)) co = torch.sigmoid(self.Wxo(x) + self.Who(h) + cc * self.Wco) ch = co * torch.tanh(cc) return ch, cc def init_hidden(self, batch_size, hidden, shape): if self.Wci is None: self.Wci = nn.Parameter(torch.zeros(1, hidden, shape[0], shape[1])).cuda() self.Wcf = nn.Parameter(torch.zeros(1, hidden, shape[0], shape[1])).cuda() self.Wco = nn.Parameter(torch.zeros(1, hidden, shape[0], shape[1])).cuda() else: assert shape[0] == self.Wci.size()[2], 'Input Height Mismatched!' assert shape[1] == self.Wci.size()[3], 'Input Width Mismatched!' return (Variable(torch.zeros(batch_size, hidden, shape[0], shape[1])).cuda(), Variable(torch.zeros(batch_size, hidden, shape[0], shape[1])).cuda()) class ConvLSTM(nn.Module): # input_channels corresponds to the first input feature map # hidden state is a list of succeeding lstm layers. def __init__(self, input_channels, hidden_channels, kernel_size, step=1, effective_step=[1]): super(ConvLSTM, self).__init__() self.input_channels = [input_channels] + hidden_channels self.hidden_channels = hidden_channels self.kernel_size = kernel_size self.num_layers = len(hidden_channels) self.step = step self.effective_step = effective_step self._all_layers = [] for i in range(self.num_layers): name = 'cell{}'.format(i) cell = ConvLSTMCell(self.input_channels[i], self.hidden_channels[i], self.kernel_size) setattr(self, name, cell) self._all_layers.append(cell) def forward(self, input): internal_state = [] outputs = [] for step in range(self.step): x = input for i in range(self.num_layers): # all cells are initialized in the first step name = 'cell{}'.format(i) if step == 0: bsize, _, height, width = x.size() (h, c) = getattr(self, name).init_hidden(batch_size=bsize, hidden=self.hidden_channels[i], shape=(height, width)) internal_state.append((h, c)) # do forward (h, c) = internal_state[i] x, new_c = getattr(self, name)(x, h, c) internal_state[i] = (x, new_c) # only record effective steps if step in self.effective_step: outputs.append(x) return outputs, (x, new_c) if __name__ == '__main__': # gradient check convlstm = ConvLSTM(input_channels=512, hidden_channels=[128, 64, 64, 32, 32], kernel_size=3, step=5, effective_step=[4]).cuda() loss_fn = torch.nn.MSELoss() input = Variable(torch.randn(1, 512, 64, 32)).cuda() target = Variable(torch.randn(1, 32, 64, 32)).double().cuda() output = convlstm(input) output = output[0][0].double() res = torch.autograd.gradcheck(loss_fn, (output, target), eps=1e-6, raise_exception=True) print(res)
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 阿里最新开源QwQ-32B,效果媲美deepseek-r1满血版,部署成本又又又降低了!
· 单线程的Redis速度为什么快?
· SQL Server 2025 AI相关能力初探
· AI编程工具终极对决:字节Trae VS Cursor,谁才是开发者新宠?
· 展开说说关于C#中ORM框架的用法!