虽然没进决赛,但总结一下oceanbase ai hackthon中所实现的多路召回
1. 重排序(Rerank)实现
实现了基于 bge-reranker-large 模型的重排序功能,通过对检索到的文档进行二次排序,提高检索质量。
核心实现代码:
在 rerank.py 中,项目定义了 rerank_topn 函数:
def rerank_topn(question,docs,N=5):
pairs = []
for i in docs:
pairs.append([question,i.page_content])
with torch.no_grad():
inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
scores = model(**inputs, return_dict=True).logits.view(-1, ).float()
scores = scores.argsort().numpy()[::-1][:N]
bk = []
for i in scores:
bk.append(docs[i])
return bk
这个函数首先将问题和每个文档组成对,然后使用预训练的 bge-reranker-large 模型计算相关性分数,最后返回得分最高的 N 个文档。
在 rag_class.py 中通过 rerank_chain 方法调用这个重排序功能:
def rerank_chain(self,question):
retriever = self.vectstore.as_retriever(search_kwargs={"k": 10})
docs = retriever.invoke(question)
docs = rerank_topn(question,docs,N=5)
_chain = (
self.prompts
| self.llm
| StrOutputParser()
)
answer = _chain.invoke({"context":self.format_docs(docs),"question": question})
return answer
这个方法首先检索 10 个候选文档,然后使用 rerank_topn 筛选出最相关的 5 个文档,最后将这些文档作为上下文传递给大模型生成回答。
2. 融合实现
项目实现了多种文档融合策略,特别是"复杂召回方式"通过问题扩展和递归融合提高了全面性:
# 多问题递归召回,每次召回后,问题和答案同时作为下一次召回的参考,再次用新问题召回
def rag_chain(self, questions):
q_a_pairs = ""
for q in questions:
_chain = (
{"context": itemgetter("question") | self.retriever,
"question": itemgetter("question"),
"q_a_pairs": itemgetter("q_a_paris")
}
| self.decomposition_prompt
| self.llm
| StrOutputParser()
)
answer = _chain.invoke({"question": q, "q_a_paris": q_a_pairs})
q_a_pairs = self.format_qa_pairs(q, answer)
q_a_pairs = q_a_pairs + "\n----\n" + q_a_pairs
return answer
首先通过 decomposition_chain 方法生成多个相关问题:
# 获取问题的 扩展问题
def decomposition_chain(self, question):
_chain = (
{"question": RunnablePassthrough()}
| self.prompt_questions
| self.llm
| StrOutputParser()
| (lambda x: x.split("\n"))
)
questions = _chain.invoke({"question": question}) + [question]
return questions
这种方式使用 LLM 生成多个相关问题,然后对每个问题分别进行检索,并将所有问答对积累起来作为上下文,实现了多角度信息的融合。
3. 过滤实现
过滤功能主要体现在两个方面:文档检索过滤和重排序过滤。
在 rerank_topn 函数中实现了基于相关性得分的过滤:
scores = scores.argsort().numpy()[::-1][:N]
bk = []
for i in scores:
bk.append(docs[i])
return bk
在向量数据库实现中也有文件级别的过滤功能:
# 删除 某个collection中的 某个文件
def del_files(self, del_files_name, c_name):
vectorstore = self.chromadb._client.get_collection(c_name)
del_ids = []
vec_dict = vectorstore.get()
for id, md in zip(vec_dict["ids"], vec_dict["metadatas"]):
for dl in del_files_name:
if dl in md["source"]:
del_ids.append(id)
vectorstore.delete(ids=del_ids)
print("数据块总量:", vectorstore.count())
return vectorstore
这些过滤机制确保了只有最相关的内容会被用于回答生成。
4. 摘要和总结实现
项目中的摘要和总结功能主要通过精心设计的提示模板和 LLM 调用实现:
基本的问答提示模板:
template = """
根据上下文回答以下问题,不要自己发挥,要根据以下参考内容总结答案,如果以下内容无法得到答案,就返回无法根据参考内容获取答案,
参考内容为:{context}
问题: {question}
"""
对于网络搜索结果的总结:
def summarize_with_ollama(model_dropdown,text, question):
prompt = """
根据下边的内容,回答用户问题,
内容为:'{0}'\n
问题为:{1}
""".format(text, question)
ollama_url = 'http://localhost:11434/api/generate' # 替换为你的Ollama实例URL
data = {
'model': model_dropdown,
"prompt": prompt,
"stream": False
}
response = requests.post(ollama_url, json=data)
response.raise_for_status()
return response.json()
更复杂的包含背景问答对的提示模板:
template2 = """
以下是您需要回答的问题:
\n--\n {question} \n---\n
以下是任何可用的背景问答对:
\n--\n {q_a_pairs} \n---\n
以下是与该问题相关的其他上下文:
\n--\n {context} \n---\n
使用以上上下文和背景问答对来回答问题,问题是:{question} ,答案是:
"""
self.decomposition_prompt = ChatPromptTemplate.from_template(template2)
5. 系统整合
这些组件在 webui.py 的 chat_response 函数中被整合起来:
def chat_response(model_dropdown, vector_dropdown, chat_knowledge_base_dropdown, chain_dropdown, message):
global chat_history
if message:
chat_history.append(("User", message))
if chat_knowledge_base_dropdown == "仅使用模型":
rag = RAG_class(model=model_dropdown,persist_directory=DB_directory)
answer = rag.mult_chat(chat_history)
if chat_knowledge_base_dropdown and chat_knowledge_base_dropdown != "仅使用模型":
rag = RAG_class(model=model_dropdown, embed=vector_dropdown, c_name=chat_knowledge_base_dropdown, persist_directory=DB_directory)
if chain_dropdown == "复杂召回方式":
questions = rag.decomposition_chain(message)
answer = rag.rag_chain(questions)
elif chain_dropdown == "简单召回方式":
answer = rag.simple_chain(message)
else:
answer = rag.rerank_chain(message)
response = f" {answer}"
chat_history.append(("Bot", response))
return format_chat_history(chat_history), ""
用户可以选择三种不同的召回方式:
- 复杂召回方式:使用问题扩展和递归融合
- 简单召回方式:直接检索相关文档并生成回答
- rerank:使用重排序提高检索质量
那么如果用户想要对召回的质量进行评估,应该怎么做呢?
现有的重排序功能进行评估
如果实现了重排序功能,这本身就是一种评估和改进召回质量的方法。可以比较使用重排序前后的结果差异
通过修改逻辑在重排序前后分别保存结果,然后比较两者的差异,评估重排序的效果。
2. 实现标准评估指标
可以在项目中添加以下常用的信息检索评估指标:
a. 精确率和召回率
def evaluate_precision_recall(retrieved_docs, relevant_docs):
"""
评估检索结果的精确率和召回率
Args:
retrieved_docs: 系统检索到的文档ID列表
relevant_docs: 标注的相关文档ID列表
Returns:
precision: 精确率
recall: 召回率
"""
if not retrieved_docs:
return 0, 0
relevant_retrieved = set(retrieved_docs).intersection(set(relevant_docs))
precision = len(relevant_retrieved) / len(retrieved_docs)
recall = len(relevant_retrieved) / len(relevant_docs) if relevant_docs else 0
return precision, recall
b. 平均精度均值 (Mean Average Precision, MAP)
def calculate_map(all_queries_results, all_queries_relevant):
"""
计算平均精度均值
Args:
all_queries_results: 每个查询的检索结果 {query_id: [doc_id1, doc_id2, ...]}
all_queries_relevant: 每个查询的相关文档 {query_id: [doc_id1, doc_id2, ...]}
Returns:
map_score: MAP分数
"""
average_precisions = []
for query_id in all_queries_results:
if query_id not in all_queries_relevant:
continue
retrieved = all_queries_results[query_id]
relevant = set(all_queries_relevant[query_id])
if not relevant:
continue
precisions = []
relevant_count = 0
for i, doc_id in enumerate(retrieved):
if doc_id in relevant:
relevant_count += 1
precisions.append(relevant_count / (i + 1))
if precisions:
average_precisions.append(sum(precisions) / len(relevant))
return sum(average_precisions) / len(average_precisions) if average_precisions else 0
c. 归一化折损累积增益 (NDCG)
def calculate_ndcg(retrieved_docs, relevant_docs, k=None):
"""
计算NDCG
Args:
retrieved_docs: 检索结果列表
relevant_docs: 相关文档字典 {doc_id: relevance_score}
k: 截断位置,默认为None表示使用所有结果
Returns:
ndcg: NDCG分数
"""
import numpy as np
if k is not None:
retrieved_docs = retrieved_docs[:k]
dcg = 0
for i, doc_id in enumerate(retrieved_docs):
if doc_id in relevant_docs:
# 使用2^rel-1公式,也可以使用其他公式
rel = relevant_docs[doc_id]
dcg += (2 ** rel - 1) / np.log2(i + 2) # i+2 因为log_2(1)=0
# 计算理想DCG
ideal_ranking = sorted(relevant_docs.items(), key=lambda x: x[1], reverse=True)
ideal_dcg = 0
for i, (doc_id, rel) in enumerate(ideal_ranking[:len(retrieved_docs)]):
ideal_dcg += (2 ** rel - 1) / np.log2(i + 2)
return dcg / ideal_dcg if ideal_dcg > 0 else 0
3. 集成到现有系统中
您可以在 rag_class.py 中添加一个评估方法,例如:
def evaluate_retrieval(self, test_questions, ground_truth):
"""
评估检索效果
Args:
test_questions: 测试问题列表
ground_truth: 每个问题的标准答案 {question: [relevant_doc_ids]}
Returns:
metrics: 评估指标字典
"""
results = {}
for question in test_questions:
# 使用不同的检索方法
# 1. 简单检索
simple_retriever = self.vectstore.as_retriever(search_kwargs={"k": 10})
simple_docs = simple_retriever.invoke(question)
simple_doc_ids = [doc.metadata.get('id') for doc in simple_docs]
# 2. 重排序检索
rerank_docs = rerank_topn(question, simple_docs, N=5)
rerank_doc_ids = [doc.metadata.get('id') for doc in rerank_docs]
# 3. 复杂检索
questions = self.decomposition_chain(question)
complex_docs = []
for q in questions:
docs = self.retriever.invoke(q)
complex_docs.extend(docs)
complex_doc_ids = [doc.metadata.get('id') for doc in complex_docs]
results[question] = {
'simple': simple_doc_ids,
'rerank': rerank_doc_ids,
'complex': complex_doc_ids
}
# 计算评估指标
metrics = {
'simple': {},
'rerank': {},
'complex': {}
}
for method in metrics:
precisions = []
recalls = []
for question, doc_ids in results.items():
if question in ground_truth:
p, r = evaluate_precision_recall(doc_ids[method], ground_truth[question])
precisions.append(p)
recalls.append(r)
metrics[method]['precision'] = sum(precisions) / len(precisions) if precisions else 0
metrics[method]['recall'] = sum(recalls) / len(recalls) if recalls else 0
metrics[method]['f1'] = 2 * metrics[method]['precision'] * metrics[method]['recall'] / (metrics[method]['precision'] + metrics[method]['recall']) if (metrics[method]['precision'] + metrics[method]['recall']) > 0 else 0
return metrics
4. 创建测试数据集
为了进行评估,需要创建一个测试数据集,包含问题和相关文档的标注。:
- 手动创建一组测试问题和标准答案
- 从现有知识库中抽取一部分作为测试集
- 使用 LLM 生成测试问题和答案
例如,可以创建一个测试数据集文件 test_dataset.json:
{
"questions": [
"什么是RAG系统?",
"Easy-RAG支持哪些向量数据库?",
"如何使用rerank功能提高检索质量?"
],
"ground_truth": {
"什么是RAG系统?": ["doc_id_1", "doc_id_5", "doc_id_10"],
"Easy-RAG支持哪些向量数据库?": ["doc_id_3", "doc_id_7"],
"如何使用rerank功能提高检索质量?": ["doc_id_2", "doc_id_8", "doc_id_12"]
}
}
5. 添加可视化评估界面
可以在 webui.py 中添加一个评估标签页,用于可视化评估结果:
with gr.TabItem("评估"):
test_file = gr.File(label="上传测试数据集")
eval_knowledge_base_dropdown = gr.Dropdown(choices=["仅使用模型"] + vectordb.get_all_collections_name(), label="选择知识库")
eval_model_dropdown = gr.Dropdown(choices=get_llm(), label="选择模型")
eval_vector_dropdown = gr.Dropdown(choices=get_embeding_model(), label="选择向量模型")
eval_btn = gr.Button("开始评估")
eval_result = gr.DataFrame(label="评估结果")
def evaluate_system(test_file, knowledge_base, model, vector_model):
# 加载测试数据
import json
with open(test_file.name, 'r', encoding='utf-8') as f:
test_data = json.load(f)
questions = test_data.get('questions', [])
ground_truth = test_data.get('ground_truth', {})
# 初始化RAG系统
rag = RAG_class(model=model, embed=vector_model, c_name=knowledge_base, persist_directory=DB_directory)
# 评估
metrics = rag.evaluate_retrieval(questions, ground_truth)
# 格式化结果为DataFrame
results = []
for method in metrics:
for metric, value in metrics[method].items():
results.append({
"方法": method,
"指标": metric,
"值": f"{value:.4f}"
})
return pd.DataFrame(results)
eval_btn.click(evaluate_system, inputs=[test_file, eval_knowledge_base_dropdown, eval_model_dropdown, eval_vector_dropdown], outputs=eval_result)
6. 利用现有的数据流架构进行评估
根据项目数据流架构,可以在检索过程中插入评估代码, 在不同的检索方法之间添加评估逻辑,比较它们的效果。
7. 人工评估
对于主观质量评估,可以添加一个反馈机制,让用户对检索结果进行评分:
with gr.Row():
feedback_radio = gr.Radio(choices=["非常相关", "相关", "部分相关", "不相关"], label="请评价检索结果的相关性")
feedback_btn = gr.Button("提交反馈")
def collect_feedback(feedback, question, answer):
# 存储用户反馈
with open("feedback_log.jsonl", "a", encoding="utf-8") as f:
import json
feedback_data = {
"question": question,
"answer": answer,
"feedback": feedback,
"timestamp": datetime.now().isoformat()
}
f.write(json.dumps(feedback_data, ensure_ascii=False) + "\n")
return "感谢您的反馈!"
feedback_btn.click(collect_feedback, inputs=[feedback_radio, chat_input, chat_di