【Tensorboard】在PyTorch中使用Tensorboard
在PyToch中使用Tensorboard
The SummaryWriter class is your main entry to log data for consumption and visualization by TensorBoard. For example:
import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
# Writer will output to ./runs/ directory by default
writer = SummaryWriter() # 可以指定log_dir, 即log的保存路径
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = datasets.MNIST('mnist_train', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
model = torchvision.models.resnet50(False)
# Have ResNet model take in grayscale rather than RGB
model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
images, labels = next(iter(trainloader))
grid = torchvision.utils.make_grid(images)
writer.add_image('images', grid, 0)
writer.add_graph(model, images)
writer.close()
在PyTorch中使用TensorFlow中的可视化工具Tensorboard
,具体流程如下:👇
-
安装
pip install tensorboard
-
实例化一个SummaryWriter(记录器)对象, 用于后续保存标量(scalar)/图像(image)/图(graph)等日志文件
# from tensorboardX import SummaryWriter from torch.utils.tensorboard import SummaryWriter #SummaryWriter Encapsultes everything log_dir = "./my_log_dir" writer = SummaryWriter(log_dir) #实例化对象时指定存放log的目录
-
保存sth(something)
通用的API格式:add_sth(tag_name, object, iter_num)
举例说明:保存标量writer.add_scalar('loss', value, iteration)
-
可视化网络 Add graph
https://tensorboardx.readthedocs.io/en/latest/tutorial.html#add-graph -
监测训练过程
tensorboard --logdir your_log_dir --bind_all
注意是日志目录,而不是要指定日志文件;指定--bind_all
后:服务器训练模型,本地浏览器可以打开tensorboard