3、向量数据库Milvus实现图片查询

39 阅读3分钟

以下代码将演示如何提取图片的特征向量,可配合向量数据库实现图片的搜索功能。

下载测试资源: github.com/towhee-io/e…

以下是我的代码:

import csv
from glob import glob
from pathlib import Path
from statistics import mean

from towhee import pipe, ops, DataCollection
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility

# 安装依赖
import subprocess
# Towhee参数,指定使用的模型为resnet50
MODEL = 'resnet50'
# 如果为None,则使用默认设备(如果可用则启用CUDA)
DEVICE = None

# Milvus参数
# Milvus服务的主机地址
HOST = '127.0.0.1'
# Milvus服务的端口号
PORT = '19530'
# 搜索时返回的结果数量
TOPK = 10
# 由MODEL提取的嵌入维度
DIM = 2048
# Milvus集合的名称
COLLECTION_NAME = 'reverse_image_search'
# Milvus索引类型
INDEX_TYPE = 'IVF_FLAT'
# 距离度量类型
METRIC_TYPE = 'L2'

# 插入数据的源,可为CSV文件或图像路径模式
INSERT_SRC = 'reverse_image_search.csv'
# 查询数据的源,图像路径模式
QUERY_SRC = './test/*/*.JPEG'

# 加载图像路径
def load_image(x):
    """
    该函数用于加载图像路径。
    如果输入是CSV文件,则读取CSV文件中的图像路径;
    如果输入是路径模式,则使用glob函数匹配路径。

    :param x: 输入的CSV文件路径或图像路径模式
    :yield: 图像路径
    """
    # 判断输入是否为CSV文件
    if x.endswith('csv'):
        # 打开CSV文件
        with open(x) as f:
            # 创建CSV读取器
            reader = csv.reader(f)
            # 跳过标题行
            next(reader)
            # 遍历每一行
            for item in reader:
                # 假设第二列是图像路径
                yield item[1]
    else:
        # 遍历匹配到的所有路径
        for item in glob(x):
            # 生成图像路径
            yield item

# 嵌入管道
p_embed = (
    # 定义输入源
    pipe.input('src')
    # 对输入源进行扁平化处理,将其转换为图像路径
    .flat_map('src', 'img_path', load_image)
    # 对图像路径进行解码,得到图像数据
    .map('img_path', 'img', ops.image_decode())
    # 使用指定模型提取图像的嵌入向量
    .map('img', 'vec', ops.image_embedding.timm(model_name=MODEL, device=DEVICE))
)

# 连接到 Milvus
def connect_to_milvus():
    # 连接到Milvus服务
    connections.connect(host=HOST, port=PORT)

# 创建 Milvus 集合
def create_milvus_collection():
    # 检查指定名称的集合是否已存在
    if utility.has_collection(COLLECTION_NAME):
        # 若存在,获取该集合
        collection = Collection(COLLECTION_NAME)
        return collection
    # 定义集合的字段
    fields = [
        # 定义主键字段,类型为64位整数,自动生成ID
        FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
        # 定义图像路径字段,类型为可变长度字符串,最大长度为512
        FieldSchema(name="img_path", dtype=DataType.VARCHAR, max_length=512),
        # 定义向量字段,类型为浮点向量,维度为DIM
        FieldSchema(name="vec", dtype=DataType.FLOAT_VECTOR, dim=DIM)
    ]
    # 创建集合模式
    schema = CollectionSchema(fields=fields)
    # 创建集合
    collection = Collection(name=COLLECTION_NAME, schema=schema)
    # 定义索引参数
    index_params = {
        "metric_type": METRIC_TYPE,
        "index_type": INDEX_TYPE,
        "params": {"nlist": 128}
    }
    # 为向量字段创建索引
    collection.create_index(field_name="vec", index_params=index_params)
    return collection

# 插入数据到 Milvus
def insert_data_to_milvus(collection):
    # 执行嵌入管道,输出图像路径和向量
    results = p_embed.output('img_path', 'vec')(INSERT_SRC).to_list()
    # 存储图像路径的列表
    img_paths = []
    # 存储向量的列表
    vectors = []
    # 遍历结果
    for result in results:
        # 将图像路径添加到列表中
        img_paths.append(result[0])
        # 将向量转换为列表并添加到列表中
        vector = result[1].tolist()
        vectors.append(vector)
        # 逐行显示向量
        print(f"Image path: {result[0]}, Vector: {vector}")
    # 组合数据
    data = [img_paths, vectors]
    # 插入数据到集合中
    collection.insert(data)
    # 刷新集合,使插入的数据生效
    collection.flush()

# 图片搜索功能
def search_images(collection):
    # 执行嵌入管道,输出查询向量
    query_results = p_embed.output('vec')(QUERY_SRC).to_list()
    # 遍历查询结果
    for result in query_results:
        # 将查询向量转换为列表
        query_vector = result[0].tolist()
        # 定义搜索参数
        search_params = {
            "metric_type": METRIC_TYPE,
            "params": {"nprobe": 10}
        }
        # 执行搜索操作
        results = collection.search(
            data=[query_vector],
            anns_field="vec",
            param=search_params,
            limit=TOPK,
            output_fields=["img_path"]
        )
        # 遍历搜索结果
        for hit in results[0]:
            # 打印图像路径和距离
            print(f"Image path: {hit.entity.get('img_path')}, Distance: {hit.distance}")

if __name__ == "__main__":
    # 连接到Milvus服务
    connect_to_milvus()
    # 创建或获取Milvus集合
    collection = create_milvus_collection()
    # 插入数据到Milvus集合
    insert_data_to_milvus(collection)
    # 执行图片搜索
    search_images(collection)