图像搜索,1000张图片中搜索出最匹配的10张图,正确率85%以上

上代码.

拍照注意事项:

1. 照片不要拍糊了,要清晰

2. 背景颜色需要统一一点,我用的是A4纸上拍的

3. 照片除了要识别的物体不要出现其他东西

4. 同一个物品拍多张照片采样(如:前后左右中上-通过移动A4纸来拍)

5. 光线不能太暗了,太暗了细节特征提取不出来,导致识别率下降

  1 import tensorflow as tf
  2 import annoy
  3 import numpy as np
  4 import os
  5 from PIL import Image
  6 import matplotlib.pyplot as plt
  7 import matplotlib.image as mpimg
  8 
  9 class ImageSimilaritySearch:
 10     def __init__(self, image_dir='./uploads/6266/', model_name='ResNet50', feature_dim=2048, num_trees=10, index_file='annoy_index.ann'):
 11         self.image_dir = image_dir
 12         self.model_name = model_name
 13         self.feature_dim = feature_dim
 14         self.num_trees = num_trees
 15         self.base_model = None
 16         self.index = None
 17         self.index_file = index_file
 18         self.image_files = None
 19 
 20         # 加载预训练的CNN模型
 21         self.load_base_model()
 22 
 23         # 设置或加载Annoy索引
 24         if os.path.exists(self.index_file):
 25             self.load_annoy_index()
 26         else:
 27             self.setup_annoy_index()
 28 
 29     def load_base_model(self):
 30         # 加载预训练的CNN模型
 31         self.base_model = tf.keras.applications.__getattribute__(self.model_name)(
 32             weights='imagenet', include_top=False, pooling='avg')
 33     
 34     def setup_annoy_index(self):
 35         # 设置Annoy索引
 36         self.index = annoy.AnnoyIndex(self.feature_dim, 'euclidean')
 37         self.image_files = self.get_image_paths(self.image_dir)
 38 
 39         for i, image_file in enumerate(self.image_files):
 40             image_path = image_file
 41             features = self.extract_features(image_path)
 42             self.index.add_item(i, features)
 43         self.index.build(self.num_trees)
 44         # 保存Annoy索引到文件
 45         self.index.save(self.index_file)
 46 
 47     def load_annoy_index(self):
 48         # 从文件加载Annoy索引
 49         self.index = annoy.AnnoyIndex(self.feature_dim, 'euclidean')
 50         self.index.load(self.index_file)
 51         self.image_files = self.get_image_paths(self.image_dir)
 52 
 53     def update_annoy_index_with_images(self, new_image_paths):
 54         # 更新Annoy索引,添加新图片
 55         for i, image_path in enumerate(new_image_paths):
 56             features = self.extract_features(image_path)
 57             self.index.add_item(len(self.image_files) + i, features)
 58 
 59         # 重新构建Annoy索引
 60         self.index.build(self.num_trees)
 61 
 62     def get_image_paths(self, image_dir):
 63         # 获取文件夹下所有图片路径
 64         image_paths = []
 65         for root, dirs, files in os.walk(image_dir):
 66             for file in files:
 67                 if file.endswith('.jpg') or file.endswith('.jpeg') or file.endswith('.png'):
 68                     image_path = os.path.join(root, file)
 69                     image_paths.append(image_path)
 70         return image_paths
 71 
 72     def extract_features(self, image_path):
 73         # 加载图像,并使用模型提取特征向量
 74         img = tf.keras.preprocessing.image.load_img(image_path, target_size=(224, 224))
 75         img_array = tf.keras.preprocessing.image.img_to_array(img)
 76         img_array = tf.keras.applications.resnet50.preprocess_input(img_array)
 77         img_array = np.expand_dims(img_array, axis=0)
 78         features = self.base_model.predict(img_array)
 79         return features.flatten()
 80 
 81     def search_similar_images(self, query_image_path, top_k=5):
 82         # 搜索相似图像
 83         query_features = self.extract_features(query_image_path)
 84         similar_indices = self.index.get_nns_by_vector(query_features, top_k)
 85         similar_images = [self.image_files[i] for i in similar_indices]
 86         return similar_images
 87 
 88     def show_similar_images(self, similar_images):
 89         # 显示相似图像
 90         num_images = len(similar_images)
 91         fig, axes = plt.subplots(1, num_images, figsize=(15, 5))
 92 
 93         for i, image_path in enumerate(similar_images):
 94             ax = axes[i] if num_images > 1 else axes
 95             img = mpimg.imread(image_path)
 96             ax.imshow(img)
 97             ax.axis('off')
 98 
 99         plt.show()
100 
101 if __name__ == "__main__":
102     # 创建ImageSimilaritySearch对象
103     similarity_search = ImageSimilaritySearch()
104 
105     # 示例用法
106     query_image_path = './uploads/6266/Bracelet/SL101-HseK-101/1711003989672129.jpg'
107     similar_images = similarity_search.search_similar_images(query_image_path)
108     print("Similar images:", similar_images)
109 
110     # 显示相似图像
111     # similarity_search.show_similar_images(similar_images)

 

posted @ 2024-04-01 15:11  看一百次夜空里的深蓝  阅读(89)  评论(0编辑  收藏  举报