一个单机多卡训练模型的例子

# """My demo train script."""

import argparse
import logging
import os
import random
import time
import numpy as np
import torch

from torch import nn, optim, Tensor
from torch.utils.data import DataLoader, Dataset


def parse_args() -> argparse.Namespace:
    """Parse arguments."""
    parser = argparse.ArgumentParser(description="Training")
    parser.add_argument("--seed", type=int, help="Fix random seed", default=123)
    parser.add_argument(
        "--log_file", type=str, help="Log file", default="test_train.log"
    )
    parser.add_argument(
        "--log_path", type=str, help="Model path", default="./training_log/"
    )
    parser.add_argument(
        "--train_epochs", type=int, help="Epochs of training", default=5
    )
    parser.add_argument("--batch_size", type=int, help="Batch size", default=32)
    parser.add_argument(
        "--learning_rate",
        type=float,
        help="Learning rate",
        default=1e-3,
    )
    parser.add_argument("--device", type=str, help="Run on which device", default="cpu")
    parser.add_argument(
        "--cuda_visible_devices", type=str, help="Cuda visible devices", default="0"
    )
    return parser.parse_args()


def init_logging(log_file: str, level: str = "INFO") -> None:
    """Initialize logging."""
    logging.basicConfig(
        filename=log_file,
        filemode="w",
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
        level=level,
    )
    logging.getLogger().addHandler(logging.StreamHandler())


def set_seed(seed: int) -> None:
    """Set seed for reproducibility."""
    os.environ["PYTHONHASHSEED"] = str(seed)
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms(True)


def seed_worker(work_id: int) -> None:
    """Set seed for worker."""
    np.random.seed(work_id)
    random.seed(work_id)


class DatasetClass(Dataset):
    """My demo dataset class."""

    def __init__(self):
        self.input = np.random.rand(1000000, 2).astype(np.float32)
        # self.input[:, 1] = 0.0
        self.target = np.zeros([1000000, 1])
        self.target[:, 0] = self.input[:, 0] + 1.0

    def __len__(self):
        return len(self.input)

    def __getitem__(self, idx: int) -> tuple:
        return self.input[idx], self.target[idx]


class ModelClass(torch.nn.Module):
    """My demo model class."""

    def __init__(self):
        super().__init__()
        self.my_layer = nn.Linear(2, 1)

    def forward(self, inputs: Tensor) -> Tensor:
        """My demo forward function."""
        outputs = self.my_layer(inputs)
        return outputs


def get_loss(model_output: Tensor, target: Tensor) -> Tensor:
    """My demo loss function."""
    loss = torch.norm(model_output - target, dim=-1).sum()
    return loss


def training() -> None:
    """My demo training function."""
    train_set = DatasetClass()
    g = torch.Generator()
    g.manual_seed(args.seed)
    train_loader = DataLoader(
        dataset=train_set,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=os.cpu_count(),
        pin_memory=True,
        worker_init_fn=seed_worker,
        generator=g,
    )
    model = ModelClass()
    if args.device == "cuda":
        model = nn.DataParallel(model)
    model.to(args.device)
    optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
    for epoch in range(args.train_epochs):
        model.train()
        for batch_index, (features, labels) in enumerate(train_loader):
            features = features.to(args.device)
            labels = labels.to(args.device)
            model_outputs = model(features)
            optimizer.zero_grad(set_to_none=True)
            loss = get_loss(model_outputs, labels)
            loss.backward()
            optimizer.step()
            if batch_index % 1000 == 0:
                logging.info(
                    "Epoch: %s, Batch index: %s, Loss: %s",
                    epoch,
                    batch_index,
                    loss.item(),
                )
    torch.save(model.state_dict(), f"{args.log_path}/trained_model.pth")


def testing() -> None:
    """My demo testing function."""
    test_set = DatasetClass()
    g = torch.Generator()
    g.manual_seed(args.seed)
    test_loader = DataLoader(
        dataset=test_set,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=os.cpu_count(),
        pin_memory=True,
        worker_init_fn=seed_worker,
        generator=g,
    )
    model = ModelClass()
    if args.device == "cuda":
        model = nn.DataParallel(model)
    model.load_state_dict(torch.load(f"{args.log_path}/trained_model.pth"))
    model.to(args.device)
    model.eval()
    with torch.no_grad():
        for batch_index, (features, labels) in enumerate(test_loader):
            features = features.to(args.device)
            labels = labels.to(args.device)
            model_outputs = model(features)
            loss = get_loss(model_outputs, labels)
            if batch_index % 1000 == 0:
                logging.info(
                    "Batch index: %s, Loss: %s",
                    batch_index,
                    loss.item() / args.batch_size,
                )


if __name__ == "__main__":
    args = parse_args()
    set_seed(args.seed)
    init_logging(args.log_file)
    os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_visible_devices
    main_start_time = time.time()
    training()
    main_end_time = time.time()
    logging.info("Main time: %s", main_end_time - main_start_time)
    testing()

  

posted @ 2024-08-14 19:12  南乡水  阅读(6)  评论(0编辑  收藏  举报