数据存储方式tfrecord

为什么使用tfrecord?

正常情况下我们训练文件夹经常会生成 train, test 或者val文件夹,这些文件夹内部往往会存着成千上万的图片或文本等文件,这些文件被散列存着,这样不仅占用磁盘空间,并且再被一个个读取的时候会非常慢,繁琐。占用大量内存空间(有的大型数据不足以一次性加载)。此时我们TFRecord格式的文件存储形式会很合理的帮我们存储数据。TFRecord内部使用了“Protocol Buffer”二进制数据编码方案,它只占用一个内存块,只需要一次性加载一个二进制文件的方式即可,简单,快速,尤其对大型训练数据很友好。而且当我们的训练数据量比较大的时候,可以将数据分成多个TFRecord文件,来提高处理效率。

来自 <https://www.jianshu.com/p/b480e5fcb638>

这个就是队列,也就不用placeholder了,FIFO等队列了。

TFrecord write

#各个数据集生成tfrecord文件
def tfrecord_write(args):
    #获取数据集分割后的txt文件名字
    txt_names = [txt_name for txt_name in os.listdir(args.out_path) if txt_name.split('.')[1]=='txt']
    #txt路径
    txt_paths = [os.path.join(args.out_path,txt_path) for txt_path in txt_names]

    #tfrecord文件名字
    tfrecord_names = [name.split('.')[0] for name in txt_names]
    #tfrecord路径
    tfrecord_paths = [os.path.join(args.out_path,tfrecord_path+'.tfrecord') for tfrecord_path in tfrecord_names]

    #产生txt文件数目个txrecord
    for txt_path, tfrecord_path in zip(txt_paths, tfrecord_paths):
        print(tfrecord_path)
        writer = tf.python_io.TFRecordWriter(tfrecord_path)
        with open(txt_path, 'r') as f:
            for line in f.readlines():
                name,num = line.strip().split('\t')
                #print(name)
                if name == 'George_W_Bush':
                    print("Because George_W_Bush has 530 pictures so we will not use it to save time ")
                    #尽可能一次writer 尽可能多的写进去数据,否则会很慢
                    continue
                pics, len_pics = _get_pics(os.path.join(args.path,name))
                for i, pic in enumerate(pics):
                    _store2tfrecord(pic, i, writer)
                assert int(num) == len_pics
        writer.close()

#读取某个人名文件夹下的所有人脸,以及图片个数
def _get_pics(path):
    pic_list = [os.path.join(path, pic) for pic in os.listdir(path)]
    pics = []

    for pic_path in pic_list:
        pics.append(cv2.imread(pic_path))

    return np.asarray(pics), len(pic_list)

#将某个文件夹下的图片和个数tfrecord保存
def _store2tfrecord(pic, index, writer):
    pic_shape = list(pic.shape)
    print(pic_shape)
    pic_string = pic.tostring()

    example = tf.train.Example(features=tf.train.Features(
        feature={
            'index': tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
            'shape':tf.train.Feature(int64_list=tf.train.Int64List(value=pic_shape)),
            'pic': tf.train.Feature(bytes_list=tf.train.BytesList(value=[pic_string]))
        }
    ))
    serialized = example.SerializeToString()
    writer.write(serialized)
    #这里不能有writer.close 否则就会关闭。
    #一个writer will make a tfrecord file ,if exists it will remake

tfrecord_write(parsed)

1、设置存放.tfrecord文件的位置
2、在该位置生成tfrecord文件
    不要每个example都生成一个writer否则只会存储最后一个数据。
    • writer = tf.python_io.TFRecordWriter("位置信息")#方法一
    writer.close() 
    • with tf.python_io.TFRecordWriter(位置信息)  as writer#方法二
3、组成example
example = tf.train.Example(features=tf.train.Features(
                feature={
                'label': tf.train.Feature(int64_list = tf.train.Int64List(value=[i])),     
                'img_raw':tf.train.Feature(bytes_list = tf.train.BytesList(value=[img_raw]))
                }))
值得注意的是赋值给example的数据格式。从前面tf.train.Example的定义可知,
tfrecord支持整型、浮点数和二进制三种格式,分别是value必须是列表    
tf.train.Feature(int64_list = tf.train.Int64List(value=[int_scalar]))    
tf.train.Feature(bytes_list = tf.train.BytesList(value=[array_string_or_byte]))    
tf.train.Feature(float_list = tf.train.FloatList(value=[float_scalar]))    
    
如果单个保存的是list长度,value=[len_list]    
如果单个保存的是数组的形状,value=array_array,这个在读取解析的时候[]里面就需要指定个数    
如果是array的数据(比如图片)可以通过array_data.tostring()转化为string。再用byte_list
这样可以节省空间。矩阵会失去维度。所以还要保存维度信息    
如果读取的时候是批量读取的,每个‘label’的形状必须一样否则只能一个一个的读取    
如果数据可以存储成(25,160,160,3)的形式,就不要存储成25个(160,160,3)的格式,后面一种存储空间大,并且时间长    
10G的tfrecord格式文件速度比较快,在1080Ti上    

4、序列化,减少内存
example = example.SerializeToString()
5、写进tfrecords
writer.write(example)
6、关闭tfrecord文件
writer.close()

TFrecord read

1、创建文件队列
files_queue = tf.train.string_input_producer(tfrecord_paths)
2、创建reader
 reader = tf.TFRecordReader()  
3、读取序列化后的文件名和example
  _, serialized_example = reader.read(路径) 
4、反序列化
 features = tf.parse_single_example(  
        serialized_example,  
        features={  
            'a': tf.FixedLenFeature([], tf.float32),  
            'b': tf.FixedLenFeature([2], tf.int64),  
            'c': tf.FixedLenFeature([], tf.string)  
        }  
    ) 
    • 如果序列化时value = ['somgthing'],这里的[]内就不用写数字了
    • 如果序列化时value=something,这个something就要在[]指定这一个是由几个元素组成的了
5、获取数值
     a = features['a']  
 b = features['b']  
 c_raw = features['c']
    如果是to_string过的,还必须经过三步
    • pic = tf.decode_raw(pic,tf.uint8)转换成指定格式
    • pic = tf.reshape(pic,pic_shape)
    • pic.set_shape([182,182,3])如果是tf.train.batch或者是shuffle_batch都必须用第三个,如果是一个一个的读取就不用第三个了
    

sess使用
一个一个读取
    sess = tf.Session()
    glo = tf.global_variables_initializer()
    loc = tf.local_variables_initializer()
    sess.run(glo)
    sess.run(loc)
    coord = tf.train.Coordinator()
    tf.train.start_queue_runners(sess=sess, coord=coord)
    print('put the mouse in the window,press "q" for continue.')
    for i in range(20):
        real_pic, real_index, real_shape = sess.run([pic, index,shape_pic])
一个批量读取
 with tf.Graph().as_default():
            pic, index, shape = _read_from_record(['./brief_test_name.tfrecord'])
            #这一句千万不要放进里面,否则tensorflow就会挂起来不动,不执行也不报错。
            pic_batch, indexs, shapes = tf.train.batch([pic, index, shape],
                                                       batch_size=16,
                                                       num_threads=2,
                                                       capacity=16 * 2
                                                       )
            init = tf.initialize_all_variables()
            with tf.Session() as sess:
                sess.run(init)
                coord = tf.train.Coordinator()
                tf.train.start_queue_runners(sess, coord)
                
               
                    for i in range(32):
                        
                        print(sess.run([pic_batch,indexs,shapes]))
                        sess.close()

----------------------------代码---------------------------
#一个批次一个批次的读取
def tfrecord_read_batch(pic,index, batch_size, num_threads, capacity):


    pic_batch,indexs = tf.train.batch([pic, index],
                                         batch_size=batch_size,
                                         num_threads= num_threads,
                                         capacity= capacity)

    return pic_batch, indexs


#一张一张的读取tfrecord图片,主要是用于测试
def tfrecord_read_one():
    pic, index , shape_pic= _read_from_record(['./brief_test_name.tfrecord'])

    sess = tf.Session()
    glo = tf.global_variables_initializer()
    loc = tf.local_variables_initializer()
    sess.run(glo)
    sess.run(loc)
    coord = tf.train.Coordinator()
    tf.train.start_queue_runners(sess=sess, coord=coord)
    print('put the mouse in the window,press "q" for continue.')
    for i in range(20):
        real_pic, real_index, real_shape = sess.run([pic, index,shape_pic])
        cv2.imshow('%s %s'%(real_index,list(real_shape)), real_pic)
        if cv2.waitKey(0) & 0xff == ord('q'):
            continue

    cv2.destroyAllWindows()

#读取tfrecord数据
def _read_from_record(tfrecord_paths):
    files_queue = tf.train.string_input_producer(tfrecord_paths)
    reader = tf.TFRecordReader()
    _,serialized_example = reader.read(files_queue)
    features = tf.parse_single_example(
        serialized_example,
        features={
            'index':tf.FixedLenFeature([],tf.int64),
            'shape':tf.FixedLenFeature([3],tf.int64),
            'pic'  :tf.FixedLenFeature([],tf.string)
        }
    )
    index = features['index']
    pic_shape = features['shape']
    pic = features['pic']
    pic = tf.decode_raw(pic,tf.uint8)
    #pics = tf.image.resize_images(pics, pics_shape)
    #下面这个就不行否则会 出错All shapes must be fully defined:
    pic = tf.reshape(pic,pic_shape)

    print(pic.get_shape())
    pic.set_shape([182,182,3])
    print(pic.get_shape())
    return pic, index, pic_shape

if __name__ == "__main__":
    parsed = parse(sys.argv[1:])
    #默认应该执行1,2,3
    flg = 0
    
        
        with tf.Graph().as_default():
            pic, index, shape = _read_from_record(['./brief_test_name.tfrecord'])
            pic_batch, indexs, shapes = tf.train.batch([pic, index, shape],
                                                       batch_size=16,
                                                       num_threads=2,
                                                       capacity=16 * 2
                                                       )
            init = tf.initialize_all_variables()
            with tf.Session() as sess:
                sess.run(init)
                coord = tf.train.Coordinator()
                threads = tf.train.start_queue_runners(sess, coord)
                #tf.train.start_queue_runners(sess = sess)
                try:
                    for i in range(32):
                        
                        print(sess.run([pic_batch,indexs,shapes]))
                        sess.close()

 

posted @ 2020-09-03 23:33  yunshangyue  阅读(433)  评论(0编辑  收藏  举报