使用tensorboard绘制pytorch网络模型
本文简单介绍一下如何使用tensorboard绘制pytorch网络模型
版本
torch == 1.10.1
tensorboardX == 2.5
代码
import torch
import torch.nn as nn
from tensorboardX import SummaryWriter
# 定义超参数
batch_size = 64
learning_rate = 1e-2
num_epoches = 20
# 定义网络
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=3),
nn.BatchNorm2d(16),
nn.ReLU(True))
self.layer2 = nn.Sequential(
nn.Conv2d(16, 32, kernel_size=3),
nn.BatchNorm2d(32),
nn.ReLU(True),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.layer3 = nn.Sequential(
nn.Conv2d(32, 64, kernel_size=3),
nn.BatchNorm2d(64),
nn.ReLU(True))
self.layer4 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.fc = nn.Sequential(
nn.Linear(128 * 4 * 4, 1024),
nn.ReLU(True),
nn.Linear(1024, 128),
nn.ReLU(True),
nn.Linear(128, 10)
)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
model = CNN()
dummy_input = torch.rand(20, 1, 28, 28)
writer = SummaryWriter('log')
with SummaryWriter(comment='LeNet') as w:
w.add_graph(model, (dummy_input,))
在pychram中执行完代码后,可以看到在项目的根目录下有一个名为runs的目录,该目录下有一个刚生成的目录,例如 Apr04_16-33-37_LAPTOP-QGED210TLeNet
在pycharm的命令行中执行命令
tensorboard --logdir ./runs/Apr04_16-33-37_LAPTOP-QGED210TLeNet
然后pycharm会给出tensorboard的访问路径,例如http://localhost:6006/,接下来就可以用浏览器进行访问了