检索增强生成(Retrieval-Augmented Generation, RAG)简介
在自然语言处理和生成任务中,检索增强生成(RAG)是一种将信息检索(Retrieval)和文本生成(Generation)结合的方法,旨在提升文本生成的质量和准确性。这种方法尤其适合于需要引用大量知识或实时信息的任务,比如问答系统、聊天机器人以及一些需要动态内容生成的应用。
RAG的核心原理
RAG的核心原理在于结合检索与生成的能力。传统生成模型(如GPT-3、BERT等)在生成答案时,依赖于训练过程中学习到的知识,但这种知识是固定的,无法随时间和需求的变化而更新。而RAG模型通过引入一个动态检索机制,可以在生成内容时引入外部信息,使得生成的答案更为实时和相关。其工作流程可以简单概括为以下几个步骤:
-
检索阶段:给定一个用户的查询,RAG模型会先在外部知识库中进行检索,比如查找相应的文档、短文或其他相关信息。这些知识库可以是事先准备的结构化或非结构化数据,比如文档库、网页内容等。
-
生成阶段:在得到相关的信息后,生成模型会将这些检索到的内容与原始查询一同输入,以生成最终的回答。这一阶段结合了生成模型的语言能力以及检索信息的辅助,使得模型能够提供更丰富的回答内容。
-
融合结果:最终答案可能会基于检索到的多个结果生成,或直接从这些检索结果中提取并融合,生成一个精简而全面的答案。
RAG的简单原理
RAG的简单原理是将文本向量化,判断多个Chunk和问题的相似度。这是一个经典的NLP问题,比较简单的判断手段是余弦相似度。
其中, 和 是两个向量, 表示它们的点积, 和 分别表示它们的模长
代码实现
import os
from volcenginesdkarkruntime import Ark
from typing import List, Any
from langchain.embeddings.base import Embeddings
from langchain.pydantic_v1 import BaseModel
# 初始化Embedding类
class DoubaoEmbeddings(BaseModel, Embeddings):
client: Ark = None
api_key: str = ""
model: str
def __init__(self, **data: Any):
super().__init__(**data)
if self.api_key == "":
self.api_key = os.environ["OPENAI_API_KEY"]
self.client = Ark(
base_url=os.environ["OPENAI_BASE_URL"],
api_key=self.api_key
)
def embed_query(self, text: str) -> List[float]:
"""
生成输入文本的 embedding.
Args:
texts (str): 要生成 embedding 的文本.
Return:
embeddings (List[float]): 输入文本的 embedding,一个浮点数值列表.
"""
embeddings = self.client.embeddings.create(model=self.model, input=text)
return embeddings.data[0].embedding
def embed_documents(self, texts: List[str]) -> List[List[float]]:
return [self.embed_query(text) for text in texts]
class Config:
arbitrary_types_allowed = True
embd = DoubaoEmbeddings(
model=os.environ["EMBEDDING_MODELEND"],
)
在这里定义了一个豆包的词向量模型,用于将文本转化成向量嵌入。
texts = [
["今天天气真好,适合出去散步。"],
["今天阳光明媚,适合户外活动。"],
["这本书非常有趣,我一口气读完了。"],
["这本书内容精彩,我一晚上就看完了。"],
["他是一个非常聪明的人,总是能找到解决问题的方法。"],
["他非常机智,总能找到解决难题的办法。"],
]
vec = embd.embed_documents(texts)
print(vec)
可以看到,这些文本被转化成一维向量的模式,便于后续的数学计算。
from sklearn.metrics.pairwise import cosine_similarity
similarity_matrix = cosine_similarity(vec)
# 打印相似度矩阵
print(similarity_matrix)
在这段代码中,
cosine_similarity函数是从sklearn.metrics.pairwise模块中导入的,这个函数用于计算两个向量之间的余弦相似度。vec是一个二维数组,其中每一行代表一个向量。函数将返回一个二维数组,其中的每个元素(i, j)表示第i个向量和第j个向量之间的余弦相似度。
很显然,它们分别在各自的语义空间上更接近,所以余弦相似度更接近于1。
总结
相似度嵌入通常用于文本分析、推荐系统、图像处理等领域,用于找出数据集中最相似的对象。例如,在文本分析中,它可以用来找出与某个文档最相似的其他文档。通过学习这个原理,就能在后续检索增强生成中明白如何切分Chunk,找到合适的文本,如何构建索引给LLM。