第十六节,卷积神经网络之AlexNet网络实现(六)
目录
上一节内容已经详细介绍了AlexNet的网络结构。这节主要通过Tensorflow来实现AlexNet。
这里做测试我们使用的是CIFAR-10数据集介绍数据集,关于该数据集的具体信息可以通过以下链接查看:https://blog.csdn.net/davincil/article/details/78793067下面粗略的介绍一下CIFAR-10数据集。
一 CIFAR-10数据集
1.1 CIFAR-10数据集介绍
CIFAR-10数据集由10类32x32的彩色图片组成,一共包含60000张图片,每一类包含6000图片。其中50000张图片作为训练集,10000张图片作为测试集。
CIFAR-10数据集被划分成了5个训练的batch和1个测试的batch,每个batch均包含10000张图片。测试集batch的图片是从每个类别中随机挑选的1000张图片组成的,训练集batch以随机的顺序包含剩下的50000张图片。不过一些训练集batch可能出现包含某一类图片比其他类的图片数量多的情况。训练集batch包含来自每一类的5000张图片,一共50000张训练图片。
数据集下载地址:http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
文件下载后,解压cifar-10-python.tar.gz,得到cifar-10-batches-py文件夹,打开该文件夹,我们会看到有如下文件:
其中每个文件的作用如下:
- batches.meta 程序中不需要使用该文件;
- data_batch_1 训练集的第一个batch,含有10000张图片;
- data_batch_2 训练集的第二个batch,含有10000张图片;
- data_batch_3 训练集的第三个batch,含有10000张图片;
- data_batch_4 训练集的第四个batch,含有10000张图片;
- data_batch_5 训练集的第五个batch,含有10000张图片;
- readme.html 网页文件,程序中不需要使用该文件;
- test_batch 测试集的batch,含有10000张图片;
上述文件结构中,每一个batch文件包含一个python的字典(dict)结构,结构如下:
- b'data’ 是一个10000x3072的array,每一行的元素组成了一个32x32的3通道图片,共10000张;
- b'labels’ 一个长度为10000的list,对应包含data中每一张图片的 label;
- b'batch_label' 这一份batch的名称;
- b'filenames' 一个长度为10000的list,对应包含data中每一张图片的名称;
由于数据集比较大,在训练的时候如果把所有数据一次性加载到内存训练,会出现内容不足的问题,因此先从batch中读取所有图片的数据,以及每一张图片对应的标签,然后我们创建一个文件夹叫做CIFAR-10-data。
在这个文件夹下面创建train和test文件夹,然后在每个文件夹下面创建名称从0-9的文件夹,我们利用OpenCV把每一张图片保存在对应文件夹下面。
后面再创建两个文件,一个叫做CIFAR-10-test-label.pkl,另一个叫做CIFAR-10-train-label.pkl,均保存由如下元组:(测试集或训练集的图片路径,以及对应标签)组成的list集合。
1.2 数据集处理
datagenerator.py文件代码如下:
# -*- coding: utf-8 -*- """ Created on Wed Apr 11 14:51:27 2018 @author: Administrator """ ''' 用于加载数据集合 数据集下载地址:http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz CIFAR-10数据集介绍:https://blog.csdn.net/davincil/article/details/78793067 一、CIFAR-10 CIFAR-10数据集由10类32x32的彩色图片组成,一共包含60000张图片,每一类包含6000图片。其中50000张图片作为训练集,10000张图片作为测试集。 CIFAR-10数据集被划分成了5个训练的batch和1个测试的batch,每个batch均包含10000张图片。 测试集batch的图片是从每个类别中随机挑选的1000张图片组成的,一共10000张测试图片, 训练集batch包含来自每一类的5000张图片,一共50000张训练图片。 训练集batch以随机的顺序包含剩下的50000张图片。 不过一些训练集batch可能出现包含某一类图片比其他类的图片数量多的情况。 ''' ''' 文件下载之后,解压 主要包括以下文件 名称 作用 batches.meta 程序中不需要使用该文件 data_batch_1 训练集的第一个batch,含有10000张图片 data_batch_2 训练集的第二个batch,含有10000张图片 data_batch_3 训练集的第三个batch,含有10000张图片 data_batch_4 训练集的第四个batch,含有10000张图片 data_batch_5 训练集的第五个batch,含有10000张图片 readme.html 网页文件,程序中不需要使用该文件 test_batch 测试集的batch,含有10000张图片 上述文件结构中,每一个batch文件包含一个python的字典(dict)结构,结构如下: 名称 作用 b'data’ 是一个10000x3072的array,每一行的元素组成了一个32x32的3通道图片,共10000张 b'labels’ 一个长度为10000的list,对应包含data中每一张图片的label b'batch_label' 这一份batch的名称 b'filenames' 一个长度为10000的list,对应包含data中每一张图片的名称 ''' import pickle import numpy as np import cv2 from skimage import io class datagenerator(object): def __init__(self): pass def unpickle(self,filename): ''' batch文件中真正重要的两个关键字是data和labels 反序列化出对象 每一个batch文件包含一个python的字典(dict)结构,结构如下: 名称 作用 b'data’ 是一个10000x3072的array,每一行的元素组成了一个32x32的3通道图片,共10000张 b'labels’ 一个长度为10000的list,对应包含data中每一张图片的label b'batch_label' 这一份batch的名称 b'filenames' 一个长度为10000的list,对应包含data中每一张图片的名称 ''' with open(filename,'rb') as f: #默认把字节转换为ASCII编码 这里设置encoding='bytes'直接读取字节数据 因为里面含有图片像素数据 大小从0-255 不能解码为ascii编码,因此先转换成字节类型 后面针对不同项数据再解码,转换为字符串 dic = pickle.load(f,encoding='bytes') return dic def get_image(self,image): ''' 提取每个通道的数据,进行重新排列,最后返回一张32x32的3通道的图片 在字典结构中,每一张图片是以被展开的形式存储(即一张32x32的3通道图片被展开成了3072长度的list), 每一个数据的格式为uint8,前1024个数据表示红色通道,接下来的1024个数据表示绿色通道,最后的1024 个通道表示蓝色通道。 image:每一张图片的数据 数据按照R,G,B通道依次排列 长度为3072 ''' assert len(image) == 3072 #对list进行切片操作,然后reshape r = image[:1024].reshape(32,32,1) g = image[1024:2048].reshape(32,32,1) b = image[2048:].reshape(32,32,1) #numpy提供了numpy.concatenate((a1,a2,...), axis=0)函数。能够一次完成多个数组的拼接。其中a1,a2,...是数组类型的参数 #沿着某个轴拼接,默认为列方向(axis=0) img = np.concatenate((r,g,b),-1) return img def get_data_by_keyword(self,keyword,filelist=[],normalized=False,size=(32,32),one_hot=False): ''' 按照给出的关键字提取batch中的数据(默认是训练集的所有数据) args: keyword:'data’ 或 'labels’ 或 'batch_label' 或 'filenames' 表示需要返回的项 filelist:list 表示要读取的文件集合 normalized:当keyword = 'data',表示是否需要归一化 size:当keyword = 'data',表示需要返回的图片的尺寸 one_hot:当keyword = 'labels'时,one_hot=Flase,返回实际标签 True时返回二值化后的标签 return: keyword = 'data' 返回像素数据 keyword = 'labels' 返回标签数据 keyword = 'batch_label' 返回batch的名称 keyword = 'filenames' 返回图像文件名 ''' #keyword编码为字节 keyword = keyword.encode('ascii') assert keyword in [b'data',b'labels',b'batch_label',b'filenames'] assert type(filelist) is list and len(filelist) != 0 assert type(normalized) is bool assert type(size) is tuple or type(size) is list ret = [] for i in range(len(filelist)): #反序列化出对象 dic = self.unpickle(filelist[i]) if keyword == b'data': #b'data’ 是一个10000x3072的array,每一行的元素组成了一个32x32的3通道图片,共10000张 #合并成一个数组 for item in dic[b'data']: ret.append(item) print('总长度:',len(ret)) elif keyword == b'labels': #b'labels’ 一个长度为10000的list,对应包含data中每一张图片的label #合并成一个数组 for item in dic[b'labels']: ret.append(item) elif keyword == b'batch_label': #b'batch_label' 这一份batch的名称 #合并成一个数组 for item in dic[b'batch_label']: ret.append(item.decode('ascii')) #把数据转换为ascii编码 else: #b'filenames' 一个长度为10000的list,对应包含data中每一张图片的名称 #合并成一个数组 for item in dic[b'filenames']: ret.append(item.decode('ascii')) #把数据转换为ascii编码 if keyword == b'data': if normalized == False: array = np.ndarray([len(ret),size[0],size[1],3],dtype = np.float32) #遍历每一张图片数据 for i in range(len(ret)): #图像进行缩放 array[i] = cv2.resize(self.get_image(ret[i]),size) return array else: array = np.ndarray([len(ret),size[0],size[1],3],dtype = np.float32) #遍历每一张图片数据 for i in range(len(ret)): array[i] = cv2.resize(self.get_image(ret[i]),size)/255 return array pass elif keyword == b'labels': #二值化标签 if one_hot == True: #类别 depth = 10 m = np.zeros([len(ret),depth]) for i in range(len(ret)): m[i][ret[i]] = 1 return m pass #其它keyword直接返回 return ret import os import pickle def save_images(): ''' 报CIFAR-10数据集图片提取出来保存下来 1.创建一个文件夹 CIFAR-10-data 包含两个子文件夹test,train 2.在文革子文件夹创建10个文件夹 文件名依次为0-9 对应10个类别 3.训练集数据生成bmp格式文件,存在对应类别的文件下 4.测试集数据生成bmp格式文件,存在对应类别的文件下 5 生成两个文件train_label.pkl,test_label.pkl 分别保存相应的图片文件路径以及对应的标签 ''' #根目录 root = 'CIFAR-10-data' #如果存在该目录 说明数据存在 if os.path.isdir(root): print(root+'目录已经存在!') return ''' 如果文件夹不存在 创建文件夹 ''' #'data'目录不存在,创建目录 os.mkdir(root) #创建文件失败 if not os.path.isdir(root): print(root+'目录创建失败!') return #创建'test'和'train'目录 以及子文件夹 train = os.path.join(root,'train') os.mkdir(train) if os.path.isdir(train): for i in range(10): name = os.path.join(train,str(i)) os.mkdir(name) test = os.path.join(root,'test') os.mkdir(test) if os.path.isdir(test): for i in range(10): name = os.path.join(test,str(i)) os.mkdir(name) ''' 把训练集数据转换为图片 ''' data_dir = 'cifar-10-batches-py' #数据所在目录 filelist = [] for i in range(5): name = os.path.join(data_dir,str('data_batch_%d'%(i+1))) filelist.append(name) data = datagenerator() #获取训练集数据 train_x = data.get_data_by_keyword('data',filelist, normalized=True,size=(32,32)) #标签 train_y = data.get_data_by_keyword('labels',filelist) #读取图片文件名 train_filename = data.get_data_by_keyword('filenames',filelist) #保存训练集的文件名和标签 train_file_labels = [] #保存图片 for i in range(len(train_x)): #获取图片标签 y = int(train_y[i]) #文件保存目录 dir_name = os.path.join(train,str(y)) #获取文件名 file_name = train_filename[i] #文件的保存路径 file_path = os.path.join(dir_name,file_name) #保存图片 io.imsave(file_path,train_x[i]) #追加第i张图片路径和标签 (文件路径,标签) train_file_labels.append((file_path,y)) if i % 1000 == 0: print('训练集完成度{0}/{1}'.format(i,len(train_x))) for i in range(10): print('训练集前10张图片:',train_file_labels[i]) #保存训练集的文件名和标签 with open('CIFAR-10-train-label.pkl','wb') as f: pickle.dump(train_file_labels,f) print('训练集图片保存成功!\n') ''' 把测试集数据转换为图片 ''' filelist = [os.path.join(data_dir,'test_batch')] #获取训练集数据 数据标准化为0-1之间 test_x = data.get_data_by_keyword('data',filelist, normalized=True,size=(32,32)) #标签 test_y = data.get_data_by_keyword('labels',filelist) #读取图片文件名 test_filename = data.get_data_by_keyword('filenames',filelist) #保存测试卷的文件名和标签 test_file_labels = [] #保存图片 for i in range(len(test_x)): #获取图片标签 y = int(test_y[i]) #文件保存目录 dir_name = os.path.join(test,str(y)) #获取文件名 file_name = test_filename[i] #文件的保存路径 file_path = os.path.join(dir_name,file_name) #保存图片 这里要求图片像素值在-1-1之间,所以在获取数据的时候做了标准化 io.imsave(file_path,test_x[i]) #追加第i张图片路径和标签 (文件路径,标签) test_file_labels.append((file_path,y)) if i % 1000 == 0: print('测试集完成度{0}/{1}'.format(i,len(test_x))) print('测绘集图片保存成功!\n') #保存测试卷的文件名和标签 with open('CIFAR-10-test-label.pkl','wb') as f: pickle.dump(test_file_labels,f) for i in range(10): print('测试集前10张图片:',test_file_labels[i]) def load_data(): ''' 加载数据集 返回训练集数据和测试卷数据 training_data 由(x,y)元组组成的list集合 x:图片路径 y:对应标签 ''' #加载使用的训练集文件名和标签 [(文件路径,标签),....] with open('CIFAR-10-train-label.pkl','rb') as f: training_data = pickle.load(f) #加载使用的测试集文件名和标签 with open('CIFAR-10-test-label.pkl','rb') as f: test_data = pickle.load(f) return training_data,test_data def get_one_hot_label(labels,depth): ''' 把标签二值化 返回numpy.array类型 args: labels:标签的集合 depth:标签总共有多少类 ''' m = np.zeros([len(labels),depth]) for i in range(len(labels)): m[i][labels[i]] = 1 return m def get_image_data_and_label(value,image_size='NONE',depth=10,one_hot = False): ''' 获取图片数据,以及标签数据 注意每张图片维度为 n_w x n_h x n_c args: value:由(x,y)元组组成的numpy.array类型 x:图片路径 y:对应标签 image_size:图片大小 'NONE':不改变图片尺寸 one_hot:把标签二值化 depth:数据类别个数 ''' #图片数据集合 x_batch = [] #图片对应的标签集合 y_batch = [] #遍历每一张图片 for image in value: if image_size == 'NONE': x_batch.append(cv2.imread(image[0])/255) #标准化0-1之间 else: x_batch.append(cv2.resize(cv2.imread(image[0]),image_size)/255) y_batch.append(image[1]) if one_hot == True: #标签二值化 y_batch = get_one_hot_label(y_batch,depth) return np.asarray(x_batch,dtype=np.float32),np.asarray(y_batch,dtype=np.float32) ''' 测试 保存所有图片 ''' save_images()
save_image()函数执行上面所述的功能:
- 创建一个文件夹 CIFAR-10-data 包含两个子文件夹test,train;
- 在子文件夹创建10个文件夹 文件名依次为0-9,对应10个类别;
- 训练集数据生成bmp格式文件,存在对应类别的文件下;
- 测试集数据生成bmp格式文件,存在对应类别的文件下;
- 生成两个文件train_label.pkl,test_label.pkl 分别保存相应的图片文件路径以及对应的标签;
执行完save_image()函数,会生成如下文件:
二 使用传统神经网络训练
2.1 传统神经网络
在使用AlexNet网络进行训练之前,我们先使用传统network进行训练,这里我们设置网络为4层,包括输入层在内,每一层神经元个数如下3072,7200,1024,10。
我们在训练的时候,每次随机读取batch_size大小的图片数据进行训练。
传统network的实现,我是通过定义一个单独的类来完成该功能。network.py文件代码如下
# -*- coding: utf-8 -*- """ Created on Mon Apr 2 10:32:10 2018 @author: Administrator """ ''' 定义一个network类,实现全连接网络 ''' import datagenerator import os from tensorflow.python import pywrap_tensorflow def get_one_hot_label(labels,depth): ''' 把标签二值化 args: labels:标签的集合 depth:标签总共有多少类 ''' m = np.zeros([len(labels),depth]) for i in range(len(labels)): m[i][labels[i]] = 1 return m import tensorflow as tf import numpy as np import random import pickle class network(object): ''' 全连接神经网络 ''' def __init__(self,sizes,param_path= None): ''' 注意程序中op变量只需要初始化一遍就可以,在fit()中初始化 args: sizes:list传入每层神经元个数 param_path:是否从指定文件加载模型,None:重新训练 否则指定模型路径 必须指定路径./或者绝对路径 ''' #保存参数 self.__sizes = sizes #神经网络每一层的神经元个数数组类型 self.sizes = tf.placeholder(tf.int64,shape=[1,len(sizes)]) #计算神经网络层数 包括输入层 self.num_layer = tf.size(self.sizes) #输入样本和输出类别变量 self.x_ = tf.placeholder(tf.float32,shape=[None,sizes[0]]) self.y_ = tf.placeholder(tf.float32,shape=[None,sizes[-1]]) #设置tensorflow对GPU使用按需分配 config = tf.ConfigProto() config.gpu_options.allow_growth = True self.sess = tf.InteractiveSession(config=config) file_exist = False #如果已经存在保存的模型 加载之前保存的w和b if not param_path is None: if os.path.isfile(param_path): with open(param_path,'rb') as f: dic = pickle.load(f) weights = dic['weightes'] biases = dic['biases'] file_exist = True if file_exist: #使用保存的数据初始化 第i层和i+1层之间的权重向量 self.weights = [self.weight_variable(shape=(x,y),value = weights[i]) for x,y,i in zip(sizes[:-1],sizes[1:],range(len(sizes)-1))] #使用保存的数据初始化 第i层的偏置向量 i=1...num_layers 注意不可以设置shape=(x,1) self.biases = [self.bias_variable(shape=[x,],value = biases[i]) for x,i in zip(sizes[1:],range(len(sizes)-1))] print('成功加载参数数据!') else: #使用高斯正态分布初始化 第i层和i+1层之间的权重向量 self.weights = [self.weight_variable(shape=(x,y)) for x,y in zip(sizes[:-1],sizes[1:])] #使用高斯正态分布初始化 第i层的偏置向量 i=1...num_layers 注意不可以设置shape=(x,1) self.biases = [self.bias_variable(shape=[x,]) for x in sizes[1:]] print('重新初始化模型参数!') '''这一段代码是使用tensorflow Saver对象保存 但是在取数据时候总是失败,因此不再使用这中方法 #如果已经存在保存的模型 加载之前保存的w和b if not model_path is None: file = model_path+'.meta' #print(file) #文件存在 则直接加载 if os.path.isfile(file): #加载以前保存的网络 将保存在.meta文件中的图添加到当前的图中 self.new_saver = tf.train.import_meta_graph(file) #从指定目录下获取最近一次检查点 self.new_saver.restore(self.sess,tf.train.latest_checkpoint(os.path.dirname(file))) #使用加载的模型 graph = tf.get_default_graph() #恢复w和b self.weights = [graph.get_tensor_by_name( 'network_w'+str(i)+':0') for i in range(1,len(sizes))] self.biases = [graph.get_tensor_by_name( 'network_b'+str(i)+':0') for i in range(1,len(sizes))] print('成功从模型恢复数据!') sh = [self.sess.run(i).shape for i in self.weights] print('权重维度:',sh) file_exist = True #不存在 再新创建 if not file_exist: #随机初始化权重 第i层和i+1层之间的权重向量 self.weights = [self.weight_variable(shape=(x,y),name='network_w'+str(i)) for x,y,i in zip(sizes[:-1],sizes[1:],range(1,len(sizes)))] #随机初始化偏置 第i层的偏置向量 i=1...num_layers 注意不可以设置shape=(x,1) self.biases = [self.bias_variable(shape=[x,],name='network_b'+str(i)) for x,i in zip(sizes[1:],range(1,len(sizes)))] print('重新初始化模型参数!') #创建Saver op用于保存新的训练参数 指定保存哪些变量,如果全部保存,在恢复时,会有错误 主要是由于偏置和权重我保存在了一个dict中引起的 param =[w for w in self.weights] for i in self.biases: param.append(i) self.saver = tf.train.Saver(param)''' def weight_variable(self,shape,value=None,name=None): ''' 初始化权值 ''' if value is None: #使用截断式正太分布初始化权值 截断式即在正态分布基础上加以限制,以使生产的数据在一定范围上 value = tf.truncated_normal(shape,mean=0.0,stddev= 1.0/shape[0]) #方差为1/nin if name is None: return tf.Variable(value) else: return tf.Variable(value,name=name) def bias_variable(self,shape,value=None,name=None): ''' #初始化偏重 ''' if value is None: value = tf.truncated_normal(shape,mean=0.0,stddev= 1.0/shape[0]) #方差为1/nin if name is None: return tf.Variable(value) else: return tf.Variable(value,name=name) def feedforward(self,x): ''' 构建阶段:前向反馈 x:变量op,tf.placeholder()类型变量 返回一个op ''' #计算隐藏层 output = x for i in range(len(self.__sizes)-1): b = self.biases[i] w = self.weights[i] if i != len(self.__sizes)-2 : output = tf.nn.relu(tf.matmul(output,w) + b) else: output = tf.nn.softmax(tf.matmul(output,w) + b) return output def fit(self,training_data,learning_rate=0.001,batch_size=64,epoches=10): ''' 训练神经网络 training_data (x,y)元祖组成的list x:训练集图片数据路径 y:训练集样本对应的标签 learning_rate:学习率 batch_size:批量大小 epoches:迭代轮数 ''' #计算输出层 output = self.feedforward(self.x_) #代价函数 J =-(Σy.logaL)/n .表示逐元素乘 cost = tf.reduce_mean( -tf.reduce_sum(self.y_*tf.log(output),axis = 1)) #求解 train = tf.train.AdamOptimizer(learning_rate).minimize(cost) #使用会话执行图 #初始化变量 必须在train之后 变量出事后化之后才可以显示相关属性 self.sess.run(tf.global_variables_initializer()) sh = [self.sess.run(i).shape for i in self.weights] print('权重维度:',sh) #训练集长度 n_train =len(training_data) #开始迭代 for i in range(epoches): #打乱训练集 random.shuffle(training_data) #分组 mini_batchs = [training_data[k:k+batch_size] for k in range(0,n_train,batch_size)] #遍历每一个mini_batch for mini_batch in mini_batchs: x_batch,y_batch = datagenerator.get_image_data_and_label(mini_batch,one_hot=True) #每张图片数据展开正一列 x_batch = x_batch.reshape(len(mini_batch),-1) train.run(feed_dict={self.x_:x_batch,self.y_:y_batch}) ''' #计算每一轮迭代后在整个训练集的误差 并打印 ''' train_cost_sum = [] train_accuracy_sum = [] #遍历每一个mini_batch for mini_batch in mini_batchs: train_cost = cost.eval(feed_dict={self.x_:x_batch,self.y_:y_batch}) train_cost_sum.append(train_cost) train_accuracy = self.accuracy(x_batch,y_batch) train_accuracy_sum.append(train_accuracy) print('Epoch {0} Training set cost {1} accuracy {2}:'.format(i,np.mean(train_cost),np.mean(train_accuracy))) def predict(self,test_x): ''' 对输入test_x样本进行预测(图片数据) nxm维 n为样本个数 m为数据长度 ''' output = self.feedforward(self.x_) #使用会话执行图 return output.eval(feed_dict={self.x_:test_x}) def accuracy(self,x,y): ''' 返回值准确率 x:测试样本集合(图片数据) nxm维 n为样本个数 m为数据长度 y:测试类别集合 也是二值化的标签 ''' output = self.feedforward(self.x_) correct = tf.equal(tf.argmax(output,1),tf.argmax(self.y_,1)) #返回一个数组 表示统计预测正确或者错误 accuracy = tf.reduce_mean(tf.cast(correct,tf.float32)) #求准确率 #使用会话执行图 return accuracy.eval(feed_dict={self.x_:x,self.y_:y}) def cost(self,x,y): ''' 计算代价值 x:样本集合(图片数据) nxm维 n为样本个数 m为数据长度 y:类别集合 也是二值化的标签 ''' #计算输出层 output = self.feedforward(self.x_) #代价函数 J =-(Σy.logaL)/n .表示逐元素乘 cost = tf.reduce_mean( -tf.reduce_sum(self.y_*tf.log(output),axis = 1)) #使用会话执行图 return cost.eval(feed_dict={self.x_:x,self.y_:y}) def remove_model_file(self,file): ''' 删除保存的模型文件 删除成功返回True 否则返回 False args: file:模型文件名 ''' path = [file+item for item in ('.index','.meta')] path.append(os.path.join( os.path.dirname(path[0]),'checkpoint')) for f in path: if os.path.isfile(f): os.remove(f) ret = True #检查文件是否存在 for f in path: if os.path.isfile(f): ret = False return ret def save_model(self,file): ''' 保存模型 args: file:保存文件名 ''' ''' #先删除之前的模型数据 if self.remove_model_file(file) == True: self.saver.save(self.sess,file) print('模型保存成功!')''' #把参数序列化并保存 weights = [item.eval() for item in self.weights] biases = [item.eval() for item in self.biases] dic = {'weightes':weights,'biases':biases} with open(file,'wb') as f: pickle.dump(dic,f) print('模型参数保存成功!') def global_variables_initializer(self): ''' 全局张量初始化,由于全局变量只在fit()训练时进行初始化,如果我们直接加载已经训练好的数据 就不需要调用fit()函数,因此只有收到调用这个函数就行初始化权重和偏置以及其他张量 然后才可以执行图 如accuracy()等等 ''' self.sess.run(tf.global_variables_initializer())
然后我们使用该网络进行测试,并且把测试后的权重和偏置保存在指定文件,下次可以使用该数据初始化,然后继续训练
- 我们先读取CIFAR-10-train-label.pkl,获取每张的训练图片路径以及标签组成的元组,保存在一个字典中;
- 我们先把该字典打乱,然后选择batch_size张图片,读取图片,以及one_hot(二值化)标签,组成一个mini_batch;
- 使用mini_batch进行训练;
- 迭代字典中的所有数据,迭代完一轮,我们继续迭代,直至到我们设置的迭代轮数;
在这里有个地方需要注意:由于我们每次都是从指定路径读取图片,因此速度比较慢,其实我们可以利用get_data_by_keyword()函数一次性读取所有数据和标签,然后保存在内存中,然后每次选择batch_size数据进行训练。这样就避免了每次从硬盘读取图片数据。这里我们取数据的时候可以先设置一个训练集大小的索引字典,然后打乱该字典,我们每次把batch_size个字典元素,传给数据,这样我们就可以达到随机取数据的目的。
# 生成并打乱训练集的顺序。 indices = np.arange(50000) random.shuffle(indices)
2.2 测试代码
测试代码如下:
# -*- coding: utf-8 -*- """ Created on Wed Apr 11 16:42:09 2018 @author: Administrator """ ''' 用于训练network网络 由于内存有限,不能一次读取所有的数据,因此采用每次读取小批量的图片 1.在datagenerator.py文件中 把所有图片保存成bmp格式文件 2.存储一个dict字典,每个元素为 (图片的相对路径,标签),将这个字典序列化保存到文件 3.读取该文件,把元组顺序打乱,分成多组,每次加载mini_batch个图片进行训练 ''' import tensorflow as tf from alexnet import alexnet import datagenerator import numpy as np #import cv2 import random import network import os def remove_model_file(file): ''' 删除保存的模型文件 删除成功返回True 否则返回 False args: file:模型文件名 ''' path = [file+item for item in ('.index','.meta')] path.append(os.path.join( os.path.dirname(path[0]),'checkpoint')) for f in path: if os.path.isfile(f): os.remove(f) ret = True #检查文件是否存在 for f in path: if os.path.isfile(f): ret = False return ret def network_main(): ''' 使用network网络测试 由于数据量比较大,图片不能一次全部加载进来,因此采用分批读取图片数据 ''' ''' 一 加载数据 ''' training_data,test_data = datagenerator.load_data() param_path = './network_param/network_param.pkl' #模型参数保存所在文件 learning_rate = 1e-4 #学习率 training_epoches = 10 #训练轮数 batch_size = 256 #小批量大小 ''' 二 创建网络 ''' #如果已经存在训练好的数据,直接加载初始化 nn = network.network([3072,7200,1024,10],param_path=param_path) ''' 三 开始训练 ''' nn.fit(training_data,learning_rate=learning_rate,batch_size=batch_size,epoches=training_epoches) nn.save_model(param_path) #保存参数 ''' 四 校验 ''' mini_batchs = [test_data[k:k+batch_size] for k in range(0,len(test_data),batch_size)] test_accuracy_sum = [] #如果不训练网络,需要手动初始化张量 #nn.global_variables_initializer() for mini_batch in mini_batchs: x_batch,y_batch = datagenerator.get_image_data_and_label(mini_batch,one_hot=True) #每张图片数据展开正一列 x_batch = x_batch.reshape(len(mini_batch),-1) test_accuracy_sum.append(nn.accuracy(x_batch,y_batch)) print('准确率:',np.mean(test_accuracy_sum)) ''' 测试 ''' if __name__ == '__main__': network_main()
运行结果如下,我们可以看到迭代10次后,在测试集的准确率大概为52%,如果我们使用卷积神经网络的话或者迭代次数更多一些的话,准确率可能会更高,你也可以尝试使用LeNet进行训练,详情可以参考:Tensorflow深度学习之二十一:LeNet的实现(CIFAR-10数据集)https://blog.csdn.net/davincil/article/details/78794044。
三 使用AlexNet网络训练
亲爱的读者和支持者们,自动博客加入了打赏功能,陆陆续续收到了各位老铁的打赏。在此,我想由衷地感谢每一位对我们博客的支持和打赏。你们的慷慨与支持,是我们前行的动力与源泉。
日期 | 姓名 | 金额 |
---|---|---|
2023-09-06 | *源 | 19 |
2023-09-11 | *朝科 | 88 |
2023-09-21 | *号 | 5 |
2023-09-16 | *真 | 60 |
2023-10-26 | *通 | 9.9 |
2023-11-04 | *慎 | 0.66 |
2023-11-24 | *恩 | 0.01 |
2023-12-30 | I*B | 1 |
2024-01-28 | *兴 | 20 |
2024-02-01 | QYing | 20 |
2024-02-11 | *督 | 6 |
2024-02-18 | 一*x | 1 |
2024-02-20 | c*l | 18.88 |
2024-01-01 | *I | 5 |
2024-04-08 | *程 | 150 |
2024-04-18 | *超 | 20 |
2024-04-26 | .*V | 30 |
2024-05-08 | D*W | 5 |
2024-05-29 | *辉 | 20 |
2024-05-30 | *雄 | 10 |
2024-06-08 | *: | 10 |
2024-06-23 | 小狮子 | 666 |
2024-06-28 | *s | 6.66 |
2024-06-29 | *炼 | 1 |
2024-06-30 | *! | 1 |
2024-07-08 | *方 | 20 |
2024-07-18 | A*1 | 6.66 |
2024-07-31 | *北 | 12 |
2024-08-13 | *基 | 1 |
2024-08-23 | n*s | 2 |
2024-09-02 | *源 | 50 |
2024-09-04 | *J | 2 |
2024-09-06 | *强 | 8.8 |
2024-09-09 | *波 | 1 |
2024-09-10 | *口 | 1 |
2024-09-10 | *波 | 1 |
2024-09-12 | *波 | 10 |
2024-09-18 | *明 | 1.68 |
2024-09-26 | B*h | 10 |
2024-09-30 | 岁 | 10 |
2024-10-02 | M*i | 1 |
2024-10-14 | *朋 | 10 |
2024-10-22 | *海 | 10 |
2024-10-23 | *南 | 10 |
2024-10-26 | *节 | 6.66 |
2024-10-27 | *o | 5 |
2024-10-28 | W*F | 6.66 |
2024-10-29 | R*n | 6.66 |
2024-11-02 | *球 | 6 |
2024-11-021 | *鑫 | 6.66 |
2024-11-25 | *沙 | 5 |
2024-11-29 | C*n | 2.88 |

【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 记一次.NET内存居高不下排查解决与启示
· 探究高空视频全景AR技术的实现原理
· 理解Rust引用及其生命周期标识(上)
· 浏览器原生「磁吸」效果!Anchor Positioning 锚点定位神器解析
· 没有源码,如何修改代码逻辑?
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· DeepSeek 开源周回顾「GitHub 热点速览」
· 记一次.NET内存居高不下排查解决与启示
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了