图像搜索,1000张图片中搜索出最匹配的10张图,正确率85%以上
上代码.
拍照注意事项:
1. 照片不要拍糊了,要清晰
2. 背景颜色需要统一一点,我用的是A4纸上拍的
3. 照片除了要识别的物体不要出现其他东西
4. 同一个物品拍多张照片采样(如:前后左右中上-通过移动A4纸来拍)
5. 光线不能太暗了,太暗了细节特征提取不出来,导致识别率下降
1 import tensorflow as tf 2 import annoy 3 import numpy as np 4 import os 5 from PIL import Image 6 import matplotlib.pyplot as plt 7 import matplotlib.image as mpimg 8 9 class ImageSimilaritySearch: 10 def __init__(self, image_dir='./uploads/6266/', model_name='ResNet50', feature_dim=2048, num_trees=10, index_file='annoy_index.ann'): 11 self.image_dir = image_dir 12 self.model_name = model_name 13 self.feature_dim = feature_dim 14 self.num_trees = num_trees 15 self.base_model = None 16 self.index = None 17 self.index_file = index_file 18 self.image_files = None 19 20 # 加载预训练的CNN模型 21 self.load_base_model() 22 23 # 设置或加载Annoy索引 24 if os.path.exists(self.index_file): 25 self.load_annoy_index() 26 else: 27 self.setup_annoy_index() 28 29 def load_base_model(self): 30 # 加载预训练的CNN模型 31 self.base_model = tf.keras.applications.__getattribute__(self.model_name)( 32 weights='imagenet', include_top=False, pooling='avg') 33 34 def setup_annoy_index(self): 35 # 设置Annoy索引 36 self.index = annoy.AnnoyIndex(self.feature_dim, 'euclidean') 37 self.image_files = self.get_image_paths(self.image_dir) 38 39 for i, image_file in enumerate(self.image_files): 40 image_path = image_file 41 features = self.extract_features(image_path) 42 self.index.add_item(i, features) 43 self.index.build(self.num_trees) 44 # 保存Annoy索引到文件 45 self.index.save(self.index_file) 46 47 def load_annoy_index(self): 48 # 从文件加载Annoy索引 49 self.index = annoy.AnnoyIndex(self.feature_dim, 'euclidean') 50 self.index.load(self.index_file) 51 self.image_files = self.get_image_paths(self.image_dir) 52 53 def update_annoy_index_with_images(self, new_image_paths): 54 # 更新Annoy索引,添加新图片 55 for i, image_path in enumerate(new_image_paths): 56 features = self.extract_features(image_path) 57 self.index.add_item(len(self.image_files) + i, features) 58 59 # 重新构建Annoy索引 60 self.index.build(self.num_trees) 61 62 def get_image_paths(self, image_dir): 63 # 获取文件夹下所有图片路径 64 image_paths = [] 65 for root, dirs, files in os.walk(image_dir): 66 for file in files: 67 if file.endswith('.jpg') or file.endswith('.jpeg') or file.endswith('.png'): 68 image_path = os.path.join(root, file) 69 image_paths.append(image_path) 70 return image_paths 71 72 def extract_features(self, image_path): 73 # 加载图像,并使用模型提取特征向量 74 img = tf.keras.preprocessing.image.load_img(image_path, target_size=(224, 224)) 75 img_array = tf.keras.preprocessing.image.img_to_array(img) 76 img_array = tf.keras.applications.resnet50.preprocess_input(img_array) 77 img_array = np.expand_dims(img_array, axis=0) 78 features = self.base_model.predict(img_array) 79 return features.flatten() 80 81 def search_similar_images(self, query_image_path, top_k=5): 82 # 搜索相似图像 83 query_features = self.extract_features(query_image_path) 84 similar_indices = self.index.get_nns_by_vector(query_features, top_k) 85 similar_images = [self.image_files[i] for i in similar_indices] 86 return similar_images 87 88 def show_similar_images(self, similar_images): 89 # 显示相似图像 90 num_images = len(similar_images) 91 fig, axes = plt.subplots(1, num_images, figsize=(15, 5)) 92 93 for i, image_path in enumerate(similar_images): 94 ax = axes[i] if num_images > 1 else axes 95 img = mpimg.imread(image_path) 96 ax.imshow(img) 97 ax.axis('off') 98 99 plt.show() 100 101 if __name__ == "__main__": 102 # 创建ImageSimilaritySearch对象 103 similarity_search = ImageSimilaritySearch() 104 105 # 示例用法 106 query_image_path = './uploads/6266/Bracelet/SL101-HseK-101/1711003989672129.jpg' 107 similar_images = similarity_search.search_similar_images(query_image_path) 108 print("Similar images:", similar_images) 109 110 # 显示相似图像 111 # similarity_search.show_similar_images(similar_images)