以图搜图
以图搜图的基本原理:
以图搜图是一种基于内容的图像检索 (CBIR) 技术²,它的特点是无需关键字就能理解图像的相关内容,主要依赖于 AI 算法,目前一些排名较好的图像分类算法可以到达 99% 准确率(TOP5)³。本文将利用 AI 模型提取图像特征向量,通过特征向量计算来完成以图搜图。
一 ,Towhee & Milvus
Towhee (http://github.com/towhee-io/towhee)提供开箱即用的 Embedding 流水线可以将任何非结构化数据(图像,视频,音频等)转为特征向量,通过 Towhee 我们运行一条流水线就能轻松得到特征向量。
Milvus(http://github.com/milvus-io/milvus) 是一个开源的向量数据库项目,它支持丰富的向量索引算法和向量计算方式,轻松实现对数百万、数十亿甚至数万亿向量的相似性搜索,具有高度灵活、稳定可靠以及高速查询等特点。
通过 Towhee + Milvus 就可以实现端到端的图像等非结构化数据分析。我们先使用 Towhee 完成非结构化数据的特征向量提取,然后 Milvus 负责存储并搜索向量,最终获取与查询数据最相似的结果并展示。
Towhee 和 Milvus 的安装:
注意:Milvus 支持单机安装和集群安装,本文使用docker-compose(http://milvus.io/docs/v2.0.x/install_standalone-docker.md)方式安装单机 Milvus,在此之前请先检查本机环境的软硬件条件(http://milvus.io/docs/v2.0.x/prerequisite-docker.md)。
#安装 Towhee
$ pip install towhee
#安装单机版 Milvus
$ wget http://github.com/milvus-io/milvus/releases/download/v2.0.2/milvus-standalone-docker-compose.yml -O docker-compose.yml
$ docker-compose up -d
Towhee 支持图像 Embedding,音频 Embedding,视频 Embedding 等非结构化数据特征提取的方法,这些都被称为 Towhee 的算子(Operator),算子是流水线(Pipeline)中的单个节点,一个图像特征提取流水线就可以通过连接 image_decode(http://towhee.io/image-decode/cv2) 算子和 image_embedding.timm(http://towhee.io/image-embedding/timm) 算子实现,其中 Embedding 算子可以通过指定model_name="resnet50"
利用 ResNet50 模型生成特征向量
代码:
import towhee
towhee.glob['path']('./test/lion/n02129165_13728.JPEG') \
.image_decode['path', 'img']() \
.image_embedding.timm['img', 'vec'](model_name='resnet50') \
.select['img', 'vec']() \
.show()
接下来在 Milvus 数据库中创建集合(Collection),集合中的 Fields 包含两列:id 和 embedding,其中 id 是集合的主键。另外我们可以为 embedding 创建 IVF_FLAT (http://milvus.io/docs/v2.0.x/index.md#IVF_FLAT) 基于量化的索引,其中索引的参数是 nlist=2048,计算方式是 "L2" 欧式距离:
代码:
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility
def create_milvus_collection(collection_name, dim):
connections.connect(host='127.0.0.1', port='19530')
if utility.has_collection(collection_name):
utility.drop_collection(collection_name)
fields = [
FieldSchema(name='id', dtype=DataType.INT64, descrition='ids', is_primary=True, auto_id=False),
FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, descrition='embedding vectors', dim=dim)
]
schema = CollectionSchema(fields=fields, description='reverse image search')
collection = Collection(name=collection_name, schema=schema)
# create IVF_FLAT index for collection.
index_params = {
'metric_type':'L2',
'index_type':"IVF_FLAT",
'params':{"nlist":2048}
}
collection.create_index(field_name="embedding", index_params=index_params)
return collection
collection = create_milvus_collection('reverse_image_search', 2048)
图像数据入库
Towhee 不光拥有丰富的算子来处理非结构化数据,还提供了简单好用的接口来处理各种数据,当然也集成了 Milvus 的一些基本用法,通过在“流水线”中连接这些算子或接口,图像入库操作将变得十分Milvus简单。
import towhee
dc = (
towhee.read_csv('reverse_image_search.csv') #读取 CSV 格式的表格,包含了 id,path 和 label 列
.runas_op['id', 'id'](func=lambda x: int(x)) #将每一行的 id 从 str 类型转为 int 类型
.image_decode['path', 'img']() #读取每一行 path 对应的图像,并将其解码为 Towhee 的图像格式
.image_embedding.timm['img', 'vec'](model_name='resnet50') #提取特征向量
.tensor_normalize['vec', 'vec']() #将向量进行归一化
.to_milvus['id', 'vec'](collection=collection, batch=100) #将 id 和 vec 批量 100 条插入到 Milvus 集合
)
查询图像并展示
查询图像时需要的图像处理算子与前面类似,包括image_decode
,image_embedding.timm
和tensor_normalize
,而在最后分析检索结果时,需用到数据准备部分定义好的read_images
函数,通过指定runas_op
中的func
将该函数加入到 Towhee 流水线中。
(towhee.glob['path']('./test/w*/*.JPEG') #读取满足指定模式下的所有图片数据为 path
.image_decode['path', 'img']() #读取每一行 path 对应的图像,并将其解码为 Towhee 的图像格式
.image_embedding.timm['img', 'vec'](model_name='resnet50') #提取特征向量
.tensor_normalize['vec', 'vec']() #将向量进行归一化
.milvus_search['vec', 'result'](collection=collection, limit=5) #在 Milvus 集合中搜索向量,并返回结果
.runas_op['result', 'result_img'](func=read_images) #处理 Milvus 的检索结果,最终返回图像用于展示
.select['img', 'result_img']() #选择指定列;
.show()
)
二,
1,选用resnet网络提取图像特征
2,milvus建表,用milvus存放图像特征,通过唯一ID(此处称:milvus_id)与图像一一对应,sql建表将milvus_id作为唯一索引,存放图像的其他信息
3,异步添加图像,同步搜索图像,添加图像的量通常会很大,因此采用异步批量的方式将图像特征加载到milvus,图像添加服务会将每次的请求信息存到sql,写个脚本专门用来定时批量加载图像特征到milvus,由于是异步操作,可能会出现重复加载的情况,此处使用redis进行去重。图像搜索的请求通常会比图像添加少很多,因此图像搜索使采用同步方式返回结果;
(总结:需建立三个表:milvus表1,存放图像特征;sql表2,存放图像信息,数据与milvus表1一一对应;sql表3,存放图像添加请求信息,用于图像特征异步批量加载到milvus)
图像向量化
"""
功能:图像向量化
"""
from
keras.applications.resnet50
import
ResNet50
from
keras.preprocessing
import
image
from
keras.applications.resnet50
import
preprocess_input, decode_predictions
import
numpy as np
from
numpy
import
linalg as LA
import
time
model
=
ResNet50(weights
=
'imagenet'
)
# model.summary()
def
img2feature(img_path, input_dim
=
224
):
# 图像路径???图像数据
img
=
image.load_img(img_path, target_size
=
(input_dim, input_dim))
x
=
image.img_to_array(img)
x
=
np.expand_dims(x, axis
=
0
)
x
=
preprocess_input(x)
x
=
model.predict(x)
x
=
x
/
LA.norm(x)
return
x
def
main():
img_path
=
'1.jpg'
t0
=
time.time()
res
=
img2feature(img_path)
print
(time.time()
-
t0, res.shape)
# print(res, type(res), res.shape)
if
__name__
=
=
"__main__"
:
main()
milvus表的操作
# coding:utf-8
from
functools
import
reduce
import
numpy as np
import
time
from
img2feature
import
img2feature
from
pymilvus
import
(
connections, list_collections,
FieldSchema, CollectionSchema, DataType,
Collection, utility
)
field_name
=
'image_feature'
host
=
'***.***.***.***'
port
=
'19530'
dim
=
1000
default_fields
=
[
FieldSchema(name
=
"milvus_id"
, dtype
=
DataType.INT64, is_primary
=
True
),
FieldSchema(name
=
"feature"
, dtype
=
DataType.FLOAT_VECTOR, dim
=
dim),
FieldSchema(name
=
"create_time"
, dtype
=
DataType.INT64)
]
# create_table
def
create_table():
connections.connect(host
=
host, port
=
port)
# create collection
default_schema
=
CollectionSchema(fields
=
default_fields, description
=
"test collection"
)
print
(f
"\nCreate collection..."
)
collection
=
Collection(name
=
field_name, schema
=
default_schema)
print
(f
"\nCreate index..."
)
default_index
=
{
"index_type"
:
"FLAT"
,
"params"
: {
"nlist"
:
128
},
"metric_type"
:
"L2"
}
collection.create_index(field_name
=
"feature"
, index_params
=
default_index)
print
(
print
(f
"\nCreate index...is OKOKOKOKOK"
))
collection.load()
# insert data
def
insert_data():
connections.connect(host
=
host, port
=
port)
default_schema
=
CollectionSchema(fields
=
default_fields, description
=
"test collection"
)
collection
=
Collection(name
=
field_name, schema
=
default_schema)
vectors
=
img2feature(
'1.jpg'
).tolist()[
0
]
print
(
type
(vectors),
len
(vectors))
data1
=
[
[
123
],
[vectors],
[
int
(time.time())]
]
collection.insert(data1)
print
(
'insert compete'
)
# search data
def
search_data():
print
(
'search'
)
connections.connect(host
=
host, port
=
port)
collection
=
Collection(name
=
field_name)
print
(
'连接成功'
)
# 首次查询建立索引和load()
# default_index = {"index_type": "FLAT", "params": {"nlist": 128}, "metric_type": "L2"}
# print(f"\nCreate index...")
# collection.create_index(field_name="feature", index_params=default_index)
# print(print(f"\nCreate index...is OKOKOKOKOK"))
# collection.load()
# exit()
vectors
=
img2feature(
'1.jpg'
).tolist()[
0
]
topK
=
10
search_params
=
{
"metric_type"
:
"L2"
,
"params"
: {
"nprobe"
:
10
}}
res
=
collection.search(
[vectors],
"feature"
,
search_params,
topK,
"create_time > {}"
.
format
(
0
),
output_fields
=
[
"milvus_id"
]
)
print
(
'>>>'
, res)
for
hits
in
res:
print
(
len
(hits))
for
hit
in
hits:
print
(hit)
print
(
'查询结束'
)
def
show_nums():
connections.connect(host
=
host, port
=
port)
collection
=
Collection(name
=
field_name)
print
(
'ok'
)
print
(collection.num_entities)
# delete data
def
delete_table():
connections.connect(host
=
host, port
=
port)
default_schema
=
CollectionSchema(fields
=
default_fields, description
=
"test collection"
)
collection
=
Collection(name
=
field_name, schema
=
default_schema)
print
(
'>>>'
, utility.has_collection(field_name))
collection.drop()
print
(
'>>>'
, utility.has_collection(field_name))
if
__name__
=
=
"__main__"
:
t1
=
time.time()
# create_table()
# insert_data()
# search_data()
show_nums()
# delete_table()
print
(
'time cost: {}'
.
format
(time.time()
-
t1))
图像添加、搜索服务
from
rest_framework.views
import
APIView as View
from
kpdjango.response
import
SucessAPIResponse, ErrorAPIResponse
from
kpmysql.base
import
Kpmysql
from
core
import
search_image
import
kplog
import
logging
log
=
logging.getLogger(
"console"
)
class
add_image(View):
def
post(
self
, requests):
try
:
db
=
Kpmysql.connect(
"db168"
)
cur
=
db.cursor()
image_info
=
requests.POST.get(
'image_info'
)
image_path
=
requests.POST.get(
'image_path'
)
sql
=
"INSERT INTO t_image_search_image_add_log(image_path, info) VALUES(%s, %s)"
cur.execute(sql, (image_path, image_info))
db.commit()
log.info(
'添加图像成功:{}-{}'
.
format
(image_path, image_info))
return
SucessAPIResponse(msg
=
"Success"
)
except
Exception as e:
log.info(
'添加图像失败:{}'
.
format
(e))
return
ErrorAPIResponse(msg
=
"Fail"
)
class
search_image(View):
def
post(
self
, requests):
try
:
image_path
=
requests.POST.get(
'image_path'
)
res
=
search_image(image_path)
log.info(
'查询图像成功:{}-{}'
.
format
(image_path, res))
return
SucessAPIResponse(msg
=
"Success"
, data
=
{
"data"
: res})
except
Exception as e:
log.info(
'查询图像成功:{}'
.
format
(e))
return
ErrorAPIResponse(msg
=
"Fail"
)
图像异步批量加载
import
time, datetime
from
kpmysql.base
import
Kpmysql
from
core
import
insert_data_many
from
concurrent.futures
import
ThreadPoolExecutor
import
redis
from
conf.setting
import
REDIS
from
core
import
str2time
import
kplog
import
logging
log
=
logging.getLogger(
"console"
)
log_addimgs
=
logging.getLogger(
"console_addimgs"
)
def
worker(datas):
try
:
redis_cli
=
redis.Redis(host
=
REDIS.get(
'host'
), port
=
REDIS.get(
'port'
), password
=
REDIS.get(
'password'
),
db
=
REDIS.get(
'db'
))
dics
=
[]
ids
=
[]
for
data
in
datas:
if
redis_cli.zscore(
'image_search'
,
str
(data[
0
])):
# 基于redis去重
continue
dics.append({
'image_path'
: data[
1
],
'create_time'
: data[
2
]})
ids.append((data[
0
]))
redis_cli.zadd(
'image_search'
, {
str
(data[
0
]): str2time(data[
2
])})
# 数据插入milvus
insert_data_many(dics)
# 更新 set t_image_search_image_add_log is_load=1
sql_update
=
"""UPDATE t_image_search_image_add_log SET is_load=1 WHERE id=%s"""
db168
=
Kpmysql.connect(
"db168"
)
cur168
=
db168.cursor()
cur168.executemany(sql_update, ids)
db168.commit()
except
Exception as e:
print
(e)
def
main():
max_workers
=
20
# 最大线程数
pool
=
ThreadPoolExecutor(max_workers
=
max_workers, thread_name_prefix
=
'Thread'
)
task_list
=
[]
init_time
=
datetime.datetime.now()
-
datetime.timedelta(hours
=
13
)
create_time_init
=
'2020-2-22 00:00:00'
while
True
:
now
=
datetime.datetime.now()
diff
=
now
-
init_time
if
diff.seconds >
3600
:
# 加载 t_image_search_image_add_log where is_load=0 数据
db168
=
Kpmysql.connect(
"db168"
)
cur168
=
db168.cursor()
sql
=
"""SELECT id, image_path, create_time FROM t_image_search_image_add_log WHERE is_load=0 and create_time >= %s ORDER BY create_time"""
cur168.execute(sql, create_time_init)
datas
=
cur168.fetchall()
create_time_init
=
datas[
-
1
][
2
]
while
True
:
for
_i, _n
in
enumerate
(task_list):
if
_n.done():
task_list.pop(_i)
if
len
(task_list) <
int
(max_workers
*
0.9
):
break
task_list.append(pool.submit(worker, datas))
init_time
=
now
time.sleep(
600
)
if
__name__
=
=
"__main__"
:
main()
优化
1. keras在调用GPU时并开启多线程时不如pytorch方便,pytorch占用显存更少;
2. 定时从数据库拿数据,改成kafka生产消费模型,代码更简洁,逻辑更简单;
三, 还有一些获取图片特征的VGG和Milvus组合使用:
参考:https://cloud.tencent.com/developer/article/1605032
参考:
1,https://maimai.cn/article/detail?fid=1743956531&efid=sTnHYzKAy8MK8AhgjSi7Bg
2,https://www.cnblogs.com/niulang/p/15921786.html
本文来自博客园,作者:zwbsoft,转载请注明原文链接:https://www.cnblogs.com/zwbsoft/p/16891539.html
电话微信:13514280351