keras中使用预训练模型进行图片分类
keras中含有多个网络的预训练模型,可以很方便的拿来进行使用。
安装及使用主要参考官方教程:https://keras.io/zh/applications/ https://keras-cn.readthedocs.io/en/latest/other/application/
官网上给出了使用 ResNet50 进行 ImageNet 分类的样例
from keras.applications.resnet50 import ResNet50 from keras.preprocessing import image from keras.applications.resnet50 import preprocess_input, decode_predictions import numpy as np model = ResNet50(weights='imagenet') img_path = 'elephant.jpg' img = image.load_img(img_path, target_size=(224, 224)) x = image.img_to_array(img) x = np.expand_dims(x, axis=0) x = preprocess_input(x) preds = model.predict(x) # decode the results into a list of tuples (class, description, probability) # (one such list for each sample in the batch) print('Predicted:', decode_predictions(preds, top=3)[0]) # Predicted: [(u'n02504013', u'Indian_elephant', 0.82658225), (u'n01871265', u'tusker', 0.1122357), (u'n02504458', u'African_elephant', 0.061040461)]
那么对于其他的网络,便可以参考此代码
首先vgg19
# coding: utf-8 from keras.applications.vgg19 import VGG19 from keras.preprocessing import image from keras.applications.vgg19 import preprocess_input from keras.models import Model import numpy as np base_model = VGG19(weights='imagenet', include_top=True) model = Model(inputs=base_model.input, outputs=base_model.get_layer('fc2').output) img_path = '../mdataset/img_test/p2.jpg' img = image.load_img(img_path, target_size=(224, 224)) x = image.img_to_array(img) x = np.expand_dims(x, axis=0) x = preprocess_input(x) fc2 = model.predict(x) print(fc2.shape) #(1, 4096)
然后mobilenet
# coding: utf-8 from keras.applications.mobilenet import MobileNet from keras.preprocessing import image from keras.applications.mobilenet import preprocess_input,decode_predictions from keras.models import Model import numpy as np import time model = MobileNet(weights='imagenet', include_top=True,classes=1000) start = time.time() img_path = '../mdataset/img_test/dog.jpg' img = image.load_img(img_path, target_size=(224, 224)) x = image.img_to_array(img) x = np.expand_dims(x, axis=0) x = preprocess_input(x) preds = model.predict(x) # decode the results into a list of tuples (class, description, probability) # (one such list for each sample in the batch) print('Predicted:', decode_predictions(preds, top=15)[0]) end = time.time() print('time:\n') print str(end-start)
时间统计时伪统计加载模型的时间,大概需要不到1秒,如果把加载模型的时间算进去,大概3s左右