以下代码将演示如何提取图片的特征向量,可配合向量数据库实现图片的搜索功能。
下载测试资源: github.com/towhee-io/e…
以下是我的代码:
import csv
from glob import glob
from pathlib import Path
from statistics import mean
from towhee import pipe, ops, DataCollection
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility
# 安装依赖
import subprocess
# Towhee参数,指定使用的模型为resnet50
MODEL = 'resnet50'
# 如果为None,则使用默认设备(如果可用则启用CUDA)
DEVICE = None
# Milvus参数
# Milvus服务的主机地址
HOST = '127.0.0.1'
# Milvus服务的端口号
PORT = '19530'
# 搜索时返回的结果数量
TOPK = 10
# 由MODEL提取的嵌入维度
DIM = 2048
# Milvus集合的名称
COLLECTION_NAME = 'reverse_image_search'
# Milvus索引类型
INDEX_TYPE = 'IVF_FLAT'
# 距离度量类型
METRIC_TYPE = 'L2'
# 插入数据的源,可为CSV文件或图像路径模式
INSERT_SRC = 'reverse_image_search.csv'
# 查询数据的源,图像路径模式
QUERY_SRC = './test/*/*.JPEG'
# 加载图像路径
def load_image(x):
"""
该函数用于加载图像路径。
如果输入是CSV文件,则读取CSV文件中的图像路径;
如果输入是路径模式,则使用glob函数匹配路径。
:param x: 输入的CSV文件路径或图像路径模式
:yield: 图像路径
"""
# 判断输入是否为CSV文件
if x.endswith('csv'):
# 打开CSV文件
with open(x) as f:
# 创建CSV读取器
reader = csv.reader(f)
# 跳过标题行
next(reader)
# 遍历每一行
for item in reader:
# 假设第二列是图像路径
yield item[1]
else:
# 遍历匹配到的所有路径
for item in glob(x):
# 生成图像路径
yield item
# 嵌入管道
p_embed = (
# 定义输入源
pipe.input('src')
# 对输入源进行扁平化处理,将其转换为图像路径
.flat_map('src', 'img_path', load_image)
# 对图像路径进行解码,得到图像数据
.map('img_path', 'img', ops.image_decode())
# 使用指定模型提取图像的嵌入向量
.map('img', 'vec', ops.image_embedding.timm(model_name=MODEL, device=DEVICE))
)
# 连接到 Milvus
def connect_to_milvus():
# 连接到Milvus服务
connections.connect(host=HOST, port=PORT)
# 创建 Milvus 集合
def create_milvus_collection():
# 检查指定名称的集合是否已存在
if utility.has_collection(COLLECTION_NAME):
# 若存在,获取该集合
collection = Collection(COLLECTION_NAME)
return collection
# 定义集合的字段
fields = [
# 定义主键字段,类型为64位整数,自动生成ID
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
# 定义图像路径字段,类型为可变长度字符串,最大长度为512
FieldSchema(name="img_path", dtype=DataType.VARCHAR, max_length=512),
# 定义向量字段,类型为浮点向量,维度为DIM
FieldSchema(name="vec", dtype=DataType.FLOAT_VECTOR, dim=DIM)
]
# 创建集合模式
schema = CollectionSchema(fields=fields)
# 创建集合
collection = Collection(name=COLLECTION_NAME, schema=schema)
# 定义索引参数
index_params = {
"metric_type": METRIC_TYPE,
"index_type": INDEX_TYPE,
"params": {"nlist": 128}
}
# 为向量字段创建索引
collection.create_index(field_name="vec", index_params=index_params)
return collection
# 插入数据到 Milvus
def insert_data_to_milvus(collection):
# 执行嵌入管道,输出图像路径和向量
results = p_embed.output('img_path', 'vec')(INSERT_SRC).to_list()
# 存储图像路径的列表
img_paths = []
# 存储向量的列表
vectors = []
# 遍历结果
for result in results:
# 将图像路径添加到列表中
img_paths.append(result[0])
# 将向量转换为列表并添加到列表中
vector = result[1].tolist()
vectors.append(vector)
# 逐行显示向量
print(f"Image path: {result[0]}, Vector: {vector}")
# 组合数据
data = [img_paths, vectors]
# 插入数据到集合中
collection.insert(data)
# 刷新集合,使插入的数据生效
collection.flush()
# 图片搜索功能
def search_images(collection):
# 执行嵌入管道,输出查询向量
query_results = p_embed.output('vec')(QUERY_SRC).to_list()
# 遍历查询结果
for result in query_results:
# 将查询向量转换为列表
query_vector = result[0].tolist()
# 定义搜索参数
search_params = {
"metric_type": METRIC_TYPE,
"params": {"nprobe": 10}
}
# 执行搜索操作
results = collection.search(
data=[query_vector],
anns_field="vec",
param=search_params,
limit=TOPK,
output_fields=["img_path"]
)
# 遍历搜索结果
for hit in results[0]:
# 打印图像路径和距离
print(f"Image path: {hit.entity.get('img_path')}, Distance: {hit.distance}")
if __name__ == "__main__":
# 连接到Milvus服务
connect_to_milvus()
# 创建或获取Milvus集合
collection = create_milvus_collection()
# 插入数据到Milvus集合
insert_data_to_milvus(collection)
# 执行图片搜索
search_images(collection)