深度学习(RNN,LSTM,GRU)

三个网络的架构图:

RNN:

LSTM:

GRU:

特性对比列表:

特性
RNN
LSTM
GRU
门控数量
3门(输入/遗忘/输出)
2门(更新/重置)
记忆机制
 仅隐藏状态ht
显式状态Ct + 隐藏状态ht
隐式记忆(通过门控更新状态)
核心操作
 直接状态传递
门控细胞状态更新 
门控候选状态混合 
计算复杂度
O(d2)(1组权重)
 O(4d2)(4 组权重)
 O(3d2)(3 组权重)
长期依赖学习
差(<10步)
优秀(>1000步)
良好(~100步)
梯度消失问题
严重
显著缓解
较好缓解
参数数量
最少
最多(3倍于RNN)
中等(2倍于RNN)
训练速度
最快
最慢
较快
过拟合风险
中等
典型应用场景
简单序列分类
机器翻译/语音识别
文本生成/时间序列预测
下面是两个例子:

一、LSTM识别数字:

复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms,datasets

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)

class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, num_classes)
    
    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) 
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
        out, _ = self.lstm(x, (h0, c0))  
        out = self.fc(out[:, -1, :])
        return out

sequence_length = 28
input_size = 28  
hidden_size = 128
num_layers = 2
num_classes = 10

model =RNN( input_size, hidden_size, num_layers, num_classes).to(device)
model.train()

optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

num_epochs = 10
for epoch in range(num_epochs):

    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in train_loader:
        B,_,_,_= images.shape  
        images = images.reshape(B,sequence_length,input_size)

        images = images.to(device)
        labels = labels.to(device)
 
        output = model(images)
        loss = criterion(output, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(output.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}, Accuracy: {(100 * correct / total):.2f}%")
复制代码

二、GRU数据拟合:

复制代码
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

class RNN(nn.Module):
    def __init__(self):
        super().__init__()
       # self.rnn=nn.RNN(input_size=1,hidden_size=128,num_layers=1,batch_first=True)
       # self.rnn=nn.LSTM(input_size=1,hidden_size=128,num_layers=1,batch_first=True)
        self.rnn=nn.GRU(input_size=1,hidden_size=128,num_layers=1,batch_first=True)
        self.linear=nn.Linear(128,1)
        
    def forward(self,x):
        output,_=self.rnn(x)
        x=self.linear(output)
        return x

if __name__ == '__main__':

    x = torch.linspace(-300,300,1000)*0.01 
    y = torch.sin(x*3.0) + torch.linspace(-300,300,1000)*0.01
    plt.plot(x, y,'r')  

    x = x.unsqueeze(1).cuda()
    y = y.unsqueeze(1).cuda()
  
    model=RNN().cuda()
    optimizer=torch.optim.Adam(model.parameters(),lr=5e-4)
    criterion = nn.MSELoss().cuda()

    for epoch in range(5000):
        preds=model(x)
        loss=criterion(preds,y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print('loss',epoch, loss.item())

    x = torch.linspace(-300,310,1000)*0.01
    x = x.unsqueeze(1).cuda()
    pred = model(x)
    plt.plot(x.cpu().detach().numpy(), pred.cpu().detach().numpy(),'b')
    plt.show()
复制代码

参考:Understanding LSTM Networks -- colah's blog

posted @   Dsp Tian  阅读(41)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
· Manus的开源复刻OpenManus初探
· AI 智能体引爆开源社区「GitHub 热点速览」
· 三行代码完成国际化适配,妙~啊~
· .NET Core 中如何实现缓存的预热?
历史上的今天:
2022-02-03 matlab练习程序(Stanley路径跟踪)
2022-02-03 matlab练习程序(PID路径跟踪)
2019-02-03 matlab练习程序(局部加权线性回归)
2012-02-03 一组区间中交集最多的一个
2012-02-03 找出10个被打乱的数中被拿出的一个数
点击右上角即可分享
微信分享提示