milvus操作

import json
import sys
import time
from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility, Index


# 连接到 Milvus
def connect_milvus(host='xxxxxx', port='31800'):
print("Connecting to Milvus...")
connections.connect(host=host, port=port)


# 创建或获取集合
def get_or_create_collection(collection_name, dim=256):
if utility.has_collection(collection_name):
print(f"Collection '{collection_name}' already exists.")
return Collection(name=collection_name)
else:
print(f"Creating collection '{collection_name}'.")
fields = [
FieldSchema(name="item_code", dtype=DataType.VARCHAR, is_primary=True, auto_id=False, max_length=100),
FieldSchema(name="blip_embedding", dtype=DataType.FLOAT_VECTOR, dim=dim)
]
schema = CollectionSchema(fields, "item_blip向量")
return Collection(name=collection_name, schema=schema)


# 创建索引
def create_index_if_not_exists(collection, field_name="blip_embedding"):
# 检查是否已经存在索引
if not collection.has_index():
index_params = {
"index_type": "IVF_FLAT", # 选择合适的索引类型
"metric_type": "IP", # 选择合适的距离度量方式
"params": {"nlist": 1414} # nlist 是一个影响索引性能的参数,需根据数据量调整
}
print(f"Creating index on '{field_name}'...")
Index(collection, field_name, index_params)
print("Index has created.")
else:
print("Index already exists. No need to create a new one.")


# 重建索引
def recreate_index(collection, field_name="blip_embedding"):
# 尝试释放集合,如果集合未加载则会捕获异常
try:
print("Releasing the collection before dropping the index...")
collection.release()
except Exception as e:
print("Collection is not loaded, proceeding to drop index.")

# 删除现有索引
if collection.has_index():
print("Dropping existing index...")
collection.drop_index()
print("Index dropped.")

# 创建新的索引
index_params = {
"index_type": "IVF_FLAT", # 选择合适的索引类型
"metric_type": "IP", # 选择合适的距离度量方式
"params": {"nlist": 1414} # nlist 是一个影响索引性能的参数,需根据数据量调整
}
print(f"Creating new index on '{field_name}'...")
Index(collection, field_name, index_params)
print("New index created.")


# 检查索引
def check_index_status(collection):
index_info = collection.index()
print("Index info:", index_info)


# 批量插入数据
def batch_insert(collection, item_codes, embeddings):
entities = [item_codes, embeddings]
try:
insert_result = collection.insert(entities)
print('Insert result: ', insert_result)
except Exception as e:
print('Error during insert:', e)


# 检查商品是否存在
def item_code_exists(collection, item_code):
expr = f'item_code == "{item_code}"'
try:
results = collection.query(expr=expr, output_fields=["item_code"])
return len(results) > 0
except Exception as e:
print(f"Error checking item code existence: {e}")
return False


# 删除商品
def delete_item(collection, item_code):
expr = f'item_code == "{item_code}"'
try:
collection.delete(expr)
print(f"Deleted item code: {item_code}")
except Exception as e:
print(f"Error deleting item code: {e}")


# 主函数示例
def write2milvus(collection_blip):
# 从文件中读取数据并批量插入
with open('youshi_ic_embedding_all.txt', 'r', encoding='utf-8') as r_file:
item_codes = []
embeddings = []
batch_size = 1024
for index, line in enumerate(r_file):
if index % 1000 == 0:
print('商品写入的个数:', index+1)
item_code, emb = line.strip().split('\t')
emb = json.loads('[' + emb + ']')
item_codes.append(item_code)
embeddings.append(emb)
if len(embeddings) == batch_size:
batch_insert(collection_blip, item_codes, embeddings)
item_codes = []
embeddings = []
if item_codes:
batch_insert(collection_blip, item_codes, embeddings)


# 删除集合
def drop_collection(collection_name):
try:
collection = Collection(name=collection_name)
collection.drop()
print(f"Collection '{collection_name}' has been dropped.")
except Exception as e:
print(f"Error dropping collection '{collection_name}': {e}")


# 进行相似性搜索
def search_similar(collection, item_vec, limit=50):
search_params = {
"metric_type": "IP",
"params": {"nprobe": 128},
}
collection.load()
try:
result = collection.search([item_vec], "blip_embedding", search_params, limit=limit, output_fields=["item_code"])
item_codes = [hit.entity.get('item_code') for hits in result for hit in hits]
return item_codes
except Exception as e:
print(f"Error during search: {e}")
return []


def get_search_similar_all(collection, fin):
with open(fin, 'r', encoding='utf-8') as r_file, open('youshi_ic_similar_rs.txt', 'w', encoding='utf-8') as out:
for index, line in enumerate(r_file):
if index % 1000 == 0:
print('完成商品相似品计算的个数:', index+1)
item_code, emb = line.strip().split('\t')
emb = json.loads('[' + emb + ']')
result = search_similar(collection, emb, limit=10)
if result:
out.write("{}\t{}\n".format(item_code, '#'.join(result)))


if __name__ == "__main__":
# 连接到 Milvus
connect_milvus()
# 创建或获取集合
collection_name = 'youshi_item_blip_vec'
collection = get_or_create_collection(collection_name)
# write2milvus(collection)
recreate_index(collection)
create_index_if_not_exists(collection)
fin = './test.txt'
get_search_similar_all(collection, fin)
posted @ 2024-09-18 16:16  15375357604  阅读(2)  评论(0编辑  收藏  举报