4-6 TF之TFRecord数据打包案例
import numpy as np
import tensorflow as tf
import cv2
import numpy as np
classification = [
'airplane',
'automobile',
'bird',
'cat',
'deer',
'dog',
'frog',
'horse',
'ship',
'truck']
import glob # 获取类别文件夹图片的获取
###读入图片的src,并且相应的在im_labels标注图片类别
idx = 0
im_data = []
im_labels = []
for path in classification:
path = 'data/image/train/' + path
im_list = glob.glob(path + '/*') # get images url
im_label = [idx for i in range(im_list.__len__())] # 对于每一个i,都加入idx
idx += 1
im_data += im_list
im_labels += im_label
print(im_labels)
print(im_data)
tfrecord_file = 'data/train.tfrecord'
writer = tf.python_io.TFRecordWriter(tfrecord_file) # 定义写入实例
index=[i for i in range(im_data.__len__())]#打乱图片顺序
np.random.shuffle(index)#实际上是把数字打乱,然后根据数字来取图片,达到乱序取图
##循环把每张图片都改变储存结构
##value=的值需要转换为适当类型,在tf,train.BytesList是byte列表转换函数,因为value可能是多维
##cv2和tf都有读取图片的函数,区别在于:
## tf的图片读取后就是byte型,所以value不需要转换类型,并且tf读取图片会重新编码,减小内存,但图片输出需要被解压
for i in range(im_data.__len__()):
im_d = im_data[index[i]]
im_l = im_labels[index[i]]
data = cv2.imread(im_d) # 从图片url获取到真实数据
#tf.gfile.FastGFile(src,'rb').read()#tf的图片读取方式,优点是读取的图片本身就是byte型,下面就不需要类型转换
ex = tf.train.Example( #主要用在将数据处理成二进制方面
features=tf.train.Features(
feature={
'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[data.tobytes()])),
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[im_l])),
#'height':tf.train.Feature(int64_list=tf.train.Int64List(value=[data.shape[1]])),
#'width':tf.train.Feature(int64_list=tf.train.Int64List(value=[data.shape[2]])),
##这里也可以记录图片大小,因为cifar图像都是32*32,所以这里不记录
##对于图片尺寸可以在这里feature记录,也可以在opcv处理时归一化
}
)
)
writer.write(ex.SerializeToString())
writer.close()
ex = tf.train.Example
tf.train.Example有一个属性为features
tf.train.Example还有一个方法SerializeToString()需要说一下,这个方法的作用是把tf.train.Example对象序列化为字符串,因为我们写入文件的时候不能直接处理对象,需要将其转化为字符串才能处理。
当然,既然有对象序列化为字符串的方法,那么肯定有从字符串反序列化到对象的方法,该方法是FromString(),需要传递一个tf.train.Example对象序列化后的字符串进去做为参数才能得到反序列化的对象。