Python|深度学习可视化工具wandb的使用
基础知识
1 安装库
pip install wandb
2 创建账户
wandb login
3 初始化
# Inside my model training code
import wandb
wandb.init(project="my-project")
4 声明超参数
wandb.config.dropout = 0.2
wandb.config.hidden_layer_size = 128
5 记录日志
def my_train_loop():
for epoch in range(10):
loss = 0 # change as appropriate :)
wandb.log({'epoch': epoch, 'loss': loss})
6 保存文件
# by default, this will save to a new subfolder for files associated
# with your run, created in wandb.run.dir (which is ./wandb by default)
wandb.save("mymodel.h5")
# you can pass the full path to the Keras model API
model.save(os.path.join(wandb.run.dir, "mymodel.h5"))
使用wandb以后,模型输出,log和要保存的文件将会同步到cloud。
PyTorch应用wandb
这里以一个对抗神经网络为例展示wandb的用法:
首先导入必要的库:
import torch.nn as nn
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import os
import numpy as np
import argparse
import math
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn.functional as F
import torch
import wandb
初始化一个wandb run,并设置超参数:
wandb.init(entity="weltm", project="GAN-project")
# WandB – Config is a variable that holds and saves hyperparameters and inputs
config = wandb.config # Initialize config
config.batch_size = 256 # input batch size for training (default: 64)
config.epochs = 300 # number of epochs to train (default: 10)
config.lr = 0.0002 # learning rate (default: 0.01)
config.momentum = 0.1 # SGD momentum (default: 0.5)
config.log_interval = 10 # how many batches to wait before logging training status
定义训练函数,使用wandb记录G_Loss 与 D_Loss
def train(d, g, criterion, d_optimizer, g_optimizer, epochs=1, show_every=100, print_every =10):
iter_count = 0
for epoch in range(epochs):
for inputs, _ in train_loader:
real_inputs = inputs # 真实样本
fake_inputs = g(torch.randn(opt.batch_size, 100)) # 伪造样本
real_labels = torch.ones(real_inputs.size(0)) # 真实标签
fake_labels = torch.zeros(opt.batch_size) # 伪造标签
real_outputs = d(real_inputs)
d_loss_real = criterion(real_outputs, real_labels)
fake_outputs = d(fake_inputs)
d_loss_fake = criterion(fake_outputs, fake_labels)
d_loss = d_loss_real + d_loss_fake
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()
fake_inputs = g(torch.randn(opt.batch_size, 100))
outputs = d(fake_inputs)
real_labels = torch.ones(outputs.size(0))
g_loss = criterion(outputs, real_labels)
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()
wandb.log({"G_Loss": g_loss,
"D_Loss": d_loss})
if (iter_count % show_every == 0):
print('Epoch:{}, Iter:{}, D:{}, G:{}'.format(epoch,iter_count, d_loss.item(),g_loss.item()))
picname = "Epoch_" + str(epoch) + "Iter_" + str(iter_count)
img_show(torchvision.utils.make_grid(fake_inputs.data[0:5]))
if (iter_count % print_every == 0):
print('Epoch:{}, Iter:{}, D:{}, G:{}'.format(epoch,iter_count, d_loss.item(), g_loss.item()))
iter_count += 1
# print('Finished Training!')