pytorh分布式训练
DataParallel & DistributedDataParallel分布式训练
参考博客 《DataParallel & DistributedDataParallel分布式训练》:
细节参考博客(推荐)
###DDP
# 引入包
import argparse
import torch.distributed as dist
# 设置可选参数
parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', default=0, type=int,
help='node rank for distributed training')
args = parser.parse_args()
# print(args.local_rank)
dist.init_process_group(backend='nccl')
# 1.上面讲到的初始化进程组
dist.init_process_group(backend='nccl')
torch.cuda.set_device(args.local_rank)
# 2.使用DistributedSampler
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle = (train_sampler is None), sampler=train_sampler, pin_memory=False)
# 3.创建DDP模型进行分布式训练
model = nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True)
# 4.命令行开始训练 --nproc_per_node参数指定为当前主机创建的进程数(比如我当前可用但卡数是2 那就为这个主机创建两个进程,每个进程独立执行训练脚本)
# 我是单机多卡, 所以nnode=1, 就是一台主机, 一台主机上--nproc_per_node个进程
python -m torch.distributed.run --nnodes=1 --nproc_per_node=2 --node_rank=0 --master_port=6005 train.py
使用DP或者DDP在保存和使用模型时需要注意的地方
《使用DP或者DDP在保存和使用模型时需要注意的地方》:
在保存模型的时候建议用net.module.state_dict(),这是因为如果裁剪了DP或者DDP,网络结构变为nn.Sequential()这种数据类型了,而完美常用的保存方式是:
net = torch.nn.Linear(10,1)
# 先构造一个网络
net = torch.nn.DataParallel(net, device_ids=[0,3])
torch.save(net.module.state_dict(), './tmp.pth')
有了上述的知识基础,在加载模型的时候建议先用
def get_bare_model(net):
if isinstance(net, (nn.DataParallel, nn.parallel.DistributedDataParalleled)):
net = net.module
return net
清澈的爱,只为中国
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· winform 绘制太阳,地球,月球 运作规律
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人
· 上周热点回顾(3.3-3.9)
2018-01-18 mysql数据库表结构与表约束