[AIGC]超详细解析dify中知识库检索的原理

4,869 阅读17分钟

前言

在dify中,知识库检索是核心功能之一,这里面大大小小的环节也比较多,但是在前端来看只有一个输入和一个输出,那么我们在知识库设定的那些设置 和 在知识库检索节点设定的召回设置等都是怎么发挥作用的?如果两个设置中都设置了rerank,那么是如何工作的?

这就需要从源码来看,看看知识库检索这个节点到底是如何工作的。

知识库检索

知识库检索这节点的代码在dify项目中api/core/workflow/nodes/knowledge_retrieval目录下的knowledge_retrieval_node.py,入口函数就是_run函数:

def _run(self) -> NodeRunResult:
    ...
    try:
        results = self._fetch_dataset_retriever(node_data=node_data, query=query)
        outputs = {"result": results}
        return NodeRunResult(
            status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs
        )
    ...

召回设置

执行_fetch_dataset_retriever函数得到结果返回,这个函数就很大了,聚焦关键代码即可。在新版本中dify已经不支持单路召回了,所以现在都是多路召回,所以看这个分支即可:

elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value:
    if node_data.multiple_retrieval_config.reranking_mode == "reranking_model":    #代码1
        if node_data.multiple_retrieval_config.reranking_model:
            reranking_model = {
                "reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider,
                "reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model,
            }
        else:
            reranking_model = None
        weights = None
    elif node_data.multiple_retrieval_config.reranking_mode == "weighted_score":    #代码2
        reranking_model = None
        vector_setting = node_data.multiple_retrieval_config.weights.vector_setting
        weights = {
            "vector_setting": {
                "vector_weight": vector_setting.vector_weight,
                "embedding_provider_name": vector_setting.embedding_provider_name,
                "embedding_model_name": vector_setting.embedding_model_name,
            },
            "keyword_setting": {
                "keyword_weight": node_data.multiple_retrieval_config.weights.keyword_setting.keyword_weight
            },
        }
    else:
        reranking_model = None
        weights = None
    all_documents = dataset_retrieval.multiple_retrieve(
        self.app_id,
        self.tenant_id,
        self.user_id,
        self.user_from.value,
        available_datasets,
        query,
        node_data.multiple_retrieval_config.top_k,
        node_data.multiple_retrieval_config.score_threshold,
        node_data.multiple_retrieval_config.reranking_mode,
        reranking_model,
        weights,
        node_data.multiple_retrieval_config.reranking_enable,
    )       #代码3

这里看代码1和代码2,根据multiple_retrieval_configreranking_mode来做不同的动作,这里multiple_retrieval_config就是在知识库检索节点中设定的召回设置,如下

召回设置.jpg

有两种方式:权重设置和rerank模型。reranking_mode的两个值weighted_scorereranking_model就是对应着两种方式。

所以代码1和代码2中是将召回设置的信息组装一下,用于后面使用。

在代码3执行了multiple_retrieve函数

开始检索

multiple_retrieve函数源码如下:

def multiple_retrieve(
    self,
    ...
):
    ...
    for dataset in available_datasets:           #代码1
        index_type = dataset.indexing_technique
        retrieval_thread = threading.Thread(
            target=self._retriever,
            kwargs={
                "flask_app": current_app._get_current_object(),
                "dataset_id": dataset.id,
                "query": query,
                "top_k": top_k,
                "all_documents": all_documents,
            },
        )
        threads.append(retrieval_thread)
        retrieval_thread.start()
    for thread in threads:
        thread.join()

    with measure_time() as timer:
        if reranking_enable:      #代码2
            # do rerank for searched documents
            data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)

            all_documents = data_post_processor.invoke(
                query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k
            )
        else:
            ...
    ...
    return all_documents

这部分是核心代码,有两大块代码,正好对应知识库检索的两个重要过程:检索和重排。

代码1遍历节点中配置的知识库,调用_retriever函数对每个知识库进行检索,得到检索结果。

代码2则对所有检索结果进行重排序,是通过DataPostProcessor来实现的。

我们一步步来看。

检索知识库

先来看检索这一步,_retriever函数如下:

def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list):
    with flask_app.app_context():
        dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()

        if not dataset:
            return []

        if dataset.provider == "external":
            ...
        else:
            # get retrieval model , if the model is not setting , using default
            retrieval_model = dataset.retrieval_model or default_retrieval_model

            if dataset.indexing_technique == "economy":
                # use keyword table query
                documents = RetrievalService.retrieve(
                    retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=top_k
                )          #代码1
                if documents:
                    all_documents.extend(documents)
            else:
                if top_k > 0:
                    # retrieval source
                    documents = RetrievalService.retrieve(
                        retrieval_method=retrieval_model["search_method"],
                        dataset_id=dataset.id,
                        query=query,
                        top_k=retrieval_model.get("top_k") or 2,
                        score_threshold=retrieval_model.get("score_threshold", 0.0)
                        if retrieval_model["score_threshold_enabled"]
                        else 0.0,
                        reranking_model=retrieval_model.get("reranking_model", None)
                        if retrieval_model["reranking_enable"]
                        else None,
                        reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
                        weights=retrieval_model.get("weights", None),
                    )          #代码2

                    all_documents.extend(documents)

先在数据库中获取知识库详情,然后看else分支,这里根据dataset.indexing_technique做不同的处理,这个indexing_technique就是我们在知识库的设置中设定的索引模式

索引模式.jpg

虽然处理不一样,但是可以看到都是调用了RetrievalService.retrieve来检索。

检索方式

RetrievalService.retrieve函数比较长,主要代码如下:

def retrieve(
    cls,
    retrieval_method: str,
    dataset_id: str,
    query: str,
    top_k: int,
    score_threshold: Optional[float] = 0.0,
    reranking_model: Optional[dict] = None,
    reranking_mode: Optional[str] = "reranking_model",
    weights: Optional[dict] = None,
):
    dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
    ...
    # retrieval_model source with keyword
    if retrieval_method == "keyword_search":   #代码1
        ...
    # retrieval_model source with semantic
    if RetrievalMethod.is_support_semantic_search(retrieval_method):   #代码2
        embedding_thread = threading.Thread(
            target=RetrievalService.embedding_search,
            kwargs={
                "flask_app": current_app._get_current_object(),
                "dataset_id": dataset_id,
                "query": query,
                "top_k": top_k,
                "score_threshold": score_threshold,
                "reranking_model": reranking_model,
                "all_documents": all_documents,
                "retrieval_method": retrieval_method,
                "exceptions": exceptions,
            },
        )
        threads.append(embedding_thread)
        embedding_thread.start()

    # retrieval source with full text
    if RetrievalMethod.is_support_fulltext_search(retrieval_method):   #代码3
        full_text_index_thread = threading.Thread(
            target=RetrievalService.full_text_index_search,
            kwargs={
                "flask_app": current_app._get_current_object(),
                "dataset_id": dataset_id,
                "query": query,
                "retrieval_method": retrieval_method,
                "score_threshold": score_threshold,
                "top_k": top_k,
                "reranking_model": reranking_model,
                "all_documents": all_documents,
                "exceptions": exceptions,
            },
        )
        threads.append(full_text_index_thread)
        full_text_index_thread.start()

    ...

    if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value:   #代码4
        data_post_processor = DataPostProcessor(
            str(dataset.tenant_id), reranking_mode, reranking_model, weights, False
        )
        all_documents = data_post_processor.invoke(
            query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k
        )
    return all_documents

这里重点关注四个if语句,即代码1-4,根据不同的检索方式执行不同的检索流程。这个检索方式就是知识库的设置中的检索设置

检索设置.jpg

其中代码1的关键字检索应该是废弃了,代码2、3、4分别对应向量检索、全文检索和混合检索。

这里注意观察代码2和代码3,并不是直接对比,而是用了两个is_support_xxx函数,这两个函数如下:

@staticmethod
def is_support_semantic_search(retrieval_method: str) -> bool:
    return retrieval_method in {RetrievalMethod.SEMANTIC_SEARCH.value, RetrievalMethod.HYBRID_SEARCH.value}

@staticmethod
def is_support_fulltext_search(retrieval_method: str) -> bool:
    return retrieval_method in {RetrievalMethod.FULL_TEXT_SEARCH.value, RetrievalMethod.HYBRID_SEARCH.value}

可以看到除了对应的检索方式,混合检索也会执行代码2和代码3。所以在代码4我们看不到检索相关的代码,因为如果是混合检索,那么既要执行向量检索(代码2)也要执行全文检索(代码3),所有的检索结果都会存入all_documents中。代码4仅仅是将这两个结果进行重排序,这部分代码可以看到跟multiple_retrieve函数中召回设置的重排序代码是一样的,用DataPostProcessor来实现,所以后面一起来说。

全文检索

先来看看代码3的全文检索,调用了RetrievalService.full_text_index_search,这个函数代码如下:

@classmethod
def full_text_index_search(
    cls,
    ...
):
    with flask_app.app_context():
        try:
            ...
            documents = vector_processor.search_by_full_text(cls.escape_query_for_search(query), top_k=top_k)  #代码1
            if documents:
                if (
                    reranking_model
                    and reranking_model.get("reranking_model_name")
                    and reranking_model.get("reranking_provider_name")
                    and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH.value
                ):       #代码2
                    data_post_processor = DataPostProcessor(
                        str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False
                    )
                    all_documents.extend(
                        data_post_processor.invoke(
                            query=query, documents=documents, score_threshold=score_threshold, top_n=len(documents)
                        )
                    )
                else:
                    all_documents.extend(documents)
        except Exception as e:
            exceptions.append(str(e))

代码1进行全文检索;代码2如果全文检索下设置了rerank则进使用rerank模型对结果进行重排序,可以看到这部分也是用DataPostProcessor来实现,dify中所有重排序都是使用DataPostProcessor,后面一起说。

看回代码1,执行了vector_processor.search_by_full_text,代码如下:

def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
    return self._vector_processor.search_by_full_text(query, **kwargs)

这里的_vector_processor是一个BaseVector对象,所有的实现都在子类中。具体是哪个子类,跟我们使用的向量数据库有关,默认是使用weaviate作为向量数据库,那么就是WeaviateVector这个类,他的search_by_full_text代码如下:

def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
    """Return docs using BM25F.

    Args:
        query: Text to look up documents similar to.
        k: Number of Documents to return. Defaults to 4.

    Returns:
        List of Documents most similar to the query.
    """
    collection_name = self._collection_name
    content: dict[str, Any] = {"concepts": [query]}
    properties = self._attributes
    properties.append(Field.TEXT_KEY.value)
    if kwargs.get("search_distance"):
        content["certainty"] = kwargs.get("search_distance")
    query_obj = self._client.query.get(collection_name, properties)
    if kwargs.get("where_filter"):
        query_obj = query_obj.with_where(kwargs.get("where_filter"))
    query_obj = query_obj.with_additional(["vector"])
    properties = ["text"]
    result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get("top_k", 4)).do()    #代码1
    if "errors" in result:
        raise ValueError(f"Error during query: {result['errors']}")
    docs = []
    for res in result["data"]["Get"][collection_name]:
        text = res.pop(Field.TEXT_KEY.value)
        additional = res.pop("_additional")
        docs.append(Document(page_content=text, vector=additional["vector"], metadata=res))
    return docs

可以看到这里是使用BM25算法来进行检索,取top_k个条目即可。

本身weaviate支持两种索引:

  • 近似最近邻(ANN)索引 - 用于所有向量搜索查询
  • 倒排索引 - 支持按属性过滤查询和 BM25 查询

所以这里直接使用它自带的特性进行检索的。可以对比看其他实现,比如MilvusVector的对应函数如下:

def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
    # milvus/zilliz doesn't support bm25 search
    return []

可以看到这个数据库不支持BM25算法检索,所以直接返回空数组了。

向量检索

我们再来看看向量检索,在RetrievalService.retrieve的代码2处可以看到是执行RetrievalService.embedding_search来进行向量检索的,这个函数如下:

@classmethod
def embedding_search(
    cls,
    ...
):
    with flask_app.app_context():
        try:
            ...
            documents = vector.search_by_vector(
                cls.escape_query_for_search(query),
                search_type="similarity_score_threshold",
                top_k=top_k,
                score_threshold=score_threshold,
                filter={"group_id": [dataset.id]},
            )   #代码1

            if documents:
                if (
                    reranking_model
                    and reranking_model.get("reranking_model_name")
                    and reranking_model.get("reranking_provider_name")
                    and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH.value
                ):    #代码2
                    data_post_processor = DataPostProcessor(
                        str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False
                    )
                    all_documents.extend(
                        data_post_processor.invoke(
                            query=query, documents=documents, score_threshold=score_threshold, top_n=len(documents)
                        )
                    )
                else:
                    all_documents.extend(documents)
        except Exception as e:
            exceptions.append(str(e))

可以看到与全文检索代码几乎是一样的,先进行检索(代码1),然后如果向量检索下设置了rerank就进行重排(代码2),重排也是DataPostProcessor,后面一起说。

检索(代码1)是执行了vector.search_by_vector,这个函数代码:

def search_by_vector(self, query: str, **kwargs: Any) -> list[Document]:
    query_vector = self._embeddings.embed_query(query)
    return self._vector_processor.search_by_vector(query_vector, **kwargs)

先执行了_embeddings.embed_query函数对query进行向量化,这块我们简单展开说说。

向量化

embed_query的函数内容如下:

def embed_query(self, text: str) -> list[float]:
    """Embed query text."""
    # use doc embedding cache or store if not exists
    hash = helper.generate_text_hash(text)
    embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{hash}"
    embedding = redis_client.get(embedding_cache_key)  #代码1
    if embedding:
        redis_client.expire(embedding_cache_key, 600)
        return list(np.frombuffer(base64.b64decode(embedding), dtype="float"))
    try:
        embedding_result = self._model_instance.invoke_text_embedding(
            texts=[text], user=self._user, input_type=EmbeddingInputType.QUERY
        )       #代码2
    except Exception as ex:
        if dify_config.DEBUG:
            logging.exception(f"Failed to embed query text: {ex}")
        raise ex

    try:
        # encode embedding to base64
        embedding_vector = np.array(embedding_results)
        vector_bytes = embedding_vector.tobytes()
        # Transform to Base64
        encoded_vector = base64.b64encode(vector_bytes)
        # Transform to string
        encoded_str = encoded_vector.decode("utf-8")
        redis_client.setex(embedding_cache_key, 600, encoded_str)    #代码3
    except Exception as ex:
        if dify_config.DEBUG:
            logging.exception("Failed to add embedding to redis %s", ex)
        raise ex

    return embedding_results

可以看到先在redis中查找之前是否有记录(代码1),如果有直接返回,这样同样的query就不必反复调用text-embedding模型了。这个记录有效时间是600秒,即10分钟(代码3)。

如果记录才会执行_model_instance.invoke_text_embedding来进行向量化(代码2)。这个函数代码:

def invoke_text_embedding(
    ...
) -> TextEmbeddingResult:
    ...
    self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
    return self._round_robin_invoke(
        function=self.model_type_instance.invoke,  #代码1
        model=self.model,
        credentials=self.credentials,
        texts=texts,
        user=user,
        input_type=input_type,
    )

调用了model_type_instance.invoke(代码1),model_type_instanceTextEmbeddingModel的一个对象,它的invoke函数实际上是调用了自己的_invoke 函数,这个函数是在子类里实现的,至于是哪个子类,就要看在知识库设置中配置的text-embedding模型是哪个,会使用这个模型对应供应商的实现。比如使用的模型是bge-large-zh,是百度文心的,那么调用的子类就是WenxinTextEmbeddingModel

具体的实现这里就不展示了,每个供应商都不一样,但是都是发送一个请求并获取结果。

数据库检索

回到vector.search_by_vector这个函数,向量化后与全文检索一样,执行了_vector_processor的函数,它是一个BaseVector对象,所以看它的子类WeaviateVector的对应函数:

def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
    ...
    query_obj = self._client.query.get(collection_name, properties)

    vector = {"vector": query_vector}
    if kwargs.get("where_filter"):     #代码1
        query_obj = query_obj.with_where(kwargs.get("where_filter"))
    result = (
        query_obj.with_near_vector(vector)
        .with_limit(kwargs.get("top_k", 4))
        .with_additional(["vector", "distance"])
        .do()
    )
    if "errors" in result:
        raise ValueError(f"Error during query: {result['errors']}")

    docs_and_scores = []
    for res in result["data"]["Get"][collection_name]:     #代码2
        text = res.pop(Field.TEXT_KEY.value)
        score = 1 - res["_additional"]["distance"]
        docs_and_scores.append((Document(page_content=text, metadata=res), score))

    docs = []
    for doc, score in docs_and_scores:      #代码3
        score_threshold = float(kwargs.get("score_threshold") or 0.0)
        # check score threshold
        if score > score_threshold:
            doc.metadata["score"] = score
            docs.append(doc)
    # Sort the documents by score in descending order
    docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
    return docs

整个过程分成三步,先在数据库中检索出top_k个相似条目(代码1),然后计算得分(代码2),最后如果向量检索设置下配置了Score阈值则进行过滤(代码3)

混合检索

上面也说了,混合检索就是进行全文检索top_k和向量检索top_k,然后根据混合检索的设置进行重排序(这部分跟召回设置一样,所以后面一起说),经过Score阈值过滤后(如果设置了)得到最高的top_k个结果即可。

重排序

上面用很大一个篇幅来讲了在知识库中是如何进行检索的,这部分的流程主要依赖的是知识库自身的配置。下面我们回到 #开始检索# 这一大节中。

上面提到multiple_retrieve函数有两大块代码,正好对应知识库检索的两个重要过程:检索和重排。

刚刚结束的 #检索知识库# 这一大节对应的就是第一个过程,现在我们从设置的知识库(一个或者多个)中得到了检索结果。如果只是一个知识库还好,因为结果是排序过的。但是如果是多个知识库,那么每个知识库自己的结果是排序过的,但是它们之间的并没有排序。

所以第二个过程就是重排(但是如果只有一个知识库也需要这一步,我认为这里dify可以优化一下)。

通过上面我们知道,在dify中,对检索结果的重排都是通过DataPostProcessor来实现的,它的invoke函数如下:

def invoke(
    ...
) -> list[Document]:
    if self.rerank_runner:
        documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user)

    if self.reorder_runner:
        documents = self.reorder_runner.run(documents)

    return documents

这里有两种,重排都会进入第一种rerank_runner,它是BaseRerankRunner的对象,这个的实现也是在子类中,有两个子类:

  • RerankModelRunner:rerank模型重排
  • WeightRerankRunner:权重重排

再来看召回设置

召回设置.jpg 召回设置里有两种方式:权重设置和rerank模型。上面两个类就是对应这两种方式的。

这里要提一下,在知识库设置中,检索模式如果是全文检索或向量检索,当开启rerank得时候就是用rerank模型来重排,如果不开启则不重排;而在混合检索中,可以看到与召回设置一样,可以选择两种方式。

下面我们就一个个来看看这两种方式。

rerank模型重排

先来看看rerank模型重排,RerankModelRunner的run函数:

def run(
    ...
) -> list[Document]:
    ...
    for document in dify_documents:    #代码1
        if document.metadata["doc_id"] not in doc_id:
            doc_id.append(document.metadata["doc_id"])
            docs.append(document.page_content)
            unique_documents.append(document)
    for document in external_documents:
        docs.append(document.page_content)
        unique_documents.append(document)

    documents = unique_documents

    rerank_result = self.rerank_model_instance.invoke_rerank(
        query=query, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user
    )     #代码2
    ...
    return rerank_documents

先将各个知识库的检索结果整合到一起(代码1),然后执行rerank_model_instance.invoke_rerank进行重排(代码2)。

rerank_model_instance这里其实跟上面 #向量化# 章节中的流程类似,最终会根据设置的rerank模型选择其供应商对应的实现类来执行,本质上同样是执行一个请求获取结果。这里就不细说了。

所以rerank模型重排其实比较简单,完全依赖模型来重排,包括top_k和Score阈值都是模型处理的。

权重重排

然后来看权重重排,WeightRerankRunner的run函数如下:

def run(
    ...
) -> list[Document]:
    ...
    for document in documents:    #代码1
        if document.metadata["doc_id"] not in doc_id:
            doc_id.append(document.metadata["doc_id"])
            docs.append(document.page_content)
            unique_documents.append(document)

    documents = unique_documents

    rerank_documents = []
    query_scores = self._calculate_keyword_score(query, documents)    #代码2

    query_vector_scores = self._calculate_cosine(self.tenant_id, query, documents, self.weights.vector_setting)       #代码3
    
    for document, query_score, query_vector_score in zip(documents, query_scores, query_vector_scores):        #代码4
        # format document
        score = (
            self.weights.vector_setting.vector_weight * query_vector_score
            + self.weights.keyword_setting.keyword_weight * query_score
        )       #代码4
        if score_threshold and score < score_threshold:   #代码5
            continue
        document.metadata["score"] = score
        rerank_documents.append(document)
    rerank_documents = sorted(rerank_documents, key=lambda x: x.metadata["score"], reverse=True)  #代码6
    return rerank_documents[:top_n] if top_n else rerank_documents  #代码7

权重重排有一个权重设置,如图:

召回设置.jpg

可以设置语义和关键词的权重比,上面的代码就体现了这部分。

首先还是先整合所有知识库的检索结果(代码1);然后进行关键词打分(代码2)和语义打分(代码3);最后根据设置好的权重比,对关键词打分和语义打分进行权重计算重新打分(代码4),如果设置了Score阈值这里还会进行过滤(代码5);最终重新排序(代码6)并返回top_k个结果(代码7)

这里dify可以优化一下,因为如果权重设置中有一方为0的情况下,那么就不必进行打分(比如如果关键词权重是0,就没必要进行关键词打分)。dify这里全都进行打分,有些浪费。

权重计算其实就是乘法运算,所以这里值得说的就是关键词打分(代码2)和语义打分(代码3),一个个来说。

语义打分

语义打分是在_calculate_cosine这个函数中:

def _calculate_cosine(
    self, tenant_id: str, query: str, documents: list[Document], vector_setting: VectorSetting
) -> list[float]:
    ...
    embedding_model = model_manager.get_model_instance(
        tenant_id=tenant_id,
        provider=vector_setting.embedding_provider_name,
        model_type=ModelType.TEXT_EMBEDDING,
        model=vector_setting.embedding_model_name,
    )
    cache_embedding = CacheEmbedding(embedding_model)
    query_vector = cache_embedding.embed_query(query)   #代码1
    for document in documents:
        # calculate cosine similarity
        if "score" in document.metadata:   #代码2
            query_vector_scores.append(document.metadata["score"])
        else:              #代码3
            # transform to NumPy
            vec1 = np.array(query_vector)
            vec2 = np.array(document.vector)

            # calculate dot product
            dot_product = np.dot(vec1, vec2)

            # calculate norm
            norm_vec1 = np.linalg.norm(vec1)
            norm_vec2 = np.linalg.norm(vec2)

            # calculate cosine similarity
            cosine_sim = dot_product / (norm_vec1 * norm_vec2)
            query_vector_scores.append(cosine_sim)

    return query_vector_scores

首先会对query进行向量化(代码1),这部分在前面 #向量化# 章节中详细说了,尤其是里面是用了redis缓存,那么注意这里因为刚刚进行了知识库的检索,如果当时是向量检索或混合检索,这里就一定会有缓存,不会再调用模型了;如果当时是全文检索,那么这里可能就没有缓存,就需要调用模型。

然后对每一个片段进行打分,如果片段结果中有score,那么就直接使用即可(代码2);如果没有,则通过query向量和片段向量计算相似度进行打分(代码3)。前面知识库的检索过程中可以看到,如果是向量检索或混合检索,结果一定会有得分score,这里就直接用;如果是全文检索,则没有得分score,这里就需要重新计算一下。

所以如果前面知识库环节使用的是向量检索或者混合检索,语义打分这里基本就什么都不做。

关键词打分

关键词打分是在_calculate_keyword_score这个函数中:

def _calculate_keyword_score(self, query: str, documents: list[Document]) -> list[float]:
    """
    Calculate BM25 scores
    :param query: search query
    :param documents: documents for reranking

    :return:
    """
    keyword_table_handler = JiebaKeywordTableHandler()    #代码1
    query_keywords = keyword_table_handler.extract_keywords(query, None)    #代码2
    documents_keywords = []
    for document in documents:
        # get the document keywords
        document_keywords = keyword_table_handler.extract_keywords(document.page_content, None)     #代码3
        document.metadata["keywords"] = document_keywords
        documents_keywords.append(document_keywords)

    # Counter query keywords(TF)
    query_keyword_counts = Counter(query_keywords)

    # total documents
    total_documents = len(documents)

    # calculate all documents' keywords IDF
    all_keywords = set()
    for document_keywords in documents_keywords:
        all_keywords.update(document_keywords)

    keyword_idf = {}
    for keyword in all_keywords:
        # calculate include query keywords' documents
        doc_count_containing_keyword = sum(1 for doc_keywords in documents_keywords if keyword in doc_keywords)
        # IDF
        keyword_idf[keyword] = math.log((1 + total_documents) / (1 + doc_count_containing_keyword)) + 1

    query_tfidf = {}

    for keyword, count in query_keyword_counts.items():
        tf = count
        idf = keyword_idf.get(keyword, 0)
        query_tfidf[keyword] = tf * idf

    # calculate all documents' TF-IDF
    documents_tfidf = []
    for document_keywords in documents_keywords:
        document_keyword_counts = Counter(document_keywords)
        document_tfidf = {}
        for keyword, count in document_keyword_counts.items():
            tf = count
            idf = keyword_idf.get(keyword, 0)
            document_tfidf[keyword] = tf * idf
        documents_tfidf.append(document_tfidf)

    def cosine_similarity(vec1, vec2):
        intersection = set(vec1.keys()) & set(vec2.keys())
        numerator = sum(vec1[x] * vec2[x] for x in intersection)

        sum1 = sum(vec1[x] ** 2 for x in vec1)
        sum2 = sum(vec2[x] ** 2 for x in vec2)
        denominator = math.sqrt(sum1) * math.sqrt(sum2)

        if not denominator:
            return 0.0
        else:
            return float(numerator) / denominator

    similarities = []
    for document_tfidf in documents_tfidf:
        similarity = cosine_similarity(query_tfidf, document_tfidf)
        similarities.append(similarity)

    return similarities

先使用Jiaba这个库(代码1)对query和检索结果进行关键字提取(代码2、代码3)。

然后下面的代码就是通过关键字来计算相似度,即得分的。这里使用的依然是BM25算法,不是本篇文章的重点,就不细说了,感兴趣的同学可以自己查阅相关文档。

计算权重

得到语言打分和关键词打分后,就根据设置的权重计算最终得分,最后重新排序,取top_k个结果即可。

总结

上面就是dify中知识库检索的完整流程,这个应该也是dify所有功能中最复杂的一个了,理解了这个流程对才能了解如何去正确使用,尤其是优化流程。比如知识库使用了向量检索,而且检索结果很理想的情况下,召回设置就可以使用权重重排,这样语言打分不用重新计算,只进行关键词打分,相对于rerank模型重排速度会快很多,会减少这个环节的耗时。