day12-TensorFlow读取文件

读取csv文件


# coding=utf-8
import tensorflow as tf
import os

def readcsv(fileList):

    """
    csv文件的读取
    :param fileList: 文件路径+文件名 列表
    :return:
    """

    # 1、构造文件读取队列
    file_queue = tf.train.string_input_producer(fileList)

    # 2、构造csv文件读取器
    csvReader = tf.TextLineReader()

    key,value = csvReader.read(file_queue)

    # 3、对每行内容解码
    y1,y2 = tf.decode_csv(value,record_defaults=[["None"],["None"]])

    # 4、批处理读取多个数据
    # num_threads为线程数,batch_size为读取的数据数,capacity为读取多少个数据后传递给主线程,第一个参数为要读取的张量
    y1_batch,y2_batch = tf.train.batch([y1,y2],batch_size=4,capacity=4,num_threads=1)


    return y1_batch,y2_batch


if __name__ == '__main__':

    # 构造待读取的文件列表
    files = os.listdir("../data/day05/")
    fileList = [os.path.join("../data/day05/",file) for file in files]

    # 读取函数
    y1_batch,y2_batch = readcsv(fileList)

    # 操作
    with tf.Session() as sess:
        # 线程管理器
        coord = tf.train.Coordinator()

        # 开启线程
        threads = tf.train.start_queue_runners(sess,coord=coord)

        # 打印y1,y2
        print(sess.run([y1_batch,y2_batch]))

        # 回收线程
        coord.request_stop()
        coord.join(threads)

读取图片文件


# coding=utf-8
import tensorflow as tf
import os

def readphoto(fileList):

    """
    读取图片文件
    :param fileList:
    :return:
    """

    # 1、构造文件读取队列
    file_queue = tf.train.string_input_producer(fileList)

    # 2、构造图片读取器
    file_reader = tf.WholeFileReader()

    # 3、读取文件
    key,value = file_reader.read(file_queue)

    # 4、解码
    image = tf.image.decode_jpeg(value)

    # 5、统一图片的大小
    image_resize = tf.image.resize_images(image,[200,200])
    # 因为批处理需要张量的所有维数都固定,所以需要将图片指定为彩色图片,即为三通道
    image_resize.set_shape([200,200,3])

    # 6、批处理
    image_batch = tf.train.batch([image_resize],batch_size=2,num_threads=1,capacity=2)


    return image_batch



if __name__ == '__main__':

    # 构造文件列表
    files = os.listdir("../data/day05/image/")
    fileList = [os.path.join("../data/day05/image/",file) for file in files]

    image_batch = readphoto(fileList)

    # 进行会话操作
    with tf.Session() as sess:
        # 创建线程管理器
        coord = tf.train.Coordinator()

        # 开启线程
        threads = tf.train.start_queue_runners(coord=coord)

        # 打印图片数据
        print(sess.run([image_batch]))

        # 回收线程
        coord.request_stop()
        coord.join(threads)

图片文件读取时,格式为[长,宽,通道数]
通道数为1时为黑白图片
通道数为3时为彩色图片

posted @ 2021-01-20 21:06  Nevesettle  阅读(70)  评论(0编辑  收藏  举报