以图搜图实现
此次需要使用到的工具:
IDE:eclipse,pydev
Python:3.10
Packages:Keras + TensorFlow + Pillow + Numpy
keras
Keras是一个高层神经网络API,Keras由纯Python编写而成并基 Tensorflow、 Theano以及 CNTK后端。简单来说,keras就是对TF等框架的再一次封装,使得使用起来更加方便。
基于vgg16网络提取图像特征 我们都知道,vgg网络在图像领域有着广泛的应用,后续许多层次更深,网络更宽的模型都是基于此扩展的,vgg网络能很好的提取到图片的有用特征,本次实现是基于Keras实现的,提取的是最后一层卷积特征。
思路
主要思路是基于CVPR2015的论文 《Deep Learning of Binary Hash Codes for Fast Image Retrieval》实现的海量数据下的基于内容图片检索系统。简单说来就是对图片数据库的每张图片抽取特征(一般形式为特征向量),
存储于数据库中,对于待检索图片,抽取同样的特征向量,然后并对该向量和数据库中向量的距离(相似度计算),找出最接近的一些特征向量,其对应的图片即为检索结果
from keras.applications.vgg16 import VGG16 from keras.preprocessing import image from keras.utils.image_utils import load_img,img_to_array from keras.applications.vgg16 import preprocess_input import numpy as np from numpy import linalg as LA import os import h5py import matplotlib.image as mpimg import matplotlib.pyplot as plt from PIL import Image class soutu2: def __init__(self): self.input_shape = (244, 244, 3) self.weights = 'imagenet' self.pooling = 'max' self.model = VGG16(weights=self.weights, input_shape=self.input_shape, pooling=self.pooling, include_top=False) self.model.predict(np.zeros((1, 244, 244, 3))) self.h5f_index = 'models/vgg_featureCNN.h5' #抽取某个目录中的图片特征并保存 def xunlian(self, dirpath): print("开始特征训练...") feats = [] names = [] img_list = self.get_imglist(dirpath) for i, img_path in enumerate(img_list): norm_feat = self.extract_feat(img_path) img_name = os.path.split(img_path)[1] feats.append(norm_feat) names.append(img_name) print("正在处理%s/%s图片" % (i+1, len(img_list))) feats = np.array(feats) h5f = h5py.File(self.h5f_index, 'w') h5f.create_dataset('database_1', data=feats) h5f.create_dataset('database_2', data=np.string_(names)) h5f.close() print("训练完毕\n") #查找相似图 def chazhao(self, img_path): print("开始按输入图片特征查找...") h5f = h5py.File(self.h5f_index, 'r') feats = h5f['database_1'][:] names = h5f['database_2'][:] h5f.close() img_paths = [] img_paths.append(img_path) query_feat = self.extract_feat(img_path) scores = np.dot(query_feat, feats.T) rank_id = np.argsort(scores)[::-1] rank_score = scores[rank_id] print("查找到以下图片") img_list = [] for i, index in enumerate(rank_id[0:3]): img_list.append(names[index]) img_path_in_db = 'database/%s' % str(names[index], 'utf-8') print("图片名称:%s,得分:%s" % (img_path_in_db, rank_score[i])) img_paths.append(img_path_in_db) self.show_imgs(img_paths) print("查找完毕\n") #提取图片的特征向量 def extract_feat(self, img_path): img = load_img(img_path, False, "rgb", target_size=(224, 224)) img = img_to_array(img) #扩展维度,因为preprocess_input需要4D的格式 img = np.expand_dims(img, axis=0) #对张量进行预处理 img = preprocess_input(img) feat = self.model.predict(img) norm_feat = feat[0] / LA.norm(feat[0]) return norm_feat def get_imglist(self, path): return [os.path.join(path, f) for f in os.listdir(path) if f.endswith('.jpg') or f.endswith('.jpeg') or f.endswith('.png') or f.endswith('.gif')] def show_img(self, img_path): query_img = mpimg.imread(img_path) plt.imshow(query_img) plt.show() def show_imgs(self, img_paths): fig = plt.figure(figsize=(12,4)) fig.canvas.manager.set_window_title("第一张为输入的搜索图,其余3张为搜索结果") for i,img_path in enumerate(img_paths): query_img = mpimg.imread(img_path) img_name = os.path.split(img_path)[1] ax = fig.add_subplot(2, 4, i + 1, xticks=[], yticks=[]) ax.set_title(img_name, color=("black" ), fontsize=6, ha='center') plt.subplots_adjust(wspace=0.05, hspace=0) plt.imshow(query_img) plt.show() if __name__ == '__main__': soutu2obj = soutu2() #soutu2obj.xunlian('database/') soutu2obj.chazhao('query/3.jpg')
训练:
搜索验证:
还挺好用
参考:
https://blog.51cto.com/captainbed/5572330
https://blog.csdn.net/starter_____/article/details/79340715
本文来自博客园,作者:河北大学-徐小波,转载请注明原文链接:https://www.cnblogs.com/xuxiaobo/p/17187948.html