MMClassification 实践笔记
1. 配置环境
参考文档:https://mmclassification.readthedocs.io/zh_CN/dev-1.x/get_started.html
git clone -b 1.x https://github.com/open-mmlab/mmclassification.git
## 容易失败,参考 https://blog.csdn.net/good_good_xiu/article/details/118567249
cd mmclassification
conda install pytorch torchvision -c pytorch
pip install -U openmim && mim install -e . -i https://pypi.tuna.tsinghua.edu.cn/simple
2. 验证
验证 GPU 是否可用?
python
import torch
torch.cuda.is_available() # True
第 1 步 我们需要下载配置文件和模型权重文件
mim download mmcls --config resnet50_8xb32_in1k --dest .
第 2 步 验证示例的推理流程
python demo/image_demo.py demo/demo.JPEG resnet50_8xb32_in1k.py resnet50_8xb32_in1k_20210831-ea4938fc.pth --device cpu
3. 训练自定义数据集
数据组织:切记一定要这种格式。
.
└── xx_data
├── train
│ ├── class1(xx张图像)
│ ├── class2(xx张图像)
└── val
│ ├── class1(xx张图像)
│ ├── class2(xx张图像)
配置文件:
注意:num_classes=7, data_prefix='/data/xxx/datasets/age_classify/age_data/train',
model = dict(
type='ImageClassifier',
backbone=dict(type='MobileNetV2', widen_factor=1.0),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=7,
in_channels=1280,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, )))
load_from = 'mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth'
# dataset_type = 'ImageNet'
# data_preprocessor = dict(
# num_classes=1000,
# mean=[123.675, 116.28, 103.53],
# std=[58.395, 57.12, 57.375],
# to_rgb=True)
# train_pipeline = [
# dict(type='LoadImageFromFile'),
# dict(type='RandomResizedCrop', scale=224, backend='pillow'),
# dict(type='RandomFlip', prob=0.5, direction='horizontal'),
# dict(type='PackClsInputs')
# ]
# test_pipeline = [
# dict(type='LoadImageFromFile'),
# dict(type='ResizeEdge', scale=256, edge='short', backend='pillow'),
# dict(type='CenterCrop', crop_size=224),
# dict(type='PackClsInputs')
# ]
train_dataloader = dict(
batch_size=64,
num_workers=5,
dataset=dict(
type='CustomDataset',
# data_root='data/imagenet',
# ann_file='meta/train.txt',
data_prefix='/data/xxx/datasets/age_classify/age_data/train',
pipeline=[
dict(type='LoadImageFromFile'),
dict(type='RandomResizedCrop', scale=168, backend='pillow'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackClsInputs')
]),
sampler=dict(type='DefaultSampler', shuffle=True))
val_dataloader = dict(
batch_size=64,
num_workers=5,
dataset=dict(
type='CustomDataset',
# data_root='data/imagenet',
# ann_file='meta/val.txt',
data_prefix='/data/xxx/datasets/age_classify/age_data/val',
pipeline=[
dict(type='LoadImageFromFile'),
dict(type='ResizeEdge', scale=224, edge='short', backend='pillow'),
dict(type='CenterCrop', crop_size=168),
dict(type='PackClsInputs')
]),
sampler=dict(type='DefaultSampler', shuffle=False))
val_evaluator = dict(type='Accuracy', topk=(1, ))
test_dataloader = dict(
batch_size=32,
num_workers=5,
dataset=dict(
type='CustomDataset',
# data_root='data/imagenet',
# ann_file='meta/val.txt',
data_prefix='/data/xxx/datasets/age_classify/age_data/val',
pipeline=[
dict(type='LoadImageFromFile'),
dict(type='ResizeEdge', scale=224, edge='short', backend='pillow'),
dict(type='CenterCrop', crop_size=168),
dict(type='PackClsInputs')
]),
sampler=dict(type='DefaultSampler', shuffle=False))
test_evaluator = dict(type='Accuracy', topk=(1,))
optim_wrapper = dict(
optimizer=dict(type='SGD', lr=0.005, momentum=0.9, weight_decay=4e-05))
param_scheduler = dict(type='StepLR', by_epoch=True, step_size=1, gamma=0.98)
train_cfg = dict(by_epoch=True, max_epochs=100, val_interval=5)
val_cfg = dict()
test_cfg = dict()
auto_scale_lr = dict(base_batch_size=256)
default_scope = 'mmcls'
default_hooks = dict(
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=10),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', interval=5),
sampler_seed=dict(type='DistSamplerSeedHook'),
visualization=dict(type='VisualizationHook', enable=False))
env_cfg = dict(
cudnn_benchmark=False,
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
dist_cfg=dict(backend='nccl'))
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='ClsVisualizer', vis_backends=[dict(type='LocalVisBackend')])
log_level = 'INFO'
resume = False
randomness = dict(seed=None, deterministic=False)
训练命令:
python tool/tools/train.py --config 上面的配置文件的路径即可
参考:https://github.com/wangruohui/sjtu-openmmlab-tutorial/blob/main/cls-2-train.ipynb
其他详解:如何编写配置文件?
https://www.bilibili.com/video/BV1J341127nQ?p=9&vd_source=2ed6e8af02f9ba8cb90b90e99bd4ccee