tfrecord生成
import os
import xmltodict
import tensorflow as tf
import numpy as np
dir_path = 'F:\数据存储\VOCdevkit\VOC2012\Annotations'
dirs = os.listdir(dir_path)
imgs_dir = "F:\数据存储\VOCdevkit\VOC2012\JPEGImages"
out_path = 'F:\数据存储\VOCdevkit\\voc2012.tfrecord'
classes = [
"background", "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat",
"chair", "cow", "diningtable", "dog", "horse", "motorbike", "person",
"pottedplant", "sheep", "sofa", "train", "tvmonitor"
]
sess = tf.Session()
def get_and_resize_img(img_file):
'''
将图片设置为224*224的尺寸大小
返回图片,返回变化倍数,shape
'''
img = tf.read_file(imgs_dir + '/' + img_file)
img = tf.image.decode_jpeg(img)
shape_old = sess.run(img).shape
resized_img = tf.image.resize_images(img, [224, 224], method=0)
resized_img = sess.run(resized_img)
resized_img = np.asarray(resized_img, dtype='uint8')
resized_img_str = resized_img.tostring()
shape_new = resized_img.shape
# print(shape_new)
# print(shape_old)
# print('shape_old的长是width是维度1,height是维度0')
w_scale = shape_new[0] / shape_old[1]
h_scale = shape_new[1] / shape_old[0]
return resized_img_str, w_scale, h_scale, shape_new
writer = tf.python_io.TFRecordWriter(out_path)
i = 0
for file in dirs:
i = i + 1
# if i > 1000:
# break
with open(dir_path + '/' + file) as xml_txt:
doc = xmltodict.parse(xml_txt.read())
img_file_name = file.split('.')[0]
resized_img_str, w_scale, h_scale, shape = get_and_resize_img(img_file_name + '.jpg')
img_obtain_classes = []
y_mins = []
x_mins = []
y_maxes = []
x_maxes = []
if type(doc['annotation']["object"]).__name__ == 'OrderedDict':
if doc['annotation']["object"]['name'] in classes:
img_obtain_classes.append(classes.index(doc['annotation']["object"]['name']))
y_mins.append(float(h_scale * int(doc['annotation']["object"]['bndbox']['ymin'])))
x_mins.append(float(w_scale * int(doc['annotation']["object"]['bndbox']['xmin'])))
y_maxes.append(float(h_scale * int(doc['annotation']["object"]['bndbox']['ymax'])))
x_maxes.append(float(w_scale * int(doc['annotation']["object"]['bndbox']['xmax'])))
else:
for one_object in doc['annotation']["object"]:
# ['annotation']["object"][0]["name"]
if one_object['name'] in classes:
img_obtain_classes.append(classes.index(one_object['name']))
y_mins.append(float(h_scale * int(one_object['bndbox']['ymin'])))
x_mins.append(float(w_scale * int(one_object['bndbox']['xmin'])))
y_maxes.append(float(h_scale * int(one_object['bndbox']['ymax'])))
x_maxes.append(float(w_scale * int(one_object['bndbox']['xmax'])))
# example = tf.train.Example(features=tf.train.Features(feature={
# 'name': tf.train.Feature(bytes_list=tf.train.BytesList(value=[name])),
# 'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[0], shape[1], shape[2]])),
# 'data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[resized_img_str]))
# }
# ))
img_file_name = bytes(img_file_name, encoding='utf8')
example = tf.train.Example(features=tf.train.Features(feature={
'filename': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_file_name])),
'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[0], shape[1], shape[2]])),
'classes': tf.train.Feature(int64_list=tf.train.Int64List(value=img_obtain_classes)),
'y_mins': tf.train.Feature(float_list=tf.train.FloatList(value=y_mins)), # 各个 object 的 ymin
'x_mins': tf.train.Feature(float_list=tf.train.FloatList(value=x_mins)),
'y_maxes': tf.train.Feature(float_list=tf.train.FloatList(value=y_maxes)),
'x_maxes': tf.train.Feature(float_list=tf.train.FloatList(value=x_maxes)),
'encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[resized_img_str]))
}))
writer.write(example.SerializeToString())
writer.close()
sess.close()
print('ok')
tfrecord读取
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt
# import sys
#
# sys.path.append("..")
classes = [
"aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat",
"chair", "cow", "diningtable", "dog", "horse", "motorbike", "person",
"pottedplant", "sheep", "sofa", "train", "tvmonitor"
]
# 'filename': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_file_name])),
# 'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[0], shape[1], shape[2]])),
# 'classes': tf.train.Feature(int64_list=tf.train.Int64List(value=np.array(img_obtain_classes))),
# 'y_mins': tf.train.Feature(float_list=tf.train.FloatList(value=y_mins)), # 各个 object 的 ymin
# 'x_mins': tf.train.Feature(float_list=tf.train.FloatList(value=x_mins)),
# 'y_maxes': tf.train.Feature(float_list=tf.train.FloatList(value=y_maxes)),
# 'x_maxes': tf.train.Feature(float_list=tf.train.FloatList(value=x_maxes)),
# 'encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[resized_img_str]))
def _parse_record(example_proto):
features = {
'filename': tf.FixedLenFeature([], tf.string),
'shape': tf.FixedLenFeature([3], tf.int64),
'classes': tf.VarLenFeature(tf.int64),
'y_mins': tf.VarLenFeature(tf.float32),
'x_mins': tf.VarLenFeature(tf.float32),
'y_maxes': tf.VarLenFeature(tf.float32),
'x_maxes': tf.VarLenFeature(tf.float32),
'encoded': tf.FixedLenFeature((), tf.string)
}
parsed_features = tf.parse_single_example(example_proto, features=features)
return parsed_features
def read_test(input_file):
# 用 dataset 读取 tfrecord 文件
dataset = tf.data.TFRecordDataset(input_file)
dataset = dataset.map(_parse_record)
iterator = dataset.make_initializable_iterator()
max_value = tf.placeholder(tf.int64, shape=[])
with tf.Session() as sess:
sess.run(iterator.initializer, feed_dict={max_value: 100})
for i in range(2):
features = sess.run(iterator.get_next())
name = features['filename']
name = name.decode()
shape = features['shape']
classes = features['classes']
y_mins = features['y_mins']
x_mins = features['x_mins']
y_maxes = features['y_maxes']
x_maxes = features['x_maxes']
# name = name.decode()
img_data = features['encoded']
print(len(img_data))
print('=======')
print("shape", shape)
print("name", name)
print("classes", classes.values)
print("y_mins", y_mins.values)
print("x_mins", x_mins.values)
print("y_maxes", y_maxes.values)
print("x_maxes", x_maxes.values)
img_data = np.fromstring(img_data, dtype=np.uint8)
image_data = np.reshape(img_data, shape)
print("img_data", image_data)
# 从 bytes 数组中加载图片原始数据,并重新 reshape.它的结果是 ndarray 数组
# img_data = np.fromstring(img_data, dtype=np.uint8)
# image_data = np.reshape(img_data, shape)
#
# plt.figure()
# # 显示图片
plt.imshow(image_data)
plt.show()
read_test('F:\数据存储\VOCdevkit\\voc2012.tfrecord')
尺寸不固定矩阵的存储和读取
import json
import jieba
import tensorflow as tf
with open('../data_save/words_info.txt', 'r', encoding='utf-8') as file:
dic = json.loads(file.read())
all_words_word2id = dic["all_words_word2id"]
stop_words = []
with open('./stop_words.txt', encoding='utf-8') as f:
line = f.readline()
while line:
stop_words.append(line[:-1])
line = f.readline()
stop_words = set(stop_words)
print('停用词读取完毕,共{n}个单词'.format(n=len(stop_words)))
dir_path = 'F:\\数据存储\新闻语料\\news2016zh_train.json'
dir_path_test = 'F:\\数据存储\新闻语料\\news2016zh_valid.json'
out_path = 'F:\\数据存储\新闻语料\\news2016zh_train_new.tfrecord'
def getCutSequnce(line):
# 使用jieba 进行中文分词
raw_words = list(jieba.cut(line, cut_all=False))
# 存储一句话的分词结果
raw_word_list = []
# 去除停用词
for word in raw_words:
if word not in stop_words and word not in ['www', 'com', 'http']:
raw_word_list.append(word)
return raw_word_list
writer = tf.python_io.TFRecordWriter(out_path)
i = 0
with open(dir_path, encoding='utf-8') as txt:
one_dic = txt.readline()
while one_dic:
i = i + 1
if i > 10000:
break
if (i % 1000) == 0:
print(i)
one_dic_json = json.loads(one_dic)
title = one_dic_json['title']
content = one_dic_json['content']
if len(content) > 3000:
one_dic = txt.readline()
continue
one_dic = txt.readline()
if len(title) == 0 or len(content) == 0:
continue
title_list = getCutSequnce(title)
content_list = getCutSequnce(content)
title_list_index = []
for one in title_list:
try:
title_list_index.append(all_words_word2id[one])
except:
pass
content_list_index = []
for one_word in content_list:
try:
content_list_index.append(all_words_word2id[one_word])
except:
pass
example = tf.train.Example(features=tf.train.Features(feature={
'title': tf.train.Feature(int64_list=tf.train.Int64List(value=title_list_index)),
'content': tf.train.Feature(int64_list=tf.train.Int64List(value=content_list_index))
}))
writer.write(example.SerializeToString())
import tensorflow as tf
import numpy as np
def _parse_record(example_proto):
features = {
'title': tf.VarLenFeature(tf.int64),
'content': tf.VarLenFeature(dtype=tf.int64)
}
parsed_features = tf.parse_single_example(example_proto, features=features)
return parsed_features
def read_test(input_file):
# 用 dataset 读取 tfrecord 文件
dataset = tf.data.TFRecordDataset(input_file)
dataset = dataset.map(_parse_record)
iterator = dataset.make_initializable_iterator()
with tf.Session() as sess:
sess.run(iterator.initializer)
for i in range(5):
features = sess.run(iterator.get_next())
name = features['title']
content = features['content']
print("xx", content)
print("xx", np.array(content).shape)
# 从 bytes 数组中加载图片原始数据,并重新 reshape.它的结果是 ndarray 数组
read_test('F:\\数据存储\新闻语料\\news2016zh_train_new.tfrecord')
统计数据条数
import tensorflow as tf
def total_sample(file_name):
sample_nums = 0
for record in tf.python_io.tf_record_iterator(file_name):
sample_nums += 1
return sample_nums
result = total_sample('F:\\数据存储\新闻语料\\news2016zh_train_new.tfrecord')
print(result)