tensorflow(四)
tensorflow数据处理方法,
1.输入数据集
小数据集,可一次性加载到内存处理。
大数据集,一般由大量数据文件组成,因为数据集的规模太大,无法一次性加载到内存,只能每一步训练时加载数据,可以采用流水线并行读取数据。
流水线并行读取数据过程, (1)创建文件名列表(2)创建文件名队列(3)创建Reader和Decoder(4)创建样例队列
filename_queue = tf.train.string_input_producer(['stat0.csv','stat1.csv']) reader = tf.TextLinerReader() _,value = reader.read(filename_queue) record_defaults = [[0],[0],[0.0],[0.0]] id,age = tf.decode_csv(value,record_defaults=record_defaults) features = tf.stack([id,age])
def get_my_example(filename_queue): reader = tf.SomeReader() _,value = reader.read(filename_queue) features = tf.decode_some(value) processed_example = some_processing(features) return processed_example def input_pipeline(filenames,batch_size,num_epochs=None): filename_queue = tf.train.string_input_producer(filenames,num_epochs,shuffle=True) example = get_my_example(filename_queue) min_after_deque = 10000 capacity = min_after_deque + 3*batch_size example_batch = tf.train.shuffle_batch([example],batch_size=batch_size,capacity=capacity,min_after_deque=min_after_deque) return example_batch x_batch = input_pipeline(['stat.tfrecord'],batch_size=20) sess = tf.Session() init_op = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer()) sess.run(init_op) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess,coord=coord) try: for _ in range(1000): if not coord.should_stop(): sess.run(train_op) print(example) except: print('catch exception') finally: coord.request_stop() coord.join(threads) sess.close()
2.模型参数
模型参数指的是模型的权重值和偏置值,使用tf.Variable创建模型参数
W = tf.Variable(0.0,name='W') double = tf.multiply(2.0,W) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for i in range(4): sess.run(tf.assign_add(W,1.0)) print(sess.run(W))
3.保持和恢复模型参数
tf.train.Saver是辅助训练工具类,它实现了存储模型参数的变量和checkpoint文件间的读写操作。
W = tf.Variable(0.0,name='W') double = tf.multiply(2.0,W) saver = tf.train.Saver({'weights':W}) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for i in range(4): sess.run(tf.assign_add(W,1.0)) print(sess.run(W)) saver.save(sess,'/tmp/text/ckpt')