Keras猫狗大战三:加载模型,预测目录中图片,画混淆矩阵

版权声明:本文为博主原创文章,欢迎转载,并请注明出处。联系方式:460356155@qq.com

 一、加载模型,预测测试集

%matplotlib inline
import matplotlib.pyplot as plt

import os
import itertools
import cv2

import numpy as np
from sklearn.metrics import confusion_matrix

from keras.preprocessing.image import ImageDataGenerator
from keras.models import load_model

dst_path = r'D:\BaiduNetdiskDownload\small'
model_file = r"D:\fastai\projects\cats_and_dogs_small_1.h5"
test_dir = os.path.join(dst_path, 'test')

batch_size = 20

model = load_model(model_file)

test_datagen = ImageDataGenerator(rescale=1. / 255)

test_generator = test_datagen.flow_from_directory(
    test_dir,
    target_size=(150, 150),
    batch_size=batch_size,
    class_mode='binary')

test_loss, test_acc = model.evaluate_generator(test_generator, steps=test_generator.samples / batch_size)
print('test acc: %.3f%%' % test_acc)
Found 400 images belonging to 2 classes.
test acc: 0.747%

二、预测测试集,画混淆矩阵
def get_input_xy(src=[]):
    pre_x = []
    true_y = []

    class_indices = {'cat': 0, 'dog': 1}

    for s in src:
        input = cv2.imread(s)
        input = cv2.resize(input, (150, 150))
        input = cv2.cvtColor(input, cv2.COLOR_BGR2RGB)
        pre_x.append(input)

        _, fn = os.path.split(s)
        y = class_indices.get(fn[:3])
        true_y.append(y)

    pre_x = np.array(pre_x) / 255.0

    return pre_x, true_y


def plot_sonfusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    thresh = cm.max() / 2.0
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, cm[i, j], horizontalalignment='center', color='white' if cm[i, j] > thresh else 'black')

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predict label')


test = os.listdir(test_dir)

images = []

# 获取每张图片的地址,并保存在列表images中
for testpath in test:
    for fn in os.listdir(os.path.join(test_dir, testpath)):
        if fn.endswith('jpg'):
            fd = os.path.join(test_dir, testpath, fn)
            images.append(fd)

# 得到规范化图片及true label
pre_x, true_y = get_input_xy(images)

# 预测
pred_y = model.predict_classes(pre_x)

# 画混淆矩阵
confusion_mat = confusion_matrix(true_y, pred_y)
plot_sonfusion_matrix(confusion_mat, classes=range(2))

 

posted @ 2019-06-22 19:18  碧水青山  阅读(4364)  评论(1编辑  收藏  举报