随笔分类 - Pytorch
摘要:参考一 浅谈 PyTorch 中的 tensor 及使用 该博文分为以下6个部分: tensor.requires_grad torch.no_grad() 反向传播及网络的更新 tensor.detach() CPU and GPU tensor.item() torch.detach()和tor
阅读全文
摘要:模型和数据可以在CPU和GPU上来回迁移,怎么判断模型和数据在哪里呢? import torch import torch.nn as nn # 判断模型是在CPU还是GPU上 model = nn.LSTM(input_size=10, hidden_size=4, num_layers=1, b
阅读全文
摘要:Python自带的random库 例如: Python产生一个数值范围内的不重复的随机数,可以使用random模块中的random.sample函数。例如从0~99中,随机取10个不重复的数: random.sample(range(100), 10) numpy的random库 np.random
阅读全文
摘要:MSE是mean squared error的缩写,即平均平方误差,简称均方误差。 MSE是逐元素计算的,计算公式为: 旧版的nn.MSELoss()函数有reduce、size_average两个参数,新版的只有一个reduction参数了,功能是一样的。reduction的意思是维度要不要缩减,
阅读全文
摘要:PyTorch 训练 RNN 时,序列长度不固定怎么办? pytorch中如何在lstm中输入可变长的序列 上面两篇文章写得很好,把LSTM中训练变长序列所需的三个函数讲解的很清晰,但是这两篇文章没有给出完整的训练代码,并且没有写关于带label的情况,为此,本文给出一个完整的带label的训练代码
阅读全文
摘要:一. 普通全连接神经网络的计算过程 假设用全连接神经网络做MNIST手写数字分类,节点数如下: 第一层是输入层,784个节点; 第二层是隐层,100个节点; 第三层是输出层,10个节点。 对于这个神经网络,我们在脑海里浮现的可能是类似这样的画面: 但实际上,神经网络的计算过程,本质上是输入向量(矩阵
阅读全文
摘要:Pytorch之permute函数
阅读全文
摘要:torch.nn.lstm()接受的数据输入是(序列长度,batchsize,输入维数),使用batch_first=True,可以使lstm接受维度为(batchsize,序列长度,输入维数)的数据输入,同时,lstm的输出数据维度也会变为batchsize放在第一维(可参考这篇博客)。
阅读全文
摘要:在用PyTorch保存模型时,常常会遇到UserWarning: Couldn't retrieve source code for container of type Net. It won't be checked for correctness upon loading."type " + o
阅读全文
摘要:airflights passengers dataset下载地址https://raw.githubusercontent.com/jbrownlee/Datasets/master/airline-passengers.csv 这个dataset包含从1949年到1960年每个月的航空旅客数目,
阅读全文
摘要:参考一: PyTorch中LSTM的输出格式 该文章的核心内容截图如下: 总的结论: 注意:如果在搭建lstm网络时使用了batch_first=True,则lstm网络不仅接受的数据第一维是batch,而且输出的结果中,batch也会在第一维,即 output's shape (batch, se
阅读全文
摘要:torch.nn.Module()类有一些重要属性,我们可用其下面几个属性来实现对神经网络层结构的提取: torch.nn.Module.children() torch.nn.Module.modules() torch.nn.Module.named_children() torch.nn.Mo
阅读全文
摘要:Pytorch:利用预训练好的VGG16网络提取图片特征
阅读全文
摘要:在PyTorch自定义数据集中,我们介绍了如何通过重写Dataset类来自定义数据集,但其实对于图像数据,自定义数据集有一个更简单的方法,那就是直接调用ImageFolder,它是torchvision.datasets里的函数。 ImageFolder介绍 ImageFolder假设所有的文件按文
阅读全文
摘要:PyTorch中Dataset, DataLoader, Sampler的关系可用下图概括: 用文字表达就是:Dataloader中包含Sampler和Dataset,Sampler产生索引,Dataset拿着这个索引在数据集文件夹中找到对应的样本(每个样本对应一个索引,就像列表中每个元素对应一个索
阅读全文
摘要:数据传递机制 我们首先回顾识别手写数字的程序: ... Dataset = torchvision.datasets.MNIST(root='./mnist/', train=True, transform=transform, download=True,) dataloader = torch.
阅读全文
摘要:保存和加载模型 在PyTorch中使用torch.save来保存模型的结构和参数,有两种保存方式: # 方式一:保存模型的结构信息和参数信息 torch.save(model, './model.pth') # 方式二:仅保存模型的参数信息 torch.save(model.state_dict()
阅读全文
摘要:转:PyTorch 中,nn 与 nn.functional 有什么区别?
阅读全文
摘要:PyTorch有多种方法搭建神经网络,下面识别手写数字为例,介绍4种搭建神经网络的方法。 方法一:torch.nn.Sequential() torch.nn.Sequential类是torch.nn中的一种序列容器,参数会按照我们定义好的序列自动传递下去。 import torch.nn as n
阅读全文
摘要:torch.mm(mat1, mat2) performs a matrix multiplication of mat1 and mat2 a = torch.randint(0, 5, (2, 3)) # tensor([[3, 3, 2], # [2, 2, 2]]) b = torch.ra
阅读全文