图像检索实践

1、Related_functions.py

from PIL import Image
import os
import numpy as np
import cv2
import warnings
warnings.filterwarnings("ignore", category=Warning)

def get_feature(image_dir):
    im = cv2.imread(image_dir)
    #start_time1=time.time()  
    hog = cv2.HOGDescriptor()
    winStride = (8, 8)
    padding = (8, 8)
    hist = hog.compute(im, winStride, padding)
    hist = hist.reshape((-1,))
    #stop_time1=time.time()
    #print (stop_time1-start_time1)
    return hist
'''
import torch
from torchvision import models, transforms
def get_feature(image_dir):
    vgg_model = models.vgg19(pretrained=True)
    new_classifier = torch.nn.Sequential(*list(vgg_model.children())[-1][:6])
    vgg_model.classifier = new_classifier
    trans = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    im = Image.open(image_dir).convert('RGB')
    im = trans(im)
    im.unsqueeze_(dim=0)
    vgg_model = vgg_model.eval()
    y = vgg_model(im).data.numpy().tolist()
    feature  = y[0]
    return feature
'''
def get_img_feature(img_dir):
    img_feature = get_feature(img_dir)
    return img_feature

def get_Datasets_feature(Datasets_dir):
    if not Datasets_dir.endswith('/'):
        Datasets_dir = Datasets_dir + '/'
    try:
        os.listdir(Datasets_dir)
    except:
        print('请检查数据库的路径')
    all_feature= []
    paths = []
    for fi in os.listdir(Datasets_dir):
        img_paths = Datasets_dir+fi+'/'
        for fj in os.listdir(img_paths):
            img_dir = img_paths+fj
            img_feature = get_img_feature(img_dir)
            all_feature.append(img_feature)
            paths.append(img_dir) 
            print('正在提特征的图像是:',img_dir)  
    return all_feature,paths  
def calEuclideanDistance(x,y):
    return np.sqrt(sum(pow(a-b,2) for a,b in zip(x,y)))
def query_sim_img(query_img_feature,Datasets_features,Datasets_paths,top_num):
    need_im_det_instance=[]
    for i in range(Datasets_features.shape[0]):
        img_feature = Datasets_features[i,:]
        dist=calEuclideanDistance(query_img_feature,img_feature) 
        need_im_det_instance.append(dist)
    im_distanc=np.array(need_im_det_instance)
    y = im_distanc.argsort()
    similar_img_path=[]
    for index in y[0:top_num]:
        similar_img_path.append(Datasets_paths[index])
    return similar_img_path

 

 

 

import torch
from torchvision import models, transforms
from PIL import Image
import os
import numpy as np

import warnings
warnings.filterwarnings("ignore", category=Warning)

def get_feature(image_dir):
    vgg_model = models.vgg19(pretrained=True)
    new_classifier = torch.nn.Sequential(*list(vgg_model.children())[-1][:6])
    vgg_model.classifier = new_classifier
    trans = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    im = Image.open(image_dir).convert('RGB')
    im = trans(im)
    im.unsqueeze_(dim=0)
    vgg_model = vgg_model.eval()
    y = vgg_model(im).data.numpy().tolist()
    feature  = y[0]
    return feature

def get_img_feature(img_dir):
    img_feature = get_feature(img_dir)
    return img_feature

def get_Datasets_feature(Datasets_dir):
    if not Datasets_dir.endswith('/'):
        Datasets_dir = Datasets_dir + '/'
    try:
        os.listdir(Datasets_dir)
    except:
        print('请检查数据库的路径')
    all_feature= []
    paths = []
    for fi in os.listdir(Datasets_dir):
        img_paths = Datasets_dir+fi+'/'
        for fj in os.listdir(img_paths):
            img_dir = img_paths+fj
            img_feature = get_img_feature(img_dir)
            all_feature.append(img_feature)
            paths.append(img_dir) 
            print('正在提特征的图像是:',img_dir)  
    return all_feature,paths  
def calEuclideanDistance(x,y):
    return np.sqrt(sum(pow(a-b,2) for a,b in zip(x,y)))
def query_sim_img(query_img_feature,Datasets_features,Datasets_paths,top_num):
    need_im_det_instance=[]
    for i in range(Datasets_features.shape[0]):
        img_feature = Datasets_features[i,:]
        dist=calEuclideanDistance(query_img_feature,img_feature) 
        need_im_det_instance.append(dist)
    im_distanc=np.array(need_im_det_instance)
    y = im_distanc.argsort()
    similar_img_path=[]
    for index in y[0:top_num]:
        similar_img_path.append(Datasets_paths[index])
    return similar_img_path

 

2、img_dataset_feature.py

# -*- coding:utf-8 -*-
import os
from Related_functions import *

if __name__ == '__main__':
    Datasets_dir = 'D:/My_work/python_code/03_lianxi/C00304/Animals_with_Attributes2/JPEGImages'
    features,paths = get_Datasets_feature(Datasets_dir)
    np.save('Datasets_features.npy', features)
    np.save('Datasets_paths.npy', paths)

 

 

3、mian_gui.py

import sys, cv2
from PyQt5 import QtCore, QtGui, QtWidgets
from PyQt5.QtWidgets import QApplication
from PyQt5.QtCore import Qt
from Related_functions import *
import numpy  as np

class Main_Ui_Dialog(object):
    def setupUi(self, main_Dialog):
        main_Dialog.setObjectName("main_Dialog")
        main_Dialog.resize(1600, 1200)   #整个窗体会被resize到这个尺寸显示

        main_Dialog.setFixedSize(main_Dialog.width(), main_Dialog.height());
        #背景
        palette = QtGui.QPalette()
        palette.setColor(QtGui.QPalette.Background, Qt.gray)
        main_Dialog.setPalette(palette)
       
        #字体
        font = QtGui.QFont()
        font.setFamily("Times New Roman")
        font.setPointSize(12)

        #窗体设计
        ####title
        self.label_0 = QtWidgets.QLabel(main_Dialog)        
        self.label_0.setGeometry(QtCore.QRect(570, 40, 800, 100))
        font0 = QtGui.QFont()
        font0.setFamily("Times New Roman")
        font0.setPointSize(32)
        self.label_0.setFont(font0)
        self.label_0.setObjectName("label_0")
       
        ####choose Datasets:
        self.label_1 = QtWidgets.QLabel(main_Dialog)
        self.label_1.setGeometry(QtCore.QRect(65, 180, 220, 40))
        self.label_1.setFont(font)      
        self.label_1.setObjectName("label_1")
       
        ####the path of Datasets:
        self.lineEdit1 = QtWidgets.QLineEdit(main_Dialog)      
        self.lineEdit1.setGeometry(QtCore.QRect(300, 180, 1165, 40))   #(a,b,w,h)
        self.lineEdit1.setFont(font)
        self.lineEdit1.setObjectName("lineEdit1")
       
        ####...
        self.toolButton1 = QtWidgets.QToolButton(main_Dialog)  
        self.toolButton1.setGeometry(QtCore.QRect(1425, 180, 100, 40)) #(a+w,b,w,h)
        self.toolButton1.setFont(font)    
        self.toolButton1.setObjectName("toolButton1")
       
        ####choose TestImg:
        self.label_2 = QtWidgets.QLabel(main_Dialog)
        self.label_2.setGeometry(QtCore.QRect(65, 240, 220, 40))
        self.label_2.setFont(font)      
        self.label_2.setObjectName("label_2")
       
        ####the path of TestImg:
        self.lineEdit2 = QtWidgets.QLineEdit(main_Dialog)      
        self.lineEdit2.setGeometry(QtCore.QRect(300, 240, 1165, 40))   #(a,b,w,h)
        self.lineEdit2.setFont(font)
        self.lineEdit2.setObjectName("lineEdit2")
       
        ####...
        self.toolButton2 = QtWidgets.QToolButton(main_Dialog)  
        self.toolButton2.setGeometry(QtCore.QRect(1425, 240, 100, 40)) #(a+w,b,w,h)
        self.toolButton2.setFont(font)    
        self.toolButton2.setObjectName("toolButton2")

        ####query
        self.toolButton_query = QtWidgets.QToolButton(main_Dialog)  
        self.toolButton_query.setGeometry(QtCore.QRect(1170, 320, 160, 60))
        self.toolButton_query.setFont(font)    
        self.toolButton_query.setObjectName("toolButton_query")

        ####clear all
        self.toolButton_end = QtWidgets.QToolButton(main_Dialog)  
        self.toolButton_end.setGeometry(QtCore.QRect(1370, 320, 160, 60))
        self.toolButton_end.setFont(font)    
        self.toolButton_end.setObjectName("toolButton_end")
        ####query result
        self.graphicsView_query = QtWidgets.QGraphicsView(main_Dialog)
        self.graphicsView_query.setGeometry(QtCore.QRect(65, 400, 270, 270))
        self.graphicsView_query.setObjectName("graphicsView_query")
        self.label_query = QtWidgets.QLabel(main_Dialog)  
        self.label_query.setGeometry(QtCore.QRect(65, 400, 270, 270))
        self.label_query.setObjectName("label_query")
        self.label_query.setScaledContents(True)

        self.graphicsView_sim1 = QtWidgets.QGraphicsView(main_Dialog)
        self.graphicsView_sim1.setGeometry(QtCore.QRect(360, 400, 270, 270))
        self.graphicsView_sim1.setObjectName("graphicsView_sim1")
        self.label_sim1 = QtWidgets.QLabel(main_Dialog)  
        self.label_sim1.setGeometry(QtCore.QRect(360, 400, 270, 270))
        self.label_sim1.setObjectName("label_sim1")
        self.label_sim1.setScaledContents(True)

        self.graphicsView_sim2 = QtWidgets.QGraphicsView(main_Dialog)
        self.graphicsView_sim2.setGeometry(QtCore.QRect(660, 400, 270, 270))
        self.graphicsView_sim2.setObjectName("graphicsView_sim1")
        self.label_sim2 = QtWidgets.QLabel(main_Dialog)  
        self.label_sim2.setGeometry(QtCore.QRect(660, 400, 270, 270))
        self.label_sim2.setObjectName("label_sim2")
        self.label_sim2.setScaledContents(True)

        self.graphicsView_sim3= QtWidgets.QGraphicsView(main_Dialog)
        self.graphicsView_sim3.setGeometry(QtCore.QRect(960, 400, 270, 270))
        self.graphicsView_sim3.setObjectName("graphicsView_sim3")
        self.label_sim3 = QtWidgets.QLabel(main_Dialog)  
        self.label_sim3.setGeometry(QtCore.QRect(960, 400, 270, 270))
        self.label_sim3.setObjectName("label_sim3")
        self.label_sim3.setScaledContents(True)        

        self.graphicsView_sim4= QtWidgets.QGraphicsView(main_Dialog)
        self.graphicsView_sim4.setGeometry(QtCore.QRect(1255, 400, 270, 270))
        self.graphicsView_sim4.setObjectName("graphicsView_sim4")
        self.label_sim4 = QtWidgets.QLabel(main_Dialog)  
        self.label_sim4.setGeometry(QtCore.QRect(1255, 400, 270, 270))
        self.label_sim4.setObjectName("label_sim4")
        self.label_sim4.setScaledContents(True)  

        self.graphicsView_sim5 = QtWidgets.QGraphicsView(main_Dialog)
        self.graphicsView_sim5.setGeometry(QtCore.QRect(360, 750, 270, 270))
        self.graphicsView_sim5.setObjectName("graphicsView_sim5")
        self.label_sim5 = QtWidgets.QLabel(main_Dialog)  
        self.label_sim5.setGeometry(QtCore.QRect(360, 750, 270, 270))
        self.label_sim5.setObjectName("label_sim5")
        self.label_sim5.setScaledContents(True)

        self.graphicsView_sim6 = QtWidgets.QGraphicsView(main_Dialog)
        self.graphicsView_sim6.setGeometry(QtCore.QRect(660, 750, 270, 270))
        self.graphicsView_sim6.setObjectName("graphicsView_sim6")
        self.label_sim6 = QtWidgets.QLabel(main_Dialog)  
        self.label_sim6.setGeometry(QtCore.QRect(660, 750, 270, 270))
        self.label_sim6.setObjectName("label_sim6")
        self.label_sim6.setScaledContents(True)

        self.graphicsView_sim7= QtWidgets.QGraphicsView(main_Dialog)
        self.graphicsView_sim7.setGeometry(QtCore.QRect(960, 750, 270, 270))
        self.graphicsView_sim7.setObjectName("graphicsView_sim7")
        self.label_sim7 = QtWidgets.QLabel(main_Dialog)  
        self.label_sim7.setGeometry(QtCore.QRect(960, 750, 270, 270))
        self.label_sim7.setObjectName("label_sim7")
        self.label_sim7.setScaledContents(True)        

        self.graphicsView_sim8= QtWidgets.QGraphicsView(main_Dialog)
        self.graphicsView_sim8.setGeometry(QtCore.QRect(1255, 750, 270, 270))
        self.graphicsView_sim8.setObjectName("graphicsView_sim8")
        self.label_sim8 = QtWidgets.QLabel(main_Dialog)  
        self.label_sim8.setGeometry(QtCore.QRect(1255, 750, 270, 270))
        self.label_sim8.setObjectName("label_sim8")
        self.label_sim8.setScaledContents(True)  

        ####test_img
        self.label_3 = QtWidgets.QLabel(main_Dialog)
        self.label_3.setGeometry(QtCore.QRect(131, 751, 170, 40))
        self.label_3.setFont(font)      
        self.label_3.setObjectName("label_3")
       
        ####result_img
        self.label_4 = QtWidgets.QLabel(main_Dialog)
        self.label_4.setGeometry(QtCore.QRect(850, 1050, 170, 40))
        self.label_4.setFont(font)      
        self.label_4.setObjectName("label_4")

        self.retranslateUi(main_Dialog)
        QtCore.QMetaObject.connectSlotsByName(main_Dialog)
 
    def retranslateUi(self, main_Dialog):
        _translate = QtCore.QCoreApplication.translate
        main_Dialog.setWindowTitle(_translate("main_Dialog", "Image_Retrieval_System"))
       
        self.label_0.setText(_translate("main_Dialog", "Image_Retrieval"))
        self.label_1.setText(_translate("main_Dialog", "Choose_Datasets"))
        self.lineEdit1.setText(_translate("main_Dialog", "D:/"))
        self.toolButton1.setText(_translate("main_Dialog", "..."))
        self.label_2.setText(_translate("main_Dialog", "Choose_TestImg"))
        self.lineEdit2.setText(_translate("main_Dialog", "D:/"))
        self.toolButton2.setText(_translate("main_Dialog", "..."))    
        self.toolButton_query.setText(_translate("main_Dialog", "query"))
        self.toolButton_end.setText(_translate("main_Dialog", "clear all"))
        self.label_3.setText(_translate("main_Dialog", "Query_Img"))
        self.label_4.setText(_translate("main_Dialog", "Result_Img"))

class Image_Processing(QtWidgets.QWidget, Main_Ui_Dialog):
    def __init__(self):
        super(Image_Processing, self).__init__()
        self.setupUi(self)
        self.toolButton1.clicked.connect(self.ChooseDatasetPath)
        self.toolButton2.clicked.connect(self.ChooseImgPath)
        self.toolButton_query.clicked.connect(self.query_img)
        self.toolButton_end.clicked.connect(self.clear_img_and_result)  

    def ChooseDatasetPath(self):
        file_name = QtWidgets.QFileDialog.getOpenFileName(self, "open file dialog", "E:/")#, "图片(*.npy)")
        print(file_name[0])
        self.test_datasets_path = file_name[0]
        self.lineEdit1.setText(self.test_datasets_path)
        if 'feature' in self.test_datasets_path:
            try:
                self.Datasets_features = np.load(self.test_datasets_path)
                self.Datasets_paths = np.load((self.test_datasets_path).replace('features','paths'))
            except:
                print("test_datasets_path is error")
        if 'path' in self.test_datasets_path:
            try:
                self.Datasets_paths = np.load(self.test_datasets_path)
                self.Datasets_features = np.load((self.test_datasets_path).replace('paths','features'))
            except:
                print("test_datasets_path is error")

    def ChooseImgPath(self):
        file_name = QtWidgets.QFileDialog.getOpenFileName(self, "open file dialog", "E:/")#, "图片(*.jpg)")
        print(file_name[0])
        self.query_img_path = file_name[0]
        self.lineEdit2.setText(self.query_img_path)
        self.label_query.setPixmap(QtGui.QPixmap(self.query_img_path)) #显示待检索图像
        self.query_img_feature = get_img_feature(self.query_img_path)

    def query_img(self):
        self.top_num = 8
        self.similar_img_path = query_sim_img(self.query_img_feature,self.Datasets_features,self.Datasets_paths,self.top_num)
        Image_Processing.show_img(self)
    def show_img(self):
        self.label_sim1.setPixmap(QtGui.QPixmap(self.similar_img_path[0]))
        self.label_sim2.setPixmap(QtGui.QPixmap(self.similar_img_path[1]))
        self.label_sim3.setPixmap(QtGui.QPixmap(self.similar_img_path[2]))
        self.label_sim4.setPixmap(QtGui.QPixmap(self.similar_img_path[3]))
        self.label_sim5.setPixmap(QtGui.QPixmap(self.similar_img_path[4]))
        self.label_sim6.setPixmap(QtGui.QPixmap(self.similar_img_path[5]))
        self.label_sim7.setPixmap(QtGui.QPixmap(self.similar_img_path[6]))
        self.label_sim8.setPixmap(QtGui.QPixmap(self.similar_img_path[7]))
    def clear_img_and_result(self):
        self.lineEdit2.clear()
        self.label_query.clear()
        self.label_sim1.clear()
        self.label_sim2.clear()
        self.label_sim3.clear()
        self.label_sim4.clear()
        self.label_sim5.clear()
        self.label_sim6.clear()
        self.label_sim7.clear()
        self.label_sim8.clear()
     
if __name__ == '__main__':
    app = QtWidgets.QApplication(sys.argv)
    w = Image_Processing()
    w.show()
    sys.exit(app.exec_())

 

 

posted @ 2024-06-05 15:02  皮卡皮卡妞  阅读(107)  评论(0编辑  收藏  举报