关于卷积花分类的一些代码
CNN
import torch.nn as nn class CNN(nn.Module): def __init__(self): super(CNN, self).__init__() self.conv1 = nn.Conv2d(3,16,5,1,2) self.pool1 = nn.MaxPool2d(8) self.conv2 = nn.Conv2d(16,32,5,1,2) self.pool2 = nn.MaxPool2d(4) self.fc = nn.Linear(32*7*7,5) self.relu = nn.ReLU(inplace=True) def forward(self, x): x = self.conv1(x) x = self.relu(x) x = self.pool1(x) tmp = x.shape x = self.conv2(x) x = self.relu(x) x = self.pool2(x) tmp = x.shape x = x.view(-1, 32*7*7) x= self.fc(x) return x
Main
import torch import os from torchvision import transforms import My_dataset from torch.utils.data import DataLoader from CNN import CNN import torch.nn as nn from torch.utils.tensorboard import SummaryWriter writer = SummaryWriter(log_dir='runs/flowers_experiment') USE_GPU = True LR = 0.0001 TIMES = 20 batch_size = 8 num_worker = min([os.cpu_count(), batch_size if batch_size>1 else 0,8]) # type: ignore def main(root:str): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(device) train_path, train_label, test_path, test_label = My_dataset.read_split(root=root ,test_rate=0.1) data_transforms = { #数据集的处理方法 "train": transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]), "test":transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) } #训练集 train_set = My_dataset.My_Dataset(img_path=train_path, img_label=train_label, transforms=data_transforms["train"]) #测试集 test_set = My_dataset.My_Dataset(img_path=test_path, img_label=test_label, transforms=data_transforms["test"]) train_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True, num_workers=num_worker, collate_fn=train_set.collate_fn) test_loader = DataLoader(dataset=test_set, shuffle=True, collate_fn=train_set.collate_fn) #My_dataset.plot_load_image(train_loader) #test 无需打包 cnn = CNN() loss_function = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(cnn.parameters(), lr=LR) if torch.cuda.is_available() and USE_GPU == True: cnn = cnn.cuda() loss_function = loss_function.cuda() #train for times in range(TIMES): for data in train_loader: images, labels = dataif torch.cuda.is_available() and USE_GPU == True: images = images.cuda() labels = labels.cuda() output = cnn(images) loss = loss_function(output, labels) optimizer.zero_grad() loss.backward() optimizer.step() writer.add_scalar('训练损失值', loss, times) writer.add_scalar('梯度', optimizer.param_groups[0]["lr"], times) ''' #test wrong_num = 0 for data in test_loader: x, real_y = data if torch.cuda.is_available() and USE_GPU == True: x = x.cuda() real_y = real_y.cuda() pred_y = cnn(x) if pred_y != real_y: wrong_num+=1 print("the right num:{}".format(int(len(train_loader)))) ''' #test wrong_num = 0 i = 0 for data in test_loader: images, labels = data if torch.cuda.is_available() and USE_GPU == True: images = images.cuda() labels = labels.cuda() #print(images) #exit() tmp_output = cnn(images) pred_y = torch.max(torch.softmax(tmp_output, dim=1), dim=1)[1].data if torch.cuda.is_available() and USE_GPU == True: pred_y = pred_y.cuda() if pred_y != labels: print(pred_y, labels) wrong_num +=1 i+=1 print("wrong num:{} , sum:{}".format(wrong_num, i)) a = './flower_photos' if __name__ == '__main__': main(a)
My_dataset
from torch.utils.data import Dataset from PIL import Image import torch import os import random import numpy import matplotlib.pyplot as plt class My_Dataset(Dataset): def __init__(self, img_path: list, img_label: list, transforms= None): self.img_path = img_path self.img_label = img_label self.transforms = transforms def __len__(self): return len(self.img_path) def __getitem__(self, item): img = Image.open(self.img_path[item]) if img.mode != 'RGB': #只处理RGB图像 raise ValueError("image:{} is not the RGB".format(self.img_path[item])) label = self.img_label[item] if self.transforms is not None: img = self.transforms(img) return img, label @staticmethod def collate_fn(batch): images, labels = tuple(zip(*batch)) images = torch.stack(images, dim=0) labels = torch.as_tensor(labels) return images, labels #test_rate 测试集占全部数据的百分比。默认是0.2 def read_split(root: str, test_rate: float = 0.2): if os.path.exists(root) == False: print("--the dataset does not exict.--") exit() #这里默认是遍历文件并提取出文件夹,也就是类别的名称 Myclass=[cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))] #print(Myclass) Myclass.sort() #建立索引 index = list(range(0,len(Myclass))) Myclass_index = {Myclass[i]: index[i] for i in range(len(Myclass))} #print(Myclass_index) #print(Myclass_index['roses']) file = open('./index.index','w') file.write(str(Myclass_index)) file.close() train_path = [] train_label = [] test_path = [] test_label = [] class_num = [] #每个类别的样本个数 for cla in Myclass: cla_path = os.path.join(root, cla) #类别的文件目录’ img_path = [os.path.join(root, cla, name) for name in os.listdir(cla_path)] img_class = Myclass_index[cla] #记录图片所属的类别 #print(img_path) class_num.append(len(img_path)) test_path_tmp = random.sample(img_path, k=int(len(img_path)*test_rate)) for path in img_path: if path in test_path_tmp: test_path.append(path) test_label.append(img_class) else: train_path.append(path) train_label.append(img_class) return train_path, train_label, test_path, test_label def plot_load_image(data_loader): batch_size = data_loader.batch_size plot_num = min(batch_size, 4) path = './index.index' assert os.path.exists(path), path + 'does not exist!!' file = open(path ,'r') class_index = eval(file.read()) class_index = dict(zip(class_index.values(), class_index.keys())) print(class_index) for data in data_loader: images, tmp = data for i in range(plot_num): img = images[i].numpy().transpose(1,2,0) img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255 labels = tmp[i].item() plt.subplot(1, plot_num, i+1) plt.xlabel(class_index[labels]) plt.xticks([]) plt.yticks([]) plt.imshow(img.astype('uint8')) plt.show()
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· AI与.NET技术实操系列(二):开始使用ML.NET
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· 全程不用写代码,我用AI程序员写了一个飞机大战
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· .NET10 - 预览版1新功能体验(一)