关于卷积花分类的一些代码

CNN

复制代码
import torch.nn as nn



class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3,16,5,1,2)
        self.pool1 = nn.MaxPool2d(8)
        self.conv2 = nn.Conv2d(16,32,5,1,2)
        self.pool2 = nn.MaxPool2d(4)
        self.fc = nn.Linear(32*7*7,5)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.pool1(x)
        tmp = x.shape
        x = self.conv2(x)
        x = self.relu(x)
        x = self.pool2(x)
        tmp = x.shape
        x = x.view(-1, 32*7*7)
        x= self.fc(x)
        return x
复制代码

Main

复制代码
import torch
import os
from torchvision import transforms
import My_dataset
from torch.utils.data import DataLoader
from CNN import CNN
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter(log_dir='runs/flowers_experiment')
USE_GPU = True
LR = 0.0001
TIMES = 20
batch_size = 8
num_worker = min([os.cpu_count(), batch_size if batch_size>1 else 0,8]) # type: ignore
def main(root:str):
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)
    train_path, train_label, test_path, test_label = My_dataset.read_split(root=root ,test_rate=0.1)

    data_transforms = {                      #数据集的处理方法
        "train": transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "test":transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
    }
    #训练集
    train_set = My_dataset.My_Dataset(img_path=train_path,
                           img_label=train_label,
                           transforms=data_transforms["train"])
    #测试集
    test_set = My_dataset.My_Dataset(img_path=test_path,
                          img_label=test_label,
                          transforms=data_transforms["test"])
    
    train_loader = DataLoader(dataset=train_set,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=num_worker,
                              collate_fn=train_set.collate_fn)
    test_loader = DataLoader(dataset=test_set,
                             shuffle=True,
                             collate_fn=train_set.collate_fn)


    #My_dataset.plot_load_image(train_loader)

    #test 无需打包
    cnn = CNN()
    loss_function = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
    if torch.cuda.is_available() and USE_GPU == True:
        cnn = cnn.cuda()
        loss_function = loss_function.cuda()
    
    #train
    for times in range(TIMES):
        for data in train_loader:
            images, labels = dataif torch.cuda.is_available() and USE_GPU == True:
                images = images.cuda()
                labels = labels.cuda()
            output = cnn(images)
            loss = loss_function(output, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        writer.add_scalar('训练损失值', loss, times)
        writer.add_scalar('梯度', optimizer.param_groups[0]["lr"], times)
    
    '''
    #test
    wrong_num = 0
    for data in test_loader:
        x, real_y = data
        if torch.cuda.is_available() and USE_GPU == True:
            x = x.cuda()
            real_y = real_y.cuda()
        pred_y = cnn(x)
        if pred_y != real_y:
            wrong_num+=1
    print("the right num:{}".format(int(len(train_loader))))
    '''

    #test
    wrong_num = 0
    i = 0
    for data in test_loader:
        images, labels = data
        if torch.cuda.is_available() and USE_GPU == True:
            images = images.cuda()
            labels = labels.cuda()
        #print(images)
        #exit()
        tmp_output = cnn(images)
        pred_y = torch.max(torch.softmax(tmp_output, dim=1), dim=1)[1].data
        if torch.cuda.is_available() and USE_GPU == True:
            pred_y = pred_y.cuda()
        if pred_y != labels:
            print(pred_y, labels)
            wrong_num +=1
        i+=1
    print("wrong num:{} , sum:{}".format(wrong_num, i))

a = './flower_photos'

if __name__ == '__main__':
    main(a)
复制代码

My_dataset

复制代码
from torch.utils.data import Dataset
from PIL import Image
import torch
import os
import random
import numpy
import matplotlib.pyplot as plt

class My_Dataset(Dataset):
    def __init__(self, img_path: list, img_label: list, transforms= None):
        self.img_path = img_path
        self.img_label = img_label
        self.transforms = transforms

    def __len__(self):
        return len(self.img_path)
    
    def __getitem__(self, item):
        img = Image.open(self.img_path[item])
        if img.mode != 'RGB':    #只处理RGB图像
            raise ValueError("image:{} is not the RGB".format(self.img_path[item]))
        label = self.img_label[item]

        if self.transforms is not None:
            img = self.transforms(img)
        
        return img, label
    
    @staticmethod
    def collate_fn(batch):
        images, labels = tuple(zip(*batch))
        images = torch.stack(images, dim=0)
        labels = torch.as_tensor(labels)
        return images, labels


#test_rate  测试集占全部数据的百分比。默认是0.2
def read_split(root: str, test_rate: float = 0.2):
    if os.path.exists(root) == False:
        print("--the dataset does not exict.--")
        exit()
    #这里默认是遍历文件并提取出文件夹,也就是类别的名称
    Myclass=[cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
    #print(Myclass)
    Myclass.sort()
    #建立索引
    index = list(range(0,len(Myclass)))
    Myclass_index = {Myclass[i]: index[i] for i in range(len(Myclass))}
    #print(Myclass_index)
    #print(Myclass_index['roses'])
    file = open('./index.index','w')
    file.write(str(Myclass_index))
    file.close()

    train_path = []
    train_label = []

    test_path = []
    test_label = []
    class_num = []   #每个类别的样本个数
    
    for cla in Myclass:
        cla_path = os.path.join(root, cla)    #类别的文件目录’
        img_path = [os.path.join(root, cla, name) for name in os.listdir(cla_path)]
        img_class = Myclass_index[cla]   #记录图片所属的类别
        #print(img_path)
        class_num.append(len(img_path))

        test_path_tmp = random.sample(img_path, k=int(len(img_path)*test_rate))
        
        for path in img_path:
            if path in test_path_tmp:
                test_path.append(path)
                test_label.append(img_class)
            else:
                train_path.append(path)
                train_label.append(img_class)
    

    return train_path, train_label, test_path, test_label



def plot_load_image(data_loader):
    batch_size = data_loader.batch_size
    plot_num = min(batch_size, 4)
    path = './index.index'
    assert os.path.exists(path), path + 'does not exist!!'
    file = open(path ,'r')
    class_index = eval(file.read())
    class_index = dict(zip(class_index.values(), class_index.keys()))
    print(class_index)
    for data in data_loader:
        images, tmp = data
        for i in range(plot_num):
            img = images[i].numpy().transpose(1,2,0)
            img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255
            
            labels = tmp[i].item()
            plt.subplot(1, plot_num, i+1)
            plt.xlabel(class_index[labels])
            plt.xticks([])
            plt.yticks([])
            plt.imshow(img.astype('uint8'))
        plt.show()
复制代码

 

posted @   ZeroHzzzz  阅读(10)  评论(0编辑  收藏  举报
编辑推荐:
· AI与.NET技术实操系列(二):开始使用ML.NET
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
阅读排行:
· 全程不用写代码,我用AI程序员写了一个飞机大战
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· .NET10 - 预览版1新功能体验(一)
点击右上角即可分享
微信分享提示