PIKE-RAG知识库本地化部署之分块

54 阅读10分钟

最近正在做一个本地RAG项目,即数据需要留在本地,模型也需要本地搭建,特此记录。本系列总体以PIKE-RAG开源知识库为基础,包含本地化改造、FastAPI封装接口,页面搭建等内容。本篇只包含PIKE-RAG开源知识库部署与如何利用本地部署大模型作为对话模型对内容进行分块。

PIKE-RAG知识库介绍

PIKE-RAG知识库是微软开源的一个模块化的知识库系统,包括文档解析、知识抽取、知识存储、知识检索、知识组织、以知识为中心的推理以及任务分解与调用等功能。除了没有界面,我们可以使用PIKE-RAG完成知识库中的所有流程。 它相比于现有知识库主要做了两个创新点。1.知识原子化:把一段资料拆成 “最小有用知识单元”,还会给每个单元配个 “问题标签”(比如一段讲 “某药 2020 年获批” 的文字,标签是 “某药的获批年份是啥?”)。这样搜的时候,不管是直接搜资料,还是搜 “问题标签”,都能快速找到关键信息。2.知识感知的任务分解:拆复杂问题时,会先看知识库有啥信息,再决定怎么拆。比如问 “有多少款可替换生物类似药”,如果知识库有现成的 “可替换清单”,就直接统计;如果只有 “所有生物类似药清单”,就拆成 “找清单→判断是否可替换→统计”,避免瞎拆导致走弯路。 github仓库:github.com/microsoft/P… gitee镜像:gitee.com/mirrors_mic…

PIKE-RAG知识库搭建

代码结构

核心代码:

  • 核心代码pikerag/ 目录,包含文档加载器、转换器等核心组件。
    • document_loaders/:文档加载与读取工具;
    • document_transformers/:文档切分与过滤,包括基于 LLM 的 tagger/splitter;
    • knowledge_retrievers/:多种检索器实现,如 BM25、Chroma、ChunkAtom 检索器;
    • llm_client/:语言模型客户端接口,支持 OpenAI API、Azure、HuggingFace 等;
    • prompts/:各种 prompt 模板定义,涵盖 chunking、QA、生成功能等;
    • utils/:通用工具类,如日志、配置解析、路径管理等;
    • workflows/:核心工作流封装,包括 QA、评估、标注等流程控制模块。
  • 数据处理data_process/ 目录,含句子拆分、基准测试数据处理等脚本(如 chunk_by_sentence.pyretrieval_contexts_as_chunks.py)。
  • 示例脚本examples/ 目录,提供生物学、HotpotQA、MuSiQue 等场景的示例(如问答、评估、标记等脚本)。
  • 文档docs/ 目录,包含环境配置、示例运行等指南。
  • 辅助脚本scripts/ 目录,含 Azure 相关安装和登录脚本。
  • 配置文件:各示例场景下的 configs/ 目录,包含 YAML 配置文件(如标记、问答流程配置)。

本地模型部署

我使用了Xinference部署了DeepSeekR1-32B的4bit量化版模型作为对话模型,部署了beg-m3作为嵌入模型。如果想学习Xinference如何部署的请查看:mp.weixin.qq.com/s/glAeQDgdX… 也可自己使用熟悉的方式部署大模型与嵌入模型。

环境搭建

# 安装uv
curl -LsSf https://astral.sh/uv/install.sh | sh
# 初始化文件目录
uv init PithyRAG
cd PithyRAG # 修改python版本为3.12
uv run main.py
# 克隆仓库
git clone https://gitee.com/mirrors_microsoft/PIKE-RAG.git
# 复制pikerag至PithyRAG目录下
cp -r PIKE-RAG/pikerag ./

删除uv.lock文件,并修改pyproject.toml文件,将以下内容覆盖原文件。

[project]

name = "pithyrag"

version = "0.1.0"

description = "Add your description here"

readme = "README.md"

requires-python = ">=3.12"

dependencies = [

    "bs4>=0.0.2",

    "chromadb>=1.1.1",

    "dacite>=1.9.2",

    "datasets>=4.2.0",

    "fastapi[standard]>=0.120.0",

    "jsonlines>=4.0.0",

    "langchain>=0.3.27",

    "langchain-chroma>=0.2.6",

    "langchain-community>=0.3.31",

    "langchain-huggingface>=0.3.1",

    "locust>=2.41.6",

    "markdown>=3.9",

    "openai>=2.3.0",

    "openpyxl>=3.1.5",

    "pandas>=2.3.3",

    "pickledb>=1.3.2",

    "pydantic-settings>=2.11.0",

    "python-docx>=1.2.0",

    "rank-bm25>=0.2.2",

    "rouge>=1.0.1",

    "sentence-transformers>=5.1.1",

    "spacy>=3.8.7",

    "tabulate>=0.9.0",

    "torch>=2.8.0",

    "tqdm>=4.67.1",

    "transformers>=4.57.0",

    "unstructured>=0.18.15",

    "word2number>=1.1",

    "xinference-client>=1.10.1",

]

[[tool.uv.index]]

url = "https://pypi.tuna.tsinghua.edu.cn/simple"

default = true

使用uv sync命令下载依赖。

编写本地大模型接口

首先在pikerag/llm_client目录下添加xinference_client.py文件,并将以下代码复制进去。

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    :   2025/11/26 19:05:27
# @Author  :   Jsm
# @Version :   1.0
# @Desc    :   Describe

import json
import re
import time
from typing import List, Literal, Optional, Union
import os

import openai
from langchain_core.embeddings import Embeddings
from openai import OpenAI
from openai.types import CreateEmbeddingResponse
from openai.types.chat.chat_completion import ChatCompletion
from pickledb import PickleDB

from pikerag.llm_client.base import BaseLLMClient
from pikerag.utils.logger import Logger
# 测试时需要加
# from config.config import load_config
# model_config = load_config().model_config

# def parse_wait_time_from_error(error: openai.RateLimitError) -> Optional[int]:
#     """Parse wait time from OpenAI RateLimitError.

#     Args:
#         error (openai.RateLimitError): The rate limit error from OpenAI API.

#     Returns:
#         Optional[int]: The suggested wait time in seconds, None if parsing failed.
#     """
#     try:
#         info_str: str = error.args[0]
#         info_dict_str: str = info_str[info_str.find("{"):]
#         error_info: dict = json.loads(re.compile(r"(?<!\\)'").sub('"', info_dict_str))
#         error_message = error_info["error"]["message"]
#         matches = re.search(r"Try again in (\d+) seconds", error_message)
#         wait_time = int(matches.group(1)) + 3  # Add 3 seconds buffer
#         return wait_time
#     except Exception:
#         return None


class XinferenceClient(BaseLLMClient):
    """Xinference client implementation for DeepSeek models."""

    NAME = "XinferenceClient"

    def __init__(
        self,
        location: str = None,
        auto_dump: bool = True,
        logger: Logger = None,
        max_attempt: int = 5,
        exponential_backoff_factor: int = None,
        unit_wait_time: int = 60,
        **kwargs,
    ) -> None:
        """LLM Communication Client for Xinference endpoints with models.

        Args:
            location (str): The file location of the LLM client communication cache. No cache would be created if set to
                None. Defaults to None.
            auto_dump (bool): Automatically save the Client's communication cache or not. Defaults to True.
            logger (Logger): Client logger. Defaults to None.
            max_attempt (int): Maximum attempt time for LLM requesting. Request would be skipped if max_attempt reached.
                Defaults to 5.
            exponential_backoff_factor (int): Set to enable exponential backoff retry manner. Every time the wait time
                would be `exponential_backoff_factor ^ num_attempt`. Set to None to disable and use the `unit_wait_time`
                manner. Defaults to None.
            unit_wait_time (int): `unit_wait_time` would be used only if the exponential backoff mode is disabled. Every
                time the wait time would be `unit_wait_time * num_attempt`, with seconds (s) as the time unit. Defaults
                to 60.
            **kwargs: Additional arguments for Xinference client initialization.
            yml config example:
            ...
                llm_client:
                    module_path: pikerag.llm_client
                    class_name: XinferenceClient
                    args:{
                        base_url: http://localhost:9997/v1  # Xinference server URL
                        api_key: xinference  # Default API key for Xinference
                    }
            ...
        """
        super().__init__(location, auto_dump, logger, max_attempt, exponential_backoff_factor, unit_wait_time, **kwargs)

        print(f"kwargs: {kwargs}")
        # Xinference specific configuration
        client_configs = {
            "api_key": kwargs.get("api_key"),
            "base_url": kwargs.get("base_url"),
        }
        
        # Additional Xinference specific settings
        if "timeout" not in client_configs:
            client_configs["timeout"] = 300  # 5 minutes timeout for local inference
            
        self._client = OpenAI(**client_configs)

    def _get_response_with_messages(self, messages: List[dict], **llm_config) -> ChatCompletion:
        """Get response from Xinference chat completion API with retry mechanism.

        Args:
            messages (List[dict]): The messages to send to Xinference chat completion API.
            **llm_config: Additional configuration for the chat completion API.

        Returns:
            ChatCompletion: The response from Xinference API.
        """
        response: ChatCompletion = None
        num_attempt: int = 0

        while num_attempt < self._max_attempt:
            try:
                # Xinference may have different default parameters
                # Ensure we use appropriate defaults for DeepSeek models
                response = self._client.chat.completions.create(messages=messages, **llm_config)
                break
            # except openai.RateLimitError as e:
            #     self.warning("  Failed due to RateLimitError...")
            #     wait_time = parse_wait_time_from_error(e)
            #     self._wait(num_attempt, wait_time=wait_time)
            #     self.warning("  Retrying...")
            except openai.BadRequestError as e:
                self.warning(f"  Failed due to BadRequestError: {e}")
                # For Xinference, BadRequestError might indicate model not ready
                # Wait a bit longer and retry
                num_attempt += 1
                self._wait(num_attempt, wait_time=30)  # Wait 30 seconds for model readiness
                self.warning("  Retrying...")
            except openai.APIConnectionError as e:
                self.warning(f"  Failed due to APIConnectionError: {e}")
                # Xinference server might be starting up
                num_attempt += 1
                self._wait(num_attempt, wait_time=10)  # Wait 10 seconds for server startup
                self.warning("  Retrying...")
            except Exception as e:
                self.warning(f"  Failed due to Exception: {e}")
                num_attempt += 1
                self._wait(num_attempt)
                self.warning("  Retrying...")

        return response

    def _get_content_from_response(self, response: ChatCompletion, messages: List[dict] = None) -> str:
        """Extract content from Xinference chat completion response.

        Args:
            response (ChatCompletion): The response from Xinference chat completion API.
            messages (List[dict], optional): The original messages sent to API. Defaults to None.

        Returns:
            str: The extracted content or empty string if extraction failed.
        """
        try:
            content = response.choices[0].message.content
            if content is None:
                finish_reason = response.choices[0].finish_reason
                warning_message = f"Non-Content returned due to {finish_reason}"

                # Xinference might have different content filter structure
                if hasattr(response.choices[0], 'content_filter_results'):
                    for reason, res_dict in response.choices[0].content_filter_results.items():
                        if res_dict.get("filtered", False) or res_dict.get("severity", "safe") != "safe":
                            warning_message += f", '{reason}': {res_dict}"

                self.warning(warning_message)
                self.debug(f"  -- Complete response: {response}")
                if messages is not None and len(messages) >= 1:
                    self.debug(f"  -- Last message: {messages[-1]}")

                content = ""
        except Exception as e:
            self.warning(f"Try to get content from response but get exception:\n  {e}")
            self.debug(
                f"  Response: {response}\n"
                f"  Last message: {messages}"
            )
            content = ""

        return content
    
    async def generate_content_with_messages(self, messages: List[dict], stream: bool = False, **llm_config) -> str:
        """Generate content with messages using Xinference chat completion API.

        Args:
            messages (List[dict]): The messages to send to Xinference chat completion API.
            model (str, optional): The model to use for generation. Defaults to None.
            **llm_config: Additional configuration for the chat completion API.

        Returns:
            str: The generated content.
        """
        llm_config = {
            "model": llm_config.get("model"),
            "max_tokens": llm_config.get("max_tokens"),
            "temperature": llm_config.get("temperature"),
            "stream": stream,
        }
        response = self._get_response_with_messages(messages, **llm_config)
        if not stream:
            response = self._get_content_from_response(response, messages)
        
        # 获取</think>标签后的内容
        # response = response.split("</think>")[-1].strip()
        # print(f"response: {response}")
        return response

    def close(self):
        """Close the Xinference client."""
        super().close()
        self._client.close()

pikerag/llm_client/__init__.py文件下添加Xinference类。

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from pikerag.llm_client.azure_meta_llama_client import AzureMetaLlamaClient
from pikerag.llm_client.azure_open_ai_client import AzureOpenAIClient
from pikerag.llm_client.base import BaseLLMClient
from pikerag.llm_client.hf_meta_llama_client import HFMetaLlamaClient
from pikerag.llm_client.standard_openai_api import StandardOpenAIClient
from pikerag.llm_client.xinference_client import XinferenceClient


__all__ = ["AzureMetaLlamaClient", "AzureOpenAIClient", "BaseLLMClient", "HFMetaLlamaClient", "StandardOpenAIClient", "XinferenceClient"]

添加分块配置

PithyRAG添加example/parenting目录,并在此目录下添加chunking.yml配置文件。

# Environment Variable Setting
################################################################################
dotenv_path: null


# Logging Setting
################################################################################
log_root_dir: logs/parenting

# experiment_name: would be used to create log_dir = log_root_dir/experiment_name/
experiment_name: chunking


# Input Document & Output Dir Setting
################################################################################
input_doc_setting:
  doc_dir: data/parenting/contents

output_doc_setting:
  doc_dir: data/parenting/chunks


# LLM Setting
################################################################################
llm_client:
  module_path: pikerag.llm_client
  # available class_name: AzureMetaLlamaClient, AzureOpenAIClient, HFMetaLlamaClient
  class_name: XinferenceClient
  args: {
    api_key: xinference,
    base_url: http://localhost:9997/v1
    }

  llm_config:
    #api_key: xinference
    #base_url: http://10.96.242.110:9997/v1
    model: DeepSeek-R1-32B-AWQ
    temperature: 0
    top_k: 30

  cache_config:
    # location: will be joined with log_dir to generate the full path;
    #   if set to null, the experiment_name would be used
    location_prefix: null
    auto_dump: True


# Splitter Setting
################################################################################
chunking_protocol:
  module_path: pikerag.prompts.chunking
  chunk_summary: chunk_summary_protocol_Chinese
  chunk_summary_refinement: chunk_summary_refinement_protocol_Chinese
  chunk_resplit: chunk_resplit_protocol_Chinese


splitter:
  module_path: pikerag.document_transformers
  class_name: LLMPoweredRecursiveSplitter
  args:
    separators:
      - "\n"
    is_separator_regex: False
    chunk_size: 1024
    chunk_overlap: 0

其中llm_client是大模型的配置,由于使用的是Xinference搭建的本地大模型,所以api_key可以随便设置。base_ur表示模型的接口;model表示使用的模型名称,注意一定要在Xinference中启动该模型。chunking_protocol表示分块的策略,这个配置使用的策略是内容分块后,使用大模型对分块内容总结。

添加分块函数

example/parenting目录下创建utils.py,并将以下代码复制进去。

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import pickle
from typing import List, Literal, Tuple

from datasets import load_dataset, Dataset
from tqdm import tqdm

from langchain_core.documents import Document

from pikerag.utils.walker import list_files_recursively
from pikerag.workflows.common import MultipleChoiceQaData


def load_testing_suite(path: str="cais/mmlu", name: str="college_biology") -> List[MultipleChoiceQaData]:
    dataset: Dataset = load_dataset(path, name)["test"]
    testing_suite: List[dict] = []
    for qa in dataset:
        testing_suite.append(
            MultipleChoiceQaData(
                question=qa["question"],
                metadata={
                    "subject": qa["subject"],
                },
                options={
                    chr(ord('A') + i): choice
                    for i, choice in enumerate(qa["choices"])
                },
                answer_mask_labels=[chr(ord('A') + qa["answer"])],
            )
        )
    return testing_suite


def load_ids_and_chunks(chunk_file_dir: str) -> Tuple[Literal[None], List[Document]]:
    chunks: List[Document] = []
    chunk_idx: int = 0
    for doc_name, doc_path in tqdm(
        list_files_recursively(directory=chunk_file_dir, extensions=["pkl"]),
        desc="Loading Files",
    ):
        with open(doc_path, "rb") as fin:
            chunks_in_file: List[Document] = pickle.load(fin)

        for doc in chunks_in_file:
            doc.metadata.update(
                {
                    "filename": doc_name,
                    "chunk_idx": chunk_idx,
                }
            )
            chunk_idx += 1

        chunks.extend(chunks_in_file)

    return None, chunks

启动分块

PithyRAG目录下,创建chunking.py文件,并复制以下代码。

import argparse
import os
import shutil
import yaml

from pikerag.workflows.chunking import ChunkingWorkflow


def load_yaml_config(config_path: str, args: argparse.Namespace) -> dict:
    with open(config_path, "r") as fin:
        yaml_config: dict = yaml.safe_load(fin)

    # Create logging dir if not exists
    experiment_name = yaml_config["experiment_name"]
    log_dir = os.path.join(yaml_config["log_root_dir"], experiment_name)
    yaml_config["log_dir"] = log_dir
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    shutil.copy(config_path, log_dir)

    # LLM cache config
    if "llm_client" in yaml_config:
        if yaml_config["llm_client"]["cache_config"]["location_prefix"] is None:
            yaml_config["llm_client"]["cache_config"]["location_prefix"] = experiment_name

    # input doc dir
    input_doc_dir = yaml_config["input_doc_setting"]["doc_dir"]
    assert os.path.exists(input_doc_dir), f"Input doc dir {input_doc_dir} not exist!"
    if "extensions" not in yaml_config["input_doc_setting"]:
        yaml_config["input_doc_setting"]["extensions"] = None
    elif isinstance(yaml_config["input_doc_setting"]["extensions"], str):
        yaml_config["input_doc_setting"]["extensions"] = [yaml_config["input_doc_setting"]["extensions"]]

    # output doc dir
    output_dir: str = yaml_config["output_doc_setting"]["doc_dir"]
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    # else:
    #     assert (
    #         not os.path.isfile(output_dir)
    #         and len(os.listdir(output_dir)) == 0
    #     ), f"Output directory {output_dir} not empty!"

    return yaml_config


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("config", type=str, help="the path of the yaml config file you want to use",
                        default="examples/chunk/config.yaml")
    # TODO: add more options here, and let the ones in cmd line replace the ones in yaml file
    args = parser.parse_args()

    # Load yaml config.
    yaml_config: dict = load_yaml_config(args.config, args)

    # 不加载环境变量,因为使用的是本地模型,并不依赖 OpenAI/Azure Key
    # load_dot_env(env_path=yaml_config["dotenv_path"])

    workflow = ChunkingWorkflow(yaml_config)
    workflow.run()

然后创建data/parenting/contents目录,并添加测试文件。测试文件最好是txt文件,其他格式的文件也可以,只是需要下载额外的包,而且下载很多,别问我为啥知道(视频里会展示) 下载依赖包

apt-get install poppler-utils
apt-get install tesseract-ocr

使用uv run example/chunking.py example/parenting/chunk.yml命令对内容分块。

视频已经全部录完了,马上就剪!!!

公众号

image.png 更多优秀内容敬请关注本公众号!!!

## 参考 https://blog.csdn.net/qq_62044436/article/details/149331019