TensorFlow加载mnist数据集并显示

今天进行了TensorFlow的mnist数据集加载显示

import tensorflow as tf
import matplotlib.pyplot as plt

(x_train_all,y_train_all),(x_test,y_test) = tf.keras.datasets.mnist.load_data()
x_valid,x_train = x_train_all[:5000],x_train_all[5000:]
y_valid,y_train = y_train_all[:5000],y_train_all[5000:]
print(x_valid.shape,y_valid.shape)
print(x_train.shape,y_train.shape)
print(x_test.shape,y_test.shape)
#读取单张图片
def show_single_img(img_arr):
    plt.imshow(img_arr,cmap="binary")
    plt.show()
#显示多张图片
def show_imgs(n_rows,n_cols,x_data,y_data):
    assert len(x_data) == len(y_data)
    assert n_rows * n_cols < len(x_data)
    plt.figure(figsize=(n_cols*1.4,n_rows*1.6))
    for row in range(n_rows):
        for col in range(n_cols):
            index = n_cols * row + col
            plt.subplot(n_rows,n_cols,index+1)
            plt.imshow(x_data[index],cmap="binary",interpolation="nearest")
            plt.axis("off")
    plt.show()
show_imgs(2,2,x_train,y_train)
#show_single_img(x_train[0])

 

posted @ 2021-01-08 17:41  .HAHA  阅读(1023)  评论(0编辑  收藏  举报