前言
在dify中,知识库检索
是核心功能之一,这里面大大小小的环节也比较多,但是在前端来看只有一个输入和一个输出,那么我们在知识库设定的那些设置 和 在知识库检索
节点设定的召回设置等都是怎么发挥作用的?如果两个设置中都设置了rerank,那么是如何工作的?
这就需要从源码来看,看看知识库检索
这个节点到底是如何工作的。
知识库检索
知识库检索
这节点的代码在dify项目中api/core/workflow/nodes/knowledge_retrieval
目录下的knowledge_retrieval_node.py
,入口函数就是_run
函数:
def _run(self) -> NodeRunResult:
...
try:
results = self._fetch_dataset_retriever(node_data=node_data, query=query)
outputs = {"result": results}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs
)
...
召回设置
执行_fetch_dataset_retriever
函数得到结果返回,这个函数就很大了,聚焦关键代码即可。在新版本中dify已经不支持单路召回了,所以现在都是多路召回,所以看这个分支即可:
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value:
if node_data.multiple_retrieval_config.reranking_mode == "reranking_model": #代码1
if node_data.multiple_retrieval_config.reranking_model:
reranking_model = {
"reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider,
"reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model,
}
else:
reranking_model = None
weights = None
elif node_data.multiple_retrieval_config.reranking_mode == "weighted_score": #代码2
reranking_model = None
vector_setting = node_data.multiple_retrieval_config.weights.vector_setting
weights = {
"vector_setting": {
"vector_weight": vector_setting.vector_weight,
"embedding_provider_name": vector_setting.embedding_provider_name,
"embedding_model_name": vector_setting.embedding_model_name,
},
"keyword_setting": {
"keyword_weight": node_data.multiple_retrieval_config.weights.keyword_setting.keyword_weight
},
}
else:
reranking_model = None
weights = None
all_documents = dataset_retrieval.multiple_retrieve(
self.app_id,
self.tenant_id,
self.user_id,
self.user_from.value,
available_datasets,
query,
node_data.multiple_retrieval_config.top_k,
node_data.multiple_retrieval_config.score_threshold,
node_data.multiple_retrieval_config.reranking_mode,
reranking_model,
weights,
node_data.multiple_retrieval_config.reranking_enable,
) #代码3
这里看代码1和代码2,根据multiple_retrieval_config
的reranking_mode
来做不同的动作,这里multiple_retrieval_config
就是在知识库检索
节点中设定的召回设置,如下
有两种方式:权重设置和rerank模型。reranking_mode
的两个值weighted_score
和reranking_model
就是对应着两种方式。
所以代码1和代码2中是将召回设置的信息组装一下,用于后面使用。
在代码3执行了multiple_retrieve
函数
开始检索
multiple_retrieve
函数源码如下:
def multiple_retrieve(
self,
...
):
...
for dataset in available_datasets: #代码1
index_type = dataset.indexing_technique
retrieval_thread = threading.Thread(
target=self._retriever,
kwargs={
"flask_app": current_app._get_current_object(),
"dataset_id": dataset.id,
"query": query,
"top_k": top_k,
"all_documents": all_documents,
},
)
threads.append(retrieval_thread)
retrieval_thread.start()
for thread in threads:
thread.join()
with measure_time() as timer:
if reranking_enable: #代码2
# do rerank for searched documents
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
all_documents = data_post_processor.invoke(
query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k
)
else:
...
...
return all_documents
这部分是核心代码,有两大块代码,正好对应知识库检索
的两个重要过程:检索和重排。
代码1遍历节点中配置的知识库,调用_retriever
函数对每个知识库进行检索,得到检索结果。
代码2则对所有检索结果进行重排序,是通过DataPostProcessor
来实现的。
我们一步步来看。
检索知识库
先来看检索这一步,_retriever
函数如下:
def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list):
with flask_app.app_context():
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset:
return []
if dataset.provider == "external":
...
else:
# get retrieval model , if the model is not setting , using default
retrieval_model = dataset.retrieval_model or default_retrieval_model
if dataset.indexing_technique == "economy":
# use keyword table query
documents = RetrievalService.retrieve(
retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=top_k
) #代码1
if documents:
all_documents.extend(documents)
else:
if top_k > 0:
# retrieval source
documents = RetrievalService.retrieve(
retrieval_method=retrieval_model["search_method"],
dataset_id=dataset.id,
query=query,
top_k=retrieval_model.get("top_k") or 2,
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
else 0.0,
reranking_model=retrieval_model.get("reranking_model", None)
if retrieval_model["reranking_enable"]
else None,
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights", None),
) #代码2
all_documents.extend(documents)
先在数据库中获取知识库详情,然后看else分支,这里根据dataset.indexing_technique
做不同的处理,这个indexing_technique
就是我们在知识库的设置中设定的索引模式
虽然处理不一样,但是可以看到都是调用了RetrievalService.retrieve
来检索。
检索方式
RetrievalService.retrieve
函数比较长,主要代码如下:
def retrieve(
cls,
retrieval_method: str,
dataset_id: str,
query: str,
top_k: int,
score_threshold: Optional[float] = 0.0,
reranking_model: Optional[dict] = None,
reranking_mode: Optional[str] = "reranking_model",
weights: Optional[dict] = None,
):
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
...
# retrieval_model source with keyword
if retrieval_method == "keyword_search": #代码1
...
# retrieval_model source with semantic
if RetrievalMethod.is_support_semantic_search(retrieval_method): #代码2
embedding_thread = threading.Thread(
target=RetrievalService.embedding_search,
kwargs={
"flask_app": current_app._get_current_object(),
"dataset_id": dataset_id,
"query": query,
"top_k": top_k,
"score_threshold": score_threshold,
"reranking_model": reranking_model,
"all_documents": all_documents,
"retrieval_method": retrieval_method,
"exceptions": exceptions,
},
)
threads.append(embedding_thread)
embedding_thread.start()
# retrieval source with full text
if RetrievalMethod.is_support_fulltext_search(retrieval_method): #代码3
full_text_index_thread = threading.Thread(
target=RetrievalService.full_text_index_search,
kwargs={
"flask_app": current_app._get_current_object(),
"dataset_id": dataset_id,
"query": query,
"retrieval_method": retrieval_method,
"score_threshold": score_threshold,
"top_k": top_k,
"reranking_model": reranking_model,
"all_documents": all_documents,
"exceptions": exceptions,
},
)
threads.append(full_text_index_thread)
full_text_index_thread.start()
...
if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value: #代码4
data_post_processor = DataPostProcessor(
str(dataset.tenant_id), reranking_mode, reranking_model, weights, False
)
all_documents = data_post_processor.invoke(
query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k
)
return all_documents
这里重点关注四个if语句,即代码1-4,根据不同的检索方式执行不同的检索流程。这个检索方式就是知识库的设置中的检索设置
其中代码1的关键字检索应该是废弃了,代码2、3、4分别对应向量检索、全文检索和混合检索。
这里注意观察代码2和代码3,并不是直接对比,而是用了两个is_support_xxx
函数,这两个函数如下:
@staticmethod
def is_support_semantic_search(retrieval_method: str) -> bool:
return retrieval_method in {RetrievalMethod.SEMANTIC_SEARCH.value, RetrievalMethod.HYBRID_SEARCH.value}
@staticmethod
def is_support_fulltext_search(retrieval_method: str) -> bool:
return retrieval_method in {RetrievalMethod.FULL_TEXT_SEARCH.value, RetrievalMethod.HYBRID_SEARCH.value}
可以看到除了对应的检索方式,混合检索也会执行代码2和代码3。所以在代码4我们看不到检索相关的代码,因为如果是混合检索,那么既要执行向量检索(代码2)也要执行全文检索(代码3),所有的检索结果都会存入all_documents
中。代码4仅仅是将这两个结果进行重排序,这部分代码可以看到跟multiple_retrieve
函数中召回设置的重排序代码是一样的,用DataPostProcessor
来实现,所以后面一起来说。
全文检索
先来看看代码3的全文检索,调用了RetrievalService.full_text_index_search
,这个函数代码如下:
@classmethod
def full_text_index_search(
cls,
...
):
with flask_app.app_context():
try:
...
documents = vector_processor.search_by_full_text(cls.escape_query_for_search(query), top_k=top_k) #代码1
if documents:
if (
reranking_model
and reranking_model.get("reranking_model_name")
and reranking_model.get("reranking_provider_name")
and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH.value
): #代码2
data_post_processor = DataPostProcessor(
str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False
)
all_documents.extend(
data_post_processor.invoke(
query=query, documents=documents, score_threshold=score_threshold, top_n=len(documents)
)
)
else:
all_documents.extend(documents)
except Exception as e:
exceptions.append(str(e))
代码1进行全文检索;代码2如果全文检索下设置了rerank则进使用rerank模型对结果进行重排序,可以看到这部分也是用DataPostProcessor
来实现,dify中所有重排序都是使用DataPostProcessor
,后面一起说。
看回代码1,执行了vector_processor.search_by_full_text
,代码如下:
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return self._vector_processor.search_by_full_text(query, **kwargs)
这里的_vector_processor
是一个BaseVector
对象,所有的实现都在子类中。具体是哪个子类,跟我们使用的向量数据库有关,默认是使用weaviate作为向量数据库,那么就是WeaviateVector
这个类,他的search_by_full_text
代码如下:
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
"""Return docs using BM25F.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
Returns:
List of Documents most similar to the query.
"""
collection_name = self._collection_name
content: dict[str, Any] = {"concepts": [query]}
properties = self._attributes
properties.append(Field.TEXT_KEY.value)
if kwargs.get("search_distance"):
content["certainty"] = kwargs.get("search_distance")
query_obj = self._client.query.get(collection_name, properties)
if kwargs.get("where_filter"):
query_obj = query_obj.with_where(kwargs.get("where_filter"))
query_obj = query_obj.with_additional(["vector"])
properties = ["text"]
result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get("top_k", 4)).do() #代码1
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
docs = []
for res in result["data"]["Get"][collection_name]:
text = res.pop(Field.TEXT_KEY.value)
additional = res.pop("_additional")
docs.append(Document(page_content=text, vector=additional["vector"], metadata=res))
return docs
可以看到这里是使用BM25算法来进行检索,取top_k个条目即可。
本身weaviate支持两种索引:
- 近似最近邻(ANN)索引 - 用于所有向量搜索查询
- 倒排索引 - 支持按属性过滤查询和 BM25 查询
所以这里直接使用它自带的特性进行检索的。可以对比看其他实现,比如MilvusVector
的对应函数如下:
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# milvus/zilliz doesn't support bm25 search
return []
可以看到这个数据库不支持BM25算法检索,所以直接返回空数组了。
向量检索
我们再来看看向量检索,在RetrievalService.retrieve
的代码2处可以看到是执行RetrievalService.embedding_search
来进行向量检索的,这个函数如下:
@classmethod
def embedding_search(
cls,
...
):
with flask_app.app_context():
try:
...
documents = vector.search_by_vector(
cls.escape_query_for_search(query),
search_type="similarity_score_threshold",
top_k=top_k,
score_threshold=score_threshold,
filter={"group_id": [dataset.id]},
) #代码1
if documents:
if (
reranking_model
and reranking_model.get("reranking_model_name")
and reranking_model.get("reranking_provider_name")
and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH.value
): #代码2
data_post_processor = DataPostProcessor(
str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False
)
all_documents.extend(
data_post_processor.invoke(
query=query, documents=documents, score_threshold=score_threshold, top_n=len(documents)
)
)
else:
all_documents.extend(documents)
except Exception as e:
exceptions.append(str(e))
可以看到与全文检索代码几乎是一样的,先进行检索(代码1),然后如果向量检索下设置了rerank就进行重排(代码2),重排也是DataPostProcessor
,后面一起说。
检索(代码1)是执行了vector.search_by_vector
,这个函数代码:
def search_by_vector(self, query: str, **kwargs: Any) -> list[Document]:
query_vector = self._embeddings.embed_query(query)
return self._vector_processor.search_by_vector(query_vector, **kwargs)
先执行了_embeddings.embed_query
函数对query进行向量化,这块我们简单展开说说。
向量化
embed_query
的函数内容如下:
def embed_query(self, text: str) -> list[float]:
"""Embed query text."""
# use doc embedding cache or store if not exists
hash = helper.generate_text_hash(text)
embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{hash}"
embedding = redis_client.get(embedding_cache_key) #代码1
if embedding:
redis_client.expire(embedding_cache_key, 600)
return list(np.frombuffer(base64.b64decode(embedding), dtype="float"))
try:
embedding_result = self._model_instance.invoke_text_embedding(
texts=[text], user=self._user, input_type=EmbeddingInputType.QUERY
) #代码2
except Exception as ex:
if dify_config.DEBUG:
logging.exception(f"Failed to embed query text: {ex}")
raise ex
try:
# encode embedding to base64
embedding_vector = np.array(embedding_results)
vector_bytes = embedding_vector.tobytes()
# Transform to Base64
encoded_vector = base64.b64encode(vector_bytes)
# Transform to string
encoded_str = encoded_vector.decode("utf-8")
redis_client.setex(embedding_cache_key, 600, encoded_str) #代码3
except Exception as ex:
if dify_config.DEBUG:
logging.exception("Failed to add embedding to redis %s", ex)
raise ex
return embedding_results
可以看到先在redis中查找之前是否有记录(代码1),如果有直接返回,这样同样的query就不必反复调用text-embedding模型了。这个记录有效时间是600秒,即10分钟(代码3)。
如果记录才会执行_model_instance.invoke_text_embedding
来进行向量化(代码2)。这个函数代码:
def invoke_text_embedding(
...
) -> TextEmbeddingResult:
...
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
return self._round_robin_invoke(
function=self.model_type_instance.invoke, #代码1
model=self.model,
credentials=self.credentials,
texts=texts,
user=user,
input_type=input_type,
)
调用了model_type_instance.invoke
(代码1),model_type_instance
是TextEmbeddingModel
的一个对象,它的invoke
函数实际上是调用了自己的_invoke
函数,这个函数是在子类里实现的,至于是哪个子类,就要看在知识库设置中配置的text-embedding模型是哪个,会使用这个模型对应供应商的实现。比如使用的模型是bge-large-zh
,是百度文心的,那么调用的子类就是WenxinTextEmbeddingModel
。
具体的实现这里就不展示了,每个供应商都不一样,但是都是发送一个请求并获取结果。
数据库检索
回到vector.search_by_vector
这个函数,向量化后与全文检索一样,执行了_vector_processor
的函数,它是一个BaseVector
对象,所以看它的子类WeaviateVector
的对应函数:
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
...
query_obj = self._client.query.get(collection_name, properties)
vector = {"vector": query_vector}
if kwargs.get("where_filter"): #代码1
query_obj = query_obj.with_where(kwargs.get("where_filter"))
result = (
query_obj.with_near_vector(vector)
.with_limit(kwargs.get("top_k", 4))
.with_additional(["vector", "distance"])
.do()
)
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
docs_and_scores = []
for res in result["data"]["Get"][collection_name]: #代码2
text = res.pop(Field.TEXT_KEY.value)
score = 1 - res["_additional"]["distance"]
docs_and_scores.append((Document(page_content=text, metadata=res), score))
docs = []
for doc, score in docs_and_scores: #代码3
score_threshold = float(kwargs.get("score_threshold") or 0.0)
# check score threshold
if score > score_threshold:
doc.metadata["score"] = score
docs.append(doc)
# Sort the documents by score in descending order
docs = sorted(docs, key=lambda x: x.metadata["score"], reverse=True)
return docs
整个过程分成三步,先在数据库中检索出top_k个相似条目(代码1),然后计算得分(代码2),最后如果向量检索设置下配置了Score阈值则进行过滤(代码3)
混合检索
上面也说了,混合检索就是进行全文检索top_k和向量检索top_k,然后根据混合检索的设置进行重排序(这部分跟召回设置一样,所以后面一起说),经过Score阈值过滤后(如果设置了)得到最高的top_k个结果即可。
重排序
上面用很大一个篇幅来讲了在知识库中是如何进行检索的,这部分的流程主要依赖的是知识库自身的配置。下面我们回到 #开始检索# 这一大节中。
上面提到multiple_retrieve
函数有两大块代码,正好对应知识库检索
的两个重要过程:检索和重排。
刚刚结束的 #检索知识库# 这一大节对应的就是第一个过程,现在我们从设置的知识库(一个或者多个)中得到了检索结果。如果只是一个知识库还好,因为结果是排序过的。但是如果是多个知识库,那么每个知识库自己的结果是排序过的,但是它们之间的并没有排序。
所以第二个过程就是重排(但是如果只有一个知识库也需要这一步,我认为这里dify可以优化一下)。
通过上面我们知道,在dify中,对检索结果的重排都是通过DataPostProcessor
来实现的,它的invoke
函数如下:
def invoke(
...
) -> list[Document]:
if self.rerank_runner:
documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user)
if self.reorder_runner:
documents = self.reorder_runner.run(documents)
return documents
这里有两种,重排都会进入第一种rerank_runner
,它是BaseRerankRunner
的对象,这个的实现也是在子类中,有两个子类:
- RerankModelRunner:rerank模型重排
- WeightRerankRunner:权重重排
再来看召回设置
召回设置里有两种方式:权重设置和rerank模型。上面两个类就是对应这两种方式的。
这里要提一下,在知识库设置中,检索模式如果是全文检索或向量检索,当开启rerank得时候就是用rerank模型来重排,如果不开启则不重排;而在混合检索中,可以看到与召回设置一样,可以选择两种方式。
下面我们就一个个来看看这两种方式。
rerank模型重排
先来看看rerank模型重排,RerankModelRunner
的run函数:
def run(
...
) -> list[Document]:
...
for document in dify_documents: #代码1
if document.metadata["doc_id"] not in doc_id:
doc_id.append(document.metadata["doc_id"])
docs.append(document.page_content)
unique_documents.append(document)
for document in external_documents:
docs.append(document.page_content)
unique_documents.append(document)
documents = unique_documents
rerank_result = self.rerank_model_instance.invoke_rerank(
query=query, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user
) #代码2
...
return rerank_documents
先将各个知识库的检索结果整合到一起(代码1),然后执行rerank_model_instance.invoke_rerank
进行重排(代码2)。
rerank_model_instance
这里其实跟上面 #向量化# 章节中的流程类似,最终会根据设置的rerank模型选择其供应商对应的实现类来执行,本质上同样是执行一个请求获取结果。这里就不细说了。
所以rerank模型重排其实比较简单,完全依赖模型来重排,包括top_k和Score阈值都是模型处理的。
权重重排
然后来看权重重排,WeightRerankRunner
的run函数如下:
def run(
...
) -> list[Document]:
...
for document in documents: #代码1
if document.metadata["doc_id"] not in doc_id:
doc_id.append(document.metadata["doc_id"])
docs.append(document.page_content)
unique_documents.append(document)
documents = unique_documents
rerank_documents = []
query_scores = self._calculate_keyword_score(query, documents) #代码2
query_vector_scores = self._calculate_cosine(self.tenant_id, query, documents, self.weights.vector_setting) #代码3
for document, query_score, query_vector_score in zip(documents, query_scores, query_vector_scores): #代码4
# format document
score = (
self.weights.vector_setting.vector_weight * query_vector_score
+ self.weights.keyword_setting.keyword_weight * query_score
) #代码4
if score_threshold and score < score_threshold: #代码5
continue
document.metadata["score"] = score
rerank_documents.append(document)
rerank_documents = sorted(rerank_documents, key=lambda x: x.metadata["score"], reverse=True) #代码6
return rerank_documents[:top_n] if top_n else rerank_documents #代码7
权重重排有一个权重设置,如图:
可以设置语义和关键词的权重比,上面的代码就体现了这部分。
首先还是先整合所有知识库的检索结果(代码1);然后进行关键词打分(代码2)和语义打分(代码3);最后根据设置好的权重比,对关键词打分和语义打分进行权重计算重新打分(代码4),如果设置了Score阈值这里还会进行过滤(代码5);最终重新排序(代码6)并返回top_k个结果(代码7)
这里dify可以优化一下,因为如果权重设置中有一方为0的情况下,那么就不必进行打分(比如如果关键词权重是0,就没必要进行关键词打分)。dify这里全都进行打分,有些浪费。
权重计算其实就是乘法运算,所以这里值得说的就是关键词打分(代码2)和语义打分(代码3),一个个来说。
语义打分
语义打分是在_calculate_cosine
这个函数中:
def _calculate_cosine(
self, tenant_id: str, query: str, documents: list[Document], vector_setting: VectorSetting
) -> list[float]:
...
embedding_model = model_manager.get_model_instance(
tenant_id=tenant_id,
provider=vector_setting.embedding_provider_name,
model_type=ModelType.TEXT_EMBEDDING,
model=vector_setting.embedding_model_name,
)
cache_embedding = CacheEmbedding(embedding_model)
query_vector = cache_embedding.embed_query(query) #代码1
for document in documents:
# calculate cosine similarity
if "score" in document.metadata: #代码2
query_vector_scores.append(document.metadata["score"])
else: #代码3
# transform to NumPy
vec1 = np.array(query_vector)
vec2 = np.array(document.vector)
# calculate dot product
dot_product = np.dot(vec1, vec2)
# calculate norm
norm_vec1 = np.linalg.norm(vec1)
norm_vec2 = np.linalg.norm(vec2)
# calculate cosine similarity
cosine_sim = dot_product / (norm_vec1 * norm_vec2)
query_vector_scores.append(cosine_sim)
return query_vector_scores
首先会对query进行向量化(代码1),这部分在前面 #向量化# 章节中详细说了,尤其是里面是用了redis缓存,那么注意这里因为刚刚进行了知识库的检索,如果当时是向量检索或混合检索,这里就一定会有缓存,不会再调用模型了;如果当时是全文检索,那么这里可能就没有缓存,就需要调用模型。
然后对每一个片段进行打分,如果片段结果中有score
,那么就直接使用即可(代码2);如果没有,则通过query向量和片段向量计算相似度进行打分(代码3)。前面知识库的检索过程中可以看到,如果是向量检索或混合检索,结果一定会有得分score
,这里就直接用;如果是全文检索,则没有得分score
,这里就需要重新计算一下。
所以如果前面知识库环节使用的是向量检索或者混合检索,语义打分这里基本就什么都不做。
关键词打分
关键词打分是在_calculate_keyword_score
这个函数中:
def _calculate_keyword_score(self, query: str, documents: list[Document]) -> list[float]:
"""
Calculate BM25 scores
:param query: search query
:param documents: documents for reranking
:return:
"""
keyword_table_handler = JiebaKeywordTableHandler() #代码1
query_keywords = keyword_table_handler.extract_keywords(query, None) #代码2
documents_keywords = []
for document in documents:
# get the document keywords
document_keywords = keyword_table_handler.extract_keywords(document.page_content, None) #代码3
document.metadata["keywords"] = document_keywords
documents_keywords.append(document_keywords)
# Counter query keywords(TF)
query_keyword_counts = Counter(query_keywords)
# total documents
total_documents = len(documents)
# calculate all documents' keywords IDF
all_keywords = set()
for document_keywords in documents_keywords:
all_keywords.update(document_keywords)
keyword_idf = {}
for keyword in all_keywords:
# calculate include query keywords' documents
doc_count_containing_keyword = sum(1 for doc_keywords in documents_keywords if keyword in doc_keywords)
# IDF
keyword_idf[keyword] = math.log((1 + total_documents) / (1 + doc_count_containing_keyword)) + 1
query_tfidf = {}
for keyword, count in query_keyword_counts.items():
tf = count
idf = keyword_idf.get(keyword, 0)
query_tfidf[keyword] = tf * idf
# calculate all documents' TF-IDF
documents_tfidf = []
for document_keywords in documents_keywords:
document_keyword_counts = Counter(document_keywords)
document_tfidf = {}
for keyword, count in document_keyword_counts.items():
tf = count
idf = keyword_idf.get(keyword, 0)
document_tfidf[keyword] = tf * idf
documents_tfidf.append(document_tfidf)
def cosine_similarity(vec1, vec2):
intersection = set(vec1.keys()) & set(vec2.keys())
numerator = sum(vec1[x] * vec2[x] for x in intersection)
sum1 = sum(vec1[x] ** 2 for x in vec1)
sum2 = sum(vec2[x] ** 2 for x in vec2)
denominator = math.sqrt(sum1) * math.sqrt(sum2)
if not denominator:
return 0.0
else:
return float(numerator) / denominator
similarities = []
for document_tfidf in documents_tfidf:
similarity = cosine_similarity(query_tfidf, document_tfidf)
similarities.append(similarity)
return similarities
先使用Jiaba这个库(代码1)对query和检索结果进行关键字提取(代码2、代码3)。
然后下面的代码就是通过关键字来计算相似度,即得分的。这里使用的依然是BM25算法,不是本篇文章的重点,就不细说了,感兴趣的同学可以自己查阅相关文档。
计算权重
得到语言打分和关键词打分后,就根据设置的权重计算最终得分,最后重新排序,取top_k个结果即可。
总结
上面就是dify中知识库检索的完整流程,这个应该也是dify所有功能中最复杂的一个了,理解了这个流程对才能了解如何去正确使用,尤其是优化流程。比如知识库使用了向量检索,而且检索结果很理想的情况下,召回设置就可以使用权重重排,这样语言打分不用重新计算,只进行关键词打分,相对于rerank模型重排速度会快很多,会减少这个环节的耗时。