图像检索实践
1、Related_functions.py
from PIL import Image import os import numpy as np import cv2 import warnings warnings.filterwarnings("ignore", category=Warning) def get_feature(image_dir): im = cv2.imread(image_dir) #start_time1=time.time() hog = cv2.HOGDescriptor() winStride = (8, 8) padding = (8, 8) hist = hog.compute(im, winStride, padding) hist = hist.reshape((-1,)) #stop_time1=time.time() #print (stop_time1-start_time1) return hist ''' import torch from torchvision import models, transforms def get_feature(image_dir): vgg_model = models.vgg19(pretrained=True) new_classifier = torch.nn.Sequential(*list(vgg_model.children())[-1][:6]) vgg_model.classifier = new_classifier trans = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) im = Image.open(image_dir).convert('RGB') im = trans(im) im.unsqueeze_(dim=0) vgg_model = vgg_model.eval() y = vgg_model(im).data.numpy().tolist() feature = y[0] return feature ''' def get_img_feature(img_dir): img_feature = get_feature(img_dir) return img_feature def get_Datasets_feature(Datasets_dir): if not Datasets_dir.endswith('/'): Datasets_dir = Datasets_dir + '/' try: os.listdir(Datasets_dir) except: print('请检查数据库的路径') all_feature= [] paths = [] for fi in os.listdir(Datasets_dir): img_paths = Datasets_dir+fi+'/' for fj in os.listdir(img_paths): img_dir = img_paths+fj img_feature = get_img_feature(img_dir) all_feature.append(img_feature) paths.append(img_dir) print('正在提特征的图像是:',img_dir) return all_feature,paths def calEuclideanDistance(x,y): return np.sqrt(sum(pow(a-b,2) for a,b in zip(x,y))) def query_sim_img(query_img_feature,Datasets_features,Datasets_paths,top_num): need_im_det_instance=[] for i in range(Datasets_features.shape[0]): img_feature = Datasets_features[i,:] dist=calEuclideanDistance(query_img_feature,img_feature) need_im_det_instance.append(dist) im_distanc=np.array(need_im_det_instance) y = im_distanc.argsort() similar_img_path=[] for index in y[0:top_num]: similar_img_path.append(Datasets_paths[index]) return similar_img_path
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 | import torch from torchvision import models, transforms from PIL import Image import os import numpy as np import warnings warnings.filterwarnings( "ignore" , category = Warning) def get_feature(image_dir): vgg_model = models.vgg19(pretrained = True ) new_classifier = torch.nn.Sequential( * list (vgg_model.children())[ - 1 ][: 6 ]) vgg_model.classifier = new_classifier trans = transforms.Compose([ transforms.Resize(( 224 , 224 )), transforms.ToTensor(), transforms.Normalize(mean = [ 0.485 , 0.456 , 0.406 ], std = [ 0.229 , 0.224 , 0.225 ]) ]) im = Image. open (image_dir).convert( 'RGB' ) im = trans(im) im.unsqueeze_(dim = 0 ) vgg_model = vgg_model. eval () y = vgg_model(im).data.numpy().tolist() feature = y[ 0 ] return feature def get_img_feature(img_dir): img_feature = get_feature(img_dir) return img_feature def get_Datasets_feature(Datasets_dir): if not Datasets_dir.endswith( '/' ): Datasets_dir = Datasets_dir + '/' try : os.listdir(Datasets_dir) except : print ( '请检查数据库的路径' ) all_feature = [] paths = [] for fi in os.listdir(Datasets_dir): img_paths = Datasets_dir + fi + '/' for fj in os.listdir(img_paths): img_dir = img_paths + fj img_feature = get_img_feature(img_dir) all_feature.append(img_feature) paths.append(img_dir) print ( '正在提特征的图像是:' ,img_dir) return all_feature,paths def calEuclideanDistance(x,y): return np.sqrt( sum ( pow (a - b, 2 ) for a,b in zip (x,y))) def query_sim_img(query_img_feature,Datasets_features,Datasets_paths,top_num): need_im_det_instance = [] for i in range (Datasets_features.shape[ 0 ]): img_feature = Datasets_features[i,:] dist = calEuclideanDistance(query_img_feature,img_feature) need_im_det_instance.append(dist) im_distanc = np.array(need_im_det_instance) y = im_distanc.argsort() similar_img_path = [] for index in y[ 0 :top_num]: similar_img_path.append(Datasets_paths[index]) return similar_img_path |
2、img_dataset_feature.py
# -*- coding:utf-8 -*- import os from Related_functions import * if __name__ == '__main__': Datasets_dir = 'D:/My_work/python_code/03_lianxi/C00304/Animals_with_Attributes2/JPEGImages' features,paths = get_Datasets_feature(Datasets_dir) np.save('Datasets_features.npy', features) np.save('Datasets_paths.npy', paths)
3、mian_gui.py
import sys, cv2
from PyQt5 import QtCore, QtGui, QtWidgets
from PyQt5.QtWidgets import QApplication
from PyQt5.QtCore import Qt
from Related_functions import *
import numpy as np
class Main_Ui_Dialog(object):
def setupUi(self, main_Dialog):
main_Dialog.setObjectName("main_Dialog")
main_Dialog.resize(1600, 1200) #整个窗体会被resize到这个尺寸显示
main_Dialog.setFixedSize(main_Dialog.width(), main_Dialog.height());
#背景
palette = QtGui.QPalette()
palette.setColor(QtGui.QPalette.Background, Qt.gray)
main_Dialog.setPalette(palette)
#字体
font = QtGui.QFont()
font.setFamily("Times New Roman")
font.setPointSize(12)
#窗体设计
####title
self.label_0 = QtWidgets.QLabel(main_Dialog)
self.label_0.setGeometry(QtCore.QRect(570, 40, 800, 100))
font0 = QtGui.QFont()
font0.setFamily("Times New Roman")
font0.setPointSize(32)
self.label_0.setFont(font0)
self.label_0.setObjectName("label_0")
####choose Datasets:
self.label_1 = QtWidgets.QLabel(main_Dialog)
self.label_1.setGeometry(QtCore.QRect(65, 180, 220, 40))
self.label_1.setFont(font)
self.label_1.setObjectName("label_1")
####the path of Datasets:
self.lineEdit1 = QtWidgets.QLineEdit(main_Dialog)
self.lineEdit1.setGeometry(QtCore.QRect(300, 180, 1165, 40)) #(a,b,w,h)
self.lineEdit1.setFont(font)
self.lineEdit1.setObjectName("lineEdit1")
####...
self.toolButton1 = QtWidgets.QToolButton(main_Dialog)
self.toolButton1.setGeometry(QtCore.QRect(1425, 180, 100, 40)) #(a+w,b,w,h)
self.toolButton1.setFont(font)
self.toolButton1.setObjectName("toolButton1")
####choose TestImg:
self.label_2 = QtWidgets.QLabel(main_Dialog)
self.label_2.setGeometry(QtCore.QRect(65, 240, 220, 40))
self.label_2.setFont(font)
self.label_2.setObjectName("label_2")
####the path of TestImg:
self.lineEdit2 = QtWidgets.QLineEdit(main_Dialog)
self.lineEdit2.setGeometry(QtCore.QRect(300, 240, 1165, 40)) #(a,b,w,h)
self.lineEdit2.setFont(font)
self.lineEdit2.setObjectName("lineEdit2")
####...
self.toolButton2 = QtWidgets.QToolButton(main_Dialog)
self.toolButton2.setGeometry(QtCore.QRect(1425, 240, 100, 40)) #(a+w,b,w,h)
self.toolButton2.setFont(font)
self.toolButton2.setObjectName("toolButton2")
####query
self.toolButton_query = QtWidgets.QToolButton(main_Dialog)
self.toolButton_query.setGeometry(QtCore.QRect(1170, 320, 160, 60))
self.toolButton_query.setFont(font)
self.toolButton_query.setObjectName("toolButton_query")
####clear all
self.toolButton_end = QtWidgets.QToolButton(main_Dialog)
self.toolButton_end.setGeometry(QtCore.QRect(1370, 320, 160, 60))
self.toolButton_end.setFont(font)
self.toolButton_end.setObjectName("toolButton_end")
####query result
self.graphicsView_query = QtWidgets.QGraphicsView(main_Dialog)
self.graphicsView_query.setGeometry(QtCore.QRect(65, 400, 270, 270))
self.graphicsView_query.setObjectName("graphicsView_query")
self.label_query = QtWidgets.QLabel(main_Dialog)
self.label_query.setGeometry(QtCore.QRect(65, 400, 270, 270))
self.label_query.setObjectName("label_query")
self.label_query.setScaledContents(True)
self.graphicsView_sim1 = QtWidgets.QGraphicsView(main_Dialog)
self.graphicsView_sim1.setGeometry(QtCore.QRect(360, 400, 270, 270))
self.graphicsView_sim1.setObjectName("graphicsView_sim1")
self.label_sim1 = QtWidgets.QLabel(main_Dialog)
self.label_sim1.setGeometry(QtCore.QRect(360, 400, 270, 270))
self.label_sim1.setObjectName("label_sim1")
self.label_sim1.setScaledContents(True)
self.graphicsView_sim2 = QtWidgets.QGraphicsView(main_Dialog)
self.graphicsView_sim2.setGeometry(QtCore.QRect(660, 400, 270, 270))
self.graphicsView_sim2.setObjectName("graphicsView_sim1")
self.label_sim2 = QtWidgets.QLabel(main_Dialog)
self.label_sim2.setGeometry(QtCore.QRect(660, 400, 270, 270))
self.label_sim2.setObjectName("label_sim2")
self.label_sim2.setScaledContents(True)
self.graphicsView_sim3= QtWidgets.QGraphicsView(main_Dialog)
self.graphicsView_sim3.setGeometry(QtCore.QRect(960, 400, 270, 270))
self.graphicsView_sim3.setObjectName("graphicsView_sim3")
self.label_sim3 = QtWidgets.QLabel(main_Dialog)
self.label_sim3.setGeometry(QtCore.QRect(960, 400, 270, 270))
self.label_sim3.setObjectName("label_sim3")
self.label_sim3.setScaledContents(True)
self.graphicsView_sim4= QtWidgets.QGraphicsView(main_Dialog)
self.graphicsView_sim4.setGeometry(QtCore.QRect(1255, 400, 270, 270))
self.graphicsView_sim4.setObjectName("graphicsView_sim4")
self.label_sim4 = QtWidgets.QLabel(main_Dialog)
self.label_sim4.setGeometry(QtCore.QRect(1255, 400, 270, 270))
self.label_sim4.setObjectName("label_sim4")
self.label_sim4.setScaledContents(True)
self.graphicsView_sim5 = QtWidgets.QGraphicsView(main_Dialog)
self.graphicsView_sim5.setGeometry(QtCore.QRect(360, 750, 270, 270))
self.graphicsView_sim5.setObjectName("graphicsView_sim5")
self.label_sim5 = QtWidgets.QLabel(main_Dialog)
self.label_sim5.setGeometry(QtCore.QRect(360, 750, 270, 270))
self.label_sim5.setObjectName("label_sim5")
self.label_sim5.setScaledContents(True)
self.graphicsView_sim6 = QtWidgets.QGraphicsView(main_Dialog)
self.graphicsView_sim6.setGeometry(QtCore.QRect(660, 750, 270, 270))
self.graphicsView_sim6.setObjectName("graphicsView_sim6")
self.label_sim6 = QtWidgets.QLabel(main_Dialog)
self.label_sim6.setGeometry(QtCore.QRect(660, 750, 270, 270))
self.label_sim6.setObjectName("label_sim6")
self.label_sim6.setScaledContents(True)
self.graphicsView_sim7= QtWidgets.QGraphicsView(main_Dialog)
self.graphicsView_sim7.setGeometry(QtCore.QRect(960, 750, 270, 270))
self.graphicsView_sim7.setObjectName("graphicsView_sim7")
self.label_sim7 = QtWidgets.QLabel(main_Dialog)
self.label_sim7.setGeometry(QtCore.QRect(960, 750, 270, 270))
self.label_sim7.setObjectName("label_sim7")
self.label_sim7.setScaledContents(True)
self.graphicsView_sim8= QtWidgets.QGraphicsView(main_Dialog)
self.graphicsView_sim8.setGeometry(QtCore.QRect(1255, 750, 270, 270))
self.graphicsView_sim8.setObjectName("graphicsView_sim8")
self.label_sim8 = QtWidgets.QLabel(main_Dialog)
self.label_sim8.setGeometry(QtCore.QRect(1255, 750, 270, 270))
self.label_sim8.setObjectName("label_sim8")
self.label_sim8.setScaledContents(True)
####test_img
self.label_3 = QtWidgets.QLabel(main_Dialog)
self.label_3.setGeometry(QtCore.QRect(131, 751, 170, 40))
self.label_3.setFont(font)
self.label_3.setObjectName("label_3")
####result_img
self.label_4 = QtWidgets.QLabel(main_Dialog)
self.label_4.setGeometry(QtCore.QRect(850, 1050, 170, 40))
self.label_4.setFont(font)
self.label_4.setObjectName("label_4")
self.retranslateUi(main_Dialog)
QtCore.QMetaObject.connectSlotsByName(main_Dialog)
def retranslateUi(self, main_Dialog):
_translate = QtCore.QCoreApplication.translate
main_Dialog.setWindowTitle(_translate("main_Dialog", "Image_Retrieval_System"))
self.label_0.setText(_translate("main_Dialog", "Image_Retrieval"))
self.label_1.setText(_translate("main_Dialog", "Choose_Datasets"))
self.lineEdit1.setText(_translate("main_Dialog", "D:/"))
self.toolButton1.setText(_translate("main_Dialog", "..."))
self.label_2.setText(_translate("main_Dialog", "Choose_TestImg"))
self.lineEdit2.setText(_translate("main_Dialog", "D:/"))
self.toolButton2.setText(_translate("main_Dialog", "..."))
self.toolButton_query.setText(_translate("main_Dialog", "query"))
self.toolButton_end.setText(_translate("main_Dialog", "clear all"))
self.label_3.setText(_translate("main_Dialog", "Query_Img"))
self.label_4.setText(_translate("main_Dialog", "Result_Img"))
class Image_Processing(QtWidgets.QWidget, Main_Ui_Dialog):
def __init__(self):
super(Image_Processing, self).__init__()
self.setupUi(self)
self.toolButton1.clicked.connect(self.ChooseDatasetPath)
self.toolButton2.clicked.connect(self.ChooseImgPath)
self.toolButton_query.clicked.connect(self.query_img)
self.toolButton_end.clicked.connect(self.clear_img_and_result)
def ChooseDatasetPath(self):
file_name = QtWidgets.QFileDialog.getOpenFileName(self, "open file dialog", "E:/")#, "图片(*.npy)")
print(file_name[0])
self.test_datasets_path = file_name[0]
self.lineEdit1.setText(self.test_datasets_path)
if 'feature' in self.test_datasets_path:
try:
self.Datasets_features = np.load(self.test_datasets_path)
self.Datasets_paths = np.load((self.test_datasets_path).replace('features','paths'))
except:
print("test_datasets_path is error")
if 'path' in self.test_datasets_path:
try:
self.Datasets_paths = np.load(self.test_datasets_path)
self.Datasets_features = np.load((self.test_datasets_path).replace('paths','features'))
except:
print("test_datasets_path is error")
def ChooseImgPath(self):
file_name = QtWidgets.QFileDialog.getOpenFileName(self, "open file dialog", "E:/")#, "图片(*.jpg)")
print(file_name[0])
self.query_img_path = file_name[0]
self.lineEdit2.setText(self.query_img_path)
self.label_query.setPixmap(QtGui.QPixmap(self.query_img_path)) #显示待检索图像
self.query_img_feature = get_img_feature(self.query_img_path)
def query_img(self):
self.top_num = 8
self.similar_img_path = query_sim_img(self.query_img_feature,self.Datasets_features,self.Datasets_paths,self.top_num)
Image_Processing.show_img(self)
def show_img(self):
self.label_sim1.setPixmap(QtGui.QPixmap(self.similar_img_path[0]))
self.label_sim2.setPixmap(QtGui.QPixmap(self.similar_img_path[1]))
self.label_sim3.setPixmap(QtGui.QPixmap(self.similar_img_path[2]))
self.label_sim4.setPixmap(QtGui.QPixmap(self.similar_img_path[3]))
self.label_sim5.setPixmap(QtGui.QPixmap(self.similar_img_path[4]))
self.label_sim6.setPixmap(QtGui.QPixmap(self.similar_img_path[5]))
self.label_sim7.setPixmap(QtGui.QPixmap(self.similar_img_path[6]))
self.label_sim8.setPixmap(QtGui.QPixmap(self.similar_img_path[7]))
def clear_img_and_result(self):
self.lineEdit2.clear()
self.label_query.clear()
self.label_sim1.clear()
self.label_sim2.clear()
self.label_sim3.clear()
self.label_sim4.clear()
self.label_sim5.clear()
self.label_sim6.clear()
self.label_sim7.clear()
self.label_sim8.clear()
if __name__ == '__main__':
app = QtWidgets.QApplication(sys.argv)
w = Image_Processing()
w.show()
sys.exit(app.exec_())
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 无需6万激活码!GitHub神秘组织3小时极速复刻Manus,手把手教你使用OpenManus搭建本
· Manus爆火,是硬核还是营销?
· 终于写完轮子一部分:tcp代理 了,记录一下
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 单元测试从入门到精通