一个单机多卡训练模型的例子
# """My demo train script.""" import argparse import logging import os import random import time import numpy as np import torch from torch import nn, optim, Tensor from torch.utils.data import DataLoader, Dataset def parse_args() -> argparse.Namespace: """Parse arguments.""" parser = argparse.ArgumentParser(description="Training") parser.add_argument("--seed", type=int, help="Fix random seed", default=123) parser.add_argument( "--log_file", type=str, help="Log file", default="test_train.log" ) parser.add_argument( "--log_path", type=str, help="Model path", default="./training_log/" ) parser.add_argument( "--train_epochs", type=int, help="Epochs of training", default=5 ) parser.add_argument("--batch_size", type=int, help="Batch size", default=32) parser.add_argument( "--learning_rate", type=float, help="Learning rate", default=1e-3, ) parser.add_argument("--device", type=str, help="Run on which device", default="cpu") parser.add_argument( "--cuda_visible_devices", type=str, help="Cuda visible devices", default="0" ) return parser.parse_args() def init_logging(log_file: str, level: str = "INFO") -> None: """Initialize logging.""" logging.basicConfig( filename=log_file, filemode="w", format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=level, ) logging.getLogger().addHandler(logging.StreamHandler()) def set_seed(seed: int) -> None: """Set seed for reproducibility.""" os.environ["PYTHONHASHSEED"] = str(seed) os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.use_deterministic_algorithms(True) def seed_worker(work_id: int) -> None: """Set seed for worker.""" np.random.seed(work_id) random.seed(work_id) class DatasetClass(Dataset): """My demo dataset class.""" def __init__(self): self.input = np.random.rand(1000000, 2).astype(np.float32) # self.input[:, 1] = 0.0 self.target = np.zeros([1000000, 1]) self.target[:, 0] = self.input[:, 0] + 1.0 def __len__(self): return len(self.input) def __getitem__(self, idx: int) -> tuple: return self.input[idx], self.target[idx] class ModelClass(torch.nn.Module): """My demo model class.""" def __init__(self): super().__init__() self.my_layer = nn.Linear(2, 1) def forward(self, inputs: Tensor) -> Tensor: """My demo forward function.""" outputs = self.my_layer(inputs) return outputs def get_loss(model_output: Tensor, target: Tensor) -> Tensor: """My demo loss function.""" loss = torch.norm(model_output - target, dim=-1).sum() return loss def training() -> None: """My demo training function.""" train_set = DatasetClass() g = torch.Generator() g.manual_seed(args.seed) train_loader = DataLoader( dataset=train_set, batch_size=args.batch_size, shuffle=True, num_workers=os.cpu_count(), pin_memory=True, worker_init_fn=seed_worker, generator=g, ) model = ModelClass() if args.device == "cuda": model = nn.DataParallel(model) model.to(args.device) optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) for epoch in range(args.train_epochs): model.train() for batch_index, (features, labels) in enumerate(train_loader): features = features.to(args.device) labels = labels.to(args.device) model_outputs = model(features) optimizer.zero_grad(set_to_none=True) loss = get_loss(model_outputs, labels) loss.backward() optimizer.step() if batch_index % 1000 == 0: logging.info( "Epoch: %s, Batch index: %s, Loss: %s", epoch, batch_index, loss.item(), ) torch.save(model.state_dict(), f"{args.log_path}/trained_model.pth") def testing() -> None: """My demo testing function.""" test_set = DatasetClass() g = torch.Generator() g.manual_seed(args.seed) test_loader = DataLoader( dataset=test_set, batch_size=args.batch_size, shuffle=True, num_workers=os.cpu_count(), pin_memory=True, worker_init_fn=seed_worker, generator=g, ) model = ModelClass() if args.device == "cuda": model = nn.DataParallel(model) model.load_state_dict(torch.load(f"{args.log_path}/trained_model.pth")) model.to(args.device) model.eval() with torch.no_grad(): for batch_index, (features, labels) in enumerate(test_loader): features = features.to(args.device) labels = labels.to(args.device) model_outputs = model(features) loss = get_loss(model_outputs, labels) if batch_index % 1000 == 0: logging.info( "Batch index: %s, Loss: %s", batch_index, loss.item() / args.batch_size, ) if __name__ == "__main__": args = parse_args() set_seed(args.seed) init_logging(args.log_file) os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_visible_devices main_start_time = time.time() training() main_end_time = time.time() logging.info("Main time: %s", main_end_time - main_start_time) testing()