MMClassification 实践笔记

 


1. 配置环境

参考文档:https://mmclassification.readthedocs.io/zh_CN/dev-1.x/get_started.html

git clone -b 1.x https://github.com/open-mmlab/mmclassification.git
## 容易失败,参考 https://blog.csdn.net/good_good_xiu/article/details/118567249
cd mmclassification
conda install pytorch torchvision -c pytorch
pip install -U openmim && mim install -e . -i https://pypi.tuna.tsinghua.edu.cn/simple

2. 验证

验证 GPU 是否可用?

python

import torch
torch.cuda.is_available() # True

第 1 步 我们需要下载配置文件和模型权重文件

mim download mmcls --config resnet50_8xb32_in1k --dest .

第 2 步 验证示例的推理流程

python demo/image_demo.py demo/demo.JPEG resnet50_8xb32_in1k.py resnet50_8xb32_in1k_20210831-ea4938fc.pth --device cpu

3. 训练自定义数据集

数据组织:切记一定要这种格式。

.
└── xx_data
├── train
│ ├── class1(xx张图像)
│ ├── class2(xx张图像)
└── val
│ ├── class1(xx张图像)
│ ├── class2(xx张图像)

配置文件:

注意:num_classes=7, data_prefix='/data/xxx/datasets/age_classify/age_data/train',

model = dict(
type='ImageClassifier',
backbone=dict(type='MobileNetV2', widen_factor=1.0),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=7,
in_channels=1280,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, )))
load_from = 'mobilenet_v2_batch256_imagenet_20200708-3b2dc3af.pth'
# dataset_type = 'ImageNet'
# data_preprocessor = dict(
# num_classes=1000,
# mean=[123.675, 116.28, 103.53],
# std=[58.395, 57.12, 57.375],
# to_rgb=True)
# train_pipeline = [
# dict(type='LoadImageFromFile'),
# dict(type='RandomResizedCrop', scale=224, backend='pillow'),
# dict(type='RandomFlip', prob=0.5, direction='horizontal'),
# dict(type='PackClsInputs')
# ]
# test_pipeline = [
# dict(type='LoadImageFromFile'),
# dict(type='ResizeEdge', scale=256, edge='short', backend='pillow'),
# dict(type='CenterCrop', crop_size=224),
# dict(type='PackClsInputs')
# ]
train_dataloader = dict(
batch_size=64,
num_workers=5,
dataset=dict(
type='CustomDataset',
# data_root='data/imagenet',
# ann_file='meta/train.txt',
data_prefix='/data/xxx/datasets/age_classify/age_data/train',
pipeline=[
dict(type='LoadImageFromFile'),
dict(type='RandomResizedCrop', scale=168, backend='pillow'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackClsInputs')
]),
sampler=dict(type='DefaultSampler', shuffle=True))
val_dataloader = dict(
batch_size=64,
num_workers=5,
dataset=dict(
type='CustomDataset',
# data_root='data/imagenet',
# ann_file='meta/val.txt',
data_prefix='/data/xxx/datasets/age_classify/age_data/val',
pipeline=[
dict(type='LoadImageFromFile'),
dict(type='ResizeEdge', scale=224, edge='short', backend='pillow'),
dict(type='CenterCrop', crop_size=168),
dict(type='PackClsInputs')
]),
sampler=dict(type='DefaultSampler', shuffle=False))
val_evaluator = dict(type='Accuracy', topk=(1, ))
test_dataloader = dict(
batch_size=32,
num_workers=5,
dataset=dict(
type='CustomDataset',
# data_root='data/imagenet',
# ann_file='meta/val.txt',
data_prefix='/data/xxx/datasets/age_classify/age_data/val',
pipeline=[
dict(type='LoadImageFromFile'),
dict(type='ResizeEdge', scale=224, edge='short', backend='pillow'),
dict(type='CenterCrop', crop_size=168),
dict(type='PackClsInputs')
]),
sampler=dict(type='DefaultSampler', shuffle=False))
test_evaluator = dict(type='Accuracy', topk=(1,))
optim_wrapper = dict(
optimizer=dict(type='SGD', lr=0.005, momentum=0.9, weight_decay=4e-05))
param_scheduler = dict(type='StepLR', by_epoch=True, step_size=1, gamma=0.98)
train_cfg = dict(by_epoch=True, max_epochs=100, val_interval=5)
val_cfg = dict()
test_cfg = dict()
auto_scale_lr = dict(base_batch_size=256)
default_scope = 'mmcls'
default_hooks = dict(
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=10),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', interval=5),
sampler_seed=dict(type='DistSamplerSeedHook'),
visualization=dict(type='VisualizationHook', enable=False))
env_cfg = dict(
cudnn_benchmark=False,
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
dist_cfg=dict(backend='nccl'))
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='ClsVisualizer', vis_backends=[dict(type='LocalVisBackend')])
log_level = 'INFO'
resume = False
randomness = dict(seed=None, deterministic=False)

训练命令:

python tool/tools/train.py --config 上面的配置文件的路径即可

参考:https://github.com/wangruohui/sjtu-openmmlab-tutorial/blob/main/cls-2-train.ipynb

其他详解:如何编写配置文件?

https://www.bilibili.com/video/BV1J341127nQ?p=9&vd_source=2ed6e8af02f9ba8cb90b90e99bd4ccee

posted @   Zenith_Hugh  阅读(505)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 分享一个免费、快速、无限量使用的满血 DeepSeek R1 模型,支持深度思考和联网搜索!
· 基于 Docker 搭建 FRP 内网穿透开源项目(很简单哒)
· ollama系列01:轻松3步本地部署deepseek,普通电脑可用
· 25岁的心里话
· 按钮权限的设计及实现

喜欢请打赏

扫描二维码打赏

微信打赏

点击右上角即可分享
微信分享提示