【踩坑】Elasticsearch-DSL向量搜索报错BadRequestError

675 阅读1分钟

使用SBert模型做embedding后入es库后,再用es的余弦相似度做匹配 报错:BadRequestError(400, 'search_phase_execution_exception', 'runtime error') 搜索代码如下:

#入es库的embedding方法
def encode_text(model, text):
    return model.encode([text])[0].tolist()


# 余弦相似度匹配搜索方法
def searchByText(question: string) -> list:
    try:
        
        model = SentenceTransformer(current_app.config.get("SBERT_CHECKPOINT"))
        index_name = "xxx"

        query_vector = model.encode(question)[0].tolist()
        print(f"query_vector={query_vector}")

        # 创建搜索对象
        s = Search(index=index_name)

        # 构建基础查询
        base_query = MatchAll()

        script_query = ScriptScore(
            query=base_query,
            # es的查询得分范围是0到1, 而余弦相似度算出来是-1~+1, 所以要+1
            script={"source": "cosineSimilarity(params.query_vector, 'vector') + 1.0",
                    "params": {"query_vector": query_vector}})

        s = s.query(script_query)

        # 定义查找出来的列
        s = s.source(["id", "category", "category_id", "question", "question_id", "updated_time"])
        # 取前5条记录
        response = s[:5].execute()
        results = []
    

排查发现searchByText.query_vector计算方法中,encode入参送的是字符串而不是数组,导致计算出来的embedding只有1位,而入库的embedding是768位,导致维度匹配不上,修改query_vector即可:

query_vector = model.encode([question])[0].tolist() 

吐槽一下: 这里的debug报错太简单了,能提示维度不一致就省事了