网络结构可视化方法
方法一:输出为PDF文档(使用graphviz)
from graphviz import Digraph import torch from torch.autograd import Variable def make_dot(var, params=None): """ Produces Graphviz representation of PyTorch autograd graph Blue nodes are the Variables that require grad, orange are Tensors saved for backward in torch.autograd.Function Args: var: output Variable params: dict of (name, Variable) to add names to node that require grad (TODO: make optional) """ if params is not None: assert isinstance(params.values()[0], Variable) param_map = {id(v): k for k, v in params.items()} node_attr = dict(style='filled', shape='box', align='left', fontsize='12', ranksep='0.1', height='0.2') dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12")) seen = set() def size_to_str(size): return '('+(', ').join(['%d' % v for v in size])+')' def add_nodes(var): if var not in seen: if torch.is_tensor(var): dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange') elif hasattr(var, 'variable'): u = var.variable name = param_map[id(u)] if params is not None else '' node_name = '%s\n %s' % (name, size_to_str(u.size())) dot.node(str(id(var)), node_name, fillcolor='lightblue') else: dot.node(str(id(var)), str(type(var).__name__)) seen.add(var) if hasattr(var, 'next_functions'): for u in var.next_functions: if u[0] is not None: dot.edge(str(id(u[0])), str(id(var))) add_nodes(u[0]) if hasattr(var, 'saved_tensors'): for t in var.saved_tensors: dot.edge(str(id(t)), str(id(var))) add_nodes(t) add_nodes(var.grad_fn) return dot
itLEP_pil, itLEP_np = get_image(real_face_name, imsize) net = skip(input_depth, itLEP_np.shape[0], num_channels_down = [128] * 5, num_channels_up = [128] * 5, num_channels_skip = [128] * 5, filter_size_up = 3, filter_size_down = 3, upsample_mode='nearest', filter_skip_size=1, need_sigmoid=True, need_bias=True, pad=pad, act_fun='LeakyReLU').type(dtype) dummy_input = get_noise(input_depth, INPUT, itLEP_np.shape[1:]).type(dtype)
#上面为定义网络结构,以及定义输入;下面为输出网络结构图 y = net(dummy_input) g = make_dot(y) g.view()
方法二:使用tensorboardX
import torch import torch.nn as nn from tensorboardX import SummaryWriter class LeNet(nn.Module): def __init__(self): super(LeNet, self).__init__() self.conv1 = nn.Sequential( #input_size=(1*28*28) nn.Conv2d(1, 6, 5, 1, 2), nn.ReLU(), #(6*28*28) nn.MaxPool2d(kernel_size=2, stride=2), #output_size=(6*14*14) ) self.conv2 = nn.Sequential( nn.Conv2d(6, 16, 5), nn.ReLU(), #(16*10*10) nn.MaxPool2d(2, 2) #output_size=(16*5*5) ) self.fc1 = nn.Sequential( nn.Linear(16 * 5 * 5, 120), nn.ReLU() ) self.fc2 = nn.Sequential( nn.Linear(120, 84), nn.ReLU() ) self.fc3 = nn.Linear(84, 10) # 定义前向传播过程,输入为x def forward(self, x): x = self.conv1(x) x = self.conv2(x) # nn.Linear()的输入输出都是维度为一的值,所以要把多维度的tensor展平成一维 x = x.view(x.size()[0], -1) x = self.fc1(x) x = self.fc2(x) x = self.fc3(x) return x dummy_input = torch.rand(13, 1, 28, 28) #假设输入13张1*28*28的图片 model = LeNet() with SummaryWriter(comment='LeNet') as w: w.add_graph(model, (dummy_input, ))
这里运行后会生成runs文件夹,切换到runs所在的目录,
使用 tensorboard --logdir runs该命令,得到浏览器地址,在不同的浏览器打开(因为有些浏览器打开看不到任何东西)
双击图的结构,出现网络细节图