使用tensorboardX可视化Pytorch
可视化loss和acc
参考https://www.jianshu.com/p/46eb3004beca
-
环境安装:
conda activate xxx
pip install tensorboardX
pip install tensorflow
-
from tensorboardXimport SummaryWriter
writer = SummaryWriter('runs/001')
writer.add_scalar('Train/Loss', train_loss / batch_idx, epoch)
writer.add_scalar('Train/Acc', 100.0 * correct / total, epoch)
writer.close() -
服务器:
conda activate xxx
tensorboard --logdir=runs/001
-
本地:
终端上输入:ssh -p 22222 -L 6006:localhost:6006 yinwenbin@192.168.2.237
可视化模型
参考:https://blog.csdn.net/sunqiande88/article/details/80155925?utm_source=copy
import torchvision.models as models
from tensorboardX import SummaryWriter
import torch
model = models.resnet18()
dummy_input = torch.rand(13, 3, 224, 224)
with SummaryWriter(comment='resnet18') as w:
w.add_graph(model, (dummy_input, ))
conda activate xxx
tensorboard --logdir runs
若提示错误:'torch._C.Value' object has no attribute 'debugName'
修改tensorboardX 1.9为tensorboardX 1.8
参考:https://blog.csdn.net/East_Plain/article/details/103073311