MMClassification 实践笔记
1. 配置环境
参考文档:https://mmclassification.readthedocs.io/zh_CN/dev-1.x/get_started.html
git clone -b 1.x https://github.com/open-mmlab/mmclassification.git ## 容易失败,参考 https://blog.csdn.net/good_good_xiu/article/details/118567249 cd mmclassification conda install pytorch torchvision -c pytorch pip install -U openmim && mim install -e . -i https://pypi.tuna.tsinghua.edu.cn/simple
2. 验证
验证 GPU 是否可用?
python
import torch torch.cuda.is_available() # True
第 1 步 我们需要下载配置文件和模型权重文件
mim download mmcls --config resnet50_8xb32_in1k --dest .
第 2 步 验证示例的推理流程
python demo/image_demo.py demo/demo.JPEG resnet50_8xb32_in1k.py resnet50_8xb32_in1k_20210831-ea4938fc.pth --device cpu
3. 训练自定义数据集
数据组织:切记一定要这种格式。
. └── xx_data ├── train │ ├── class1(xx张图像) │ ├── class2(xx张图像) └── val │ ├── class1(xx张图像) │ ├── class2(xx张图像)
配置文件:
注意:num_classes=7, data_prefix='/data/xxx/datasets/age_classify/age_data/train',
model = dict( type='ImageClassifier', backbone=dict(type='MobileNetV2', widen_factor=1.0), neck=dict(type='GlobalAveragePooling'), head=dict( type='LinearClsHead', num_classes=7, in_channels=1280, loss=dict(type='CrossEntropyLoss', loss_weight=1.0), topk=(1, ))) load_from = 'mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth' # dataset_type = 'ImageNet' # data_preprocessor = dict( # num_classes=1000, # mean=[123.675, 116.28, 103.53], # std=[58.395, 57.12, 57.375], # to_rgb=True) # train_pipeline = [ # dict(type='LoadImageFromFile'), # dict(type='RandomResizedCrop', scale=224, backend='pillow'), # dict(type='RandomFlip', prob=0.5, direction='horizontal'), # dict(type='PackClsInputs') # ] # test_pipeline = [ # dict(type='LoadImageFromFile'), # dict(type='ResizeEdge', scale=256, edge='short', backend='pillow'), # dict(type='CenterCrop', crop_size=224), # dict(type='PackClsInputs') # ] train_dataloader = dict( batch_size=64, num_workers=5, dataset=dict( type='CustomDataset', # data_root='data/imagenet', # ann_file='meta/train.txt', data_prefix='/data/xxx/datasets/age_classify/age_data/train', pipeline=[ dict(type='LoadImageFromFile'), dict(type='RandomResizedCrop', scale=168, backend='pillow'), dict(type='RandomFlip', prob=0.5, direction='horizontal'), dict(type='PackClsInputs') ]), sampler=dict(type='DefaultSampler', shuffle=True)) val_dataloader = dict( batch_size=64, num_workers=5, dataset=dict( type='CustomDataset', # data_root='data/imagenet', # ann_file='meta/val.txt', data_prefix='/data/xxx/datasets/age_classify/age_data/val', pipeline=[ dict(type='LoadImageFromFile'), dict(type='ResizeEdge', scale=224, edge='short', backend='pillow'), dict(type='CenterCrop', crop_size=168), dict(type='PackClsInputs') ]), sampler=dict(type='DefaultSampler', shuffle=False)) val_evaluator = dict(type='Accuracy', topk=(1, )) test_dataloader = dict( batch_size=32, num_workers=5, dataset=dict( type='CustomDataset', # data_root='data/imagenet', # ann_file='meta/val.txt', data_prefix='/data/xxx/datasets/age_classify/age_data/val', pipeline=[ dict(type='LoadImageFromFile'), dict(type='ResizeEdge', scale=224, edge='short', backend='pillow'), dict(type='CenterCrop', crop_size=168), dict(type='PackClsInputs') ]), sampler=dict(type='DefaultSampler', shuffle=False)) test_evaluator = dict(type='Accuracy', topk=(1,)) optim_wrapper = dict( optimizer=dict(type='SGD', lr=0.005, momentum=0.9, weight_decay=4e-05)) param_scheduler = dict(type='StepLR', by_epoch=True, step_size=1, gamma=0.98) train_cfg = dict(by_epoch=True, max_epochs=100, val_interval=5) val_cfg = dict() test_cfg = dict() auto_scale_lr = dict(base_batch_size=256) default_scope = 'mmcls' default_hooks = dict( timer=dict(type='IterTimerHook'), logger=dict(type='LoggerHook', interval=10), param_scheduler=dict(type='ParamSchedulerHook'), checkpoint=dict(type='CheckpointHook', interval=5), sampler_seed=dict(type='DistSamplerSeedHook'), visualization=dict(type='VisualizationHook', enable=False)) env_cfg = dict( cudnn_benchmark=False, mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), dist_cfg=dict(backend='nccl')) vis_backends = [dict(type='LocalVisBackend')] visualizer = dict( type='ClsVisualizer', vis_backends=[dict(type='LocalVisBackend')]) log_level = 'INFO' resume = False randomness = dict(seed=None, deterministic=False)
训练命令:
python tool/tools/train.py --config 上面的配置文件的路径即可
参考:https://github.com/wangruohui/sjtu-openmmlab-tutorial/blob/main/cls-2-train.ipynb
其他详解:如何编写配置文件?
https://www.bilibili.com/video/BV1J341127nQ?p=9&vd_source=2ed6e8af02f9ba8cb90b90e99bd4ccee
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 分享一个免费、快速、无限量使用的满血 DeepSeek R1 模型,支持深度思考和联网搜索!
· 基于 Docker 搭建 FRP 内网穿透开源项目(很简单哒)
· ollama系列01:轻松3步本地部署deepseek,普通电脑可用
· 25岁的心里话
· 按钮权限的设计及实现