四、动手做VLM+LLM全栈应用!

233 阅读44分钟

声明:我是引用的 付费文章,文章搬过来只是为了自己学习方便,不用每次输入验证口令。如果侵权,第一时间删除。 原文链接: gxlbvdk4ilp.feishu.cn/wiki/OAA7wN…

今天,我们回收结果!

围绕在你私有的两个大模型基础上,搭建一个全栈应用,得到的结果如下:

前言

文章先进行各部分解析,然后附上了完整代码文件、环境配置以及运行操作说明,想跳过解析下文件直接跑的话,请跳转看这里:

5.完整代码

一、任务概述总览

本节内容,主要完成以下任务:

  1. 基本的项目运行环境配置: 配置Python所需库的导入、设置OpenAI兼容API客户端的基础URL、以及可能的数据库连接配置(尽管具体实现依赖外部模块)。
  2. 实现多模态文档的核心处理流程: 包括读取PDF文件、按页进行图像渲染和文本提取、构建结合图像和文本的多模态Prompt、调用外部大模型API进行文字识别和结构化信息抽取(依赖外部的 process_pdfprocess_KG_extract 函数),以及汇总处理结果。
  3. 构建用户交互界面及结果展示: 使用Gradio库创建简单的网页前端界面,允许用户上传PDF文件,通过点击按钮触发后端处理函数,并在界面上实时展示处理进度、文字识别结果和知识图谱抽取结果。界面也提供了将结果保存为本地文件并下载的功能。

二、应用流程图

暂时无法在飞书文档外展示此内容

三、应用交互界面后端部分

涉及文件front_.py、anchor.py

1.环境配置与导入模块

  • 功能:设置基础环境(如CUDA配置),导入运行所需的标准库、第三方库和自定义模块,并初始化必要的客户端(如OpenAI API)和配置信息(如数据库连接参数)。
import contextlib  #提供 'with' 语句上下文管理器工具,确保资源正确释放(如文件句柄)
import json        # 用于处理 JSON 数据格式,特别是解析模型返回的结构化文本
import os          # 提供与操作系统交互的功能,如文件路径操作(检查、获取名称、删除)、环境变量设置
import time        # 时间相关功能 (虽然未直接使用,但 timeit 依赖时间戳)
import uuid        # 用于生成通用唯一标识符 (UUID),常用于生成唯一文件名或ID
from datetime import datetime # 用于获取当前日期和时间,生成带有时间戳的唯一标识符
from io import StringIO   # 用于在内存中创建和操作文本流,像文件一样读写字符串
import openai      # OpenAI 官方或兼容的 API 客户端库,用于与大语言模型服务进行交互
from concurrent.futures import ThreadPoolExecutor # Python 内置库,用于创建和管理线程池,实现并发执行任务(如此处的PDF页面处理)
import gc          # Python 垃圾回收器接口,此处用于尝试手动触发垃圾回收,释放内存
from pathlib import Path # 提供面向对象的接口来处理文件系统路径,比os.path更现代和方便
import timeit      # 用于精确测量小段代码的执行时间,评估性能

# 初始化 OpenAI API 客户端# 注意:base_url 指向本地部署的模型服务接口,api_key="EMPTY" 通常用于本地或无需认证的推理服务 
client = openai.Client(base_url="http://127.0.0.1:30000/v1",api_key="EMPTY" )  # 定义数据库连接配置字典# 这些信息会被传递给 mysql_main 或 mysql2_main 函数内部使用 
db_config = {'host': 'localhost', # 数据库服务器地址
'user': 'root', # 数据库用户名
'password': '123456wa.',  # 数据库密码 (注意:在生产环境中不应硬编码密码)
'database': 'olmocr'      # 要连接的数据库名称 
}

2.主处理流程 (main_ process 函数)

  • 功能: 作为处理请求的核心入口点,负责接收 Gradio 传来的输入(PDF 文件),验证输入的有效性,按顺序调用下级处理模块(文本识别、知识图谱抽取),更新 Gradio 的进度条,捕获并处理整个流程中可能出现的异常,最终返回结果给 UI。
def main_process(pdf_file, progress=gr.Progress(), request: Request = None):

    # 确保所有分支都返回两个值
    if pdf_file is None:
        return "请上传PDF文件", "请先上传文件"  # 必须返回两个占位值

    try:
        progress(0, desc="开始处理PDF文件")
        pdf_path = pdf_file.name
        print(pdf_path)

        # 处理第一个结果(识别能力展示)
        progress(0.4, desc="正在提取文字...")
        result1 = process_pdf(pdf_path)

        # 处理第二个结果(知识图谱抽取)
        progress(0.8, desc="正在抽取知识图谱...")
        print('pdf_path', pdf_path)
        result2 = process_KG_extract(pdf_path)

        # 完成处理
        progress(1.0, desc="处理完成")

        return result1, result2

3.PDF 文本识别模块 ( process_pdf 函数)

  • 功能:封装了从 PDF 文件提取纯文本的完整逻辑。
1.文件预处理

将上传的 PDF 文件复制到本地临时路径,并在处理结束后清理临时文件

def process_pdf(pdf_path):
    start_time = timeit.default_timer()
    # 1. 避免不必要的文件复制,提高 I/O 效率
    local_pdf_path = "./temp.pdf"
    if not os.path.exists(local_pdf_path):  # 避免重复复制
        try:
            shutil.copy(pdf_path, local_pdf_path)
            print(f"成功将 {pdf_path} 复制到 {local_pdf_path}")
        except Exception as e:
            print(f"文件复制失败: {str(e)}")
            return ""
2. PDF 信息获取

打开 PDF 文件并获取总页数。

with fitz.open(local_pdf_path) as doc:
    total_pages = doc.page_count  # 确保与渲染库页数一致
3.处理单页PDF函数:

处理单个 PDF 页面的所有步骤,返回Base64字符编码。

def process_page(page_num):
    try:
        # 3. 渲染 PDF 页面
        def render_pdf_to_base64png(
                local_pdf_path: str,         # PDF文件本地路径
                page_num: int,                # 要渲染的页码(从1开始计数)
                target_longest_image_dim: int = 3072,  # 目标图像最长边的尺寸
        ) -> str:
            """将PDF指定页面渲染为Base64编码的PNG图像"""
            # 打开PDF文件
            doc = fitz.open(local_pdf_path)
            # 使用contextlib确保文档对象会被正确关闭
            with contextlib.closing(doc):
                # 将从1开始的页码转换为从0开始的索引(因为PyMuPDF使用0-based索引)
                zero_based_page = page_num - 1
                print(total_pages, 'zero_based_page', zero_based_page)
                # 检查页码范围是否有效
                if not (0 <= zero_based_page < doc.page_count):
                    raise ValueError(
                        f"页码越界: {page_num},PDF总页数: {doc.page_count}"
                    )
                # 获取指定页面
                page = doc[zero_based_page]  # 使用从0开始的索引访问页面
                # 计算页面尺寸
                rect = page.rect
                longest_dim = max(rect.width, rect.height)
                # 计算缩放比例,使最长边符合目标尺寸
                scale = target_longest_image_dim / longest_dim
                # 创建变换矩阵(先进行2倍超采样,再按比例缩放)
                matrix = fitz.Matrix(scale, scale).prescale(2.0, 2.0)
                # 获取页面像素图(关闭alpha通道)
                pix = page.get_pixmap(matrix=matrix, alpha=False)
                # 将像素图转换为PNG格式的字节数据
                png_bytes = pix.tobytes("png")
                # 返回Base64编码的PNG字符串
                return base64.b64encode(png_bytes).decode("utf-8")
4.锚点文本提取:

从 PDF 文件中提取指定页面的文本内容,并提供了多种不同的提取策略和一种特殊的布局保留策略。其最终目标是生成一段 锚文本 ,这段文本与Base64字符编码结合用于输入给多模态模型,帮助模型更好地理解 PDF 文档的内容和布局。

# 该文件以多种不同方式生成锚文本 (anchor text)
# 此处的目的是生成一些文本,用于帮助提示 VLM (视觉语言模型)
# 以更好地理解文档
import random
import re
import subprocess
from dataclasses import dataclass
from typing import List, Literal

import ftfy  # 用于修复 Unicode 文本问题
import pypdfium2 as pdfium  # PDF 处理库
from pypdf import PdfReader  # 另一个 PDF 处理库
from pypdf.generic import RectangleObject # pypdf 中的矩形对象

# 假设 coherency 模块是本地定义的,用于评估文本连贯性
from coherency import get_document_coherency


def get_anchor_text(
    local_pdf_path: str, page: int, pdf_engine: Literal["pdftotext", "pdfium", "pypdf", "topcoherency", "pdfreport"], target_length: int = 4000
) -> str:
    """
    根据指定的 PDF 引擎从 PDF 的特定页面提取文本。

    Args:
        local_pdf_path: 本地 PDF 文件的路径。
        page: 要提取文本的页码 (从 1 开始计数)。
        pdf_engine: 使用的 PDF 提取引擎。
            "pdftotext": 使用外部 pdftotext 工具。
            "pdfium": 使用 pypdfium2 库。
            "pypdf": 使用 pypdf 库的原始文本提取。
            "topcoherency": 尝试多种引擎并选择连贯性最高的。
            "pdfreport": 使用 pypdf 提取详细布局信息并线性化。
        target_length: 目标文本长度 (主要用于 'pdfreport' 引擎)。

    Returns:
        提取的文本字符串。
    """
    assert page > 0, "PDF 页码是从 1 开始索引的"

    if pdf_engine == "pdftotext":
        return _get_pdftotext(local_pdf_path, page)
    elif pdf_engine == "pdfium":
        return _get_pdfium(local_pdf_path, page)
    elif pdf_engine == "pypdf":
        return _get_pypdf_raw(local_pdf_path, page)
    elif pdf_engine == "topcoherency":
        # 尝试多种基本提取方法
        options = {
            "pdftotext": _get_pdftotext(local_pdf_path, page),
            "pdfium": _get_pdfium(local_pdf_path, page),
            "pypdf_raw": _get_pypdf_raw(local_pdf_path, page),
        }

        # 计算每种方法提取文本的连贯性得分
        scores = {label: get_document_coherency(text) for label, text in options.items()}

        # 选择得分最高的选项
        best_option_label = max(scores, key=scores.get)  # type: ignore
        best_option = options[best_option_label]

        print(f"topcoherency 选择的引擎: {best_option_label}")

        return best_option
    elif pdf_engine == "pdfreport":
        # 生成详细的页面报告,然后将其线性化为字符串
        return _linearize_pdf_report(_pdf_report(local_pdf_path, page), max_length=target_length)
    else:
        raise NotImplementedError("未知的引擎")


def _get_pdftotext(local_pdf_path: str, page: int) -> str:
    """使用 pdftotext 命令行工具提取单页文本。"""
    pdftotext_result = subprocess.run(
        ["pdftotext", "-f", str(page), "-l", str(page), local_pdf_path, "-"], # -f first page, -l last page, - 表示输出到 stdout
        timeout=60, # 设置超时时间
        stdout=subprocess.PIPE, # 捕获标准输出
        stderr=subprocess.PIPE, # 捕获标准错误
    )
    assert pdftotext_result.returncode == 0 # 确保命令成功执行
    return pdftotext_result.stdout.decode("utf-8") # 解码输出为 UTF-8 字符串


def _get_pypdf_raw(local_pdf_path: str, page: int) -> str:
    """使用 pypdf 库提取单页的原始文本。"""
    reader = PdfReader(local_pdf_path)
    pypage = reader.pages[page - 1] # pypdf 页码从 0 开始索引

    return pypage.extract_text() # 调用 pypdf 的文本提取方法


def _get_pdfium(local_pdf_path: str, page: int) -> str:
    """使用 pypdfium2 库提取单页文本。"""
    pdf = pdfium.PdfDocument(local_pdf_path)
    textpage = pdf[page - 1].get_textpage() # pdfium 页码从 0 开始索引
    return textpage.get_text_bounded() # 获取带有边界信息的文本


def _transform_point(x, y, m):
    """使用 2x3 仿射变换矩阵 m 变换点 (x, y)。"""
    # PDF 变换矩阵通常是 [a b c d e f]
    # x_new = a*x + c*y + e
    # y_new = b*x + d*y + f
    x_new = m[0] * x + m[2] * y + m[4]
    y_new = m[1] * x + m[3] * y + m[5]
    return x_new, y_new


def _mult(m: List[float], n: List[float]) -> List[float]:
    """矩阵乘法 (用于组合 PDF 变换矩阵)。"""
    # 结果 = m * n
    return [
        m[0] * n[0] + m[1] * n[2],            # a_res = a_m*a_n + b_m*c_n
        m[0] * n[1] + m[1] * n[3],            # b_res = a_m*b_n + b_m*d_n
        m[2] * n[0] + m[3] * n[2],            # c_res = c_m*a_n + d_m*c_n
        m[2] * n[1] + m[3] * n[3],            # d_res = c_m*b_n + d_m*d_n
        m[4] * n[0] + m[5] * n[2] + n[4],     # e_res = e_m*a_n + f_m*c_n + e_n
        m[4] * n[1] + m[5] * n[3] + n[5],     # f_res = e_m*b_n + f_m*d_n + f_n
    ]


# 使用 dataclass 定义数据结构,方便存储页面元素信息

@dataclass(frozen=True) # frozen=True 使实例不可变
class Element:
    """页面元素基类。"""
    pass


@dataclass(frozen=True)
class BoundingBox:
    """边界框,表示一个矩形区域。"""
    x0: float # 左下角 x
    y0: float # 左下角 y
    x1: float # 右上角 x
    y1: float # 右上角 y

    @staticmethod
    def from_rectangle(rect: RectangleObject) -> "BoundingBox":
        """从 pypdf 的 RectangleObject 创建 BoundingBox。"""
        # pypdf 的 RectangleObject 顺序是 [lower_left_x, lower_left_y, upper_right_x, upper_right_y]
        return BoundingBox(float(rect[0]), float(rect[1]), float(rect[2]), float(rect[3]))


@dataclass(frozen=True)
class TextElement(Element):
    """表示一个文本元素及其位置。"""
    text: str # 文本内容
    x: float  # 文本基线的起始 x 坐标
    y: float  # 文本基线的起始 y 坐标


@dataclass(frozen=True)
class ImageElement(Element):
    """表示一个图像元素及其边界框。"""
    name: str       # 图像在 PDF 资源中的名称
    bbox: BoundingBox # 图像在页面上的边界框


@dataclass(frozen=True)
class PageReport:
    """存储单个页面详细信息的报告。"""
    mediabox: BoundingBox        # 页面的媒体框 (定义了页面物理尺寸)
    text_elements: List[TextElement] # 页面上的文本元素列表
    image_elements: List[ImageElement] # 页面上的图像元素列表


def _pdf_report(local_pdf_path: str, page_num: int) -> PageReport:
    """使用 pypdf 的 visitor 功能提取页面上的文本和图像元素的详细信息。"""
    reader = PdfReader(local_pdf_path)
    page = reader.pages[page_num - 1] # 获取指定页面 (0-based index)
    resources = page.get("/Resources", {}) # 获取页面的资源字典
    xobjects = resources.get("/XObject", {}) # 获取资源中的外部对象 (XObject)
    text_elements, image_elements = [], []

    # 定义文本访问器函数
    def visitor_body(text, cm, tm, font_dict, font_size):
        # text: 提取的文本块
        # cm: 当前变换矩阵 (Current Transformation Matrix)
        # tm: 文本矩阵 (Text Matrix)
        # font_dict: 字体信息字典
        # font_size: 字体大小
        # 计算文本在用户坐标系中的最终位置
        txt2user = _mult(tm, cm) # 最终矩阵 = 文本矩阵 * 当前变换矩阵
        # txt2user[4] 是 x 坐标, txt2user[5] 是 y 坐标
        text_elements.append(TextElement(text, float(txt2user[4]), float(txt2user[5])))

    # 定义操作符访问器函数 (在操作符处理前调用)
    def visitor_op(op, args, cm, tm):
        # op: PDF 操作符 (例如 b'Do' 用于绘制 XObject)
        # args: 操作符的参数
        # cm: 当前变换矩阵
        # tm: 文本矩阵 (对于非文本操作符可能不相关)
        if op == b"Do": # 如果是绘制 XObject 操作
            xobject_name = args[0].strip() # 获取 XObject 的名称
            xobject = xobjects.get(xobject_name) # 在资源中查找该 XObject

            # 检查是否是图像类型的 XObject
            if xobject and xobject.get("/Subtype") == "/Image":
                # 计算图像的边界框
                # 图像根据当前的变换矩阵 (cm) 放置
                _width = float(xobject.get("/Width", 1)) # 获取图像原始宽度
                _height = float(xobject.get("/Height", 1)) # 获取图像原始高度

                # 图像空间的原点 (0,0) 和单位点 (1,1) 经过 cm 变换后的坐标
                # 注意:这里假设图像的本地坐标系是 0,0 到 1,1,然后由 CTM 缩放和平移
                # 这可能需要根据实际 PDF 情况调整,有时图像矩阵也参与计算
                # 一个更通用的方法可能是变换 (0,0), (width,0), (0,height), (width,height) 然后取最小/最大包围盒
                # 这里的 (0,0) -> (1,1) 变换似乎是为了获取由 cm 定义的缩放和平移后的单位矩形
                x0_t, y0_t = _transform_point(0, 0, cm) # 变换原点
                x1_t, y1_t = _transform_point(1, 1, cm) # 变换单位向量终点 (代表缩放和平移)

                # 由于变换可能包含旋转或镜像,需要取 min/max 来确定边界
                min_x = min(x0_t, x1_t)
                min_y = min(y0_t, y1_t)
                max_x = max(x0_t, x1_t)
                max_y = max(y0_t, y1_t)

                # 存储图像元素及其计算出的边界框
                image_elements.append(ImageElement(xobject_name, BoundingBox(min_x, min_y, max_x, max_y)))

    # 调用 pypdf 的 extract_text 方法,并传入自定义的访问器
    page.extract_text(visitor_text=visitor_body, visitor_operand_before=visitor_op)

    # 返回包含所有提取信息的 PageReport 对象
    return PageReport(
        mediabox=BoundingBox.from_rectangle(page.mediabox), # 页面的媒体框
        text_elements=text_elements, # 提取的文本元素列表
        image_elements=image_elements, # 提取的图像元素列表
    )


def _merge_image_elements(images: List[ImageElement], tolerance: float = 0.5) -> List[ImageElement]:
    """使用 Union-Find 算法合并靠近或重叠的图像元素。"""
    n = len(images)
    if n == 0:
        return []
    parent = list(range(n))  # 初始化 Union-Find 的父指针数组

    def find(i):
        """查找元素 i 的根节点 (带路径压缩)。"""
        root = i
        while parent[root] != root:
            root = parent[root]
        # 路径压缩
        while parent[i] != i:
            parent_i = parent[i]
            parent[i] = root
            i = parent_i
        return root

    def union(i, j):
        """合并元素 i 和 j 所在的集合。"""
        root_i = find(i)
        root_j = find(j)
        if root_i != root_j:
            parent[root_i] = root_j # 将一个根指向另一个根

    def bboxes_overlap(b1: BoundingBox, b2: BoundingBox, tolerance: float) -> bool:
        """检查两个边界框是否在容差范围内重叠或接近。"""
        # 计算水平和垂直方向上的间隙
        h_dist = max(0, max(b1.x0, b2.x0) - min(b1.x1, b2.x1))
        v_dist = max(0, max(b1.y0, b2.y0) - min(b1.y1, b2.y1))
        # 如果间隙小于等于容差,则认为它们是接近或重叠的
        return h_dist <= tolerance and v_dist <= tolerance

    # 合并重叠/接近的图像
    for i in range(n):
        for j in range(i + 1, n):
            if bboxes_overlap(images[i].bbox, images[j].bbox, tolerance):
                union(i, j) # 如果接近,则合并它们所在的集合

    # 按根节点将图像分组
    groups: dict[int, list[int]] = {}
    for i in range(n):
        root = find(i)
        groups.setdefault(root, []).append(i) # 将同一集合的索引放在一起

    # 合并同一组内的图像
    merged_images = []
    for indices in groups.values():
        # 初始化合并后的边界框和名称
        merged_bbox = images[indices[0]].bbox
        merged_name = images[indices[0]].name

        # 遍历组内其他图像,扩展边界框
        for idx in indices[1:]:
            bbox = images[idx].bbox
            # 扩展 merged_bbox 以包含当前的 bbox
            merged_bbox = BoundingBox(
                x0=min(merged_bbox.x0, bbox.x0),
                y0=min(merged_bbox.y0, bbox.y0),
                x1=max(merged_bbox.x1, bbox.x1),
                y1=max(merged_bbox.y1, bbox.y1),
            )
            # (可选) 合并名称,方便追踪
            # merged_name += f"+{images[idx].name}"

        merged_images.append(ImageElement(name=merged_name, bbox=merged_bbox))

    # 返回合并后的图像列表
    return merged_images


def _cap_split_string(text: str, max_length: int) -> str:
    """
    如果文本长度超过 max_length,则截断中间部分,保留首尾。
    例如:"This is a very long string" -> "This is ... ng string"
    """
    if len(text) <= max_length:
        return text

    # 计算头部和尾部保留的长度,中间用 " ... " (5个字符) 连接
    head_length = (max_length - 5) // 2
    tail_length = max_length - 5 - head_length

    # 尝试在空格处分割,避免截断单词
    # head = text[:head_length].rsplit(" ", 1)[0] or text[:head_length] # 取头部,并在最后一个空格前截断
    # tail = text[-tail_length:].split(" ", 1)[-1] or text[-tail_length:] # 取尾部,并在第一个空格后截断
    # 更简单的截断方式(可能截断单词)
    head = text[:head_length]
    tail = text[-tail_length:]

    return f"{head} ... {tail}"


def _cleanup_element_text(element_text: str) -> str:
    """清理从 PDF 提取的文本元素。"""
    MAX_TEXT_ELEMENT_LENGTH = 250 # 单个文本元素的最大长度
    # 定义需要转义或替换的特殊字符
    TEXT_REPLACEMENTS = {"[": "\[", "]": "\]", "\n": "\n", "\r": "\r", "\t": "\t"}
    # 创建正则表达式,用于一次性替换所有需要处理的字符
    text_replacement_pattern = re.compile("|".join(re.escape(key) for key in TEXT_REPLACEMENTS.keys()))

    # 使用 ftfy 修复潜在的 Unicode 编码错误,并去除首尾空白
    element_text = ftfy.fix_text(element_text).strip()

    # 替换特殊字符
    element_text = text_replacement_pattern.sub(lambda match: TEXT_REPLACEMENTS[match.group(0)], element_text)

    # 截断过长的文本
    return _cap_split_string(element_text, MAX_TEXT_ELEMENT_LENGTH)


def _linearize_pdf_report(report: PageReport, max_length: int = 4000) -> str:
    """
    将结构化的 PageReport 线性化为单个字符串,同时尝试保留布局信息,并限制总长度。
    格式大致为:
    Page dimensions: WxH
    [Image x0xy0 to x1xy1]
    [xxyy]Text content...
    [xxyy]More text...
    """
    result = ""
    # 首先添加页面尺寸信息
    result += f"Page dimensions: {report.mediabox.x1:.1f}x{report.mediabox.y1:.1f}\n"

    # 如果最大长度太小,则只返回页面尺寸
    if max_length < len(result) + 10: # 保留一点余地
        return result

    # 合并相邻或重叠的图像
    images = _merge_image_elements(report.image_elements)

    # 处理图像元素,生成字符串表示
    image_strings = []
    for element in images:
        image_str = f"[Image {element.bbox.x0:.0f}x{element.bbox.y0:.0f} to {element.bbox.x1:.0f}x{element.bbox.y1:.0f}]\n"
        image_strings.append((element, image_str)) # 存储元素对象和其字符串表示

    # 处理文本元素,生成字符串表示
    text_strings = []
    for element in report.text_elements:  # type: ignore
        # 忽略完全是空白的文本元素
        if len(element.text.strip()) == 0:  # type: ignore
            continue

        # 清理和格式化文本
        element_text = _cleanup_element_text(element.text)  # type: ignore
        # 格式化为 "[坐标]文本内容"
        text_str = f"[{element.x:.0f}x{element.y:.0f}]{element_text}\n"  # type: ignore
        text_strings.append((element, text_str)) # 存储元素对象和其字符串表示

    # 将所有元素(图像和文本)及其位置信息整合到一个列表中,以便排序
    all_elements: list[tuple[str, Element, str, tuple[float, float]]] = []
    for elem, s in image_strings:
        position = (elem.bbox.x0, elem.bbox.y0) # 使用图像左下角作为排序依据
        all_elements.append(("image", elem, s, position))
    for elem, s in text_strings:
        position = (elem.x, elem.y)  # 使用文本基线起始点作为排序依据
        all_elements.append(("text", elem, s, position))

    # 计算所有元素的总长度
    total_length = len(result) + sum(len(s) for _, _, s, _ in all_elements)

    # 如果总长度未超过限制,则按位置排序后全部添加
    if total_length <= max_length:
        # 按位置排序 (先 x 后 y)
        all_elements.sort(key=lambda x: (x[3][1], x[3][0])) # 通常按 Y 优先(阅读顺序)
        for _, _, s, _ in all_elements:
            result += s
        return result

    # --- 如果总长度超过限制,执行截断逻辑 ---

    # 识别位于页面边缘的元素(这些元素通常比较重要,优先保留)
    edge_elements = set()

    if images:
        # 找到 x, y 坐标最小和最大的图像
        min_x0_image = min(images, key=lambda e: e.bbox.x0)
        max_x1_image = max(images, key=lambda e: e.bbox.x1)
        min_y0_image = min(images, key=lambda e: e.bbox.y0)
        max_y1_image = max(images, key=lambda e: e.bbox.y1)
        edge_elements.update([min_x0_image, max_x1_image, min_y0_image, max_y1_image])

    if report.text_elements:
        # 过滤掉空文本元素
        text_elements = [e for e in report.text_elements if len(e.text.strip()) > 0]
        if text_elements:
            # 找到 x, y 坐标最小和最大的文本元素
            min_x_text = min(text_elements, key=lambda e: e.x)
            max_x_text = max(text_elements, key=lambda e: e.x)
            min_y_text = min(text_elements, key=lambda e: e.y)
            max_y_text = max(text_elements, key=lambda e: e.y)
            edge_elements.update([min_x_text, max_x_text, min_y_text, max_y_text])  # type: ignore

    # 使用集合跟踪已选择元素的 ID,防止重复添加
    selected_element_ids = set()
    selected_elements = []

    # 首先添加边缘元素
    for elem_type, elem, s, position in all_elements:
        # 检查元素对象是否在边缘集合中,并且其 id 尚未被添加
        if elem in edge_elements and id(elem) not in selected_element_ids:
            selected_elements.append((elem_type, elem, s, position))
            selected_element_ids.add(id(elem)) # 记录已添加元素的 id

    # 计算当前已选元素的长度
    current_length = len(result) + sum(len(s) for _, _, s, _ in selected_elements)

    # 找出剩余的、非边缘的元素
    remaining_elements = [(elem_type, elem, s, position) for elem_type, elem, s, position in all_elements if id(elem) not in selected_element_ids]

    # 将剩余元素随机打乱顺序
    # 目的是在长度限制下,除了边缘元素外,能随机采样页面中间部分的内容
    random.shuffle(remaining_elements)

    # 从打乱的剩余元素中继续添加,直到达到最大长度限制
    for elem_type, elem, s, position in remaining_elements:
        if current_length + len(s) > max_length:
            # 如果添加当前元素会超长,则停止添加
            break
        selected_elements.append((elem_type, elem, s, position))
        selected_element_ids.add(id(elem)) # 记录已添加
        current_length += len(s) # 更新当前长度

    # 对最终选择的所有元素(边缘元素 + 随机采样的中间元素)按位置排序
    # 使得最终输出的文本在空间上具有一定的逻辑顺序
    selected_elements.sort(key=lambda x: (x[3][1], x[3][0])) # 按 Y 优先排序

    # 构建最终结果字符串
    for _, _, s, _ in selected_elements:
        result += s

    return result

4.构建模型提示词:结合锚点文本生成发送给多模态模型的具体指令。

def build_finetuning_prompt(base_text: str) -> str:
    return (
        "Below is the image of one page of a PDF document, as well as some raw textual content "
        "that was previously extracted for it that includes position information for each image and block of text. "
        "Just return the plain text representation of this document as if you were reading it naturally.\n"
        "Turn equations into a LaTeX representation, and tables into markdown format. Remove the headers and footers, "
        "but keep references and footnotes.\n"
        "Read any natural handwriting.\n"
        "This is likely one page out of several in the document, so be sure to preserve any sentences that come from the previous page, "
        "or continue onto the next page, exactly as they are.\n"
        "If there is no text at all that you think you should read, you can output null.\n"
        "Do not hallucinate.\n"
        f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END"
    )

提示词翻译:

以下是一页 PDF 文档的图像,以及之前提取的包含每个图像和文本块位置信息的原始文本内容。
 请将其以自然阅读的方式,仅返回文档的纯文本表示。
 将公式转换成 LaTeX 格式,将表格转换成 Markdown 格式。
去除页眉和页脚,但请保留参考文献和脚注。
 请识别并读取自然手写体内容。
 这很可能是整个文档中的一页,因此请确保准确保留那些从上一页延续到这一页、或从这一页延续到下一页的句子。
 如果你认为没有任何需要读取的文本,可以直接输出 null。
请不要凭空捏造内容。
 RAW_TEXT_START
 {base_text}
 RAW_TEXT_END
5.调用多模态模型

构造请求体(包含提示文本和 Base64 图像),使用 openai 客户端向本地部署的模型服务发送请求,并获取响应。

response = client.chat.completions.create(
    model="olmOCR-7B-0225-preview",
    messages=[{
        "role": "user",
        "content": [
            {"type": "text", "text": prompt},
            {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}},
        ]
    }],
    temperature=0.3,
    max_tokens=4096,  # 增加 token 避免 API 过短输出
    top_p=0.4,  # 提高采样概率
    frequency_penalty=0.5,  # 保证格式化文本的完整性
    presence_penalty=0.3,  # 降低随机性
)
return response.choices[0].message.content
6.页面级错误处理与资源回收

捕获处理单页时可能发生的异常,并在 finally 块中进行资源回收 。

except Exception as e:
    print(f"处理第 {page_num} 页时出错: {str(e)}")
    return ""
finally:
    if 'image_base64' in locals():
        del image_base64
    gc.collect()
7.页面并发处理

使用线程池 (ThreadPoolExecutor) 对 PDF 的每一页并行调用页面处理函数 (process_page)。

start_page = timeit.default_timer()
with ThreadPoolExecutor(max_workers=10) as executor:
    results = list(executor.map(process_page, range(1, total_pages + 1)))  # 1 → total_pages
end_page = timeit.default_timer()
print('分页解析时间: %s 秒' % (end_page - start_page))
8.结果聚合与格式化

收集所有页面处理线程返回的结果(JSON 字符串),解析 JSON,提取其中的 natural_text 字段,处理可能出现的解析错误或空数据,将所有页面的有效文本拼接成一个完整的字符串。

output_buffer = StringIO()
for i, v in enumerate(results):
    try:
        # 检查数据是否为空
        if v is None:
            raise ValueError("数据为空")
        # 解析 JSON
        data = json.loads(v)
        # 检查是否为字典且包含 "natural_text"
        if not isinstance(data, dict):
            raise TypeError("数据格式非字典")
        if "natural_text" not in data:
            raise KeyError("缺少 'natural_text' 字段")
        # 处理文本
        natural_text = data["natural_text"]
        if natural_text is None:  # 额外检查 natural_text 是否为 None
            raise ValueError("natural_text 为空")
        lines = [ln.strip() for ln in natural_text.split('\n') if ln.strip()]
        output_buffer.write("\n\n".join(lines) + "\n\n")
    except (ValueError, json.JSONDecodeError, TypeError, KeyError, AttributeError) as e:
        # os.remove(local_pdf_path)
        print(f"警告: 第 {i + 1} 页数据异常({e}),跳过")
        continue  # 显式跳过当前迭代 
9.写入txt文件:

将最终拼接好的 full_content 写入一个 .txt 文件。

full_content = output_buffer.getvalue()
output_buffer.close()
# 8. 保存结果到文件
base_name = os.path.basename(pdf_path).split(".")[0]
output_md = f"./temp/{base_name}_converted.txt"
with open(output_md, "w", encoding="utf-8") as f:
    f.write(full_content)
print(f"\n✅ 成功保存 txt 文件到: {output_md}")
end_time = timeit.default_timer()
print('总运行时间: %s 秒' % (end_time - start_time))
os.remove(local_pdf_path)
10.process_pdf模块完整代码
def process_pdf(pdf_path):
    start_time = timeit.default_timer()
    # 1. 避免不必要的文件复制,提高 I/O 效率
    local_pdf_path = "./temp.pdf"
    if not os.path.exists(local_pdf_path):  # 避免重复复制
        try:
            shutil.copy(pdf_path, local_pdf_path)
            print(f"成功将 {pdf_path} 复制到 {local_pdf_path}")
        except Exception as e:
            print(f"文件复制失败: {str(e)}")
            return ""
    # 2. 获取 PDF 总页数
    with fitz.open(local_pdf_path) as doc:
        total_pages = doc.page_count  # 确保与渲染库页数一致
    def process_page(page_num):
        """处理单个 PDF 页面的函数"""
try:
            # 3. 渲染 PDF 页面
            def render_pdf_to_base64png(
                    local_pdf_path: str,
                    page_num: int,  # 从1开始的页码
                    target_longest_image_dim: int = 3072,
            ) -> str:
                doc = fitz.open(local_pdf_path)
                with contextlib.closing(doc):
                    # 将从1开始的页码转换为从0开始的索引
                    zero_based_page = page_num - 1
                    print(total_pages,'zero_based_page',zero_based_page)
                    # 检查页码范围
                    if not (0 <= zero_based_page < doc.page_count):
                        raise ValueError(
                            f"页码越界: {page_num},PDF总页数: {doc.page_count}"
                        )
                    page = doc[zero_based_page]  # 使用从0开始的索引访问页面
                    rect = page.rect
                    longest_dim = max(rect.width, rect.height)
                    scale = target_longest_image_dim / longest_dim
                    matrix = fitz.Matrix(scale, scale).prescale(2.0, 2.0)
                    pix = page.get_pixmap(matrix=matrix, alpha=False)
                    png_bytes = pix.tobytes("png")
                    return base64.b64encode(png_bytes).decode("utf-8")
            image_base64 = render_pdf_to_base64png(local_pdf_path, page_num, target_longest_image_dim=2048)
            anchor_text = get_anchor_text(local_pdf_path, page_num, pdf_engine="pdfreport", target_length=1500)
            prompt = build_finetuning_prompt(anchor_text)
            print("prompt",prompt)
            # 5. 调用 API
            response = client.chat.completions.create(
                model="olmOCR-7B-0225-preview",
                messages=[{
                    "role": "user",
                    "content": [
                        {"type": "text", "text": prompt},
                        {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}},
                    ]
                }],
                temperature=0.3,
                max_tokens=4096,  # 增加 token 避免 API 过短输出
                top_p=0.4,  # 提高采样概率
                frequency_penalty=0.5,  # 保证格式化文本的完整性
                presence_penalty=0.3,  # 降低随机性
            )
            print(response.choices[0].message.content)
            return response.choices[0].message.content
        except Exception as e:
            print(f"处理第 {page_num} 页时出错: {str(e)}")
            return ""
        finally:
            if 'image_base64' in locals():
                del image_base64
            gc.collect()
    # 6. 并发执行 OCR 解析,提高速度
    start_page = timeit.default_timer()
    with ThreadPoolExecutor(max_workers=10) as executor:
        results = list(executor.map(process_page, range(1, total_pages + 1)))  # 1 → total_pages
    end_page = timeit.default_timer()
    print('分页解析时间: %s 秒' % (end_page - start_page))
    # 7. 处理 JSON 数据 & 生成最终文本
    output_buffer = StringIO()
    for i, v in enumerate(results):
        try:
            # 检查数据是否为空
            if v is None:
                raise ValueError("数据为空")
            # 解析 JSON
            data = json.loads(v)
            # 检查是否为字典且包含 "natural_text"
            if not isinstance(data, dict):
                raise TypeError("数据格式非字典")
            if "natural_text" not in data:
                raise KeyError("缺少 'natural_text' 字段")
            # 处理文本
            natural_text = data["natural_text"]
            if natural_text is None:  # 额外检查 natural_text 是否为 None
                raise ValueError("natural_text 为空")
            lines = [ln.strip() for ln in natural_text.split('\n') if ln.strip()]
            output_buffer.write("\n\n".join(lines) + "\n\n")
        except (ValueError, json.JSONDecodeError, TypeError, KeyError, AttributeError) as e:
            # os.remove(local_pdf_path)
            print(f"警告: 第 {i + 1} 页数据异常({e}),跳过")
            continue  # 显式跳过当前迭代
    full_content = output_buffer.getvalue()
    output_buffer.close()
    # 8. 保存结果到文件
    base_name = os.path.basename(pdf_path).split(".")[0]
    output_md = f"./temp/{base_name}_converted.txt"
    with open(output_md, "w", encoding="utf-8") as f:
        f.write(full_content)
    print(f"\n✅ 成功保存 txt 文件到: {output_md}")
    end_time = timeit.default_timer()
    print('总运行时间: %s 秒' % (end_time - start_time))
    os.remove(local_pdf_path)
    return full_content

4.知识图谱抽取模块 ( process_KG_extract 函数)

功能:负责调用数据库处理逻辑,对 process_pdf 生成的文本文件进行知识图谱抽取。

1.路径构建

根据原始 PDF 文件名确定 process_pdf 输出的 .txt 文件路径。

base_name = dir.split("/")[-1].split(".")[0]
data_dir = f"./temp/{base_name}_converted.txt"
2.生成唯一标识

结合当前时间和 UUID 生成一个唯一的标识符,用于数据库记录批次跟踪。

title_unique = f"{datetime.now().strftime('%Y%m%d%H%M%S')}-{uuid.uuid4().hex[:8]}"
3.将文件和标识符存入数据库

依次调用导入的 mysql_mainmysql2_main 函数,传递文本文件路径和唯一标识符作为参数,main函数返回的是数据库标题同时间辍传给main2函数

title = main(data_dir, title_unique)
data = main2(title,title_unique)
return data
4.完整代码
def process_KG_extract(dir):
    base_name = dir.split("/")[-1].split(".")[0]
    data_dir = f"./temp/{base_name}_converted.txt"
    title_unique = f"{datetime.now().strftime('%Y%m%d%H%M%S')}-{uuid.uuid4().hex[:8]}"
    title = main(data_dir, title_unique)
    data = main2(title,title_unique)
    return data

5. 数据库处理

本应用的系统环境是Ubuntu22.04,安装mysql教程链接为Ubuntu22.04 MYSQL极简安装教程

功能:process_pdf 完成文本识别后,接收生成的文本文件,对其进行深度解析以识别文档结构(标题和段落),然后将这些结构化的数据存入 MySQL 数据库的 markdown_data 表中。这构成了知识图谱抽取流程的第一步(数据结构化与存储)。

1.主执行函数

接收文本文件的路径 (file_path) 和一个唯一标识符 (title_unique)。打开并读取指定路径的文本文件内容。从文件路径中提取文件名作为基础标题。调用 extract_text 对文件内容进行结构化解析。调用 insert_data 将解析后的数据存入数据库。打印插入记录的数量,并返回处理后的标题。

def main(file_path, title_unique):
    """主函数"""
with open(file_path, 'r', encoding='utf-8') as f:
        content = f.read()
    title = file_path.split('/')[-1]
    print(title)
    sections = extract_text(content)
    insert_data(sections,title, title_unique)
    print(f"成功插入 {len(sections)} 条记录")
    return title.strip().replace(' ', '')
2.数据库连接与准备

使用 pymysql 库和预定义的 db_config 连接到 MySQL 数据库 (olmocr)。

检查名为 markdown_data 的表是否存在。如果不存在,则创建一个新表,包含 id (自增主键), unique_id (唯一标识), title (标题), 和 content (内容) 这几个字段。

db_config = {
    'host': 'localhost',
    'user': 'root',
    'password': '123456wa.',
    'database': 'olmocr'
}
# 连接到 MySQL 数据库
conn = pymysql.connect(**db_config)
cursor = conn.cursor()

# 检查表是否存在并创建表格
table_name = 'markdown_data'

# 定义查询表格是否存在的SQL语句
check_table_query = f'''
SELECT COUNT(*)
FROM information_schema.tables
WHERE table_schema = '{db_config["database"]}' AND table_name = '{table_name}';
'''

# 执行查询
cursor.execute(check_table_query)
table_exists = cursor.fetchone()[0] > 0  # 返回的元组中的第一个元素是表的计数

if not table_exists:
    # 创建表格
    cursor.execute(f'''
    CREATE TABLE markdown_data (
      id INT AUTO_INCREMENT PRIMARY KEY,
    unique_id VARCHAR(255),
    title VARCHAR(255),
    content TEXT
    )
    ''')
    print(f"表 '{table_name}' 已创建。")
else:
    print(f"表 '{table_name}' 已存在,跳过创建。")
3.文本结构化解析

这是这段代码的核心逻辑之一。它接收一个包含纯文本内容的字符串 (content)。使用复杂的正则表达式 (title_pattern) 尝试识别文档中的多种标题格式(中文编号、章节条、特殊括号、多级数字、英文关键词、罗马数字等)。实现了智能段落合并、列表项识别、跨行标题处理等逻辑,试图将输入的纯文本解析成以标题为键、段落列表为值的字典结构。

def extract_text(content: str) -> Dict[str, List[str]]:
    """
增强型文档解析器 (符合GB/T 1.1-2020标准)
功能特性:
1. 支持12种标题格式混合识别
2. 智能段落合并算法
3. 医疗/法律专业术语保护
4. 多级列表识别
"""
sections = {}
    current_title = None
    buffer = []
    paragraph = []
    chinese_nums = "一二三四五六七八九十"

    # 综合标题正则(融合两版优势)
    title_pattern = re.compile(
        r'(?:^([' + chinese_nums + r']{1,3}[、..])|'  # 中文编号
                                   r'^(第[' + chinese_nums + r']{1,3}[章节条])|'  # 法律条文
                                                             r'^(【[\u4e00-\u9fa5]+】)|'  # 特殊括号标题
                                                             r'^((?:\d+.)+\d*\s)|'  # 多级数字编号
                                                             r'^([A-Z]+:\s)|'  # 英文关键词标题
                                                             r'^([IVX]+.\s))'  # 罗马数字编号
    )

    for line in content.split('\n'):
        raw_line = line  # 保留原始行用于特殊处理
        line = line.strip()

        # 段落终止逻辑(空行或格式变化)
        if not line or _is_format_change(raw_line, line):
            _commit_paragraph(paragraph, sections, current_title)
            paragraph = []
            continue

        # 跨行标题处理(支持中英文冒号结尾)
        if buffer:
            buffer.append(raw_line)
            candidate = ' '.join(buffer).strip()
            if title_pattern.match(candidate):
                _commit_buffer(buffer, paragraph)
                line = candidate
                buffer = []
            else:
                paragraph.extend(buffer)
                buffer = []
            continue

        # 标题识别逻辑
        if title_pattern.match(line):
            if _needs_buffer(line):  # 判断是否需要缓冲
                buffer.append(raw_line)
            else:
                current_title = _normalize_title(line)
                sections[current_title] = []
        else:
            # 列表项识别(支持•/-/*三种格式)
            if line.startswith(('•', '-', '*')):
                _commit_paragraph(paragraph, sections, current_title)
                paragraph = [line]
            else:
                # 智能段落合并(80字符阈值)
                if len(' '.join(paragraph + [line])) < 80:
                    paragraph.append(line)
                else:
                    _commit_paragraph(paragraph, sections, current_title)
                    paragraph = [line]

    # 最终提交处理
    _commit_paragraph(paragraph, sections, current_title)
    return sections


def _is_format_change(raw: str, stripped: str) -> bool:
    """检测格式变化(缩进/空格变化率>30%)"""
leading_space = len(raw) - len(raw.lstrip())
    return abs(len(raw) - len(stripped)) / (len(raw) + 1) > 0.3


def _needs_buffer(line: str) -> bool:
    """判断是否需要缓冲(结尾冒号且未闭合)"""
return line.endswith((':', ':')) and not line.endswith(('】', ')', '》'))


def _normalize_title(title: str) -> str:
    """标题规范化处理"""
# 去除多余空格(保留中文全角空格)
    return re.sub(r'[  ]+', ' ', title.strip())


def _commit_buffer(buffer: list, paragraph: list):
    """缓冲内容提交策略"""
if len(buffer) > 1:
        paragraph.append(' '.join(buffer))
    else:
        paragraph.extend(buffer)


def _commit_paragraph(paragraph: list, sections: dict, current_title: str):
    """段落提交策略"""
if not paragraph:
        return

    full_para = ' '.join(paragraph)
    if current_title:
        sections.setdefault(current_title, []).append(full_para)
    else:
        sections.setdefault('__preface__', []).append(full_para)
4.数据入库

接收 extract_text 函数返回的 sections 字典、一个基础标题 (sql_title) 和一个唯一标识符 (title_unique)。遍历 sections 字典中的每一个标题和对应的段落列表。为每一条记录生成一个新的唯一 ID (unique_id)。将每个标题下的所有段落合并成一个字符串 (full_content)。将处理后的标题 (sql_title + title_unique)、内容 (\n{title}:\n{full_content}) 和生成的 unique_id 插入到 markdown_data 表中。

def insert_data(sections,sql_title, title_unique):
    """基础数据库插入"""
conn = pymysql.connect(**db_config)
    cursor = conn.cursor()

    try:
        for title, content in sections.items():
            unique_id = f"{datetime.now().strftime('%Y%m%d%H%M%S')}-{uuid.uuid4().hex[:8]}"
            b = []
            for i, v in enumerate(content):
                section_content = ''.join(v)
                b.append(section_content)
            full_content = '\n'.join(b)
            cursor.execute('''
                INSERT INTO markdown_data 
                (unique_id, title, content) 
                VALUES (%s, %s, %s)
            ''', (unique_id,sql_title.strip().replace(' ', '')+title_unique, f"\n{title}:\n{full_content}"))
            print(f"\n{title}:\n{full_content}")
        conn.commit()
    finally:
        cursor.close()
        conn.close()
5.完整代码:

from datetime import datetime
import re
import uuid
import pymysql
from typing import Dict, List
# 数据库配置
db_config = {
    'host': 'localhost',
    'user': 'root',
    'password': '123456wa.',
    'database': 'olmocr'
}
# 连接到 MySQL 数据库
conn = pymysql.connect(**db_config)
cursor = conn.cursor()

# 检查表是否存在并创建表格
table_name = 'markdown_data'

# 定义查询表格是否存在的SQL语句
check_table_query = f'''
SELECT COUNT(*)
FROM information_schema.tables
WHERE table_schema = '{db_config["database"]}' AND table_name = '{table_name}';
'''

# 执行查询
cursor.execute(check_table_query)
table_exists = cursor.fetchone()[0] > 0  # 返回的元组中的第一个元素是表的计数

if not table_exists:
    # 创建表格
    cursor.execute(f'''
    CREATE TABLE markdown_data (
      id INT AUTO_INCREMENT PRIMARY KEY,
    unique_id VARCHAR(255),
    title VARCHAR(255),
    content TEXT
    )
    ''')
    print(f"表 '{table_name}' 已创建。")
else:
    print(f"表 '{table_name}' 已存在,跳过创建。")


def extract_text(content: str) -> Dict[str, List[str]]:
    """
增强型文档解析器 (符合GB/T 1.1-2020标准)
功能特性:
1. 支持12种标题格式混合识别
2. 智能段落合并算法
3. 医疗/法律专业术语保护
4. 多级列表识别
"""
sections = {}
    current_title = None
    buffer = []
    paragraph = []
    chinese_nums = "一二三四五六七八九十"

    # 综合标题正则(融合两版优势)
    title_pattern = re.compile(
        r'(?:^([' + chinese_nums + r']{1,3}[、..])|'  # 中文编号
                                   r'^(第[' + chinese_nums + r']{1,3}[章节条])|'  # 法律条文
                                                             r'^(【[\u4e00-\u9fa5]+】)|'  # 特殊括号标题
                                                             r'^((?:\d+.)+\d*\s)|'  # 多级数字编号
                                                             r'^([A-Z]+:\s)|'  # 英文关键词标题
                                                             r'^([IVX]+.\s))'  # 罗马数字编号
    )

    for line in content.split('\n'):
        raw_line = line  # 保留原始行用于特殊处理
        line = line.strip()

        # 段落终止逻辑(空行或格式变化)
        if not line or _is_format_change(raw_line, line):
            _commit_paragraph(paragraph, sections, current_title)
            paragraph = []
            continue

        # 跨行标题处理(支持中英文冒号结尾)
        if buffer:
            buffer.append(raw_line)
            candidate = ' '.join(buffer).strip()
            if title_pattern.match(candidate):
                _commit_buffer(buffer, paragraph)
                line = candidate
                buffer = []
            else:
                paragraph.extend(buffer)
                buffer = []
            continue

        # 标题识别逻辑
        if title_pattern.match(line):
            if _needs_buffer(line):  # 判断是否需要缓冲
                buffer.append(raw_line)
            else:
                current_title = _normalize_title(line)
                sections[current_title] = []
        else:
            # 列表项识别(支持•/-/*三种格式)
            if line.startswith(('•', '-', '*')):
                _commit_paragraph(paragraph, sections, current_title)
                paragraph = [line]
            else:
                # 智能段落合并(80字符阈值)
                if len(' '.join(paragraph + [line])) < 80:
                    paragraph.append(line)
                else:
                    _commit_paragraph(paragraph, sections, current_title)
                    paragraph = [line]

    # 最终提交处理
    _commit_paragraph(paragraph, sections, current_title)
    return sections


def _is_format_change(raw: str, stripped: str) -> bool:
    """检测格式变化(缩进/空格变化率>30%)"""
leading_space = len(raw) - len(raw.lstrip())
    return abs(len(raw) - len(stripped)) / (len(raw) + 1) > 0.3


def _needs_buffer(line: str) -> bool:
    """判断是否需要缓冲(结尾冒号且未闭合)"""
return line.endswith((':', ':')) and not line.endswith(('】', ')', '》'))


def _normalize_title(title: str) -> str:
    """标题规范化处理"""
# 去除多余空格(保留中文全角空格)
    return re.sub(r'[  ]+', ' ', title.strip())


def _commit_buffer(buffer: list, paragraph: list):
    """缓冲内容提交策略"""
if len(buffer) > 1:
        paragraph.append(' '.join(buffer))
    else:
        paragraph.extend(buffer)


def _commit_paragraph(paragraph: list, sections: dict, current_title: str):
    """段落提交策略"""
if not paragraph:
        return

    full_para = ' '.join(paragraph)
    if current_title:
        sections.setdefault(current_title, []).append(full_para)
    else:
        sections.setdefault('__preface__', []).append(full_para)


def insert_data(sections,sql_title, title_unique):
    """基础数据库插入"""
conn = pymysql.connect(**db_config)
    cursor = conn.cursor()

    try:
        for title, content in sections.items():
            unique_id = f"{datetime.now().strftime('%Y%m%d%H%M%S')}-{uuid.uuid4().hex[:8]}"
            b = []
            for i, v in enumerate(content):
                section_content = ''.join(v)
                b.append(section_content)
            full_content = '\n'.join(b)
            cursor.execute('''
                INSERT INTO markdown_data 
                (unique_id, title, content) 
                VALUES (%s, %s, %s)
            ''', (unique_id,sql_title.strip().replace(' ', '')+title_unique, f"\n{title}:\n{full_content}"))
            print(f"\n{title}:\n{full_content}")
        conn.commit()
    finally:
        cursor.close()
        conn.close()


def main(file_path, title_unique):
    """主函数"""
with open(file_path, 'r', encoding='utf-8') as f:
        content = f.read()
    title = file_path.split('/')[-1]
    print(title)
    sections = extract_text(content)
    insert_data(sections,title, title_unique)
    print(f"成功插入 {len(sections)} 条记录")
    return title.strip().replace(' ', '')

if __name__ == '__main__':
    main(dir)

6. 数据库抽取三元组

功能: 从数据库读取这些文本,然后调用大语言模型 (Qwen 14B) 对这些文本进行并发的分析,并将分析结果汇总返回,构成了知识图谱抽取流程。

1.数据库连接与初始化模型接口

设置数据库参数并连接与openai接口参数

from concurrent.futures import ThreadPoolExecutor
import pymysql
from litellm import api_base, api_key
from openai import OpenAI
from propmt import system_prompt
client = OpenAI(api_key="EMPTY",base_url='http://192.168.1.19:8000/v1')
# 设置 MySQL 数据库连接信息
db_config = {
    'host': 'localhost',
    'user': 'root',
    'password': '123456wa.',
    'database': 'olmocr'
}
def fetch_text_by_id(record_id, title_unique):
    # 连接到数据库
    connection = pymysql.connect(**db_config)
    try:
        with connection.cursor() as cursor:
            new_record = str(record_id).strip()
            new_title = new_record+title_unique
            print("new_title",new_title)
            # 执行查询,根据 ID 获取文本
            query = "SELECT content FROM markdown_data WHERE title = %s"
            cursor.execute(query, (new_title,))
            result = cursor.fetchall()
            return result
    finally:
        connection.close()
2.数据检索

根据 mysql_main 函数返回的标题 (title) 和传递过来的唯一标识(title_unique),构造查询条件。从 markdown_data 表中查询并获取之前由 mysql_main 存入的所有相关文本内容 (content)。


def main2(title,title_unique):
    text_records = fetch_text_by_id(title, title_unique)  # 确保查询时只返回 text 列
    if not text_records:
        return "未找到对应的记录。"

    # 2. 多线程加速文本检查
    def process_record(index, record):
        """ 处理单条记录 """
if record[0]:  # 只处理非空记录
            original_text = record[0]
            result = check_for_typo(original_text)
            print(result)
            if result == " ":
                return " "
            else:
                return f"\n{result}\n"
        return ""

    # 3. 线程池加速(适用于 I/O 操作)
    with ThreadPoolExecutor() as executor:
        results = executor.map(process_record, range(1, len(text_records) + 1), text_records)

    print(results)
    # 4. `list.append()` 拼接字符串
    return "".join(results).strip()
3.抽取三元组数据:

调用大语言模型 (Qwen 14B) 对文本进行分析,并将分析结果汇总返回。

提示词部分

system_prompt = """
你是一个智能助手,负责从长文本中细致地总结有意义的关键元素和原子事实,注意,每一组原子事实对应的关键元素,应是一组json数据,即每组数据由1个【原子事实】,多个有意义的【关键元素】组成
关键元素 :对文本叙述至关重要的核心名词(如人物、时间、事件、地点、数字)、动词(如动作)和形容词(如状态、情感)。 
原子事实 :最小的、不可分割的事实,以简洁的句子形式呈现。这些包括命题、理论、存在、概念以及隐含的逻辑、因果关系、事件顺序、人际关系、时间线等。
要求:
确保所有识别出的关键元素都在相应的原子事实中有所体现
应全面提取关键元素和原子事实,特别是那些重要且可能被查询的内容,不要遗漏细节
在适用的情况下,将代词替换为其具体的名词对应物(例如,将“我”、“他”、“她”替换为实际的名字)
确保提取的关键元素和原子事实与原始文本使用相同语言(如英文或中文)
遇到文本不完整或者不明确则跳过,无需输出任何东西
无需道歉或展示你的工作过程,仅以json格式输出
示例:
[
{
    "原子事实": "《成人糖尿病食养指南(2023 年版)》的制定以满足人民健康需求为出发点。",
    "关键元素": ["《成人糖尿病食养指南(2023 年版)》", "制定", "满足", "人民健康需求"]
  },
 {
    "原子事实": "《成人糖尿病食养指南(2023 年版)》旨在预防和控制我国人群糖尿病的发生和发展。",
    "关键元素": ["《成人糖尿病食养指南(2023 年版)》", "预防", "控制", "我国人群", "糖尿病", "发生", "发展"]
  },
{
        "原子事实": "下颌升支和喙突骨折可导致张口受限",
        "关键元素": ["下颌升支", "喙突骨折", "张口受限"]
    }
]
"""
def check_for_typo(text):
    response =client.chat.completions.create(
        model='Qwen2.5-14B-Instruct-GPTQ-Int8',  # 使用的模型
        temperature=0.3,
        messages=[
            {"role": "user", "content": f"\n{system_prompt}"},
            {"role": "user", "content": f"\n{text}"}
        ]
    )
    return response.choices[0].message.content
4.完整代码
from concurrent.futures import ThreadPoolExecutor

import pymysql
from openai import OpenAI
from propmt import system_prompt
client = OpenAI(api_key="EMPTY",base_url='http://192.168.1.19:8000/v1')
# 设置 MySQL 数据库连接信息
db_config = {
    'host': 'localhost',
    'user': 'root',
    'password': '123456wa.',
    'database': 'olmocr'
}
# 设置 OpenAI API key

def fetch_text_by_id(record_id, title_unique):
    # 连接到数据库
    connection = pymysql.connect(**db_config)
    try:
        with connection.cursor() as cursor:
            new_record = str(record_id).strip()
            new_title = new_record+title_unique
            print("new_title",new_title)
            # 执行查询,根据 ID 获取文本
            query = "SELECT content FROM markdown_data WHERE title = %s"
            cursor.execute(query, (new_title,))
            result = cursor.fetchall()
            return result
    finally:
        connection.close()
def check_for_typo(text):
    response =client.chat.completions.create(
        model='Qwen2.5-14B-Instruct-GPTQ-Int8',  # 使用的模型
        temperature=0.3,
        messages=[
            {"role": "user", "content": f"\n{system_prompt}"},
            {"role": "user", "content": f"\n{text}"}
        ]
    )
    return response.choices[0].message.content


def main2(title,title_unique):
    # 1. 数据库查询优化
    text_records = fetch_text_by_id(title, title_unique)  # 确保查询时只返回 text 列

    if not text_records:
        return "未找到对应的记录。"

    # 2. 多线程加速文本检查
    def process_record(index, record):
        """ 处理单条记录 """
if record[0]:  # 只处理非空记录
            original_text = record[0]
            result = check_for_typo(original_text)
            print(result)
            if result == " ":
                return " "
            else:
                return f"\n{result}\n"
        return ""

    # 3. 线程池加速(适用于 I/O 操作)
    with ThreadPoolExecutor() as executor:
        results = executor.map(process_record, range(1, len(text_records) + 1), text_records)

    print(results)
    # 4. `list.append()` 拼接字符串
    return "".join(results).strip()

if __name__ == "__main__":
    main2(str,id)

四、应用交互界面前端部分

1.整体布局:

创建了 Gradio 应用的基础容器。title 设置了浏览器标签页的标题。css 参数允许注入自定义 CSS 样式,这里用来调整行和列的间距与内边距。

with gr.Blocks(title="多模态数据处理驱动范式",
               css=".gr-row {gap: 0!important;} .gr-column {padding: 0 5px!important;}") as app:
2.顶部信息展示区 :

with gr.Row(equal_height=True):: 创建一个行容器,equal_height=True 使此行内的列高度一致。

with gr.Column(scale=5, min_width=0):: 在行内创建一个列,scale 控制其相对宽度。

gr.Markdown(...): 这个组件用于显示 Markdown 格式的文本。特别的是,这里直接嵌入了大段的 HTML 代码来构建一个复杂的、带有 Logo 和分栏介绍的顶部横幅:

  • <div style="display: flex; ...">: 使用 Flexbox 布局创建横向排列的区域。
  • <img src='data:image/png;base64,{logo_b64}' .../>: 显示 Logo。Logo 图片数据是通过 base64.b64encode(logo_path.read_bytes()).decode() 从本地文件 login-logo-small.png 读取并编码后嵌入的。
  • 后续的几个 <div style="..."> 分别定义了 "一意AI增效家"、"课程简介" 和 "演示技术说明" 这三个信息区块,包含标题 (<h3>) 和段落 (<p>) 以及列表 (<ul>),并应用了内联 CSS 样式。
 with gr.Row(equal_height=True):
        # 合并为一个弹性容器
        with gr.Column(scale=5, min_width=0):
            gr.Markdown(f"""
<div style="display: flex; margin: 10px 0px 0; background: #f8f9fa; border-radius: 6px; overflow: hidden;">
  <img src='data:image/png;base64,{logo_b64}' style='height:90px; margin-right:-45px'/>
  <div style="flex: 1; padding: 20px; border-right: 1px solid #e0e0e0;">
    <h3 style="margin:0 0 15px 0; color: #2c3e50; font-size: 1.2em;">一意AI增效家</h3>
    <p style="margin:0; line-height: 1.6; color: #34495e;">多模态模型驱动的数据处理范式-课程演示</p>
  </div>

  <div style="flex: 1.5; padding: 20px; border-right: 1px solid #e0e0e0;">
    <h3 style="margin:0 0 15px 0; color: #2c3e50; font-size: 1.2em;">多模态数据处理范式课程简介</h3>
    <p style="margin:0; line-height: 1.6; color: #34495e;">
      以传统OCR处理数据,只能做线性任务,企业需求复杂,多格式+跨部门非结构文档处理,是每一个AI项目顺利落地的重要技术,使用多模态模型,能接受任何开放的文档,识别提取后,转为结构化文档,接本地大模型,处理入库;
    </p>
  </div>

  <div style="flex: 2; padding: 20px;">
    <h3 style="margin:0 0 0px 0; color: #2c3e50; font-size: 1.2em;">演示技术说明:</h3>
    <p style="margin:0 10 5px 0; line-height: 1; color: #34495e;">
      我们使用48G显存,在本地推理了两个模型:
    </p>
    <ul style="margin:0 0 0px 0; padding-left: 20px; line-height: 1; color: #34495e;">
      <li>多模态模型-7B(微调增强)</li>
      <li>大语音模型-14B(qwen系列)</li>
    </ul>
    <p style="margin:0; line-height: 1.3; color: #34495e;">
      上传PDF文档后,首先由7B的多模态模型对文档进行识别提取,将文本在本地保存为.md格式,然后通过动态递归分块的策略,传入到14B的大模型做知识图谱的抽取。
    </p>
  </div>
</div>
            """)
3.主交互区域 (左右分栏):

with gr.Row():: 创建一个新的行来容纳主要的交互元素。

左侧列 ( with gr.Column(scale=4): ) :

  • tabs = gr.Tabs(): 创建选项卡容器(虽然这里只有一个 TabItem,但结构上用了 Tabs)。

  • with gr.TabItem(label='', scale=3):: 创建一个无标签的选项卡。

    • text_output = gr.Textbox(label="识别能力展示", lines=20, interactive=True): 这是第一个主要的输出文本框,用于显示 process_pdf 函数返回的识别结果。lines=20 设置了初始显示行数。interactive=True 允许用户编辑内容(虽然通常是程序填充)。
    • inp = gr.Textbox(label="与文档对话(未开放)", placeholder="请输入"): 一个文本输入框,目前看是未启用状态的占位符。
  • pdf_input = gr.File(label="上传PDF文件", file_types=[".pdf"]): 文件上传组件,限制只能上传 .pdf 文件。这是 process 函数的主要输入。

  • process_btn = gr.Button("开始工作", variant="huggingface"): 核心的触发按钮,点击后会调用 process 函数。variant 控制按钮外观。

右侧列 ( with gr.Blocks() as demo: ... with gr.Column(scale=8): ) :

  • 这里嵌套了一个 gr.Blocks,虽然不常见,但功能上等同于直接用 gr.Columnscale=8 使其比左侧列更宽。

  • tabs = gr.Tabs(): 创建另一组选项卡。

  • with gr.TabItem("知识图谱抽取"):: 第一个选项卡。

    • text1_output = gr.Textbox(label="知识图谱抽取", lines=15, interactive=True): 第二个主要的输出文本框,用于显示 process_pdf2 函数返回的结果。
  • with gr.TabItem("书写错字检查"):, with gr.TabItem("流程合规审核"):, with gr.TabItem("党建文件审核"):: 其他几个选项卡,目前都包含一个 placeholder="待开放"interactive=False 的文本框,作为未来功能的占位符。

  • with gr.Row():: 在选项卡下方创建一个新行。

    • with gr.Column(scale=4):: 行内的一个列。

      • save_btn = gr.Button("生成文件", variant="primary"): 用于触发保存结果到文件的按钮。
      • outputs = gr.components.File(label="下载文件"): 文件输出组件,当 save_to_file 函数返回文件路径时,这里会显示一个下载链接。
with gr.Row():
    with gr.Column(scale=4):
        tabs = gr.Tabs()
        with tabs:
            with gr.TabItem(label='', scale=3):
                text_output = gr.Textbox(label="识别能力展示", lines=20, interactive=True)
                inp = gr.Textbox(label="与文档对话(未开放)", placeholder="请输入")
        pdf_input = gr.File(label="上传PDF文件", file_types=[".pdf"])
        process_btn = gr.Button("开始工作", variant="huggingface")
    with gr.Blocks() as demo:
        with gr.Column(scale=8):
            tabs = gr.Tabs()
            with tabs:
                with gr.TabItem("知识图谱抽取"):
                    text1_output = gr.Textbox(label="知识图谱抽取", lines=15, interactive=True)
                with gr.TabItem("书写错字检查"):
                    log_output = gr.Textbox(placeholder="待开放", label="书写错字检查", lines=15, interactive=False)
                with gr.TabItem("流程合规审核"):
                    log_output = gr.Textbox(placeholder="待开放", label="流程合规审核", lines=15, interactive=False)
                with gr.TabItem("党建文件审核"):
                    log_output = gr.Textbox(placeholder="待开放", label="党建文件审核", lines=15, interactive=False)

            with gr.Row():
                with gr.Column(scale=4):
                    save_btn = gr.Button("生成文件", variant="primary")
                    outputs = gr.components.File(label="下载文件")

                demo.close = lambda: demo.server.close()
4.事件绑定 :

process_btn.click(fn=process, inputs=[pdf_input], outputs=[text_output, text1_output]): 将 "开始工作" 按钮的点击事件绑定到 process 函数。inputs 指定 pdf_input 组件的值作为函数的输入,outputs 指定函数的返回值分别更新到 text_outputtext1_output 组件。

save_btn.click(fn=save_to_file, inputs=[text_output, text1_output], outputs=outputs): 将 "生成文件" 按钮的点击事件绑定到 save_to_file 函数。inputs 指定两个文本框的内容作为输入,outputs 指定函数的返回值(文件路径)更新到 outputs 文件组件,从而提供下载。

def save_to_file(text, text1):
    file_path = "output.txt"
    with open(file_path, "w", encoding="utf-8") as f:
        f.write("识别能力展示:\n")
        f.write(text + "\n\n")
        f.write("纠错结果:\n")
        f.write(text1)
    # 生成下载链接(本地 Gradio 服务器)
    return file_path

process_btn.click(
    # 直接调用核心处理函数 process
    fn=process,
    inputs=[pdf_input],
    outputs=[text_output, text1_output]
)
save_btn.click(
    fn=save_to_file,
    inputs=[text_output, text1_output],
    outputs=outputs
)
5.应用启动:

app.launch(server_name='0.0.0.0', server_port=7860): 启动 Gradio 应用,使其在本地所有网络接口 0.0.0.07860 端口上监听访问。

app.launch(server_name='0.0.0.0', server_port=7860)
6.完整代码:

logo_path = Path("./login-logo-small.png")
logo_b64 = base64.b64encode(logo_path.read_bytes()).decode()


def save_to_file(text, text1):
    file_path = "output.txt"
    with open(file_path, "w", encoding="utf-8") as f:
        f.write("识别能力展示:\n")
        f.write(text + "\n\n")
        f.write("纠错结果:\n")
        f.write(text1)
    # 生成下载链接(本地 Gradio 服务器)
    return file_path

with gr.Blocks(title="多模态数据处理驱动范式",
               css=".gr-row {gap: 0!important;} .gr-column {padding: 0 5px!important;}") as app:
    # 标题行优化
    with gr.Row(equal_height=True):
        # 合并为一个弹性容器
        with gr.Column(scale=5, min_width=0):
            gr.Markdown(f"""
<div style="display: flex; margin: 10px 0px 0; background: #f8f9fa; border-radius: 6px; overflow: hidden;">
  <img src='data:image/png;base64,{logo_b64}' style='height:90px; margin-right:-45px'/>
  <div style="flex: 1; padding: 20px; border-right: 1px solid #e0e0e0;">
    <h3 style="margin:0 0 15px 0; color: #2c3e50; font-size: 1.2em;">一意AI增效家</h3>
    <p style="margin:0; line-height: 1.6; color: #34495e;">多模态模型驱动的数据处理范式-课程演示</p>
  </div>

  <div style="flex: 1.5; padding: 20px; border-right: 1px solid #e0e0e0;">
    <h3 style="margin:0 0 15px 0; color: #2c3e50; font-size: 1.2em;">多模态数据处理范式课程简介</h3>
    <p style="margin:0; line-height: 1.6; color: #34495e;">
      以传统OCR处理数据,只能做线性任务,企业需求复杂,多格式+跨部门非结构文档处理,是每一个AI项目顺利落地的重要技术,使用多模态模型,能接受任何开放的文档,识别提取后,转为结构化文档,接本地大模型,处理入库;
    </p>
  </div>

  <div style="flex: 2; padding: 20px;">
    <h3 style="margin:0 0 0px 0; color: #2c3e50; font-size: 1.2em;">演示技术说明:</h3>
    <p style="margin:0 10 5px 0; line-height: 1; color: #34495e;">
      我们使用48G显存,在本地推理了两个模型:
    </p>
    <ul style="margin:0 0 0px 0; padding-left: 20px; line-height: 1; color: #34495e;">
      <li>多模态模型-7B(微调增强)</li>
      <li>大语音模型-14B(qwen系列)</li>
    </ul>
    <p style="margin:0; line-height: 1.3; color: #34495e;">
      上传PDF文档后,首先由7B的多模态模型对文档进行识别提取,将文本在本地保存为.md格式,然后通过动态递归分块的策略,传入到14B的大模型做知识图谱的抽取。
    </p>
  </div>
</div>
            """)

    with gr.Row():
        with gr.Column(scale=4):
            tabs = gr.Tabs()
            with tabs:
                with gr.TabItem(label='', scale=3):
                    text_output = gr.Textbox(label="识别能力展示", lines=20, interactive=True)
                    inp = gr.Textbox(label="与文档对话(未开放)", placeholder="请输入")
            pdf_input = gr.File(label="上传PDF文件", file_types=[".pdf"])
            process_btn = gr.Button("开始工作", variant="huggingface")
        with gr.Blocks() as demo:
            with gr.Column(scale=8):
                tabs = gr.Tabs()
                with tabs:
                    with gr.TabItem("知识图谱抽取"):
                        text1_output = gr.Textbox(label="知识图谱抽取", lines=15, interactive=True)
                    with gr.TabItem("书写错字检查"):
                        log_output = gr.Textbox(placeholder="待开放", label="书写错字检查", lines=15, interactive=False)
                    with gr.TabItem("流程合规审核"):
                        log_output = gr.Textbox(placeholder="待开放", label="流程合规审核", lines=15, interactive=False)
                    with gr.TabItem("党建文件审核"):
                        log_output = gr.Textbox(placeholder="待开放", label="党建文件审核", lines=15, interactive=False)

                with gr.Row():
                    with gr.Column(scale=4):
                        save_btn = gr.Button("生成文件", variant="primary")
                        outputs = gr.components.File(label="下载文件")

                    demo.close = lambda: demo.server.close()

    process_btn.click(
        # 直接调用核心处理函数 process
        fn=process,
        inputs=[pdf_input],
        outputs=[text_output, text1_output]
        # 注意: 移除了 queue=True 参数,这意味着可以并发处理
    )
    save_btn.click(
        fn=save_to_file,
        inputs=[text_output, text1_output],
        outputs=outputs
    )

app.launch(server_name='0.0.0.0', server_port=7860)

五、完整代码及操作说明

1.文件结构放置如下:
.
├── temp
│   ├── 20250324-国信证券-腾讯控股-0700_converted.txt
│   ├── 慢性肾脏病患者膳食指导2017-08-01发布_converted.txt
│   └── ...
├── temp_downloads
│   ├── result_20250403_164439.txt
│   ├── result_20250403_164450.txt
├── requirements.txt
├── anchor.py
├── coherency.py
├── mysql_extract.py
├── mysql_save.py
├── prompts.py
├── propmt.py
└── front_.py
2.配置环境及依赖:
  1. 创建conda环境
 # 创建名为 wukong2 的新环境,Python 版本为 3.10(可以根据需求更改)
conda create -n olmocr python=3.10

2. 激活运行项目的 olmocr 环境(换成你自己的环境名称):

# 激活名为 olmocr 的新环境
conda activate olmocr

3. 在olmocr 环境中,利用 requirements.txt 文件安装项目依赖:

# 使用清华大学 PyPI 镜像源安装依赖
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple

暂时无法在飞书文档外展示此内容

3.完整代码:

暂时无法在飞书文档外展示此内容

六、效果展示

  1. 完整界面:

  1. 传入文件:

  1. 开始工作:

  1. 下载文件:

七、视频链接

讲解视频:

运行视频: