ResNet50

1、样本集下载与划分

下载,https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_5340.zip

划分,将Cat和Dog文件夹里的图片按3:1:1复制到各自的train、val、test文件夹

目录样式如下:

  --train
      |--Cat
      |--Dog
  --val
      |--Cat
      |--Dog
  --test
      |--Cat
      |--Dog

由于里边有些非RGB的文件,剔除掉:

"""
    删除非RGB的文件
"""

import os
from PIL import Image

import warnings #如果有警告,当成错误,以便try、except捕捉错误
warnings.filterwarnings("error", category=UserWarning)


def remove(src_dir):
    remove_count = 0 #统计删除了多少
    file_path = ''
    file_list = os.listdir(src_path)
    try:
        for file_name in file_list:
            file_path = os.path.join(src_dir,file_name) #文件全路径
            if os.path.isdir(file_path): #如果是目录
                print('请使用存放图片的根文件夹')
                return
            #删除非图片文件,和图片不是RGB格式的
            if os.path.splitext(file_name)[1].lower() in ['.jpg','.jpeg','.png','.bmp']:
                img = Image.open(file_path)
                if img.mode!='RGB':
                    os.remove(file_path)
                    remove_count += 1
            else:
                os.remove(file_path)
                remove_count += 1
    except Exception as e:
        print('异常文件{}'.format(file_path))
        os.remove(file_path)
        remove_count += 1
    print('删除了{}个文件'.format(remove_count))

if __name__ == '__main__':
    src_path = r"/home/jv/datasets/src/cat_dog/Cat"      # 输入路径,注意写到图片的根文件夹
    remove(src_path)

之后划分:

import os
import random
from shutil import copy2

"""
    读取源数据文件夹,生成划分好的文件夹,分为train、val、test三个文件夹进行
    :param src_dir:             源文件夹
    :param dst_dir:             目标文件夹
    :param train_scale:         训练集比例
    :param val_scale:           验证集比例
    :param test_scale:          测试集比例
    :return:
"""
def data_set_split(src_dir, dst_dir, train_scale=0.6, val_scale=0.2, test_scale=0.2):
    print("开始数据集划分")
    class_names = os.listdir(src_dir)
    split_names = ['train', 'val', 'test']  # 在目标目录下创建文件夹
    for split_name in split_names:
        split_path = os.path.join(dst_dir, split_name)
        if os.path.isdir(split_path):
            pass
        else:
            os.mkdir(split_path)
        for class_name in class_names:      # 然后在split_path的目录下创建类别文件夹
            class_split_path = os.path.join(split_path, class_name)
            if os.path.isdir(class_split_path):
                pass
            else:
                os.mkdir(class_split_path)
    for class_name in class_names:          # 首先进行分类遍历, 按照比例划分数据集, 并进行数据图片的复制
        current_class_data_path = os.path.join(src_dir, class_name)
        current_all_data = os.listdir(current_class_data_path)
        current_data_length = len(current_all_data)
        current_data_index_list = list(range(current_data_length))
        random.shuffle(current_data_index_list)
        train_folder = os.path.join(os.path.join(dst_dir, 'train'), class_name)
        val_folder = os.path.join(os.path.join(dst_dir, 'val'), class_name)
        test_folder = os.path.join(os.path.join(dst_dir, 'test'), class_name)
        train_stop_flag = current_data_length * train_scale
        val_stop_flag = current_data_length * (train_scale + val_scale)
        current_idx = 0
        train_num = 0
        val_num = 0
        test_num = 0
        for i in current_data_index_list:
            src_img_path = os.path.join(current_class_data_path, current_all_data[i])
            if current_idx <= train_stop_flag:
                copy2(src_img_path, train_folder)   # print("{}复制到了{}".format(src_img_path, train_folder))
                train_num = train_num + 1
            elif (current_idx > train_stop_flag) and (current_idx <= val_stop_flag):
                copy2(src_img_path, val_folder)     # print("{}复制到了{}".format(src_img_path, val_folder))
                val_num = val_num + 1
            else:
                copy2(src_img_path, test_folder)    # print("{}复制到了{}".format(src_img_path, test_folder))
                test_num = test_num + 1
            current_idx = current_idx + 1
        print("*********************************{}*************************************".format(class_name))
        print("{}类按照{}:{}:{}的比例划分完成,一共{}张图片".format(
                class_name, train_scale, val_scale, test_scale, current_data_length))
        print("训练集{}:{}张".format(train_folder, train_num))
        print("验证集{}:{}张".format(val_folder, val_num))
        print("测试集{}:{}张".format(test_folder, test_num))

if __name__ == '__main__':
    src_path = r"/home/jv/datasets/src/cat_dog"     # 输入路径
    dst_path = r"/home/jv/datasets/dst/cat_dog"     # 输出路径
    data_set_split(src_path, dst_path)

[参考] 划分到文件夹或txt里 https://blog.51cto.com/u_16213607/9455774

2、查找最优的ResNet50预训练版本

 

从具体的预训练模型目录 <no title> — Torchvision 0.20 documentation 中,可以知道表现最好的ResNet50版本

 

2、加载ResNet50预训练模型

加载预训练模型使用TorchVision方式,torch.hub.load方式本文不再研究。

from torchvision import models
resnet50 = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)

 

会自动下载指定的模型

 

 

 

关于训练的样本尺寸:

 

ResNet50用的是3*224*224尺寸的图像训练的,因为原始样本集基本在300左右尺寸,7*2的幂次 最接近300的就是224。为什么用 7*2的幂次,因为卷积核尺寸7、步长2。

 

所以如果自己的样本集尺寸是640*384,可以设置成 3*448*448 训练

 

 

训练

resnet50训练部署教程 - 知乎

Pytorch迁移学习使用Resnet50进行模型训练预测猫狗二分类-阿里云开发者社区

ResNet50训练自己的分类模型(pt、onnx)_resnet50 onnx-CSDN博客

 

保存训练出来的最优模型,这方面有两种方式(不同方式保存对应的加载方式也不同)

  两种方式的优缺点 torch保存和加载 模型、参数_torch加载模型-CSDN博客

  torch.save()与torch.jit.script()  参考PyTorch模型保存的两种方式-百度开发者中心

  torch.save(model, path) 与 torch.save(model.state_dict(), path)  参考 pytorch保存模型参数 pytorch 保存整个模型_mob6454cc6c8549的技术博客_51CTO博客

  PyTorch中通过torch.save保存模型和torch.load加载模型介绍_pytorch_fengbingchun-华为开发者空间

  (从0-1带你了解)Pytorch之模型的读取_pytorch 读取模型-CSDN博客

 

部署

模型部署|ResNet50基于TensorRT FP16生成Engnie文件的C++工程_《自动驾驶中的深度学习模型量化、部署、加速实战》(源代码)-CSDN专栏

 

 

【ResNet50相关内容】

How to Train State-Of-The-Art Models Using TorchVision’s Latest Primitives | PyTorch

 

posted @ 2024-10-23 18:23  夕西行  阅读(69)  评论(0编辑  收藏  举报