py深度学习指南
常用函数
- 获取当前运行目录(类似c++)
import sys
curent_dir = sys.argv[0]
- 在Python中,每个模块都有一个特殊的属性__name__,它指向该模块的名称。当一个模块被导入时,__name__将被设置为该模块的名称,而当一个模块作为脚本直接执行时,__name__将被设置为字符串"main"
- 它的存在意义在于判断当前模块是否被直接运行,而不是被作为模块被其他模块导入。
if __name__ == "__main__":
- 改变当前工作路径
os.chdir(path)
- 模型保存与读取
import torch
# 保存模型步骤
torch.save(model, 'net.pth') # 保存整个神经网络的模型结构以及参数
torch.save(model, 'net.pkl') # 同上
torch.save(model.state_dict(), 'net_params.pth') # 只保存模型参数
torch.save(model.state_dict(), 'net_params.pkl') # 同上
# 加载模型步骤
model = torch.load('net.pth') # 加载整个神经网络的模型结构以及参数
model = torch.load('net.pkl') # 同上
model.load_state_dict(torch.load('net_params.pth')) # 仅加载参数
model.load_state_dict(torch.load('net_params.pkl')) # 同上
- pip 官源
-i https://pypi.org/simple
- gtx1650 pytorch
pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116
- 读取类别文件夹
- 有以下成员变量
- self.classes:用一个 list 保存类别名称
- self.class_to_idx:类别对应的索引,与不做任何转换返回的 target 对应
- self.imgs:保存(img-path, class) tuple的 list
- 有以下成员变量
torchvision.datasets.ImageFolder(root=os.path.join(new_data_dir, 'train_valid'), transform=transform_train)
- 图片迭代及参数
valid_iter = torch.utils.data.DataLoader(valid_ds, batch_size=batch_size, shuffle=True)
dataset (Dataset) – 加载数据的数据集。
batch_size (int, optional) – 每个batch加载多少个样本(默认: 1)。
shuffle (bool, optional) – 设置为True时会在每个epoch重新打乱数据(默认: False).
sampler (Sampler, optional) – 定义从数据集中提取样本的策略,即生成index的方式,可以顺序也可以乱序
num_workers (int, optional) – 用多少个子进程加载数据。0表示数据将在主进程中加载(默认: 0)
collate_fn (callable, optional) –将一个batch的数据和标签进行合并操作。
pin_memory (bool, optional) –设置pin_memory=True,则意味着生成的Tensor数据最开始是属于内存中的锁页内存,这样将内存的Tensor转义到GPU的显存就会更快一些。
drop_last (bool, optional) – 如果数据集大小不能被batch size整除,则设置为True后可删除最后一个不完整的batch。如果设为False并且数据集的大小不能被batch size整除,则最后一个batch将更小。(默认: False)
timeout,是用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错
查阅文档
- 当我们想知道一个模块里面提供了哪些可以调用的函数和类的时候,可以使用 dir 函数。下面我们打印 nd.random 模块中所有的成员或属性
from mxnet import nd
print(dir(nd.random))
['NDArray', '_Null', '__all__', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__spec__', '_internal', '_random_helper', 'current_context', 'exponential', 'exponential_like', 'gamma', 'gamma_like', 'generalized_negative_binomial', 'generalized_negative_binomial_like', 'multinomial', 'negative_binomial', 'negative_binomial_like', 'normal', 'normal_like', 'numeric_types', 'poisson', 'poisson_like', 'randint', 'randn', 'shuffle', 'uniform', 'uniform_like']
查找特定函数和类的使用
- 想了解某个函数或者类的具体用法时,可以使用 help 函数。让我们以NDArray中的ones_like函数为例,查阅它的用法。
help(nd.ones_like)
stdout重定向
import sys
# 将print输出重定向到文件
sys.stdout = open('output.txt', 'w')
print('Hello, World!')
print('This is a test.')
# 恢复标准输出
sys.stdout = sys.__stdout__
获取当前目录下的所有文件
import os
# 获取当前目录下的所有文件和文件夹
files = os.listdir('.')
# 遍历所有文件和文件夹
for file in files:
# 判断是否为文件
if os.path.isfile(file):
# 打印文件名
print(file)
本文来自博客园,作者:InsiApple,转载请注明原文链接:https://www.cnblogs.com/InsiApple/p/17300302.html