Fork me on github

基于Huggingface Accelerate的DDP训练

# -*- coding: utf-8 -*-

"""" This document is a simple Demo for DDP Image Classification """

from typing import Callable
from argparse import ArgumentParser, Namespace

import torch
from torch.backends import cudnn
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torchvision import transforms
from torchvision.datasets.cifar import CIFAR10
from torchvision.models import resnet
from tqdm import tqdm
from accelerate import Accelerator
from accelerate.utils import set_seed


def parse_args() -> Namespace:
    """Handling command-line input."""
    parser = ArgumentParser()

    # 数据集路径
    parser.add_argument(
        "-d",
        "--dataset",
        action="store",
        default="/dev/shm/dataset",
        type=str,
        help="Dataset folder.",
    )

    # 训练轮数
    parser.add_argument(
        "-e",
        "--epochs",
        action="store",
        default=248,
        type=int,
        help="Number of epochs to train.",
    )

    # Mini Batch大小
    parser.add_argument(
        "-bs",
        "--batch-size",
        action="store",
        default=128,
        type=int,
        help="Size of mini batch.",
    )

    # 优化器选择
    parser.add_argument(
        "-opt",
        "--optimizer",
        action="store",
        default="SGD",
        type=str,
        choices=["Adam", "SGD"],
        help="Optimizer used to train the model.",
    )

    # 初始学习率
    parser.add_argument(
        "-lr",
        "--learning-rate",
        action="store",
        default=2e-3,
        type=float,
        help="Learning rate.",
    )

    # 随机数种子
    parser.add_argument(
        "-s",
        "--seed",
        action="store",
        default=0,
        type=int,
        help="Random Seed.",
    )
    return parser.parse_args()


def prepare_model(num_classes: int = 1000) -> torch.nn.Module:
    """ResNet18,并替换FC层"""
    with accelerator.local_main_process_first():
        model: resnet.ResNet = resnet.resnet18(
            weights=resnet.ResNet18_Weights.DEFAULT
        )
    # 对于CIFAR数据集,ResNet-18将首层的7x7卷积核换成了3x3卷积核(参数量基本不变)
    model.conv1 = torch.nn.Conv2d(
        3, 64, kernel_size=3, stride=1, padding=1, bias=False
    )
    if num_classes != 1000:
        model.fc = torch.nn.Linear(512, num_classes)

    total_params = sum([param.nelement() for param in model.parameters()])

    accelerator.print(f"#params: {total_params / 1e6}M")

    return model


def prepare_dataset(folder: str):
    """采用CIFAR-10数据集"""
    normalize_transform = transforms.Normalize(
        (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
    )

    with accelerator.local_main_process_first():
        train_data = CIFAR10(
            folder,
            train=True,
            transform=transforms.Compose(
                [
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(0.25),
                    transforms.AutoAugment(
                        transforms.AutoAugmentPolicy.CIFAR10
                    ),
                    transforms.ToTensor(),
                    normalize_transform,
                ]
            ),
            download=accelerator.is_local_main_process,
        )

        test_data = CIFAR10(
            folder,
            train=False,
            transform=transforms.Compose(
                [transforms.ToTensor(), normalize_transform]
            ),
            download=accelerator.is_local_main_process,
        )

    train_eval_data = CIFAR10(
        folder,
        train=True,
        transform=transforms.Compose(
            [transforms.ToTensor(), normalize_transform]
        ),
    )
    return train_data, train_eval_data, test_data


def get_data_loader(
    batch_size: int,
    train_data: Dataset,
    train_eval_data: Dataset,
    test_data: Dataset,
) -> tuple[DataLoader, DataLoader, DataLoader]:
    """获取DataLoader"""
    train_loader: DataLoader = DataLoader(
        train_data,
        batch_size,
        shuffle=True,
        pin_memory=True,
        num_workers=2 if accelerator.num_processes == 1 else 0,
    )
    train_eval_loader: DataLoader = DataLoader(
        train_eval_data,
        batch_size * 2,
        shuffle=False,
        pin_memory=True,
        num_workers=2 if accelerator.num_processes == 1 else 0,
    )
    test_loader: DataLoader = DataLoader(
        test_data,
        batch_size * 2,
        shuffle=False,
        pin_memory=True,
        num_workers=2 if accelerator.num_processes == 1 else 0,
    )
    return accelerator.prepare(train_loader, train_eval_loader, test_loader)


@torch.enable_grad()
def train_epoch(
    model: torch.nn.Module,
    loss_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
    dataloader: DataLoader,
    optimizer: torch.optim.Optimizer,
) -> None:
    """训练一轮"""
    model.train()
    dataloader_with_bar = tqdm(
        dataloader, disable=(not accelerator.is_local_main_process)
    )
    for source, targets in dataloader_with_bar:
        optimizer.zero_grad()
        output: torch.Tensor = model(source)
        loss = loss_func(output, targets)
        accelerator.backward(loss)
        optimizer.step()


@torch.no_grad()
def eval_epoch(
    model: torch.nn.Module,
    loss_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
    dataloader: DataLoader,
) -> tuple[float, float]:
    """在指定测试集上测试模型的损失和准确率"""
    model.eval()
    dataloader_with_bar = tqdm(
        dataloader, disable=(not accelerator.is_local_main_process)
    )
    correct_sum, loss_sum, cnt_samples = 0, 0.0, 0
    for source, targets in dataloader_with_bar:
        output: torch.Tensor = model(source)
        loss = loss_func(output, targets)

        prediction: torch.Tensor = accelerator.gather_for_metrics(
            output.argmax(dim=1) == targets
        )  # type: ignore
        correct_sum += prediction.sum().item()
        loss_sum += loss.item()
        cnt_samples += len(prediction)
    return loss_sum / len(dataloader), correct_sum / cnt_samples


def main(args: Namespace):
    """训练的主函数"""
    set_seed(args.seed)
    model = prepare_model(10)
    train_data, train_eval_data, test_data = prepare_dataset(args.dataset)
    train_loader, train_eval_loader, test_loader = get_data_loader(
        args.batch_size, train_data, train_eval_data, test_data
    )

    optimizer: torch.optim.Optimizer = (
        torch.optim.SGD(
            model.parameters(),
            args.learning_rate,
            momentum=0.90,
            weight_decay=2e-2,
        )
        if args.optimizer != "SGD"
        else torch.optim.Adam(model.parameters(), args.learning_rate)
    )

    loss_func = torch.nn.CrossEntropyLoss(label_smoothing=0.05)
    scheduler: CosineAnnealingWarmRestarts = CosineAnnealingWarmRestarts(
        optimizer, 8, 2
    )
    model, optimizer, loss_func, scheduler = accelerator.prepare(
        model, optimizer, loss_func, scheduler
    )

    best_acc = 0

    log_file = open("log.csv", "wt")
    if accelerator.is_local_main_process:
        print(
            "epoch,train_loss,train_acc,val_loss,val_acc,learning_rate",
            file=log_file,
        )
        log_file.flush()

    for epoch in range(args.epochs + 1):
        accelerator.print(
            f"Epoch {epoch}/{args.epochs}",
            f"(lr={optimizer.param_groups[-1]['lr']}):",
        )

        # 训练模型
        if epoch != 0:
            train_epoch(model, loss_func, train_loader, optimizer)

        accelerator.wait_for_everyone()

        # 在训练集和测试集上评估模型
        train_loss, train_acc = eval_epoch(model, loss_func, train_eval_loader)
        val_loss, val_acc = eval_epoch(model, loss_func, test_loader)
        accelerator.print(
            f"[ Training ] Acc: {train_acc * 100:.2f}% Loss: {train_loss:.4f}"
        )

        # 保存最佳权重
        accelerator.wait_for_everyone()
        if accelerator.is_local_main_process:
            print(
                epoch,
                train_loss,
                train_acc,
                val_loss,
                val_acc,
                optimizer.param_groups[-1]["lr"],
                sep=",",
                file=log_file,
            )
            log_file.flush()
            accelerator.save_model(model, "./weights/last")
            if val_acc > best_acc:
                best_acc = val_acc
                accelerator.save_model(model, "./weights/best")
        accelerator.wait_for_everyone()

        accelerator.print(
            f"[Validation] Acc: {val_acc * 100:.2f}%",
            f"Loss: {val_loss:.4f}",
            f"Best: {best_acc * 100:.2f}%",
        )

        if epoch != 0:
            scheduler.step()

    log_file.close()


if __name__ == "__main__":
    cudnn.benchmark = True
    accelerator = Accelerator()
    main(parse_args())

训练ImageNet的例子:

# -*- coding: utf-8 -*-

# Usage: accelerate launch trainImageNet.py

import torch
import numpy as np
from sklearn import metrics
from os import path, makedirs
from datetime import datetime
from accelerate.utils import tqdm
from accelerate import Accelerator
from torchvision.models import resnet18
from torchvision.datasets import ImageNet
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import v2 as transforms


BATCH_SIZE = 256
LEARNING_RATE = 0.5
WEIGHT_DECAY = 5e-5
EPOCHS = 300
LABEL_SMOOTHING = 0.05
WARM_UP_EPOCHS = 10
NUM_WORKERS = 16  # Per GPU
DATASET_ROOT = "~/dataset/ImageNet-1k"
accelerator = Accelerator()


class ClassificationMeter:
    def __init__(self, num_classes: int, record_logits: bool = False) -> None:
        self.num_classes = num_classes
        self.total_loss = 0.0
        self.labels = np.zeros((0,), dtype=np.int32)
        self.prediction = np.zeros((0,), dtype=np.int32)
        self.acc5_cnt = 0
        self.record_logits = record_logits
        if self.record_logits:
            self.logits = np.ndarray((0, num_classes))

    def record(self, y_true: torch.Tensor, logits: torch.Tensor) -> None:
        self.labels = np.concatenate([self.labels, y_true.cpu().numpy()])
        # Record logits
        if self.record_logits:
            logits_softmax = torch.nn.functional.softmax(logits, dim=1).cpu().numpy()
            self.logits = np.concatenate([self.logits, logits_softmax])

        # Loss
        self.total_loss += float(
            torch.nn.functional.cross_entropy(logits, y_true, reduction="sum").item()
        )
        # Top-5 accuracy
        y_pred = logits.topk(5, largest=True).indices.to(torch.int)
        acc5_judge = (y_pred == y_true[:, None]).any(dim=-1)
        self.acc5_cnt += int(acc5_judge.sum().item())

        # Recored the predictions
        self.prediction = np.concatenate([self.prediction, y_pred[:, 0].cpu().numpy()])

    @property
    def accuracy(self) -> float:
        return float(metrics.accuracy_score(self.labels, self.prediction))

    @property
    def youden_score(self) -> float:
        result = metrics.balanced_accuracy_score(
            self.labels, self.prediction, adjusted=True
        )
        return float(result)

    @property
    def f1_micro(self) -> float:
        result = metrics.f1_score(self.labels, self.prediction, average="micro")
        return float(result)

    @property
    def f1_macro(self) -> float:
        result = metrics.f1_score(self.labels, self.prediction, average="macro")
        return float(result)

    @property
    def accuracy5(self) -> float:
        return self.acc5_cnt / len(self.labels)

    @property
    def loss(self) -> float:
        return float(self.total_loss / len(self.labels))


def make_dataloader(
    dataset: Dataset, shuffle: bool = False, num_workers: int = 32
) -> DataLoader:
    return DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=True,
    )


def prepare_datasets() -> tuple[DataLoader, DataLoader]:
    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)
    intp = transforms.InterpolationMode.BILINEAR
    basic_transform = transforms.Compose(
        [
            transforms.Resize(232, interpolation=intp),
            transforms.CenterCrop(224),
            transforms.PILToTensor(),
            transforms.ToDtype(torch.float32, scale=True),
            transforms.Normalize(mean, std, inplace=True),
            transforms.ToPureTensor(),
        ]
    )

    augment_transform = transforms.Compose(
        [
            transforms.RandomResizedCrop(176, interpolation=intp),
            transforms.RandomHorizontalFlip(0.5),
            transforms.TrivialAugmentWide(interpolation=intp),
            transforms.ToImage(),
            transforms.ToDtype(torch.float32, scale=True),
            transforms.Normalize(mean, std, inplace=True),
            transforms.RandomErasing(0.1),
            transforms.ToPureTensor(),
        ]
    )

    with accelerator.main_process_first():
        dataset_train = ImageNet(DATASET_ROOT, "train", transform=basic_transform)
    with accelerator.main_process_first():
        dataset_test = ImageNet(DATASET_ROOT, "val", transform=augment_transform)

    train_loader = make_dataloader(dataset_train, shuffle=True, num_workers=NUM_WORKERS)
    test_loader = make_dataloader(dataset_test, shuffle=False, num_workers=14)
    return train_loader, test_loader


@torch.no_grad()
def validate(model: torch.nn.Module, data_loader: DataLoader):
    model.eval()
    meter = ClassificationMeter(1000)
    for X, y in tqdm(True, data_loader, "Validating"):
        y: torch.Tensor = accelerator.gather_for_metrics(y)  # type: ignore
        logits: torch.Tensor = accelerator.gather_for_metrics(model(X))  # type: ignore
        if accelerator.is_main_process:
            meter.record(y, logits)
    return meter


def main() -> None:
    with accelerator.main_process_first():
        train_loader, test_loader = prepare_datasets()
    model = resnet18()
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=LEARNING_RATE,
        momentum=0.9,
        weight_decay=WEIGHT_DECAY,
    )

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=EPOCHS - WARM_UP_EPOCHS, eta_min=1e-5
    )
    warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
        optimizer, start_factor=1e-2, total_iters=WARM_UP_EPOCHS
    )
    scheduler = torch.optim.lr_scheduler.SequentialLR(
        optimizer, [warmup_scheduler, scheduler], [WARM_UP_EPOCHS]
    )

    criterion = torch.nn.CrossEntropyLoss(label_smoothing=LABEL_SMOOTHING)

    model = accelerator.prepare_model(model)
    train_loader = accelerator.prepare_data_loader(train_loader)
    test_loader = accelerator.prepare_data_loader(test_loader)
    optimizer = accelerator.prepare_optimizer(optimizer)
    # scheduler = accelerator.prepare_scheduler(scheduler)
    criterion = accelerator.prepare_model(criterion)

    best_acc = 0.0
    best_loss = 1e10

    saving_root = path.join(
        "saved_models",
        datetime.now().isoformat(timespec="seconds"),
    )
    makedirs(saving_root, exist_ok=True)
    logging_file_path = path.join(saving_root, "base_training.csv")

    logging_file = open(logging_file_path, "w", buffering=1)
    print(
        "epoch",
        "best_acc@1",
        "loss",
        "acc@1",
        "acc@5",
        "f1-micro",
        "training_loss",
        "training_acc@1",
        "training_acc@5",
        "training_f1-micro",
        "training_learning-rate",
        file=logging_file,
        sep=",",
    )
    for epoch in range(EPOCHS + 1):
        if epoch != 0:
            model.train()
            for X, y in tqdm(True, train_loader, desc=f"Epoch {epoch}/{EPOCHS}"):
                optimizer.zero_grad(set_to_none=True)
                logits = model(X)
                loss: torch.Tensor = criterion(logits, y)
                accelerator.backward(loss)
                optimizer.step()
            accelerator.wait_for_everyone()
            scheduler.step()

        train_meter = validate(model, train_loader)
        val_meter = validate(model, test_loader)
        if accelerator.is_main_process:
            # Validation on training set
            print(
                f"loss: {train_meter.loss:.4f}",
                f"acc@1: {train_meter.accuracy * 100:.3f}%",
                f"acc@5: {train_meter.accuracy5 * 100:.3f}%",
                f"f1-micro: {train_meter.f1_micro * 100:.3f}%",
                sep="    ",
            )

            if val_meter.loss < best_loss:
                best_acc = val_meter.accuracy
                best_loss = val_meter.loss
                if epoch != 0:
                    backbone_path = path.join(saving_root, "backbone.pth")
                    backbone = accelerator.unwrap_model(model)[0]
                    accelerator.save(backbone, backbone_path)

            # Validation on testing set
            print(
                f"loss: {val_meter.loss:.4f}",
                f"acc@1: {val_meter.accuracy * 100:.3f}%",
                f"acc@5: {val_meter.accuracy5 * 100:.3f}%",
                f"f1-micro: {val_meter.f1_micro * 100:.3f}%",
                f"best_acc@1: {best_acc * 100:.3f}%",
                sep="    ",
            )
            print(
                epoch,
                best_acc,
                val_meter.loss,
                val_meter.accuracy,
                val_meter.accuracy5,
                val_meter.f1_micro,
                train_meter.loss,
                train_meter.accuracy,
                train_meter.accuracy5,
                train_meter.f1_micro,
                optimizer.state_dict()["param_groups"][0]["lr"],
                file=logging_file,
                sep=",",
            )


if __name__ == "__main__":
    main()
posted @ 2024-02-08 16:31  fang-d  阅读(197)  评论(0编辑  收藏  举报