基于Huggingface Accelerate的DDP训练
# -*- coding: utf-8 -*- """" This document is a simple Demo for DDP Image Classification """ from typing import Callable from argparse import ArgumentParser, Namespace import torch from torch.backends import cudnn from torch.utils.data import Dataset, DataLoader from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts from torchvision import transforms from torchvision.datasets.cifar import CIFAR10 from torchvision.models import resnet from tqdm import tqdm from accelerate import Accelerator from accelerate.utils import set_seed def parse_args() -> Namespace: """Handling command-line input.""" parser = ArgumentParser() # 数据集路径 parser.add_argument( "-d", "--dataset", action="store", default="/dev/shm/dataset", type=str, help="Dataset folder.", ) # 训练轮数 parser.add_argument( "-e", "--epochs", action="store", default=248, type=int, help="Number of epochs to train.", ) # Mini Batch大小 parser.add_argument( "-bs", "--batch-size", action="store", default=128, type=int, help="Size of mini batch.", ) # 优化器选择 parser.add_argument( "-opt", "--optimizer", action="store", default="SGD", type=str, choices=["Adam", "SGD"], help="Optimizer used to train the model.", ) # 初始学习率 parser.add_argument( "-lr", "--learning-rate", action="store", default=2e-3, type=float, help="Learning rate.", ) # 随机数种子 parser.add_argument( "-s", "--seed", action="store", default=0, type=int, help="Random Seed.", ) return parser.parse_args() def prepare_model(num_classes: int = 1000) -> torch.nn.Module: """ResNet18,并替换FC层""" with accelerator.local_main_process_first(): model: resnet.ResNet = resnet.resnet18( weights=resnet.ResNet18_Weights.DEFAULT ) # 对于CIFAR数据集,ResNet-18将首层的7x7卷积核换成了3x3卷积核(参数量基本不变) model.conv1 = torch.nn.Conv2d( 3, 64, kernel_size=3, stride=1, padding=1, bias=False ) if num_classes != 1000: model.fc = torch.nn.Linear(512, num_classes) total_params = sum([param.nelement() for param in model.parameters()]) accelerator.print(f"#params: {total_params / 1e6}M") return model def prepare_dataset(folder: str): """采用CIFAR-10数据集""" normalize_transform = transforms.Normalize( (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) ) with accelerator.local_main_process_first(): train_data = CIFAR10( folder, train=True, transform=transforms.Compose( [ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(0.25), transforms.AutoAugment( transforms.AutoAugmentPolicy.CIFAR10 ), transforms.ToTensor(), normalize_transform, ] ), download=accelerator.is_local_main_process, ) test_data = CIFAR10( folder, train=False, transform=transforms.Compose( [transforms.ToTensor(), normalize_transform] ), download=accelerator.is_local_main_process, ) train_eval_data = CIFAR10( folder, train=True, transform=transforms.Compose( [transforms.ToTensor(), normalize_transform] ), ) return train_data, train_eval_data, test_data def get_data_loader( batch_size: int, train_data: Dataset, train_eval_data: Dataset, test_data: Dataset, ) -> tuple[DataLoader, DataLoader, DataLoader]: """获取DataLoader""" train_loader: DataLoader = DataLoader( train_data, batch_size, shuffle=True, pin_memory=True, num_workers=2 if accelerator.num_processes == 1 else 0, ) train_eval_loader: DataLoader = DataLoader( train_eval_data, batch_size * 2, shuffle=False, pin_memory=True, num_workers=2 if accelerator.num_processes == 1 else 0, ) test_loader: DataLoader = DataLoader( test_data, batch_size * 2, shuffle=False, pin_memory=True, num_workers=2 if accelerator.num_processes == 1 else 0, ) return accelerator.prepare(train_loader, train_eval_loader, test_loader) @torch.enable_grad() def train_epoch( model: torch.nn.Module, loss_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], dataloader: DataLoader, optimizer: torch.optim.Optimizer, ) -> None: """训练一轮""" model.train() dataloader_with_bar = tqdm( dataloader, disable=(not accelerator.is_local_main_process) ) for source, targets in dataloader_with_bar: optimizer.zero_grad() output: torch.Tensor = model(source) loss = loss_func(output, targets) accelerator.backward(loss) optimizer.step() @torch.no_grad() def eval_epoch( model: torch.nn.Module, loss_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], dataloader: DataLoader, ) -> tuple[float, float]: """在指定测试集上测试模型的损失和准确率""" model.eval() dataloader_with_bar = tqdm( dataloader, disable=(not accelerator.is_local_main_process) ) correct_sum, loss_sum, cnt_samples = 0, 0.0, 0 for source, targets in dataloader_with_bar: output: torch.Tensor = model(source) loss = loss_func(output, targets) prediction: torch.Tensor = accelerator.gather_for_metrics( output.argmax(dim=1) == targets ) # type: ignore correct_sum += prediction.sum().item() loss_sum += loss.item() cnt_samples += len(prediction) return loss_sum / len(dataloader), correct_sum / cnt_samples def main(args: Namespace): """训练的主函数""" set_seed(args.seed) model = prepare_model(10) train_data, train_eval_data, test_data = prepare_dataset(args.dataset) train_loader, train_eval_loader, test_loader = get_data_loader( args.batch_size, train_data, train_eval_data, test_data ) optimizer: torch.optim.Optimizer = ( torch.optim.SGD( model.parameters(), args.learning_rate, momentum=0.90, weight_decay=2e-2, ) if args.optimizer != "SGD" else torch.optim.Adam(model.parameters(), args.learning_rate) ) loss_func = torch.nn.CrossEntropyLoss(label_smoothing=0.05) scheduler: CosineAnnealingWarmRestarts = CosineAnnealingWarmRestarts( optimizer, 8, 2 ) model, optimizer, loss_func, scheduler = accelerator.prepare( model, optimizer, loss_func, scheduler ) best_acc = 0 log_file = open("log.csv", "wt") if accelerator.is_local_main_process: print( "epoch,train_loss,train_acc,val_loss,val_acc,learning_rate", file=log_file, ) log_file.flush() for epoch in range(args.epochs + 1): accelerator.print( f"Epoch {epoch}/{args.epochs}", f"(lr={optimizer.param_groups[-1]['lr']}):", ) # 训练模型 if epoch != 0: train_epoch(model, loss_func, train_loader, optimizer) accelerator.wait_for_everyone() # 在训练集和测试集上评估模型 train_loss, train_acc = eval_epoch(model, loss_func, train_eval_loader) val_loss, val_acc = eval_epoch(model, loss_func, test_loader) accelerator.print( f"[ Training ] Acc: {train_acc * 100:.2f}% Loss: {train_loss:.4f}" ) # 保存最佳权重 accelerator.wait_for_everyone() if accelerator.is_local_main_process: print( epoch, train_loss, train_acc, val_loss, val_acc, optimizer.param_groups[-1]["lr"], sep=",", file=log_file, ) log_file.flush() accelerator.save_model(model, "./weights/last") if val_acc > best_acc: best_acc = val_acc accelerator.save_model(model, "./weights/best") accelerator.wait_for_everyone() accelerator.print( f"[Validation] Acc: {val_acc * 100:.2f}%", f"Loss: {val_loss:.4f}", f"Best: {best_acc * 100:.2f}%", ) if epoch != 0: scheduler.step() log_file.close() if __name__ == "__main__": cudnn.benchmark = True accelerator = Accelerator() main(parse_args())
训练ImageNet的例子:
# -*- coding: utf-8 -*- # Usage: accelerate launch trainImageNet.py import torch import numpy as np from sklearn import metrics from os import path, makedirs from datetime import datetime from accelerate.utils import tqdm from accelerate import Accelerator from torchvision.models import resnet18 from torchvision.datasets import ImageNet from torch.utils.data import DataLoader, Dataset from torchvision.transforms import v2 as transforms BATCH_SIZE = 256 LEARNING_RATE = 0.5 WEIGHT_DECAY = 5e-5 EPOCHS = 300 LABEL_SMOOTHING = 0.05 WARM_UP_EPOCHS = 10 NUM_WORKERS = 16 # Per GPU DATASET_ROOT = "~/dataset/ImageNet-1k" accelerator = Accelerator() class ClassificationMeter: def __init__(self, num_classes: int, record_logits: bool = False) -> None: self.num_classes = num_classes self.total_loss = 0.0 self.labels = np.zeros((0,), dtype=np.int32) self.prediction = np.zeros((0,), dtype=np.int32) self.acc5_cnt = 0 self.record_logits = record_logits if self.record_logits: self.logits = np.ndarray((0, num_classes)) def record(self, y_true: torch.Tensor, logits: torch.Tensor) -> None: self.labels = np.concatenate([self.labels, y_true.cpu().numpy()]) # Record logits if self.record_logits: logits_softmax = torch.nn.functional.softmax(logits, dim=1).cpu().numpy() self.logits = np.concatenate([self.logits, logits_softmax]) # Loss self.total_loss += float( torch.nn.functional.cross_entropy(logits, y_true, reduction="sum").item() ) # Top-5 accuracy y_pred = logits.topk(5, largest=True).indices.to(torch.int) acc5_judge = (y_pred == y_true[:, None]).any(dim=-1) self.acc5_cnt += int(acc5_judge.sum().item()) # Recored the predictions self.prediction = np.concatenate([self.prediction, y_pred[:, 0].cpu().numpy()]) @property def accuracy(self) -> float: return float(metrics.accuracy_score(self.labels, self.prediction)) @property def youden_score(self) -> float: result = metrics.balanced_accuracy_score( self.labels, self.prediction, adjusted=True ) return float(result) @property def f1_micro(self) -> float: result = metrics.f1_score(self.labels, self.prediction, average="micro") return float(result) @property def f1_macro(self) -> float: result = metrics.f1_score(self.labels, self.prediction, average="macro") return float(result) @property def accuracy5(self) -> float: return self.acc5_cnt / len(self.labels) @property def loss(self) -> float: return float(self.total_loss / len(self.labels)) def make_dataloader( dataset: Dataset, shuffle: bool = False, num_workers: int = 32 ) -> DataLoader: return DataLoader( dataset, batch_size=BATCH_SIZE, shuffle=shuffle, num_workers=num_workers, pin_memory=True, ) def prepare_datasets() -> tuple[DataLoader, DataLoader]: mean = (0.485, 0.456, 0.406) std = (0.229, 0.224, 0.225) intp = transforms.InterpolationMode.BILINEAR basic_transform = transforms.Compose( [ transforms.Resize(232, interpolation=intp), transforms.CenterCrop(224), transforms.PILToTensor(), transforms.ToDtype(torch.float32, scale=True), transforms.Normalize(mean, std, inplace=True), transforms.ToPureTensor(), ] ) augment_transform = transforms.Compose( [ transforms.RandomResizedCrop(176, interpolation=intp), transforms.RandomHorizontalFlip(0.5), transforms.TrivialAugmentWide(interpolation=intp), transforms.ToImage(), transforms.ToDtype(torch.float32, scale=True), transforms.Normalize(mean, std, inplace=True), transforms.RandomErasing(0.1), transforms.ToPureTensor(), ] ) with accelerator.main_process_first(): dataset_train = ImageNet(DATASET_ROOT, "train", transform=basic_transform) with accelerator.main_process_first(): dataset_test = ImageNet(DATASET_ROOT, "val", transform=augment_transform) train_loader = make_dataloader(dataset_train, shuffle=True, num_workers=NUM_WORKERS) test_loader = make_dataloader(dataset_test, shuffle=False, num_workers=14) return train_loader, test_loader @torch.no_grad() def validate(model: torch.nn.Module, data_loader: DataLoader): model.eval() meter = ClassificationMeter(1000) for X, y in tqdm(True, data_loader, "Validating"): y: torch.Tensor = accelerator.gather_for_metrics(y) # type: ignore logits: torch.Tensor = accelerator.gather_for_metrics(model(X)) # type: ignore if accelerator.is_main_process: meter.record(y, logits) return meter def main() -> None: with accelerator.main_process_first(): train_loader, test_loader = prepare_datasets() model = resnet18() optimizer = torch.optim.SGD( model.parameters(), lr=LEARNING_RATE, momentum=0.9, weight_decay=WEIGHT_DECAY, ) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=EPOCHS - WARM_UP_EPOCHS, eta_min=1e-5 ) warmup_scheduler = torch.optim.lr_scheduler.LinearLR( optimizer, start_factor=1e-2, total_iters=WARM_UP_EPOCHS ) scheduler = torch.optim.lr_scheduler.SequentialLR( optimizer, [warmup_scheduler, scheduler], [WARM_UP_EPOCHS] ) criterion = torch.nn.CrossEntropyLoss(label_smoothing=LABEL_SMOOTHING) model = accelerator.prepare_model(model) train_loader = accelerator.prepare_data_loader(train_loader) test_loader = accelerator.prepare_data_loader(test_loader) optimizer = accelerator.prepare_optimizer(optimizer) # scheduler = accelerator.prepare_scheduler(scheduler) criterion = accelerator.prepare_model(criterion) best_acc = 0.0 best_loss = 1e10 saving_root = path.join( "saved_models", datetime.now().isoformat(timespec="seconds"), ) makedirs(saving_root, exist_ok=True) logging_file_path = path.join(saving_root, "base_training.csv") logging_file = open(logging_file_path, "w", buffering=1) print( "epoch", "best_acc@1", "loss", "acc@1", "acc@5", "f1-micro", "training_loss", "training_acc@1", "training_acc@5", "training_f1-micro", "training_learning-rate", file=logging_file, sep=",", ) for epoch in range(EPOCHS + 1): if epoch != 0: model.train() for X, y in tqdm(True, train_loader, desc=f"Epoch {epoch}/{EPOCHS}"): optimizer.zero_grad(set_to_none=True) logits = model(X) loss: torch.Tensor = criterion(logits, y) accelerator.backward(loss) optimizer.step() accelerator.wait_for_everyone() scheduler.step() train_meter = validate(model, train_loader) val_meter = validate(model, test_loader) if accelerator.is_main_process: # Validation on training set print( f"loss: {train_meter.loss:.4f}", f"acc@1: {train_meter.accuracy * 100:.3f}%", f"acc@5: {train_meter.accuracy5 * 100:.3f}%", f"f1-micro: {train_meter.f1_micro * 100:.3f}%", sep=" ", ) if val_meter.loss < best_loss: best_acc = val_meter.accuracy best_loss = val_meter.loss if epoch != 0: backbone_path = path.join(saving_root, "backbone.pth") backbone = accelerator.unwrap_model(model)[0] accelerator.save(backbone, backbone_path) # Validation on testing set print( f"loss: {val_meter.loss:.4f}", f"acc@1: {val_meter.accuracy * 100:.3f}%", f"acc@5: {val_meter.accuracy5 * 100:.3f}%", f"f1-micro: {val_meter.f1_micro * 100:.3f}%", f"best_acc@1: {best_acc * 100:.3f}%", sep=" ", ) print( epoch, best_acc, val_meter.loss, val_meter.accuracy, val_meter.accuracy5, val_meter.f1_micro, train_meter.loss, train_meter.accuracy, train_meter.accuracy5, train_meter.f1_micro, optimizer.state_dict()["param_groups"][0]["lr"], file=logging_file, sep=",", ) if __name__ == "__main__": main()
本文版权,除注明引用的部分外,归作者所有。本文严禁商业用途的转载。非商业用途的转载需在网页明显处署上作者名称及原文链接。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 无需6万激活码!GitHub神秘组织3小时极速复刻Manus,手把手教你使用OpenManus搭建本
· C#/.NET/.NET Core优秀项目和框架2025年2月简报
· Manus爆火,是硬核还是营销?
· 终于写完轮子一部分:tcp代理 了,记录一下
· 【杭电多校比赛记录】2025“钉耙编程”中国大学生算法设计春季联赛(1)