【Tensorboard】在PyTorch中使用Tensorboard

在PyToch中使用Tensorboard

The SummaryWriter class is your main entry to log data for consumption and visualization by TensorBoard. For example:

import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms

# Writer will output to ./runs/ directory by default
writer = SummaryWriter()  # 可以指定log_dir, 即log的保存路径

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = datasets.MNIST('mnist_train', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
model = torchvision.models.resnet50(False)
# Have ResNet model take in grayscale rather than RGB
model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
images, labels = next(iter(trainloader))

grid = torchvision.utils.make_grid(images)
writer.add_image('images', grid, 0)
writer.add_graph(model, images)
writer.close()

在PyTorch中使用TensorFlow中的可视化工具Tensorboard,具体流程如下:👇

  1. 安装 pip install tensorboard

  2. 实例化一个SummaryWriter(记录器)对象, 用于后续保存标量(scalar)/图像(image)/图(graph)等日志文件

    # from tensorboardX import SummaryWriter 
    from torch.utils.tensorboard import SummaryWriter  #SummaryWriter Encapsultes everything      
    log_dir = "./my_log_dir" 
    writer = SummaryWriter(log_dir) #实例化对象时指定存放log的目录
    
  3. 保存sth(something)
    通用的API格式:add_sth(tag_name, object, iter_num)
    举例说明:保存标量 writer.add_scalar('loss', value, iteration)

  4. 可视化网络 Add graph
    https://tensorboardx.readthedocs.io/en/latest/tutorial.html#add-graph

  5. 监测训练过程
    tensorboard --logdir your_log_dir --bind_all 注意是日志目录,而不是要指定日志文件;指定--bind_all后:服务器训练模型,本地浏览器可以打开tensorboard

参考链接

posted @ 2020-09-14 22:44  达可奈特  阅读(771)  评论(0编辑  收藏  举报