以图搜图实现

此次需要使用到的工具:

IDE:eclipse,pydev
Python:3.10
Packages:Keras + TensorFlow + Pillow + Numpy

keras

Keras是一个高层神经网络API,Keras由纯Python编写而成并基​ ​Tensorflow​​​、​ ​Theano​​​以及​ ​CNTK​​后端。简单来说,keras就是对TF等框架的再一次封装,使得使用起来更加方便。

基于vgg16网络提取图像特征 我们都知道,vgg网络在图像领域有着广泛的应用,后续许多层次更深,网络更宽的模型都是基于此扩展的,vgg网络能很好的提取到图片的有用特征,本次实现是基于Keras实现的,提取的是最后一层卷积特征。


思路

主要思路是基于CVPR2015的论文​ ​《Deep Learning of Binary Hash Codes for Fast Image Retrieval》​​实现的海量数据下的基于内容图片检索系统。简单说来就是对图片数据库的每张图片抽取特征(一般形式为特征向量),

存储于数据库中,对于待检索图片,抽取同样的特征向量,然后并对该向量和数据库中向量的距离(相似度计算),找出最接近的一些特征向量,其对应的图片即为检索结果

 

from keras.applications.vgg16 import VGG16
from keras.preprocessing import image
from keras.utils.image_utils import load_img,img_to_array
from keras.applications.vgg16 import preprocess_input
import numpy as np
from numpy import linalg as LA
import os
import h5py
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
from PIL import Image

class soutu2:
    def __init__(self):
        self.input_shape = (244, 244, 3)
        self.weights = 'imagenet'
        self.pooling = 'max'
        self.model = VGG16(weights=self.weights, input_shape=self.input_shape, pooling=self.pooling, include_top=False)
        self.model.predict(np.zeros((1, 244, 244, 3)))
        self.h5f_index = 'models/vgg_featureCNN.h5'
        
    #抽取某个目录中的图片特征并保存
    def xunlian(self, dirpath):
        print("开始特征训练...")
        feats = []
        names = []
        img_list = self.get_imglist(dirpath)
        for i, img_path in enumerate(img_list):
            norm_feat = self.extract_feat(img_path)
            img_name = os.path.split(img_path)[1]
            feats.append(norm_feat)
            names.append(img_name)
            print("正在处理%s/%s图片" % (i+1, len(img_list)))
        feats = np.array(feats)
        
        h5f = h5py.File(self.h5f_index, 'w')
        h5f.create_dataset('database_1', data=feats)
        h5f.create_dataset('database_2', data=np.string_(names))
        h5f.close()
        
        print("训练完毕\n")
        
    #查找相似图
    def chazhao(self, img_path):
        print("开始按输入图片特征查找...")
        h5f = h5py.File(self.h5f_index, 'r')
        feats = h5f['database_1'][:]
        names = h5f['database_2'][:]
        h5f.close()
        
        img_paths = []
        img_paths.append(img_path)
        
        query_feat = self.extract_feat(img_path)
        scores = np.dot(query_feat, feats.T)
        rank_id = np.argsort(scores)[::-1]
        rank_score = scores[rank_id]
        
        print("查找到以下图片")
        img_list = []
        for i, index in enumerate(rank_id[0:3]):
            img_list.append(names[index])
            img_path_in_db = 'database/%s' % str(names[index], 'utf-8')
            print("图片名称:%s,得分:%s" % (img_path_in_db, rank_score[i]))
            img_paths.append(img_path_in_db)
            
        self.show_imgs(img_paths)
        print("查找完毕\n")
        
    #提取图片的特征向量
    def extract_feat(self, img_path):
        img = load_img(img_path, False, "rgb", target_size=(224, 224))
        img = img_to_array(img)
        #扩展维度,因为preprocess_input需要4D的格式
        img = np.expand_dims(img, axis=0)
        #对张量进行预处理
        img = preprocess_input(img)
        feat = self.model.predict(img)
        norm_feat = feat[0] / LA.norm(feat[0])
        
        return norm_feat
    
    def get_imglist(self, path):
        return [os.path.join(path, f) for f in os.listdir(path) if f.endswith('.jpg') or f.endswith('.jpeg') or f.endswith('.png') or f.endswith('.gif')]
    
    def show_img(self, img_path):
        query_img = mpimg.imread(img_path)
        plt.imshow(query_img)
        plt.show()
        
    
    def show_imgs(self, img_paths):
        fig = plt.figure(figsize=(12,4))
        fig.canvas.manager.set_window_title("第一张为输入的搜索图,其余3张为搜索结果")
        for i,img_path in enumerate(img_paths):
            query_img = mpimg.imread(img_path)
            img_name = os.path.split(img_path)[1]
            ax = fig.add_subplot(2, 4, i + 1, xticks=[], yticks=[])
            ax.set_title(img_name, color=("black" ), fontsize=6, ha='center')
            plt.subplots_adjust(wspace=0.05, hspace=0)
            plt.imshow(query_img)
        plt.show()
    
if __name__ == '__main__':
    soutu2obj = soutu2()
    #soutu2obj.xunlian('database/')
    
    soutu2obj.chazhao('query/3.jpg')

 

训练:

 

 

 

 

搜索验证:

 

 

 

 

 

 

 

 

 

 

 

还挺好用

 

参考:

https://blog.51cto.com/captainbed/5572330

https://blog.csdn.net/starter_____/article/details/79340715

 

posted @ 2023-03-07 14:08  河北大学-徐小波  阅读(115)  评论(0编辑  收藏  举报