PyTorch - Hub 模块


什么是 hub

hub(modelzoo)主要用来调用其他人训练好的模型和参数
Facebook官方博客表示,PyTorch Hub是一个简易API和工作流程,为复现研究提供了基本构建模块,包含预训练模型库。
并且,PyTorch Hub还支持Colab,能与论文代码结合网站Papers With Code集成,用于更广泛的研究。

github:https://github.com/pytorch/hub
模型:https://pytorch.org/hub/research-models


使用示例

import torch
model = torch.hub.load('pytorch/vision:v0.4.2', 'deeplabv3_resnet101', pretrained=True)
model.eval()
# 下载会显示下载数据和进度
# Downloading: "https://github.com/pytorch/vision/archive/v0.4.2.zip" to /Users/xx/.cache/torch/hub/v0.4.2.zip
# Downloading: "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth" to /Users/xx/.cache/torch/hub/checkpoints/resnet101-5d3b4d8f.pth
Downloading: "https://github.com/pytorch/vision/archive/v0.4.2.zip" to /Users/shushu/.cache/torch/hub/v0.4.2.zip
Downloading: "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth" to /Users/shushu/.cache/torch/hub/checkpoints/resnet101-5d3b4d8f.pth

HBox(children=(IntProgress(value=0, max=178728960), HTML(value='')))
torch.hub.list('pytorch/vision:v0.4.2')
Using cache found in C:\Users\Administrator/.cache\torch\hub\pytorch_vision_v0.4.2

['alexnet',
 'deeplabv3_resnet101',
 'densenet121',
 'densenet161',
 'densenet169',
 'densenet201',
 'fcn_resnet101',
 'googlenet',
 'inception_v3',
 'mobilenet_v2',
 'resnet101',
 'resnet152',
 'resnet18',
 'resnet34',
 'resnet50',
 'resnext101_32x8d',
 'resnext50_32x4d',
 'shufflenet_v2_x0_5',
 'shufflenet_v2_x1_0',
 'squeezenet1_0',
 'squeezenet1_1',
 'vgg11',
 'vgg11_bn',
 'vgg13',
 'vgg13_bn',
 'vgg16',
 'vgg16_bn',
 'vgg19',
 'vgg19_bn',
 'wide_resnet101_2',
 'wide_resnet50_2']

# Download an example image from the pytorch website
import urllib
url, filename = ("https://github.com/pytorch/hub/raw/master/dog.jpg", "dog.jpg")
try: urllib.URLopener().retrieve(url, filename)
except: urllib.request.urlretrieve(url, filename)
# sample execution (requires torchvision)
from PIL import Image
from torchvision import transforms

input_image = Image.open(filename)
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model

# move the input and model to GPU for speed if available
if torch.cuda.is_available():
    input_batch = input_batch.to('cuda')
    model.to('cuda')

with torch.no_grad():
    output = model(input_batch)['out'][0]
output_predictions = output.argmax(0)
# create a color pallette, selecting a color for each class
palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
colors = torch.as_tensor([i for i in range(21)])[:, None] * palette
colors = (colors % 255).numpy().astype("uint8")

# plot the semantic segmentation predictions of 21 classes in each color
r = Image.fromarray(output_predictions.byte().cpu().numpy()).resize(input_image.size)
r.putpalette(colors)

import matplotlib.pyplot as plt
plt.imshow(r)
plt.show()


1、查询可用的模型

用户可以使用torch.hub.list()这个API列出repo中所有可用的入口点。比如你想知道PyTorch Hub中有哪些可用的计算机视觉模型:

>>> torch.hub.list('pytorch/vision')
>>>
['alexnet',
'deeplabv3_resnet101',
'densenet121',
...
'vgg16',
'vgg16_bn',
'vgg19',
 'vgg19_bn']

2、加载模型

在上一步中能看到所有可用的计算机视觉模型,如果想调用其中的一个,也不必安装,只需一句话就能加载模型。

model = torch.hub.load('pytorch/vision', 'deeplabv3_resnet101', pretrained=True)

至于如何获得此模型的详细帮助信息,可以使用下面的API:

print(torch.hub.help('pytorch/vision', 'deeplabv3_resnet101'))

如果模型的发布者后续加入错误修复和性能改进,用户也可以非常简单地获取更新,确保自己用到的是最新版本:

model = torch.hub.load(..., force_reload=True)
对于另外一部分用户来说,稳定性更加重要,他们有时候需要调用特定分支的代码。例如pytorch_GAN_zoo的hub分支:

model = torch.hub.load('facebookresearch/pytorch_GAN_zoo:hub', 'DCGAN', pretrained=True, useGPU=False)

3、查看模型可用方法

从PyTorch Hub加载模型后,你可以用dir(model)查看模型的所有可用方法。以bertForMaskedLM模型为例:

>>> dir(model)
>>>
['forward'
...
'to'
'state_dict',
]

forward

如果你对forward方法感兴趣,使用help(model.forward) 了解运行运行该方法所需的参数。

>>> help(model.forward)
>>>
Help on method forward in module pytorch_pretrained_bert.modeling:
forward(input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None)
...

支持 Colab

PyTorch Hub中提供的模型也支持Colab。

进入每个模型的介绍页面后,你不仅可以看到GitHub代码页的入口,甚至可以一键进入Colab运行模型Demo。



对于模型发布者

如果你希望把自己的模型发布到PyTorch Hub上供所有用户使用,可以去PyTorch Hub的GitHub页发送拉取请求。若你的模型符合高质量、易重复、最有利的要求,Facebook官方将会与你合作。

一旦拉取请求被接受,你的模型将很快出现在PyTorch Hub官方网页上,供所有用户浏览。

目前该网站上已经有18个提交的模型,英伟达率先提供支持,他们在PyTorch Hub已经发布了Tacotron2和WaveGlow两个TTS模型。

图片

发布模型的方法也是比较简单的,开发者只需在自己的GitHub存储库中添加一个简单的hubconf.py文件,在其中枚举运行模型所需的依赖项列表即可。

比如,torchvision中的hubconf.py文件是这样的:

# Optional list of dependencies required by the package
dependencies = ['torch']

from torchvision.models.alexnet import alexnet
from torchvision.models.densenet import densenet121, densenet169, densenet201, densenet161
from torchvision.models.inception import inception_v3
from torchvision.models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152,\
resnext50_32x4d, resnext101_32x8d
from torchvision.models.squeezenet import squeezenet1_0, squeezenet1_1
from torchvision.models.vgg import vgg11, vgg13, vgg16, vgg19, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn
from torchvision.models.segmentation import fcn_resnet101, deeplabv3_resnet101
from torchvision.models.googlenet import googlenet
from torchvision.models.shufflenetv2 import shufflenet_v2_x0_5, shufflenet_v2_x1_0
from torchvision.models.mobilenet import mobilenet_v2

Facebook官方向模型发布者提出了以下三点要求:

1、每个模型文件都可以独立运行和执行
2、不需要PyTorch以外的任何包
3、不需要单独的入口点,让模型在创建时可以无缝地开箱即用

Facebook还建议发布者最小化对包的依赖性,减少用户加载模型进行实验的阻力。


更多资料

posted @ 2021-03-01 22:54  月思  阅读(1366)  评论(0编辑  收藏  举报