相似图像检索
VGGNet特征提取
利用VGGnet的预训练模型来实现图像的检索,先用预训练模型来抽取图片的特征,然后把待检索的图像和数据库中的所有图像进行匹配,找出相似度最高的
在jupyter notebook上实现
文件路径设置:
root|____ code
|____ images|____ img_class_1
|____ img_class_2
|____ img_class_3
|.... .....
|____ img_class_n
|____models
|____queryimg
- root: 根目录
- images: 存放各类别的图片文件夹
- img_class_i: 存放相应类别的图片
- database: 用于存放数据
- queryimg: 存放待检索图片
Step 1. 构造特征提取器
这里用了Keras的应用模块(Keras.applications)提供的带有预训练权值的模型
初始化一个模型的时候,会自动下载权重到~/.keras/models/目录下
详细参考🔗
这里用VGG16预训练模型构造一个特征提取器
import numpy as np
from keras.applications.vgg16 import VGG16
from keras.applications.vgg16 import preprocess_input as preprocess_input_vgg
from keras.preprocessing import image
from numpy import linalg as LA
class VGGNet:
def __init__(self):
self.input_shape = (224, 224, 3)
self.weight = 'imagenet'
self.pooling = 'max'
self.model_vgg = VGG16(weights=self.weight,
input_shape=(self.input_shape[0], self.input_shape[1], self.input_shape[2]),
pooling=self.pooling, include_top=False)
# 提取vgg16最后一层卷积特征
def vgg_extract_feat(self, img_path):
img = image.load_img(img_path, target_size=(self.input_shape[0], self.input_shape[1]))
img = image.img_to_array(img)
img = np.expand_dims(img, axis=0)
img = preprocess_input_vgg(img)
feat = self.model_vgg.predict(img)
norm_feat = feat[0] / LA.norm(feat[0])
return norm_feat
keras.applications.vgg16.VGG16()
参数设置:
include_top: 是否包括顶层的全连接层。
weights: None 代表随机初始化, 'imagenet' 代表加载在 ImageNet 上预训练的权值。
input_tensor: 可选,Keras tensor 作为模型的输入(即 layers.Input() 输出的 tensor)。
input_shape: 可选,输入尺寸元组,仅当 include_top=False 时有效,否则输入形状必须是 (244, 244, 3)(对于 channels_last 数据格式),或者 (3, 244, 244)(对于 channels_first 数据格式)。它必须拥有 3 个输入通道,且宽高必须不小于 32。例如 (200, 200, 3) 是一个合法的输入尺寸。
pooling: 可选,当 include_top 为 False 时,该参数指定了特征提取时的池化方式。
- None 代表不池化,直接输出最后一层卷积层的输出,该输出是一个四维张量。
- 'avg' 代表全局平均池化(GlobalAveragePooling2D),相当于在最后一层卷积层后面再加一层全局平均池化层,输出是一个二维张量。
- 'max' 代表全局最大池化
classes: 可选,图片分类的类别数,仅当 include_top 为 True 并且不加载预训练权值时可用。
Step 2. 保存图片数据特征
用VGGnet提取图片特征
把图片的特征向量和文件路径存到文件中
import os
import h5py
import numpy as np
root = os.path.abspath('..')
save_path = os.path.join(root,'database','vgg_featureCNN.h5')
print("--------------------------------------------------")
print(" feature extraction starts")
print("--------------------------------------------------")
imgdir = os.path.join(root,'images')
imgpaths = []
for subdir in os.listdir(imgdir)[:]:
curpath = os.path.join(imgdir,subdir)
for imgname in os.listdir(curpath):
imgpaths += [os.path.join(curpath,imgname)] # 添加图片路径
feats = [] # 保存图片特征向量
model = VGGNet()
for i, img_path in enumerate(imgpaths):
norm_feat = model.vgg_extract_feat(img_path)
feats.append(norm_feat)
print("extracting feature from image No. %d , %d images in total" % ((i + 1), len(imgpaths)))
feats = np.array(feats)
print("--------------------------------------------------")
print(" writing feature extraction results ...")
print("--------------------------------------------------")
h5f = h5py.File(save_path, 'w')
h5f.create_dataset('dataset_1', data = feats)
h5f.create_dataset('dataset_2', data = np.string_(imgpaths))
h5f.close()
print(" writing has ended. ")
Step 3. 图片检索
把待检索图片存到queryimg中, 进行检索,输出前maxres张匹配度最高的图片
import h5py
from cv2 import cv2
import matplotlib.pyplot as plt
import numpy as np
import os
from extract_cnn_vgg16_keras import VGGNet
root = os.path.abspath('..')
save_path = os.path.join(root,'database','vgg_featureCNN.h5')
h5f = h5py.File(save_path, 'r')
feats = h5f['dataset_1'][:]
imgpaths = h5f['dataset_2'][:]
h5f.close()
querydir = os.path.join(root,'queryimg')
# init VGGNet16 model
model = VGGNet()
# 待检索图片名
imgname = 'xxx.jpg'
print("--------------------------------------------------")
print(" searching starts")
print("--------------------------------------------------")
# 待检索图片地址
querypath = os.path.join(querydir,imgname)
queryImg = cv2.imread(querypath)
queryImg = cv2.cvtColor(queryImg, cv2.COLOR_BGR2RGB)
plt.title("Query Image")
plt.imshow(queryImg)
plt.show()
# 提取待检索图片的特征
queryVec = model.vgg_extract_feat(querypath)
# 和数据库中的每张图片的特征匹配,计算匹配分数
scores = np.dot(queryVec, feats.T)
# 按匹配分数从大到小排序
rank_ID = np.argsort(scores)[::-1]
rank_score = scores[rank_ID]
maxres = 3 # 检索出三张相似度最高的图片
imlist = []
for i, index in enumerate(rank_ID[0:maxres]):
imlist.append(imgpaths[index])
print("image names: " + str(imgpaths[index]) + " scores: %f" % rank_score[i])
print("top %d images in order are: " % maxres, imlist)
# 输出检索到的图片
for i, im in enumerate(imlist):
impath = str(im)[2:-1] # 得到的im是一个byte型的数据格式,需要转换成字符串
print(impath)
image = cv2.imread(impath)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
plt.title("search output %d" % (i + 1))
plt.imshow(image)
plt.show()
RESNet50进行特征提取
RESNet50的计算量比VGG16低一点,跑得更快,同时内存使用量也更小
import numpy as np
from keras.applications.resnet50 import ResNet50
from keras.preprocessing import image
from numpy import linalg as LA
from keras.applications.resnet50 import preprocess_input as preprocess_input_resnet
class RESNet:
def __init__(self):
self.input_shape = (224, 224, 3)
self.weight = 'imagenet'
self.pooling = 'max'
self.model_resnet = ResNet50(weights=self.weight,
input_shape=self.input_shape,
pooling=self.pooling, include_top=False)
# 提取resnet50最后一层卷积特征
def resnet_extract_feat(self, img_path):
img = image.load_img(img_path, target_size=(self.input_shape[0], self.input_shape[1]))
img = image.img_to_array(img)
img = np.expand_dims(img, axis=0)
img = preprocess_input_resnet(img)
feat = self.model_resnet.predict(img)
norm_feat = feat[0]/LA.norm(feat[0])
return norm_feat
import os
import h5py
import numpy as np
root = os.path.abspath('..')
model = RESNet()
save_path = os.path.join(root,'models','resnet_featureCNN.h5')
print("--------------------------------------------------")
print(" feature extraction starts")
print("--------------------------------------------------")
imgdir = os.path.join(root,'images')
imgpaths = []
for subdir in os.listdir(imgdir)[:3]:
curpath = os.path.join(imgdir,subdir)
for imgname in os.listdir(curpath):
imgpaths += [os.path.join(curpath,imgname)]
feats = []
model = RESNet()
for i, img_path in enumerate(imgpaths):
norm_feat = model.resnet_extract_feat(img_path)
feats.append(norm_feat)
print("extracting feature from image No. %d , %d images in total" % ((i + 1), len(imgpaths)))
feats = np.array(feats)
print("--------------------------------------------------")
print(" writing feature extraction results ...")
print("--------------------------------------------------")
h5f = h5py.File(save_path, 'w')
h5f.create_dataset('dataset_1', data=feats)
h5f.create_dataset('dataset_2', data=np.string_(imgpaths))
h5f.close()
print(" writing has done. ")
import h5py
from cv2 import cv2
import matplotlib.pyplot as plt
import numpy as np
import os
root = os.path.abspath('..')
save_path = os.path.join(root,'models','resnet_featureCNN.h5')
h5f = h5py.File(save_path, 'r')
feats = h5f['dataset_1'][:]
imgpaths = h5f['dataset_2'][:]
h5f.close()
querydir = os.path.join(root,'queryimg')
model = RESNet()
queryList = ['AK47', "american-flag", 'backpack', 'baseball-bat', 'baseball-glove', 'basketball-hoop', 'bat'
, 'bathtub', 'bear', 'beer-mug']
imgname = queryList[0] + '.jpg'
print("--------------------------------------------------")
print(" searching starts")
print("--------------------------------------------------")
# read and show query image
querypath = os.path.join(querydir,imgname)
# queryImg = mpimg.imread(querypath)
queryImg = cv2.imread(querypath)
queryImg = cv2.cvtColor(queryImg, cv2.COLOR_BGR2RGB)
plt.title("Query Image")
plt.imshow(queryImg)
plt.show()
# extract query image's feature, compute simlarity score and sort
queryVec = model.resnet_extract_feat(querypath) # 修改此处改变提取特征的网络
scores = np.dot(queryVec, feats.T)
rank_ID = np.argsort(scores)[::-1]
rank_score = scores[rank_ID]
# number of top retrieved images to show
maxres = 3 # 检索出三张相似度最高的图片
imlist = []
for i, index in enumerate(rank_ID[0:maxres]):
imlist.append(imgpaths[index])
print("image names: " + str(imgpaths[index]) + " scores: %f" % rank_score[i])
print("top %d images in order are: " % maxres, imlist)
# show top #maxres retrieved result one by one
for i, im in enumerate(imlist):
impath = str(im)[2:-1]
print(impath)
image = cv2.imread(impath)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
plt.title("search output %d" % (i + 1))
plt.imshow(image)
plt.show()
参考:🔗