yolo源码解析(二)
五 读取数据pascal_voc.py文件解析
我们在YOLENet类中定义了两个占位符,一个是输入图片占位符,一个是图片对应的标签占位符,如下:
#输入图片占位符 [NONE,image_size,image_size,3] self.images = tf.placeholder( tf.float32, [None, self.image_size, self.image_size, 3], name='images') #设置标签占位符 [None,S,S,5+C] 即[None,7,7,25] self.labels = tf.placeholder( tf.float32, [None, self.cell_size, self.cell_size, 5 + self.num_class])
而pascal_voc.py文件的目的就是为了准备数据,赋值给占位符。在pascal_voc.py文件中定义了一个pascal_voc,该类包含了类初始化函数(__init__()),准备数据函数(prepare()),读取batch大小的图片以及图片对应的标签(get())等函数。
import os import xml.etree.ElementTree as ET import numpy as np import cv2 import pickle import copy import yolo.config as cfg ''' VOC2012数据集处理 ''' class pascal_voc(object):
1、类初始化函数
''' VOC2012数据集处理 ''' class pascal_voc(object): ''' VOC2012数据集处理的类,主要用来获取训练集图片文件,以及生成对应的标签文件 ''' def __init__(self, phase, rebuild=False): ''' 准备训练或者测试的数据 args: phase:传入字符串 'train':表示训练 'test':测试 rebuild:是否重新创建数据集的标签文件,保存在缓存文件夹下 ''' #VOCdevkit文件夹路径 self.devkil_path = os.path.join(cfg.PASCAL_PATH, 'VOCdevkit') #VOC2012文件夹路径 self.data_path = os.path.join(self.devkil_path, 'VOC2012') #catch文件所在路径 self.cache_path = cfg.CACHE_PATH #批大小 self.batch_size = cfg.BATCH_SIZE #图像大小 self.image_size = cfg.IMAGE_SIZE #单元格大小S self.cell_size = cfg.CELL_SIZE #VOC 2012数据集类别名 self.classes = cfg.CLASSES #类别名->索引的dict self.class_to_ind = dict(zip(self.classes, range(len(self.classes)))) ##图片是否采用水平镜像扩充训练集? self.flipped = cfg.FLIPPED #训练或测试? self.phase = phase #是否重新创建数据集标签文件 self.rebuild = rebuild #从gt_labels加载数据,cursor表明当前读取到第几个 self.cursor = 0 #存放当前训练的轮数 self.epoch = 1 #存放数据集的标签 是一个list 每一个元素都是一个dict,对应一个图片 #如果我们在配置文件中指定flipped=True,则数据集会扩充一倍,每一张原始图片都有一个水平对称的镜像文件 # imname:图片路径 # label:图片标签 # flipped:图片水平镜像? self.gt_labels = None #加载数据集标签 初始化gt_labels self.prepare()
2、prepare()所有数据准备函数
prepare()函数调用load_labels()函数,加载所有数据集的标签,保存在遍历gt_labels集合中,如果在配置文件中指定了水平镜像,则追加一倍的训练数据集。
def prepare(self): ''' 初始化数据集的标签,保存在变量gt_labels中 return: gt_labels:返回数据集的标签 是一个list 每一个元素对应一张图片,是一个dict imname:图片文件路径 label:图片文件对应的标签 [7,7,25]的矩阵 flipped:是否使用水平镜像? 设置为False ''' #加载数据集的标签 gt_labels = self.load_labels() #如果水平镜像,则追加一倍的训练数据集 if self.flipped: print('Appending horizontally-flipped training examples ...') #深度拷贝 gt_labels_cp = copy.deepcopy(gt_labels) #遍历每一个图片标签 for idx in range(len(gt_labels_cp)): #设置flipped属性为True gt_labels_cp[idx]['flipped'] = True #目标所在格子也进行水平镜像 [7,7,25] gt_labels_cp[idx]['label'] =\ gt_labels_cp[idx]['label'][:, ::-1, :] for i in range(self.cell_size): for j in range(self.cell_size): #置信度==1,表示这个格子有目标 if gt_labels_cp[idx]['label'][i, j, 0] == 1: #中心的x坐标水平镜像 gt_labels_cp[idx]['label'][i, j, 1] = \ self.image_size - 1 -\ gt_labels_cp[idx]['label'][i, j, 1] #追加数据集的标签 后面的是由原数据集标签扩充的水平镜像数据集标签 gt_labels += gt_labels_cp #打乱数据集的标签 np.random.shuffle(gt_labels) self.gt_labels = gt_labels return gt_labels
3、get()批量数据读取函数
get()函数用在训练的时候,每次从gt_labels集合随机读取batch大小的图片以及图片对应的标签。
def get(self): ''' 加载数据集 每次读取batch大小的图片以及图片对应的标签 return: images:读取到的图片数据 [45,448,448,3] labels:对应的图片标签 [45,7,7,25] ''' #[45,448,448,3] images = np.zeros( (self.batch_size, self.image_size, self.image_size, 3)) #[45,7,7,25] labels = np.zeros( (self.batch_size, self.cell_size, self.cell_size, 25)) count = 0 #一次加载batch_size个图片数据 while count < self.batch_size: #获取图片路径 imname = self.gt_labels[self.cursor]['imname'] #是否使用水平镜像? flipped = self.gt_labels[self.cursor]['flipped'] #读取图片数据 images[count, :, :, :] = self.image_read(imname, flipped) #读取图片标签 labels[count, :, :, :] = self.gt_labels[self.cursor]['label'] count += 1 self.cursor += 1 #如果读取完一轮数据,则当前cursor置为0,当前训练轮数+1 if self.cursor >= len(self.gt_labels): #打乱数据集 np.random.shuffle(self.gt_labels) self.cursor = 0 self.epoch += 1 return images, labels
4、image_read()函数读取图片
图片读取函数,先读取图片,然后缩放,转换为RGB格式,再对数据进行归一化处理。
def image_read(self, imname, flipped=False): ''' 读取图片 args: imname:图片路径 flipped:图片是否水平镜像处理? return: image:图片数据 [448,448,3] ''' #读取图片数据 image = cv2.imread(imname) #缩放处理 image = cv2.resize(image, (self.image_size, self.image_size)) #BGR->RGB uint->float32 image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32) #归一化处理 [-1.0,1.0] image = (image / 255.0) * 2.0 - 1.0 #宽倒序 即水平镜像 if flipped: image = image[:, ::-1, :] return image
5、load_labels()加载标签函数
def load_labels(self): ''' 加载数据集标签 return: gt_labels:是一个list 每一个元素对应一张图片,是一个dict imname:图片文件路径 label:图片文件对应的标签 [7,7,25]的矩阵 flipped:是否使用水平镜像? 设置为False ''' #缓冲文件名:即用来保存数据集标签的文件 cache_file = os.path.join( self.cache_path, 'pascal_' + self.phase + '_gt_labels.pkl') #文件存在,且不重新创建则直接读取 if os.path.isfile(cache_file) and not self.rebuild: print('Loading gt_labels from: ' + cache_file) with open(cache_file, 'rb') as f: gt_labels = pickle.load(f) return gt_labels print('Processing gt_labels from: ' + self.data_path) #如果缓冲文件目录不存在,创建 if not os.path.exists(self.cache_path): os.makedirs(self.cache_path) #获取训练测试集的数据文件名 if self.phase == 'train': txtname = os.path.join( self.data_path, 'ImageSets', 'Main', 'trainval.txt') #获取测试集的数据文件名 else: txtname = os.path.join( self.data_path, 'ImageSets', 'Main', 'test.txt') with open(txtname, 'r') as f: self.image_index = [x.strip() for x in f.readlines()] #存放图片的标签,图片路径,是否使用水平镜像? gt_labels = [] #遍历每一张图片的信息 for index in self.image_index: #读取每一张图片的标签label [7,7,25] label, num = self.load_pascal_annotation(index) if num == 0: continue #图片文件路径 imname = os.path.join(self.data_path, 'JPEGImages', index + '.jpg') #保存该图片的信息 gt_labels.append({'imname': imname, 'label': label, 'flipped': False}) print('Saving gt_labels to: ' + cache_file) #保存 with open(cache_file, 'wb') as f: pickle.dump(gt_labels, f) return gt_labels
6、load_pascal_annotation()函数
def load_pascal_annotation(self, index): """ Load image and bounding boxes info from XML file in the PASCAL VOC format. args: index:图片文件的index return : label:标签 [7,7,25] 0:1:置信度,表示这个地方是否有目标 1:5:目标边界框 目标中心,宽度和高度(这里是实际值,没有归一化) 5:25:目标的类别 len(objs):objs对象长度 """ #获取图片文件名路径 imname = os.path.join(self.data_path, 'JPEGImages', index + '.jpg') #读取数据 im = cv2.imread(imname) #宽和高缩放比例 h_ratio = 1.0 * self.image_size / im.shape[0] w_ratio = 1.0 * self.image_size / im.shape[1] # im = cv2.resize(im, [self.image_size, self.image_size]) #用于保存图片文件的标签 label = np.zeros((self.cell_size, self.cell_size, 25)) #图片文件的标注xml文件 filename = os.path.join(self.data_path, 'Annotations', index + '.xml') tree = ET.parse(filename) objs = tree.findall('object') for obj in objs: bbox = obj.find('bndbox') # Make pixel indexes 0-based 当图片缩放到image_size时,边界框也进行同比例缩放 x1 = max(min((float(bbox.find('xmin').text) - 1) * w_ratio, self.image_size - 1), 0) y1 = max(min((float(bbox.find('ymin').text) - 1) * h_ratio, self.image_size - 1), 0) x2 = max(min((float(bbox.find('xmax').text) - 1) * w_ratio, self.image_size - 1), 0) y2 = max(min((float(bbox.find('ymax').text) - 1) * h_ratio, self.image_size - 1), 0) #根据图片的分类名 ->类别index 转换 cls_ind = self.class_to_ind[obj.find('name').text.lower().strip()] #计算边框中心点x,y,w,h(没有归一化) boxes = [(x2 + x1) / 2.0, (y2 + y1) / 2.0, x2 - x1, y2 - y1] #计算当前物体的中心在哪个格子中 x_ind = int(boxes[0] * self.cell_size / self.image_size) y_ind = int(boxes[1] * self.cell_size / self.image_size) #表明该图片已经初始化过了 if label[y_ind, x_ind, 0] == 1: continue #置信度,表示这个地方有物体 label[y_ind, x_ind, 0] = 1 #物体边界框 label[y_ind, x_ind, 1:5] = boxes #物体的类别 label[y_ind, x_ind, 5 + cls_ind] = 1 return label, len(objs)
六 训练网络
模型训练包含于train.py文件,Solver类的train()方法之中,训练部分只需要看懂了初始化参数,整个结构就很清晰了。
import os import argparse import datetime import tensorflow as tf import yolo.config as cfg from yolo.yolo_net import YOLONet from utils.timer import Timer from utils.pascal_voc import pascal_voc slim = tf.contrib.slim ''' 用来训练YOLO网络模型 ''' class Solver(object): ''' 求解器的类,用于训练YOLO网络 '''
1、类初始化函数
def __init__(self, net, data): ''' 构造函数,加载训练参数 args: net:YOLONet对象 data:pascal_voc对象 ''' #yolo网络 self.net = net #voc2012数据处理 self.data = data #检查点文件路径 self.weights_file = cfg.WEIGHTS_FILE #训练最大迭代次数 self.max_iter = cfg.MAX_ITER #初始学习率 self.initial_learning_rate = cfg.LEARNING_RATE ##退化学习率衰减步数 self.decay_steps = cfg.DECAY_STEPS #衰减率 self.decay_rate = cfg.DECAY_RATE self.staircase = cfg.STAIRCASE ##日志文件保存间隔步 self.summary_iter = cfg.SUMMARY_ITER ##模型保存间隔步 self.save_iter = cfg.SAVE_ITER #输出文件夹路径 self.output_dir = os.path.join( cfg.OUTPUT_DIR, datetime.datetime.now().strftime('%Y_%m_%d_%H_%M')) if not os.path.exists(self.output_dir): os.makedirs(self.output_dir) #保存配置信息 self.save_cfg() #指定保存的张量 这里指定所有变量 self.variable_to_restore = tf.global_variables() self.saver = tf.train.Saver(self.variable_to_restore, max_to_keep=None) #指定保存的模型名称 self.ckpt_file = os.path.join(self.output_dir, 'yolo.cpkt') #合并所有的summary self.summary_op = tf.summary.merge_all() #创建writer,指定日志文件路径,用于写日志文件 self.writer = tf.summary.FileWriter(self.output_dir, flush_secs=60) #创建变量,保存当前迭代次数 self.global_step = tf.train.create_global_step() #退化学习率 self.learning_rate = tf.train.exponential_decay( self.initial_learning_rate, self.global_step, self.decay_steps, self.decay_rate, self.staircase, name='learning_rate') #创建求解器 self.optimizer = tf.train.GradientDescentOptimizer( learning_rate=self.learning_rate) # create_train_op that ensures that when we evaluate it to get the loss, # the update_ops are done and the gradient updates are computed. self.train_op = slim.learning.create_train_op( self.net.total_loss, self.optimizer, global_step=self.global_step) #设置GPU使用资源 gpu_options = tf.GPUOptions() #按需分配GPU使用的资源 config = tf.ConfigProto(gpu_options=gpu_options) self.sess = tf.Session(config=config) #运行图之前,初始化变量 self.sess.run(tf.global_variables_initializer()) #恢复模型 if self.weights_file is not None: print('Restoring weights from: ' + self.weights_file) self.saver.restore(self.sess, self.weights_file) #将图写入日志文件 self.writer.add_graph(self.sess.graph)
2、train()训练函数
def train(self): ''' 开始训练 ''' #训练时间 train_timer = Timer() #数据集加载时间 load_timer = Timer() #开始迭代 for step in range(1, self.max_iter + 1): #计算每次迭代加载数据的起始时间 load_timer.tic() #加载数据集 每次读取batch大小的图片以及图片对应的标签 images, labels = self.data.get() #计算这次迭代加载数据集所使用的时间 load_timer.toc() feed_dict = {self.net.images: images, self.net.labels: labels} #迭代summary_iter次,保存一次日志文件,迭代summary_iter*10次,输出一次的迭代信息 if step % self.summary_iter == 0: if step % (self.summary_iter * 10) == 0: #计算每次迭代训练的起始时间 train_timer.tic() loss = 0.0001 #开始迭代训练,每一次迭代后global_step自加1 summary_str, loss, _ = self.sess.run( [self.summary_op, self.net.total_loss, self.train_op], feed_dict=feed_dict) #输出信息 log_str = '{} Epoch: {}, Step: {}, Learning rate: {}, Loss: {:5.3f}\nSpeed: {:.3f}s/iter,Load: {:.3f}s/iter, Remain: {}'.format( datetime.datetime.now().strftime('%m-%d %H:%M:%S'), self.data.epoch, int(step), round(self.learning_rate.eval(session=self.sess), 6), loss, train_timer.average_time, load_timer.average_time, train_timer.remain(step, self.max_iter)) print(log_str) else: #计算每次迭代训练的起始时间 train_timer.tic() #开始迭代训练,每一次迭代后global_step自加1 summary_str, _ = self.sess.run( [self.summary_op, self.train_op], feed_dict=feed_dict) #计算这次迭代训练所使用的时间 train_timer.toc() #将summary写入文件 self.writer.add_summary(summary_str, step) else: #计算每次迭代训练的起始时间 train_timer.tic() #开始迭代训练,每一次迭代后global_step自加1 self.sess.run(self.train_op, feed_dict=feed_dict) #计算这次迭代训练所使用的时间 train_timer.toc() #没迭代save_iter次,保存一次模型 if step % self.save_iter == 0: print('{} Saving checkpoint file to: {}'.format( datetime.datetime.now().strftime('%m-%d %H:%M:%S'), self.output_dir)) self.saver.save( self.sess, self.ckpt_file, global_step=self.global_step)
3、保存配置参数
def save_cfg(self): ''' 保存配置信息 ''' with open(os.path.join(self.output_dir, 'config.txt'), 'w') as f: cfg_dict = cfg.__dict__ for key in sorted(cfg_dict.keys()): if key[0].isupper(): cfg_str = '{}: {}\n'.format(key, cfg_dict[key]) f.write(cfg_str)
train.py文件除了上面介绍的求解器Solver这个类外,还包含了两个函数,一个是update_config_paths()函数,这个函数主要使用了设定数据集路径,以及检查点文件路径。
def update_config_paths(data_dir, weights_file): ''' 数据集路径,和模型检查点文件路径 args: data_dir:数据文件夹 数据集放在pascal_voc目录下 weights_file:检查点文件名 该文件放在数据集目录下的weights文件夹下 ''' cfg.DATA_PATH = data_dir #数据所在文件夹 cfg.PASCAL_PATH = os.path.join(data_dir, 'pascal_voc') #VOC2012数据所在文件夹 cfg.CACHE_PATH = os.path.join(cfg.PASCAL_PATH, 'cache') #保存生成的数据集标签缓冲文件所在文件夹 cfg.OUTPUT_DIR = os.path.join(cfg.PASCAL_PATH, 'output') #保存生成的网络模型和日志文件所在的文件夹 cfg.WEIGHTS_DIR = os.path.join(cfg.PASCAL_PATH, 'weights') #检查点文件所在的目录 cfg.WEIGHTS_FILE = os.path.join(cfg.WEIGHTS_DIR, weights_file)
我们主要来说一下另一个函数main()函数,先解析命令行参数,然后创建YOLONet、pascal_voc、Solver对象,最后开始训练。
def main(): #创建一个解析器对象,并告诉它将会有些什么参数。当程序运行时,该解析器就可以用于处理命令行参数。 #https://www.cnblogs.com/lovemyspring/p/3214598.html parser = argparse.ArgumentParser() #定义参数 parser.add_argument('--weights', default="YOLO_small.ckpt", type=str) #权重文件名 parser.add_argument('--data_dir', default="data", type=str) #数据集路径 parser.add_argument('--threshold', default=0.2, type=float) parser.add_argument('--iou_threshold', default=0.5, type=float) parser.add_argument('--gpu', default='', type=str) #定义了所有参数之后,你就可以给 parse_args() 传递一组参数字符串来解析命令行。默认情况下,参数是从 sys.argv[1:] 中获取 #parse_args() 的返回值是一个命名空间,包含传递给命令的参数。该对象将参数保存其属性 args = parser.parse_args() #判断是否是使用gpu if args.gpu is not None: cfg.GPU = args.gpu #设定数据集路径,以及检查点文件路径 if args.data_dir != cfg.DATA_PATH and args.data_dir is not None: update_config_paths(args.data_dir, args.weights) #设置环境变量 os.environ['CUDA_VISIBLE_DEVICES'] = cfg.GPU #创建YOLO网络对象 yolo = YOLONet() #数据集对象 pascal = pascal_voc('train') #求解器对象 solver = Solver(yolo, pascal) print('Start training ...') #开始训练 solver.train() print('Done training.')
我们执行如下代码,开始训练网络:
if __name__ == '__main__': tf.reset_default_graph() # python train.py --weights YOLO_small.ckpt --gpu 0 main()
如果这篇文章帮助到了你,你可以请作者喝一杯咖啡