使用tensorboard绘制pytorch网络模型

本文简单介绍一下如何使用tensorboard绘制pytorch网络模型

版本
torch == 1.10.1
tensorboardX == 2.5

代码

import torch
import torch.nn as nn
from tensorboardX import SummaryWriter

# 定义超参数
batch_size = 64
learning_rate = 1e-2
num_epoches = 20


# 定义网络
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3),
            nn.BatchNorm2d(16),
            nn.ReLU(True))

        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.layer3 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3),
            nn.BatchNorm2d(64),
            nn.ReLU(True))

        self.layer4 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.fc = nn.Sequential(
            nn.Linear(128 * 4 * 4, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 128),
            nn.ReLU(True),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x


model = CNN()
dummy_input = torch.rand(20, 1, 28, 28)
writer = SummaryWriter('log')
with SummaryWriter(comment='LeNet') as w:
    w.add_graph(model, (dummy_input,))

在pychram中执行完代码后,可以看到在项目的根目录下有一个名为runs的目录,该目录下有一个刚生成的目录,例如 Apr04_16-33-37_LAPTOP-QGED210TLeNet

在pycharm的命令行中执行命令
tensorboard --logdir ./runs/Apr04_16-33-37_LAPTOP-QGED210TLeNet

然后pycharm会给出tensorboard的访问路径,例如http://localhost:6006/,接下来就可以用浏览器进行访问了

参考: https://zhuanlan.zhihu.com/p/58961505

posted @ 2022-04-04 16:41  Bill_H  阅读(253)  评论(0编辑  收藏  举报