【技术专题】嵌入模型与Chroma向量数据库 - 自定义Embedding Functions

0 阅读1分钟

大家好,我是锋哥。最近连载更新《嵌入模型与Chroma向量数据库 AI大模型应用开发必备知识》技术专题。

QQ截图20260226134650.jpg 本课程主要介绍和讲解嵌入模型与向量数据库简介,Qwen3嵌入模型使用,Chroma向量数据库使用,Chroma安装,Client-Server模式,集合添加,修改,删除,查询操作以及自定义Embedding Functions。。。 同时也配套视频教程 《1天学会 嵌入模型与Chroma向量数据库 AI大模型应用开发必备知识 视频教程》

我们前面的所有实例,默认使用的DefaultEmbeddingFunction,基于Sentence Transformers的all-MiniLM-L6-v2模型,把传入的文本转成向量。

from chromadb.utils import embedding_functions
​
# 默认使用的DefaultEmbeddingFunction,基于Sentence Transformers的all-MiniLM-L6-v2模型
embedding_functions.DefaultEmbeddingFunction()

我们如果需要使用自己定义的嵌入模型,可以使用自定义Embedding Functions。定义一个MyEmbeddingFunction,继承EmbeddingFunction即可。

# 自定义Embedding Function
class MyEmbeddingFunction(EmbeddingFunction):
​
    def __init__(self):
        model_path="Qwen3-VL-Embedding-2B"
        self.model = AutoModel.from_pretrained(model_path)
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
​
    def __call__(self, text):
        # 设置模型到评估模式
        self.model.eval()
        inputs = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True)
        with torch.no_grad():
            outputs = self.model(**inputs)
            embeddings = outputs.last_hidden_state.mean(dim=1)  # 平均所有token的embedding
        return embeddings.squeeze().float().numpy()

完整实例代码:

import torch
from chromadb import EmbeddingFunction
from transformers import AutoModel, AutoTokenizer
​
# 自定义Embedding Function
class MyEmbeddingFunction(EmbeddingFunction):
​
    def __init__(self):
        model_path="Qwen3-VL-Embedding-2B"
        self.model = AutoModel.from_pretrained(model_path)
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
​
    def __call__(self, text):
        # 设置模型到评估模式
        self.model.eval()
        inputs = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True)
        with torch.no_grad():
            outputs = self.model(**inputs)
            embeddings = outputs.last_hidden_state.mean(dim=1)  # 平均所有token的embedding
        return embeddings.squeeze().float().numpy()
​
import chromadb
chroma_client = chromadb.Client() # 创建Chroma客户端
​
​
my_embedding_function = MyEmbeddingFunction()
​
collection = chroma_client.create_collection(
    name="my_collection",
    embedding_function=my_embedding_function
)
​
collection.add(
    ids=["id1", "id2", "id3"],
    documents=["lorem ipsum...", "doc2", "doc3"],
    metadatas=[{"chapter": 3, "verse": 16}, {"chapter": 3, "verse": 5}, {"chapter": 29, "verse": 11}],
)
​
# 查询
results = collection.query(
    query_texts=["doc"], # Chroma will embed this for you
    n_results=3 # how many results to return
)
print(results)

运行结果:

Loading weights: 100%|██████████| 625/625 [00:00<00:00, 6243.72it/s, Materializing param=visual.pos_embed.weight]
embedding: [[-1.8515625  -0.546875   -0.00741577 ...  0.390625    0.32421875
  -0.19921875]
 [ 0.05273438 -2.546875   -0.41796875 ... -0.265625    0.05737305
   0.5078125 ]
 [-0.9453125  -3.15625    -0.71875    ... -0.22167969  0.24023438
  -0.09863281]] 3
embedding: [-0.87890625 -2.9375     -0.52734375 ...  0.48828125 -0.87109375
  0.9296875 ] 2048
{'ids': [['id3', 'id2', 'id1']], 'embeddings': None, 'documents': [['doc3', 'doc2', 'lorem ipsum...']], 'uris': None, 'included': ['metadatas', 'documents', 'distances'], 'data': None, 'metadatas': [[{'chapter': 29, 'verse': 11}, {'chapter': 3, 'verse': 5}, {'verse': 16, 'chapter': 3}]], 'distances': [[2719.9794921875, 2903.89501953125, 6400.73583984375]]}