Pytorch图像分类训练框架
Pytorch图像分类训练框架
使用pytorch进行图像分类训练是一个大部分代码可复用的过程,我将在kaggle 比赛Paddy Doctor中写的训练代码抽取出来,方便以后图像分类任务使用。
完整源码:pytorch_trainer
1. 依赖
pip install tensorboard
conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch pyyaml
2. 文件组织方式
.
├── checkpoints
├── configs
├── dataset
├── submissions
├── run_predict.sh
├── run_split_dataset.sh
├── run_tensorboard.sh
├── run_train.sh
├── LICENSE
├── README.md
└── src
└── main
├── data
├── engine
├── model
├── options
├── predict.py
├── split_dataset.py
└── train.py
-
checkpoints:放加载点文件
-
configs:放默认参数配置文件
-
dataset:放数据集文件
-
submissions:放最后提交的submission文件
-
run_predict.sh:预测脚本
-
run_split_dataset.sh:划分数据集脚本
-
run_tensorboard.sh:打开tensorboard脚本
-
run_train.sh:训练脚本
-
src:放源码
-
main:放源码
-
data:数据预处理源代码,包括数据增强和数据迭代器生成
-
engine:训练 预测 划分数据集的代码
-
model:神经网络模型定义代码
-
options:参数设定代码
-
predict.py:预测源码
-
split_dataset.py:划分数据集源码
-
train.py:训练源码
-
-
3. 程序入口
使用脚本作为程序入口,方便设置参数
-
run_predict.sh:预测脚本
-
run_split_dataset.sh:划分数据集脚本
-
run_tensorboard.sh:打开tensorboard脚本
-
run_train.sh:训练脚本
训练脚本:
if [ ! -d "checkpoints" ];then
mkdir checkpoints;
fi
cd ./src/main/ && \
python ./train.py \
--config_file_path ../../configs/train_config.yaml \
--epochs 500 \
--batch_size 8 \
--dataset_dir ../../dataset/paddy-disease-classification/ \
--model_type base_model \
| tee ../../checkpoints/output.txt
预测脚本:
cd ./src/main/ && \
python ./predict.py \
--config_file_path ../../configs/eval_config.yaml \
--model_type efficient_model \
--dataset_dir ../../dataset/ \
--submission_file_path ../../submissions/submission.csv
划分数据集脚本:
cd ./src/main/ && \
python ./split_dataset.py \
--config_file_path ../../configs/split_config.yaml \
--dataset_dir ../../dataset/paddy-disease-classification/
4. 参数设定
通过扩展python的argparse.ArgumentParser类,实现从yaml文件加载默认参数,并可以在命令行中重写参数。
class ConfigArgumentParser(argparse.ArgumentParser):
def __init__(self, *args, **kwargs):
self.config_parser = argparse.ArgumentParser()
self.config_parser.add_argument("-c", "--config_file_path", default=None, metavar="FILE",
help="where to load YAML configuration")
self.option_names = []
super(ConfigArgumentParser, self).__init__(*args, **kwargs)
def add_override_argument(self, *args, **kwargs):
arg = super().add_argument(*args, **kwargs)
self.option_names.append(arg.dest)
return arg
def parse_args(self, args=None):
res, remaining_argv = self.config_parser.parse_known_args(args)
if res.config_file_path is not None:
with open(res.config_file_path, "r") as f:
config_vars = yaml.safe_load(f)
for key in config_vars:
if key not in self.option_names:
self.error(f"unexpected configuration entry: {key}")
self.set_defaults(**config_vars)
return super().parse_args(remaining_argv)
训练参数 设定如下:
def prepare_train_args():
train_parser = ConfigArgumentParser()
train_parser.add_override_argument('--seed', type=int,
help='a random seed')
train_parser.add_override_argument('--gpus', nargs='+', type=int,
help='numbers of GPU')
train_parser.add_override_argument('--epochs', type=int,
help='total epochs')
train_parser.add_override_argument('--batch_size', type=int,
help='batch size')
train_parser.add_override_argument('--lr', type=float,
help='learning rate')
train_parser.add_override_argument('--momentum', type=float,
help='momentum for sgd, alpha parameter for adam')
train_parser.add_override_argument('--beta', default=0.999, type=float,
help='beta parameters for adam')
train_parser.add_override_argument('--weight_decay', '--wd', type=float,
help='weight decay')
train_parser.add_override_argument('--save_prefix', type=str,
help='some comment for model or test result dir')
train_parser.add_override_argument('--model_type', type=str,
help='choose a model type, which is defined in model folder')
train_parser.add_override_argument('--loss_type', type=str,
help='choose a loss function, which is defined in loss folder')
train_parser.add_override_argument('--acc_type', type=str,
help='choose a acc function, which is defined in metrics folder')
train_parser.add_override_argument('--is_load_strict', action='store_false',
help='allow to load only common state dicts')
train_parser.add_override_argument('--is_load_pretrained_weight', action='store_true',
help='True means try to load pretrained weights')
train_parser.add_override_argument('--pretrained_weights_path', type=str,
help='pretrained weights path')
train_parser.add_override_argument('--is_resuming_training', action='store_true',
help='True means try to resume previous train')
train_parser.add_override_argument('--checkpoint_path', type=str,
help='checkpoints path')
train_parser.add_override_argument('--dataset_dir', type=str,
help='dataset directory')
train_parser.add_override_argument('--checkpoints_dir', type=str,
help='checkpoints directory')
args = train_parser.parse_args()
get_train_model_dir(args)
save_args(args, args.checkpoints_dir)
return args
训练默认yaml参数文件:
seed: 42
gpus: [0]
epochs: 100
batch_size: 128
lr: 1e-3
momentum: 0.9
beta: 0.999
weight_decay: 0
save_prefix: "test"
model_type: "base_model"
loss_type: "focal_loss"
acc_type: "classification_acc"
is_load_strict: true
is_load_pretrained_weight: false
pretrained_weights_path: ''
is_resuming_training: false
checkpoint_path: ''
dataset_dir: ../../dataset/
checkpoints_dir:
预测参数设定如下:
def prepare_eval_args():
eval_parser = ConfigArgumentParser()
eval_parser.add_override_argument('--seed', type=int,
help='a random seed')
eval_parser.add_override_argument('--gpus', nargs='+', type=int,
help='numbers of GPU')
eval_parser.add_override_argument('--model_type', type=str,
help='used in model_interface.py')
eval_parser.add_override_argument('--weights_path', type=str,
help='weights path')
eval_parser.add_override_argument('--dataset_dir', type=str,
help='dataset directory')
eval_parser.add_override_argument('--submission_file_path', type=str,
help='submission.csv path')
args = eval_parser.parse_args()
return args
预测默认yaml参数文件:
seed: 42
gpus: [0]
model_type: "base_model"
weights_path:
dataset_dir: ../../dataset/
submission_file_path: ../../submissions/submission.csv
划分数据集参数设定如下:
def prepare_split_dataset_args():
split_parser = ConfigArgumentParser()
split_parser.add_override_argument('--seed', type=int,
help='a random seed')
split_parser.add_override_argument('--valid_ratio', type=float,
help='valid ratio')
split_parser.add_override_argument('--dataset_dir', type=str,
help='dataset directory')
args = split_parser.parse_args()
save_args(args, args.dataset_dir)
return args
划分数据集默认yaml文件:
seed: 42
valid_ratio: 0.2
dataset_dir: ../../dataset/
5. 划分数据集
图像分类中,pytorch提供torchvision.datasets.ImageFolder
函数来进行数据集加载。使用该API,数据集应使用如下目录结构:
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
本项目也使用该API,所以应在Splitter.py中针对自己的数据集修改划分数据集函数,使得数据集有如上目录结构。
6. 数据加载
使用pytorch的torch.utils.data.DataLoader
训练数据集加载:
def select_train_loader(args):
train_dataset = torchvision.datasets.ImageFolder(os.path.join(args.dataset_dir, 'train_valid_test', "train"),
transform=transform_train)
print(train_dataset.class_to_idx)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True,
drop_last=False)
return train_loader
评估数据集加载:
def select_eval_loader(args):
eval_dataset = torchvision.datasets.ImageFolder(os.path.join(args.dataset_dir, 'train_valid_test', "valid"),
transform=transform_eval)
val_loader = DataLoader(eval_dataset, batch_size=1, shuffle=False, num_workers=1, pin_memory=True, drop_last=False)
return val_loader
7. 模型建立
在src/main/model
下建立新模型文件夹,并在其中编写模型代码。然后在model_interface.py中的select_model()
函数中添加该模型对应字典。
def select_model(args):
type2model = {
'base_model': base_model(),
'better_model': better_model(),
'efficient_model': efficient_model(),
}
model = type2model[args.model_type]
return model
8. 评估量建立
在src/main/engine/metrics/loss
下定义loss函数。然后在metrics_interface.py中的select_loss()
函数中添加该模型对应字典。
def select_loss(args):
type2lossFunction = {
'focal_loss': FocalLoss(num_class=10),
}
loss_function = type2lossFunction[args.loss_type]
return loss_function
同样,分类精确度也可重新定义。
9. 训练
直接运行do_train.sh脚本,命令行输出会同时保存在checkpoints文件就的output.txt中,方便查找错误。训练中的参数也会保存在加载点路径中的args.txt文件中。
训练器类:
class Trainer(object):
def __init__(self, args, model, train_loader, val_loader):
torch.manual_seed(args.seed)
self.__args = args
self.__logger = Logger(args)
self.__loss_function = select_loss(args)
self.__acc_function = select_acc(args)
self.__train_loader = train_loader
self.__val_loader = val_loader
self.__start_epoch = 0
train_status = 'Normal'
train_status_logs = []
# loading model
self.__model = model
if args.is_load_pretrained_weight:
train_status = 'Continuance'
self.__model.load_state_dict(torch.load(args.pretrained_weights_path), strict=args.is_load_strict)
train_status_logs.append('Log Output: Loaded pretrained weights successfully')
if args.is_resuming_training:
train_status = 'Restoration'
checkpoint = torch.load(args.checkpoint_path)
self.__start_epoch = checkpoint['epoch'] + 1
self.__model.load_state_dict(checkpoint['model_state_dict'], strict=args.is_load_strict)
train_status_logs.append('Log Output: Resumed previous model state successfully')
if args.gpus == [0]:
gpu_status = 'Single-GPU'
device = torch.device("cuda:0")
self.__model.to(device)
else:
gpu_status = 'Multi-GPU'
self.__model = torch.nn.DataParallel(self.__model, device_ids=args.gpus, output_device=args.gpus[0])
# initialize the optimizer
self.__optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.__model.parameters()),
self.__args.lr,
betas=(self.__args.momentum, self.__args.beta),
weight_decay=self.__args.weight_decay)
if args.is_resuming_training:
checkpoint = torch.load(args.checkpoint_path)
self.__optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
train_status_logs.append('Log Output: Resumed previous optimizer state successfully')
# print status
print('****************************************************************************************************')
print('Model:')
print(self.__model)
print('****************************************************************************************************')
print('Params To Learn:')
for name, param in self.__model.named_parameters():
if param.requires_grad:
print('\t', name)
print('****************************************************************************************************')
print('Train Status: ' + train_status)
print('GPU Status: ' + gpu_status)
for train_status_log in train_status_logs:
print(train_status_log)
print('****************************************************************************************************')
def train(self):
for epoch in range(self.__start_epoch, self.__args.epochs):
# train for one epoch
since = time.time()
self.__train_per_epoch()
self.__val_per_epoch()
self.__logger.save_curves(epoch)
self.__logger.save_checkpoint(epoch, self.__model, self.__optimizer)
self.__logger.print_logs(epoch, time.time() - since)
self.__logger.clear_scalar_cache()
def __train_per_epoch(self):
# switch to train mode
self.__model.train()
for i, data_batch in enumerate(self.__train_loader):
input_batch, output_batch, label_batch = self.__step(data_batch)
# compute loss and acc
loss, metrics = self.__compute_metrics(output_batch, label_batch, is_train=True)
# compute gradient and do Adam step
self.__optimizer.zero_grad()
loss.backward()
self.__optimizer.step()
# logger record
for key in metrics.keys():
self.__logger.record_scalar(key, metrics[key])
def __val_per_epoch(self):
# switch to eval mode
self.__model.eval()
with torch.no_grad():
for i, data_batch in enumerate(self.__val_loader):
input_batch, output_batch, label_batch = self.__step(data_batch)
# compute loss and acc
loss, metrics = self.__compute_metrics(output_batch, label_batch, is_train=False)
for key in metrics.keys():
self.__logger.record_scalar(key, metrics[key])
def __step(self, data_batch):
input_batch, label_batch = data_batch
# warp input
input_batch = Variable(input_batch).cuda()
label_batch = Variable(label_batch).cuda()
# compute output
output_batch = self.__model(input_batch)
return input_batch, output_batch, label_batch
def __compute_metrics(self, output_batch, label_batch, is_train):
# you can call functions in metrics_interface.py
loss = self.__calculate_loss(output_batch, label_batch)
acc = self.__evaluate_accuracy(output_batch, label_batch)
prefix = 'train/' if is_train else 'val/'
metrics = {
prefix + 'loss': loss.item(),
prefix + 'accuracy': acc,
}
return loss, metrics
def __calculate_loss(self, output_batch: torch.Tensor, label_batch: torch.Tensor) -> torch.Tensor:
loss = self.__loss_function(output_batch, label_batch)
return loss
def __evaluate_accuracy(self, output_batch: torch.Tensor, label_batch: torch.Tensor) -> float:
acc = self.__acc_function(output_batch, label_batch)
return acc
@staticmethod
def __gen_imgs_to_write(img, is_train):
# override this method according to your visualization
prefix = 'train/' if is_train else 'val/'
return {
prefix + 'img': img[0],
}
10. 预测
根据数据集,修改src/main/engine/predictor.py
,然后运行do_predict.sh脚本。
预测器类:
class Predictor(object):
def __init__(self, args, model, transform):
self.__args = args
self.__model = model
self.__transform = transform
self.__model.load_state_dict(torch.load(args.weights_path), strict=True)
if args.gpus == [0]:
gpu_status = 'Single-GPU'
device = torch.device("cuda:0")
self.__model.to(device)
else:
gpu_status = 'Multi-GPU'
self.__model = torch.nn.DataParallel(self.__model, device_ids=args.gpus, output_device=args.gpus[0])
print('****************************************************************************************************')
print('Model:')
print(self.__model)
print('****************************************************************************************************')
print('GPU Status: ' + gpu_status)
print('****************************************************************************************************')
self.__model.eval()
def predict_csv(self):
df = pd.read_csv(self.__args.submission_file_path)
for index, row in df.iterrows():
test_file_dir = os.path.join(self.__args.dataset_dir, 'train_valid_test', 'test', 'unknown', row[0])
img = PIL.Image.open(test_file_dir)
input_test = self.__transform(img).unsqueeze(0)
input_test = Variable(input_test).cuda()
with torch.no_grad():
output_test = self.__model.forward(input_test)
softmax = torch.nn.Softmax(dim=1)
output_test = softmax(output_test)
output_test = output_test.cpu().detach().numpy()
label_test = np.argmax(output_test)
df.iloc[index, 1] = (labels_map[label_test.item()])
print(df)
df.to_csv(self.__args.submission_file_path, index=None)
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 被坑几百块钱后,我竟然真的恢复了删除的微信聊天记录!
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
· 【自荐】一款简洁、开源的在线白板工具 Drawnix
· 园子的第一款AI主题卫衣上架——"HELLO! HOW CAN I ASSIST YOU TODAY
· Docker 太简单,K8s 太复杂?w7panel 让容器管理更轻松!