第八节 CSV文件读取

import tensorflow as tf
import os

'''
tensorflow读取文件的流程,每一步每一种数据都有对应封装好的API进行处理:
    1、构造一个文件队列:A文件,B文件,C文件,每个文件内有100个样本
    2、读取队列内容:一个一个样本读取,二进制文件就是指定一个样本的byte读取,图片就是一张一张
    3、进行解码
    4、批处理:将样本一个一个的放入一个队列中,达到一定数量后,一次性进行处理
'''

def readcsv(filelist):
    """读取csv文件"""
    # 1.构造文件队列
    file_queue = tf.train.string_input_producer(filelist)

    # 2.构造CSV阅读器读取队列数据(按一行),read返回一个元组,一个是路径一个是样本内容
    reader = tf.TextLineReader()
    key, value = reader.read(file_queue)

    # 对每行内容进行解码,field_delim分隔符默认“,”,record_defaults指定每一个样本的每一列类型,并设置默认值对缺失值进行填充
    # CSV数据中有几列就应该有几个列表,"None"表示字符串并同时指定默认值是None,1表示int类型,并同时指定默认值是1,如果是4.5则是float类型,默认值就是4.5
    # decode_csv返回的是每一个样本每一个的值,返回的是op列表
    records = [["None"],[1]]
    example, label = tf.decode_csv(value, field_delim=',', record_defaults=records)
    

    # 批处理,batch_size从队列读取的批处理大小,num_threads使用几个线程处理,capacity批处理队列大小,tf.train.batch返回的是两个元素的op,一个op存储着一列九行数据
    example_batch, label_batch = tf.train.batch([example, label], batch_size=5, num_threads=1, capacity=10)

    return example_batch, label_batch

if __name__ == "__main__":
    # 构造文件列表
    file_name = os.listdir("./data/csvdata")
    filelist = [os.path.join("./data/csvdata", file) for file in file_name]
    example_batch, label_batch = readcsv(filelist)

    # 开启会话
    with tf.Session() as sess:
        # 定义线程协调器
        coord = tf.train.Coordinator()

        # 开启读取文件的线程
        thd = tf.train.start_queue_runners(sess, coord=coord, start=True)

        # 打印读取内容
        print(sess.run([example_batch, label_batch]))

        # 回收子线程
        coord.request_stop()
        coord.join(thd)

 

posted @ 2020-03-30 13:03  kog_maw  阅读(187)  评论(0编辑  收藏  举报