第四讲 网络八股拓展 -- mnist_app_ex

from PIL import Image
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt


model_save_path = "./checkpoint/mnist.ckpt"
model = tf.keras.models.Sequential([
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(128, activation='relu'),
      tf.keras.layers.Dense(10, activation='softmax')
])

model.load_weights(model_save_path)
preNum = int(input("input the number of test pictures:" ))

for i in range(preNum):
  image_path = input("the path of test picture:")
  img = Image.open(image_path)

  image = plt.imread(image_path)
  plt.set_cmap('gray')
  plt.imshow(image)

  img = img.resize((28, 28), Image.ANTIALIAS)
  img_arr = np.array(img.convert("L"))

  for i in range(28):
    for j in range(28):
      if img_arr[i][j] < 200:
        img_arr[i][j] = 255
      else:
        img_arr[i][j] =0
  
  img_arr /= 255.0
  x_predict = img_arr[tf.newaxis, ...]
  result = model.predict(x_predict)
  pred = tf.argmax(result, axis=1)

  print('\n')
  tf.print(pred)

  plt.pause(1)
  plt.close()

 

posted @ 2020-05-07 21:21  WWBlog  阅读(172)  评论(0编辑  收藏  举报