基于milvus搭建“以图搜图”服务(附代码)
“以图搜图”服务需要的关键功能和准备工作:
1 图像向量化功能,可选的模型有很多,本例选用resnet网络提取图像特征;
2 milvus建表,用milvus存放图像特征,通过唯一ID(此处称:milvus_id)与图像一一对应,sql建表将milvus_id作为唯一索引,存放图像的其他信息(如url,来源等);
3 异步添加图像,同步搜索图像,添加图像的量通常会很大,因此采用异步批量的方式将图像特征加载到milvus,图像添加服务会将每次的请求信息存到sql,写个脚本专门用来定时批量加载图像特征到milvus,由于是异步操作,可能会出现重复加载的情况,此处使用redis进行去重。图像搜索的请求通常会比图像添加少很多,因此图像搜索使采用同步方式返回结果;
(总结:需建立三个表:milvus表1,存放图像特征;sql表2,存放图像信息,数据与milvus表1一一对应;sql表3,存放图像添加请求信息,用于图像特征异步批量加载到milvus)
“以图搜图”服务关键功能及代码(代码仅做参考)
1 图像向量化
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 | """ 功能:图像向量化 """ from keras.applications.resnet50 import ResNet50 from keras.preprocessing import image from keras.applications.resnet50 import preprocess_input, decode_predictions import numpy as np from numpy import linalg as LA import time model = ResNet50(weights = 'imagenet' ) # model.summary() def img2feature(img_path, input_dim = 224 ): # 图像路径???图像数据 img = image.load_img(img_path, target_size = (input_dim, input_dim)) x = image.img_to_array(img) x = np.expand_dims(x, axis = 0 ) x = preprocess_input(x) x = model.predict(x) x = x / LA.norm(x) return x def main(): img_path = '1.jpg' t0 = time.time() res = img2feature(img_path) print (time.time() - t0, res.shape) # print(res, type(res), res.shape) if __name__ = = "__main__" : main() |
2 milvus表的操作
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 | # coding:utf-8 from functools import reduce import numpy as np import time from img2feature import img2feature from pymilvus import ( connections, list_collections, FieldSchema, CollectionSchema, DataType, Collection, utility ) field_name = 'image_feature' host = '***.***.***.***' port = '19530' dim = 1000 default_fields = [ FieldSchema(name = "milvus_id" , dtype = DataType.INT64, is_primary = True ), FieldSchema(name = "feature" , dtype = DataType.FLOAT_VECTOR, dim = dim), FieldSchema(name = "create_time" , dtype = DataType.INT64) ] # create_table def create_table(): connections.connect(host = host, port = port) # create collection default_schema = CollectionSchema(fields = default_fields, description = "test collection" ) print (f "\nCreate collection..." ) collection = Collection(name = field_name, schema = default_schema) print (f "\nCreate index..." ) default_index = { "index_type" : "FLAT" , "params" : { "nlist" : 128 }, "metric_type" : "L2" } collection.create_index(field_name = "feature" , index_params = default_index) print ( print (f "\nCreate index...is OKOKOKOKOK" )) collection.load() # insert data def insert_data(): connections.connect(host = host, port = port) default_schema = CollectionSchema(fields = default_fields, description = "test collection" ) collection = Collection(name = field_name, schema = default_schema) vectors = img2feature( '1.jpg' ).tolist()[ 0 ] print ( type (vectors), len (vectors)) data1 = [ [ 123 ], [vectors], [ int (time.time())] ] collection.insert(data1) print ( 'insert compete' ) # search data def search_data(): print ( 'search' ) connections.connect(host = host, port = port) collection = Collection(name = field_name) print ( '连接成功' ) # 首次查询建立索引和load() # default_index = {"index_type": "FLAT", "params": {"nlist": 128}, "metric_type": "L2"} # print(f"\nCreate index...") # collection.create_index(field_name="feature", index_params=default_index) # print(print(f"\nCreate index...is OKOKOKOKOK")) # collection.load() # exit() vectors = img2feature( '1.jpg' ).tolist()[ 0 ] topK = 10 search_params = { "metric_type" : "L2" , "params" : { "nprobe" : 10 }} res = collection.search( [vectors], "feature" , search_params, topK, "create_time > {}" . format ( 0 ), output_fields = [ "milvus_id" ] ) print ( '>>>' , res) for hits in res: print ( len (hits)) for hit in hits: print (hit) print ( '查询结束' ) def show_nums(): connections.connect(host = host, port = port) collection = Collection(name = field_name) print ( 'ok' ) print (collection.num_entities) # delete data def delete_table(): connections.connect(host = host, port = port) default_schema = CollectionSchema(fields = default_fields, description = "test collection" ) collection = Collection(name = field_name, schema = default_schema) print ( '>>>' , utility.has_collection(field_name)) collection.drop() print ( '>>>' , utility.has_collection(field_name)) if __name__ = = "__main__" : t1 = time.time() # create_table() # insert_data() # search_data() show_nums() # delete_table() print ( 'time cost: {}' . format (time.time() - t1)) |
3 创建sql表2、表3
1 | 略 |
4 图像添加、搜索服务
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 | from rest_framework.views import APIView as View from kpdjango.response import SucessAPIResponse, ErrorAPIResponse from kpmysql.base import Kpmysql from core import search_image import kplog import logging log = logging.getLogger( "console" ) class add_image(View): def post( self , requests): try : db = Kpmysql.connect( "db168" ) cur = db.cursor() image_info = requests.POST.get( 'image_info' ) image_path = requests.POST.get( 'image_path' ) sql = "INSERT INTO t_image_search_image_add_log(image_path, info) VALUES(%s, %s)" cur.execute(sql, (image_path, image_info)) db.commit() log.info( '添加图像成功:{}-{}' . format (image_path, image_info)) return SucessAPIResponse(msg = "Success" ) except Exception as e: log.info( '添加图像失败:{}' . format (e)) return ErrorAPIResponse(msg = "Fail" ) class search_image(View): def post( self , requests): try : image_path = requests.POST.get( 'image_path' ) res = search_image(image_path) log.info( '查询图像成功:{}-{}' . format (image_path, res)) return SucessAPIResponse(msg = "Success" , data = { "data" : res}) except Exception as e: log.info( '查询图像成功:{}' . format (e)) return ErrorAPIResponse(msg = "Fail" ) |
5 图像异步批量加载
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 | import time, datetime from kpmysql.base import Kpmysql from core import insert_data_many from concurrent.futures import ThreadPoolExecutor import redis from conf.setting import REDIS from core import str2time import kplog import logging log = logging.getLogger( "console" ) log_addimgs = logging.getLogger( "console_addimgs" ) def worker(datas): try : redis_cli = redis.Redis(host = REDIS.get( 'host' ), port = REDIS.get( 'port' ), password = REDIS.get( 'password' ), db = REDIS.get( 'db' )) dics = [] ids = [] for data in datas: if redis_cli.zscore( 'image_search' , str (data[ 0 ])): # 基于redis去重 continue dics.append({ 'image_path' : data[ 1 ], 'create_time' : data[ 2 ]}) ids.append((data[ 0 ])) redis_cli.zadd( 'image_search' , { str (data[ 0 ]): str2time(data[ 2 ])}) # 数据插入milvus insert_data_many(dics) # 更新 set t_image_search_image_add_log is_load=1 sql_update = """UPDATE t_image_search_image_add_log SET is_load=1 WHERE id=%s""" db168 = Kpmysql.connect( "db168" ) cur168 = db168.cursor() cur168.executemany(sql_update, ids) db168.commit() except Exception as e: print (e) def main(): max_workers = 20 # 最大线程数 pool = ThreadPoolExecutor(max_workers = max_workers, thread_name_prefix = 'Thread' ) task_list = [] init_time = datetime.datetime.now() - datetime.timedelta(hours = 13 ) create_time_init = '2020-2-22 00:00:00' while True : now = datetime.datetime.now() diff = now - init_time if diff.seconds > 3600 : # 加载 t_image_search_image_add_log where is_load=0 数据 db168 = Kpmysql.connect( "db168" ) cur168 = db168.cursor() sql = """SELECT id, image_path, create_time FROM t_image_search_image_add_log WHERE is_load=0 and create_time >= %s ORDER BY create_time""" cur168.execute(sql, create_time_init) datas = cur168.fetchall() create_time_init = datas[ - 1 ][ 2 ] while True : for _i, _n in enumerate (task_list): if _n.done(): task_list.pop(_i) if len (task_list) < int (max_workers * 0.9 ): break task_list.append(pool.submit(worker, datas)) init_time = now time.sleep( 600 ) if __name__ = = "__main__" : main() |
优化(重点)
经过实际测试和使用的建议:
1. keras在调用GPU时并开启多线程时不如pytorch方便,pytorch占用显存更少;
2. 定时从数据库拿数据,改成kafka生产消费模型,代码更简洁,逻辑更简单;
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· AI与.NET技术实操系列:基于图像分类模型对图像进行分类
· go语言实现终端里的倒计时
· 如何编写易于单元测试的代码
· 10年+ .NET Coder 心语,封装的思维:从隐藏、稳定开始理解其本质意义
· .NET Core 中如何实现缓存的预热?
· 分享一个免费、快速、无限量使用的满血 DeepSeek R1 模型,支持深度思考和联网搜索!
· 25岁的心里话
· 基于 Docker 搭建 FRP 内网穿透开源项目(很简单哒)
· ollama系列01:轻松3步本地部署deepseek,普通电脑可用
· 按钮权限的设计及实现