适逢oceanbase举办第一届ai hackthon. 比较遗憾没杀入前八强,在此记录一下自己打比赛的概要设计吧
(一) 项目简介
网络医生系统旨在通过为任何关心自身健康的人提供人工智能健康助手,解决医疗资源不平衡的问题。在优质医疗服务匮乏的地区,人们往往需要前往大城市咨询医疗专业人士,这既耗时又昂贵,也为医疗救助设置了障碍。网络医生旨在通过以下方式弥合这一差距
本系统支持这四个功能
- 基础疾病诊断协助
- 医疗记录分析
- 专业医学知识咨询
- 个人健康管理指导
(二) 核心功能
| 特征 | 描述 |
|---|---|
| 多模态交互 | 支持基于文本、语音和图像的输入/输出 |
| 问答 | 核心功能由 LLM 提供支持,并通过 RAG 增强 |
| 文档生成 | 根据健康相关主题创建 PPT 和 Word 文档 |
| 语音处理 | 支持通过麦克风输入语音,通过TTS输出语音 |
| 知识增强 | 使用知识库、知识图谱和互联网搜索实现 RAG |
| 图像识别 | 识别医学图像中的物体和文本(例如医疗记录、处方) |
| 图像/视频生成 | 根据用户提示创建视觉内容 |
(三) 多模态交互
本系统提供了多种用户交互界面:
- 文本接口:基于文本的通信的主要接口
- 语音接口:专用语音对话模块,支持语音转文本、文本转语音
- 图像处理:能够处理和分析上传的图像
- 文档处理:能够从上传的文档中提取信息
(三) 架构设计
该系统由几个关键组件组成,下面将逐一介绍
-
- 问答(QA)系统
问答系统构成了本系统的核心,协调用户输入,各种知识源和响应生成之间的关系。它分析用户查询以确定意图,并利用适当的工具和模块生成响应
QA模块的关键组件包括:
answer.py:根据问题类型生成适当的响应question_parser.py:解析并确定用户问题的类型purpose_type.py:定义系统可以处理的不同问题类型function_tool.py:包含用于响应生成的实用函数prompt_templates.py:存储用于LLM交互的提示模板
问答分类
'''问答类型判断函数,根据特定输入和大模型进行分类分类。'''
from typing import List, Dict
from client.clientfactory import Clientfactory
from qa.prompt_templates import get_question_parser_prompt
from qa.purpose_type import purpose_map
from qa.purpose_type import userPurposeType
from icecream import ic
def parse_question(question: str, image_url=None) -> userPurposeType:
if "根据知识库" in question:
return purpose_map["基于知识库"]
if "根据知识图谱" in question:
return purpose_map["基于知识图谱"]
if "搜索" in question:
return purpose_map["网络搜索"]
if ("word" in question or "Word" in question or "WORD" in question) and ("生成" in question or "制作" in question):
return purpose_map["Word生成"]
if ("ppt" in question or "PPT" in question or "PPT" in question) and ("生成" in question or "制作" in question):
return purpose_map["PPT生成"]
if image_url is not None:
return purpose_map["图片描述"]
# 在这个函数中我们使用大模型去判断问题类型
prompt = get_question_parser_prompt(question)
response = Clientfactory().get_client().chat_with_ai(prompt)
ic("大模型分类结果:" + response)
if response == "图片生成" and len(question) > 0:
return purpose_map["图片生成"]
if response == "视频生成" and len(question) > 0:
return purpose_map["视频生成"]
if response == "PPT生成" and len(question) > 0:
return purpose_map["PPT生成"]
if response == "Word生成" and len(question) > 0:
return purpose_map["Word生成"]
if response == "音频生成" and len(question) > 0:
return purpose_map["音频生成"]
if response == "文本生成":
return purpose_map["文本生成"]
return purpose_map["其他"]
提示词模版
purpose_type_template = (
f"你扮演文本分类的工具助手,类别有{len(purpose_map)}种,"
f"分别为:文本生成,图片生成,视频生成,音频生成,图片描述,问候语,PPT生成,Word生成,网络搜索,基于知识库,基于知识图谱,其他。"
f"下面给出一些例子用来辅助你判别:"
f"'我想了解糖尿病' ,文本分类结果是文本生成;"
f"'请生成老年人练习太极的图片',文本分类结果是图片生成;"
f"'你可以生成一段关于春天的视频吗',文本分类结果是视频生成;"
f"'请将上述文本转换成语音',文本分类结果是音频生成;"
f"'糖尿病如何治疗?请用语音回答',文本分类结果是音频生成;"
f"'请你描述这张美丽的图片',文本分类结果是图片描述;"
f"'您好!你是谁?',文本分类结果是问候语;"
f"'请你用制作一份关于糖尿病的PPT',文本分类结果是PPT生成;"
f"'请你用word制作一份关于糖尿病的报告',文本分类结果是Word生成;"
f"'请在互联网上找到养生保健的相关知识',文本分类结果是网络搜索;"
f"'知识库中有什么糖尿病相关的知识',文本分类结果是基于知识库;"
f"'知识图谱中有什么糖尿病相关的知识',文本分类结果是基于知识图谱;"
f"'我有糖尿病,帮我制定一个饮食锻炼计划' ,文本分类结果是文本生成;"
f"如果以上内容没有对应的类别,文本分类结果是其他。"
f"请参考上面例子,直接给出一种分类结果,不要解释,不要多余的内容,不要多余的符号,不要多余的空格,不要多余的空行,不要多余的换行,不要多余的标点符号。"
f"请你对以下内容进行文本分类:"
)
-
- LLM客户端系统
LLM客户端系统充当Cyber-Doctor和各种大型语言模型提供商之间的桥梁。它使用工厂模式根据系统配置创建适当的客户端
- 检索增强生成(RAG)系统
4. 音频处理
音频处理模块支持与系统的语音交互,处理语音到文本和文本到语音的转换。
关键组件:
audio_extract.py:处理音频输入并提取文本audio_generate.py:从文本响应生成音频输出
4.1 语音转文本 (STT) 处理
技术栈:
- Whisper模型(OpenAI开发的语音识别模型)
- speech_recognition库(Python语音识别库)
- pydub库(音频处理库,用于格式转换)
- OpenCC库(中文繁简转换)
主要使用了Whisper模型,该由OpenAI开发。具备离线工作能力,以及便捷于系统集成,并且可以和TTS进行协同交互
处理流程如下:
- 首先检查音频文件是否为WAV格式,如果不是则调用
convert_audio_to_wav函数转换 - 使用
speech_recognition库创建识别器对象 - 加载音频文件并记录音频数据
- 使用Whisper模型进行语音识别,自动设置语言为中文("zh")
- 使用OpenCC库将识别结果转换为简体中文
音频格式转换函数:
这个函数使用pydub库的AudioSegment来处理各种格式的音频文件并转换为WAV格式,这是因为Whisper模型需要WAV格式的输入
def audio_to_text(audio_file_path):
# 创建识别器对象
# 如果不是 WAV 格式,先转换为 WAV
if not audio_file_path.endswith(".wav"):
audio_file_path = convert_audio_to_wav(audio_file_path)
recognizer = sr.Recognizer()
with sr.AudioFile(audio_file_path) as source:
audio_data = recognizer.record(source)
# 使用 Google Web Speech API 进行语音识别,不用下载模型但对网络要求高
# text = recognizer.recognize_google(audio_data, language="zh-CN")
# 使用 whisper 进行语音识别,自动下载模型到本地
text = recognizer.recognize_whisper(audio_data, language="zh")
text_simplified = convert_to_simplified(text)
return text_simplified
4.2 文本转语音 (TTS) 处理
文本转语音功能主要使用了Edge-TTS,这是微软Edge浏览器的文本到语音服务。在应用中的多处地方都调用了audio_generate函数,例如:
从上下文可以看出,audio_generate函数接受两个参数:
text: 需要转换为语音的文本内容model_name: 使用的TTS模型名称,例如"zh-CN-YunxiNeural"(这是一个中文男声模型)
def convert_audio_to_wav(audio_file_path):
audio = AudioSegment.from_file(audio_file_path) # 自动识别格式
wav_file_path = audio_file_path.rsplit(".", 1)[0] + ".wav" # 生成 WAV 文件路径
audio.export(wav_file_path, format="wav") # 将音频文件导出为 WAV 格式
return wav_file_path
5. 文档生成
文档生成模块允许系统根据用户查询创建结构化文档,如PowerPoint演示文稿和Word文档。
关键组件, 用的是langchain本身自带的解析Loader
(四) 数据流架构设计
(四)模型选择与训练
- 模型选型
在项目中,模型选型需要考虑医疗健康领域的特殊需求和多模态交互的要求。基于项目的代码库和wiki内容,我们可以看到系统采用了多种模型来处理不同类型的任务:
1.1 语言模型选择
系统主要依赖于大型语言模型(LLM)来处理核心的问答功能。系统设计支持多种LLM提供商:
- 智谱AI:作为主要的中文大模型提供商
- OpenAI兼容API:包括各种支持OpenAI SDK接口的模型
- Ollama:用于本地部署的大模型
系统采用了工厂模式来支持不同的LLM提供商,通过clientfactory.py创建适当的客户端。这种设计使系统能够灵活切换不同的模型,适应不同的部署环境和需求。
1.2 多模态模型选择
除了核心的语言模型外,系统还集成了多种专用模型来处理不同类型的输入和输出:
1.2.1 语音处理模型
- Whisper:用于语音转文本(STT),提供离线工作能力和多语言支持
- Edge-TTS:用于文本转语音(TTS),支持多种语音风格和语言
Whisper模型的选择特别适合医疗场景,因为它可以在本地运行,不依赖网络连接,这对于医疗资源不平衡地区尤为重要。
在进行语音处理之前,需要先将大语言模型转换成语音的文字,
然后判断需要生成哪种语言,需要生成的男声还是女声。之后调用具体的语音处理模型进行处理
def process_audio_tool(
question_type: userPurposeType,
question: str,
history: List[List | None] = None,
image_url=None,
):
# 先让大语言模型生成需要转换成语音的文字
text = extract_text(question, history)
# 判断需要生成哪种语言(东北、陕西、粤...)
lang = extract_language(question)
# 判断需要生成男声还是女声
gender = extract_gender(question)
# 上面三步均与大语言模型进行交互
# 选择用于生成的模型
model_name, success = get_tts_model_name(lang=lang, gender=gender)
if success:
audio_file = audio_generate(text, model_name)
else:
audio_file = audio_generate(
"由于目标语言包缺失,我将用普通话回复您。" + text, model_name
)
return ((audio_file, "audio"), question_type)
1.2.2 向量嵌入模型
系统使用ModelScope嵌入模型来为RAG系统创建文档向量: 模型采用阿里达摩院提供的中文长序列转换模型
iic/nlp_corom_sentence-embedding_chinese-base
这些嵌入模型用于将文本转换为向量表示,以便进行语义搜索和相似度匹配。
如何清除旧的向量库呢?
由于oceanbase本身基于langchain没有实现recordManager记录已经存入数据库的文档的hash值。所以我们选择建立<userid, retriever>存入字典。在用户联网下载资料的时候,用user_id隔离资料库的数据,当文档更新后,如果发现该文档不再存在于本地目录,则删除向量数据库
def build_user_vector_store(self):
"""根据用户的ID加载用户文件夹中的文件并为用户构建向量库"""
user_data_path = os.path.join("user_data", self.user_id) # 用户独立文件夹
if not os.path.exists(user_data_path):
print(f"用户文件夹 {user_data_path} 不存在")
return
try:
# 清理旧的向量库(如果已经存在)
if self.user_id in self._user_retrievers:
del self._user_retrievers[self.user_id]
print(f"用户 {self.user_id} 的旧向量库已删除")
..............
# 合并文档
docs = (
pdf_docs
+ docx_docs
+ txt_docs
+ csv_docs
+ html_docs
+ mhtml_docs
+ markdown_docs
)
if not docs:
print(f"用户 {self.user_id} 文件夹中没有找到文档")
return
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=2000, chunk_overlap=100
)
splits = text_splitter.split_documents(docs)
# 为该用户构建向量库
vectorstore = OceanbaseVectorStore.from_documents(
documents=splits, embedding=self._embedding
)
user_retriever = vectorstore.as_retriever(search_kwargs={"k": 6})
# 将用户的retriever存储到字典中
self._user_retrievers[self.user_id] = user_retriever
print(f"用户 {self.user_id} 的向量库已构建完成")
except Exception as e:
print(f"构建用户 {self.user_id} 向量库时出错: {e}")
1.2.3 PDF处理
如果用户对于pdf的处理不够满意,我们还提供了一个小工具
- 实现了单例模式,确保整个程序中只有一个
ModelManager实例。 - 管理模型的加载、量化和资源使用,包括根据系统配置选择最优设备(CPU 或 GPU),设置量化后端,根据设备和模型类型选择最优量化配置,延迟加载模型,检查内存使用情况并卸载最少使用的模型,定期清理未使用的模型,提供获取模型、清除缓存和获取模型内存使用情况的方法。
- 定义了加载不同模型(Spacy、分类器、句子 Transformer)的方法,并在加载时应用量化(如果配置)
.......
# 定义 ModelManager 类,用于管理模型的加载、量化和资源使用
class ModelManager:
# 单例模式的实例变量,确保整个程序中只有一个 ModelManager 实例
_instance = None
# 存储已加载的模型
_models = {}
# 线程锁,用于线程安全的操作,避免多线程访问时出现数据竞争问题
_lock = threading.Lock()
# 记录每个模型的最后使用时间
_last_used = {}
# 内存使用率阈值,当内存使用率超过该值时,卸载最少使用的模型
_memory_threshold = 0.8
# 模型最大空闲时间(秒),超过该时间未使用的模型将被卸载
_max_idle_time = 300
def __new__(cls):
"""
实现单例模式的方法,确保只有一个 ModelManager 实例被创建
"""
# 检查实例是否已经存在
if cls._instance is None:
# 使用线程锁保证线程安全
with cls._lock:
# 再次检查实例是否已经存在,避免多个线程同时创建实例
if cls._instance is None:
# 创建新的实例
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
"""
初始化 ModelManager 类,设置模型配置、设备和量化后端,并启动清理线程
"""
# 检查是否已经初始化过
if hasattr(self, '_initialized'):
return
# 标记为已初始化
self._initialized = True
# 模型配置字典,包含模型名称、加载函数和量化配置
self._model_configs = {
'spacy': {
# 模型名称
'name': 'en_core_web_sm',
# 加载模型的函数
'loader': self._load_spacy,
# 是否进行量化
'quantize': False
},
'classifier': {
'name': 'facebook/bart-large-mnli',
'loader': self._load_classifier,
'quantize': {
# 是否启用量化
'enabled': True,
# 量化方法
'method': 'dynamic',
# 量化数据类型
'dtype': torch.qint8,
# 需要量化的层
'layers': [torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d],
# 校准集大小
'calibration_size': 100
}
},
'sentence_transformer': {
'name': 'paraphrase-MiniLM-L6-v2',
'loader': self._load_sentence_transformer,
'quantize': {
'enabled': True,
'method': 'dynamic',
'dtype': torch.qint8,
'layers': [torch.nn.Linear],
'calibration_size': 100
}
}
}
# 根据系统配置选择最优的设备(CPU 或 GPU)
self.device = self._get_optimal_device()
# 设置量化后端
self._setup_quantization_backend()
# 创建一个线程,用于定期清理未使用的模型
self._cleanup_thread = threading.Thread(target=self._cleanup_models, daemon=True)
# 启动清理线程
self._cleanup_thread.start()
def _get_optimal_device(self):
"""
根据系统配置选择最优的设备(CPU 或 GPU)
如果 GPU 不可用或 GPU 内存小于 4GB,则选择 CPU
"""
# 检查 GPU 是否可用
if not torch.cuda.is_available():
return 'cpu'
# 获取 GPU 的总内存
gpu_memory = torch.cuda.get_device_properties(0).total_memory
# 如果 GPU 内存小于 4GB
if gpu_memory < 4 * 1024 * 1024 * 1024:
return 'cpu'
return 'cuda'
def _get_optimal_quantization_config(self, model_type):
"""
根据设备和模型类型选择最优的量化配置
如果是 CPU 设备,使用动态量化和较小的校准集
如果是 GPU 设备,使用静态量化和半精度
"""
# 获取模型的基本量化配置
base_config = self._model_configs[model_type]['quantize']
if not base_config:
return base_config
if self.device == 'cpu':
return {
'enabled': True,
'method': 'dynamic',
'dtype': torch.qint8,
'layers': base_config['layers'],
'calibration_size': 50 # CPU 下使用更小的校准集
}
else:
return {
'enabled': True,
'method': 'static',
'dtype': torch.float16, # GPU 下使用半精度
'layers': base_config['layers'],
'calibration_size': base_config['calibration_size']
}
def get_model(self, model_type: str):
"""
延迟加载模型,如果模型未加载,则加载模型并进行量化(如果配置)
检查内存使用情况,必要时卸载最少使用的模型
更新模型的最后使用时间
"""
# 使用线程锁保证线程安全
with self._lock:
# 检查模型是否已经加载
if model_type not in self._models:
# 获取模型配置
config = self._model_configs.get(model_type)
if not config:
# 如果模型类型未知,抛出异常
raise ModelLoadError(f"Unknown model type: {model_type}")
# 检查内存使用情况
self._check_memory_usage()
try:
# 调用加载函数加载模型
model = config['loader'](config['name'])
# 获取最优的量化配置
quant_config = self._get_optimal_quantization_config(model_type)
if quant_config and quant_config['enabled']:
# 如果需要量化,对模型进行量化
model = self._prepare_model_for_quantization(model, quant_config)
# 将模型存储到 _models 字典中
self._models[model_type] = model
except Exception as e:
# 如果加载模型失败,抛出异常
raise ModelLoadError(f"Failed to load model {model_type}: {str(e)}")
# 更新模型的最后使用时间
self._last_used[model_type] = time.time()
# 返回加载好的模型
return self._models[model_type]
def _check_memory_usage(self):
"""
检查内存使用情况,如果内存使用率超过阈值,则卸载最少使用的模型
"""
# 获取当前内存使用率
memory_percent = psutil.virtual_memory().percent / 100
if memory_percent > self._memory_threshold:
# 如果内存使用率超过阈值,卸载最少使用的模型
self._unload_least_used_model()
def _unload_least_used_model(self):
"""
卸载最少使用的模型,如果该模型的空闲时间超过最大空闲时间
"""
# 如果 _last_used 字典为空,直接返回
if not self._last_used:
return
# 获取当前时间
current_time = time.time()
# 找到最少使用的模型
least_used_model = min(self._last_used.items(), key=lambda x: x[1])[0]
# 检查该模型的空闲时间是否超过最大空闲时间
if current_time - self._last_used[least_used_model] > self._max_idle_time:
# 如果超过最大空闲时间,卸载该模型
self._unload_model(least_used_model)
def _unload_model(self, model_type: str):
"""
卸载指定模型,删除模型缓存,进行垃圾回收,并清理 GPU 缓存(如果使用 GPU)
"""
if model_type in self._models:
# 从 _models 字典中删除该模型
del self._models[model_type]
# 从 _last_used 字典中删除该模型的最后使用时间
del self._last_used[model_type]
# 强制进行垃圾回收
import gc
gc.collect()
if self.device == 'cuda':
# 如果使用 GPU,清理 GPU 缓存
torch.cuda.empty_cache()
def _cleanup_models(self):
"""
定期清理未使用的模型,每分钟检查一次
"""
while True:
# 线程休眠 60 秒
time.sleep(60)
# 使用线程锁保证线程安全
with self._lock:
# 获取当前时间
current_time = time.time()
# 找出所有空闲时间超过最大空闲时间的模型
models_to_unload = [
model_type for model_type, last_used in self._last_used.items()
if current_time - last_used > self._max_idle_time
]
for model_type in models_to_unload:
# 卸载这些模型
self._unload_model(model_type)
def _setup_quantization_backend(self):
"""
设置量化后端,根据设备选择合适的量化引擎
"""
if self.device == 'cuda':
# 在 GPU 上使用 CUDA 量化后端
torch.backends.quantized.engine = 'fbgemm'
else:
# 在 CPU 上使用 fbgemm (Windows compatible)
torch.backends.quantized.engine = 'fbgemm'
def _prepare_model_for_quantization(self, model, config):
"""
准备模型进行量化,根据量化方法调用相应的量化函数
"""
if not config['enabled']:
# 如果不启用量化,直接返回模型
return model
if config['method'] == 'dynamic':
# 如果是动态量化,调用动态量化函数
return self._apply_dynamic_quantization(model, config)
elif config['method'] == 'static':
# 如果是静态量化,调用静态量化函数
return self._apply_static_quantization(model, config)
return model
def _apply_dynamic_quantization(self, model, config):
"""
应用动态量化,对指定层进行动态量化
"""
try:
print(f"Applying dynamic quantization with dtype {config['dtype']}")
# 对模型进行动态量化
model = torch.quantization.quantize_dynamic(
model,
qconfig_spec={
layer: torch.quantization.default_dynamic_qconfig
for layer in config['layers']
},
dtype=config['dtype']
)
print("Dynamic quantization applied successfully")
return model
except Exception as e:
print(f"Dynamic quantization failed: {str(e)}")
return model
def _apply_static_quantization(self, model, config):
"""
应用静态量化,包括准备量化配置、融合操作、准备量化、校准和转换为量化模型
"""
try:
print(f"Applying static quantization")
# 准备量化配置
model.qconfig = torch.quantization.get_default_qconfig('fbgemm' if self.device == 'cuda' else 'fbgemm')
# 融合操作,将卷积层、批归一化层和激活函数层融合
model = torch.quantization.fuse_modules(model, [['conv', 'bn', 'relu']])
# 准备量化
model = torch.quantization.prepare(model)
# 校准(这里需要实际的校准数据)
# self._calibrate_model(model, config['calibration_size'])
# 转换为量化模型
model = torch.quantization.convert(model)
print("Static quantization applied successfully")
return model
except Exception as e:
print(f"Static quantization failed: {str(e)}")
return model
def _calibrate_model(self, model, calibration_size):
"""
使用校准数据集校准模型(用于静态量化)
这里使用随机数据作为示例,实际应用中应使用真实的校准数据
"""
# 这里应该使用实际的校准数据
# 为了示例,我们使用随机数据
with torch.no_grad():
for _ in range(calibration_size):
# 生成随机输入数据
dummy_input = torch.randn(1, 3, 224, 224)
# 模型进行推理
model(dummy_input)
def _load_spacy(self):
"""
加载 Spacy 模型,如果模型未下载,则自动下载
"""
try:
# 尝试加载 Spacy 模型
return spacy.load('en_core_web_sm')
except OSError:
# 如果模型未下载,自动下载
spacy.cli.download('en_core_web_sm')
return spacy.load('en_core_web_sm')
def _load_classifier(self):
"""
加载分类器模型,并应用量化(如果配置)
"""
print("Loading classifier model...")
# 使用 pipeline 加载分类器模型
model = pipeline("zero-shot-classification",
model='facebook/bart-large-mnli',
device=self.device)
if self._model_configs['classifier']['quantize']['enabled']:
print("Applying quantization to classifier model")
# 对分类器模型进行量化
model.model = self._prepare_model_for_quantization(
model.model,
self._model_configs['classifier']['quantize']
)
return model
def _load_sentence_transformer(self):
"""
加载句子 Transformer 模型,并应用量化(如果配置)
"""
print("Loading sentence transformer model...")
# 加载句子 Transformer 模型
model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
# 将模型移动到指定设备
model = model.to(self.device)
if self._model_configs['sentence_transformer']['quantize']['enabled']:
print("Applying quantization to sentence transformer model")
# 对句子 Transformer 模型的编码器进行量化
model.encoder = self._prepare_model_for_quantization(
model.encoder,
self._model_configs['sentence_transformer']['quantize']
)
return model
def clear_cache(self, model_type: str = None):
"""
清除指定模型或所有模型的缓存
"""
# 使用线程锁保证线程安全
with self._lock:
if model_type:
if model_type in self._models:
# 如果指定了模型类型,删除该模型的缓存
del self._models[model_type]
else:
# 如果没有指定模型类型,清空所有模型的缓存
self._models.clear()
def get_model_memory_usage(self, model_type: str = None):
"""
获取模型内存使用情况,如果指定模型类型,则返回该模型的内存使用情况
否则返回所有模型的内存使用情况
"""
if model_type:
if model_type in self._models:
# 获取指定模型
model = self._models[model_type]
# 计算该模型的内存使用情况
return self._get_model_size(model)
return None
memory_usage = {}
for model_type, model in self._models.items():
# 计算所有模型的内存使用情况
memory_usage[model_type] = self._get_model_size(model)
return memory_usage
def _get_model_size(self, model):
"""
计算模型大小(以 MB 为单位)
"""
param_size = 0
buffer_size = 0
for param in model.parameters():
# 计算模型参数的总大小
param_size += param.nelement() * param.element_size()
for buffer in model.buffers():
# 计算模型缓冲区的总大小
buffer_size += buffer.nelement() * buffer.element_size()
# 将总大小转换为 MB
size_mb = (param_size + buffer_size) / 1024 / 1024
return round(size_mb, 2)
# 定义 ModelContext 类,作为模型使用的上下文管理器
class ModelContext:
def __init__(self, model_type: str, manager: 'ModelManager'):
"""
初始化上下文管理器,记录模型类型和模型管理器
"""
# 模型类型
self.model_type = model_type
# 模型管理器实例
self.manager = manager
# 存储加载的模型
self.model = None
# 存储可能出现的错误
self.error = None
def __enter__(self):
"""
进入上下文管理器,尝试加载模型,如果出现异常则记录错误并抛出异常
"""
try:
# 从模型管理器中获取模型
self.model = self.manager.get_model(self.model_type)
return self.model
except Exception as e:
# 记录错误信息
self.error = e
# 抛出模型加载错误异常
raise ModelError(f"Error loading model {self.model_type}: {str(e)}")
def __exit__(self, exc_type, exc_val, exc_tb):
"""
退出上下文管理器,如果出现异常则记录错误但不处理,让异常继续传播
"""
if exc_type is not None:
# 记录错误信息
print(f"Error using model {self.model_type}: {str(exc_val)}")
return False # 让异常继续传播
# 定义 ModelError 类,作为模型相关错误的基类
class ModelError(Exception):
pass
# 定义 ModelLoadError 类,继承自 ModelError,用于表示模型加载错误
class ModelLoadError(ModelError):
pass
# 定义 ModelInferenceError 类,继承自 ModelError,用于表示模型推理错误
class ModelInferenceError(ModelError):
pass
# 创建全局的 ModelManager 实例
model_manager = ModelManager()
# 初始化服务器,创建 Server 实例,名称为 "pdf_reader"
server = Server("pdf_reader")
# 下载必要的 NLTK 数据
nltk_resources = ['punkt', 'averaged_perceptron_tagger', 'maxent_ne_chunker', 'words', 'stopwords']
for resource in nltk_resources:
try:
# 检查资源是否已经下载
nltk.data.find(f'tokenizers/{resource}')
except LookupError:
# 如果资源未下载,下载该资源
nltk.download(resource, quiet=True)
@server.list_tools()
async def handle_list_tools() -> list[types.Tool]:
"""
列出可用的工具,返回工具列表,每个工具包含名称、描述和输入模式
"""
return [
types.Tool(
name="extract-text",
description="从 PDF 文件中提取文本内容",
inputSchema={
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "PDF 文件的路径",
},
"page_number": {
"type": "integer",
"description": "要提取的页码(从 0 开始)",
},
},
"required": ["file_path"],
},
),
types.Tool(
name="extract-images",
description="从 PDF 文件中提取图片",
inputSchema={
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "PDF 文件的路径",
},
"page_number": {
"type": "integer",
"description": "要提取的页码(从 0 开始)",
},
},
"required": ["file_path"],
},
),
types.Tool(
name="extract-tables",
description="从 PDF 文件中提取表格",
inputSchema={
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "PDF 文件的路径",
},
"page_number": {
"type": "integer",
"description": "要提取的页码(从 0 开始)",
},
},
"required": ["file_path"],
},
),
types.Tool(
name="analyze-content",
description="分析 PDF 文件内容,提取关键信息",
inputSchema={
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "PDF 文件的路径",
},
"analysis_type": {
"type": "string",
"description": "分析类型:entities(实体), summary(摘要), keywords(关键词)",
"enum": ["entities", "summary", "keywords"],
},
},
"required": ["file_path", "analysis_type"],
},
),
types.Tool(
name="get-metadata",
description="获取 PDF 文件的元数据信息",
inputSchema={
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "PDF 文件的路径",
},
},
"required": ["file_path"],
},
),
types.Tool(
name="classify-document",
description="对 PDF 文档进行分类",
inputSchema={
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "PDF 文件的路径",
},
"categories": {
"type": "array",
"items": {"type": "string"},
"description": "可能的分类类别列表",
},
},
"required": ["file_path", "categories"],
},
),
types.Tool(
name="calculate-similarity",
description="计算两个 PDF 文档的相似度",
inputSchema={
"type": "object",
"properties": {
"file_path1": {
"type": "string",
"description": "第一个 PDF 文件的路径",
},
"file_path2": {
"type": "string",
"description": "第二个 PDF 文件的路径",
},
},
"required": ["file_path1", "file_path2"],
},
),
types.Tool(
name="detect-languages",
description="检测 PDF 文档中使用的语言",
inputSchema={
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "PDF 文件的路径",
},
},
"required": ["file_path"],
},
),
types.Tool(
name="advanced-analysis",
description="执行高级文本分析",
inputSchema={
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "PDF 文件的路径",
},
},
"required": ["file_path"],
},
),
]
async def extract_text_from_pdf(file_path: str, page_number: int = None) -> str:
"""
从 PDF 中提取文本,如果指定页码,则提取该页的文本
否则提取所有页面的文本
"""
try:
# 打开 PDF 文件
doc = fitz.open(file_path)
if page_number is not None:
if 0 <= page_number < len(doc):
# 如果指定了页码且页码在有效范围内,提取该页的文本
text = doc[page_number].get_text()
# 关闭 PDF 文件
doc.close()
return text
else:
# 如果页码超出范围,关闭 PDF 文件并返回错误信息
doc.close()
return f"页码 {page_number} 超出范围。PDF 共有 {len(doc)} 页。"
# 如果没有指定页码,提取所有页面的文本
text = ""
for page in doc:
text += page.get_text() + "\n"
# 关闭 PDF 文件
doc.close()
return text
except Exception as e:
# 如果出现异常,返回错误信息
return f"提取文本时出错: {str(e)}"
async def extract_images_from_pdf(file_path: str, page_number: int = None):
"""
从 PDF 中提取图片,返回 Base64 编码的图片列表
支持指定页码,使用线程池并行处理图片
"""
try:
# 打开 PDF 文件
doc = fitz.open(file_path)
# 存储提取的图片
images = []
# 如果指定了页码,只处理该页;否则处理所有页
pages = [page_number] if page_number is not None else range(len(doc))
for page_num in pages:
# 获取指定页
page = doc[page_num]
# 获取该页的图片列表
image_list = page.get_images()
# 定义处理图片的函数
def process_image(img_index):
try:
# 获取图片的 XREF 编号
xref = image_list[img_index][0]
# 提取图片数据
base_image = doc.extract_image(xref)
image_bytes = base_image["image"]
# 转换和优化图片
image = Image.open(io.BytesIO(image_bytes))
image = optimize_image(image)
# 将图片转换为 Base64 编码
buffered = io.BytesIO()
image.save(buffered, format="PNG", optimize=True)
img_str = base64.b64encode(buffered.getvalue()).decode()
return img_str
except Exception as e:
print(f"处理图片时出错: {str(e)}")
return None
# 使用线程池并行处理图片
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [executor.submit(process_image, i) for i in range(len(image_list))]
for future in concurrent.futures.as_completed(futures):
if future.result():
images.append(future.result())
# 关闭 PDF 文件
doc.close()
return images
except Exception as e:
print(f"提取图片时出错: {str(e)}")
return []
async def extract_tables_from_pdf(file_path: str, page_number: int = None) -> List[str]:
"""
从 PDF 中提取表格数据
"""
try:
if page_number is not None:
# 如果指定了页码,提取该页的表格数据
tables = read_pdf(file_path, pages=page_number + 1) # tabula 使用 1-based 页码
else:
# 如果没有指定页码,提取所有页的表格数据
tables = read_pdf(file_path, pages='all')
if not tables:
# 如果没有找到表格,返回提示信息
return ["未找到表格"]
result = []
for i, table in enumerate(tables):
# 将表格数据转换为字符串并添加到结果列表中
result.append(f"表格 {i + 1}:\n{table.to_string()}\n---")
return result
except Exception as e:
# 如果出现异常,返回错误信息
return [f"提取表格时出错: {str(e)}"]
async def analyze_pdf_content(file_path: str, analysis_type: str) -> Dict[str, Any]:
"""
分析 PDF 内容,根据分析类型(实体、摘要、关键词)调用相应的模型进行处理
"""
try:
# 从 PDF 中提取文本
text = await extract_text_from_pdf(file_path)
if analysis_type == "entities":
# 如果分析类型是实体识别
with ModelContext('spacy', model_manager) as nlp:
# 使用 Spacy 模型进行实体识别
doc = nlp(text)
entities = [(ent.text, ent.label_) for ent in doc.ents]
return {"entities": entities}
elif analysis_type == "summary":
# 如果分析类型是摘要提取
with ModelContext('classifier', model_manager) as classifier:
# 使用分类器模型进行摘要提取
sentences = nltk.sent_tokenize(text)
results = classifier(sentences,
candidate_labels=["important", "not important"],
multi_label=False)
important_sentences = [sent for sent, score in zip(sentences, results['scores'])
if score > 0.7]
return {"summary": " ".join(important_sentences[:5])}
elif analysis_type == "keywords":
# 如果分析类型是关键词提取
with ModelContext('spacy', model_manager) as nlp:
# 使用 Spacy 模型进行关键词提取
doc = nlp(text)
keywords = [token.text for token in doc if not token.is_stop and token.is_alpha]
return {"keywords": list(set(keywords[:20]))}
except ModelError as e:
# 如果出现模型错误,返回错误信息
return {"error": f"Model error: {str(e)}"}
except Exception as e:
# 如果出现其他异常,返回错误信息
return {"error": f"Unexpected error: {str(e)}"}
async def get_pdf_metadata(file_path: str) -> Dict[str, Any]:
"""
获取 PDF 文件的元数据信息
"""
try:
# 打开 PDF 文件
doc = fitz.open(file_path)
# 获取元数据
metadata = doc.metadata
# 关闭 PDF 文件
doc.close()
return {
"title": metadata.get("title", "未知"),
"author": metadata.get("author", "未知"),
"subject": metadata.get("subject", "未知"),
"keywords": metadata.get("keywords", "未知"),
"creator": metadata.get("creator", "未知"),
"producer": metadata.get("producer", "未知"),
"creation_date": metadata.get("creationDate", "未知"),
"modification_date": metadata.get("modDate", "未知"),
"page_count": doc.page_count
}
except Exception as e:
# 如果出现异常,返回错误信息
return {"error": str(e)}
async def classify_document(file_path: str, categories: List[str]) -> Dict[str, Any]:
"""
对文档进行分类,使用分类器模型
"""
try:
# 从 PDF 中提取文本
text = pdfminer_extract_text(file_path)
with ModelContext('classifier', model_manager) as classifier:
# 使用分类器模型进行分类
result = classifier(text, categories)
return {
"labels": result["labels"],
"scores": [float(score) for score in result["scores"]]
}
except ModelError as e:
# 如果出现模型错误,返回错误信息
return {"error": f"Model error: {str(e)}"}
except Exception as e:
# 如果出现其他异常,返回错误信息
return {"error": f"Unexpected error: {str(e)}"}
........
1.3 微调与二次开发策略
对于医疗健康领域,通用大模型往往需要进行领域适应和知识增强。采用了以下策略:
- RAG技术:不直接微调模型,而是通过检索增强生成(RAG)技术,将专业知识库、知识图谱和互联网搜索结果注入到模型生成过程中
- 提示工程:通过分类引导的提示模板,引导模型生成符合医疗领域要求的回答
这种方法相比完全微调模型有以下优势:
- 降低了计算资源需求
- 提高了知识更新的灵活性
- 保持了模型的通用能力,同时增强了领域特定知识
2. 训练环境搭建
虽然主要采用预训练模型和RAG技术,但为了支持向量嵌入模型的运行和可能的微调需求,系统仍需要适当的环境配置:
2.1 硬件环境
推荐的硬件环境包括:
- GPU:至少一张支持CUDA的GPU(如NVIDIA RTX系列)用于向量嵌入模型的运行
- 内存:至少16GB RAM,推荐32GB以上,特别是处理大型知识库时
- 存储:SSD存储,至少100GB用于模型、向量数据库和知识库文件
2.2 软件环境
系统的软件环境配置如下:
- 操作系统:支持Linux、Windows和macOS
- Python环境:Python 3.10
- 深度学习框架:PyTorch
- 向量数据库:Oceanbase
- 图数据库:Neo4j(用于知识图谱)
2.3 环境优化
为了提高系统效率,代码中采用了多种优化技术:
- 多线程文档加载:使用
use_multithreading=True参数加速文档处理 - 模型缓存:自动下载并缓存嵌入模型,避免重复下载
- 向量检索优化:使用Oceanbase高效向量检索库,支持快速相似度搜索
采用多种搜索引擎融合搜索
def InternetSearchChain(question, history):
if os.path.exists(_SAVE_PATH):
shutil.rmtree(_SAVE_PATH)
if not os.path.exists(_SAVE_PATH):
os.makedirs(_SAVE_PATH)
whole_question = extract_question(question, history)
question_list = re.split(r"[;;]", whole_question)
requests.packages.urllib3.disable_warnings(InsecureRequestWarning)
threads = []
links = {}
# 为每个问题创建单独的线程
for question in question_list:
# 每个线程执行搜索操作 (bing搜索)
thread = threading.Thread(target=search_bing, args=(question, links, 3))
threads.append(thread)
thread.start()
百度搜索
thread = threading.Thread(target=search_baidu, args=(question, links, 3))
threads.append(thread)
thread.start()
# 等待所有线程完成
for thread in threads:
thread.join()
if has_html_files(_SAVE_PATH):
docs, _context = retrieve_html(question)
prompt = f"根据你现有的知识,辅助以搜索到的文件资料:\n{_context}\n 回答问题:\n{question}\n 尽可能多的覆盖到文件资料"
else:
prompt = question
response = Clientfactory().get_client().chat_with_ai_stream(prompt)
return response, links, has_html_files(_SAVE_PATH)
3. 训练数据准备
虽然主要使用中文长序列预训练模型,但RAG系统需要处理大量文档数据。系统支持多种文档格式,并进行了专门的数据处理:
3.1 支持的数据格式
系统支持多种文档格式作为知识库来源:
- PDF文档(使用PyPDFLoader)
- Word文档(使用UnstructuredWordDocumentLoader)
- 文本文件(使用TextLoader,自动检测编码)
- CSV文件(使用CSVLoader)
- HTML/MHTML文件(使用UnstructuredHTMLLoader/MHTMLLoader)
- Markdown文件(使用UnstructuredMarkdownLoader)
3.2 数据处理流程
RAG系统的数据处理流程包括:
- 文档加载:从指定目录加载各种格式的文档
- 文档分割:使用RecursiveCharacterTextSplitter将文档分割成适当大小的块
- 向量化:使用ModelScopeEmbeddings将文本块转换为向量表示
- 索引构建:使用Oceanbase构建向量索引,支持高效检索
3.3 数据分块策略
系统采用了以下分块策略来优化检索效果:
- 块大小(chunk_size):2000字符,平衡了上下文完整性和检索精度
- 块重叠(chunk_overlap):100字符,确保语义连贯性,避免信息丢失
3.4 用户特定数据处理
系统支持为不同用户构建独立的知识库,实现个性化检索:
用户可以上传自己的文档,系统会自动处理并构建用户专属的向量库:
4. 模型训练与优化策略
知识图谱采用openkg的数据集来丰富样本集: 面向家庭常见疾病的知识图谱, 分别进行数据打标和分类
数据打标:
节点标识:
['一级科室', '二级科室', '其他',"检查手段","治疗方案","生产商","疾病","症状","药物","食物","食谱"]
关系类型:
['好评药物', '宜吃', '属于', '常用药物', '并发症','忌吃','所属科室','推荐食谱','治疗方法','生产药品','症状','诊断建议']
整理实体关系
def relation_tool(entities: List[Dict] | None) -> str | None:
if not entities or len(entities) == 0:
return None
relationships = set() # 使用集合来避免重复关系
relationship_match = []
searchKey = Config.get_instance().get_with_nested_params("model", "graph-entity", "search-key")
# 遍历每个实体并查询与其他实体的关系
for entity in entities:
entity_name = entity[searchKey]
for k, v in entity.items():
relationships.add(f"{entity_name} {k}: {v}")
# 查询每个实体与其他实体的关系a-r-b
relationship_match.append(_dao.query_relationship_by_name(entity_name))
# 抽取并记录每个实体与其他实体的关系
for i in range(len(relationship_match)):
for record in relationship_match[i]:
# 获取起始节点和结束节点的名称
start_name = record["r"].start_node[searchKey]
end_name = record["r"].end_node[searchKey]
# 获取关系类型
rel = type(record["r"]).__name__ # 获取关系名称,比如 CAUSES
# 构建关系字符串并添加到集合,确保不会重复添加
relationships.add(f"{start_name} {rel} {end_name}")
# 返回关系集合的内容
if relationships:
return ";".join(relationships)
else:
return None
4.1 RAG系统优化
RAG系统的关键优化参数包括:
- 检索数量(k):设置为6,表示每次检索返回最相关的6个文档片段
- 向量维度:由所选ModelScope嵌入模型决定
- 相似度计算:使用FAISS的高效向量相似度计算
4.2 Internet搜索模型
Internet搜索模块使用类似的向量检索策略,但专注于从网络下载到本地的文件:
class InternetModel(Modelbase):
_retriever: VectorStoreRetriever
def __init__(self,*args,**krgs):
super().__init__(*args,**krgs)
# 此处需要改成下载embedding模型的位置
self._embedding_model_path =Config.get_instance().get_with_nested_params("model", "embedding", "model-name")
self._text_splitter = RecursiveCharacterTextSplitter
#self._embedding = OpenAIEmbeddings()
self._embedding = ModelScopeEmbeddings(model_id=self._embedding_model_path)
self._data_path = os.path.join(get_app_root(), "data/cache/internet")
#self._logger: Logger = Logger("rag_retriever")
# 建立向量库
def build(self):
# 加载html文件
html_loader = DirectoryLoader(self._data_path, glob="**/*.html", loader_cls=UnstructuredHTMLLoader, silent_errors=True, use_multithreading=True)
html_docs = html_loader.load()
mhtml_loader = DirectoryLoader(self._data_path, glob="**/*.mhtml", loader_cls=MHTMLLoader, silent_errors=True, use_multithreading=True)
mhtml_docs = mhtml_loader.load()
#合并文档
docs = html_docs + mhtml_docs
# 创建一个 RecursiveCharacterTextSplitter 对象,用于将文档分割成块,chunk_size为最大块大小,chunk_overlap块之间可以重叠的大小
text_splitter = RecursiveCharacterTextSplitter(chunk_size=2000, chunk_overlap=100)
splits = text_splitter.split_documents(docs)
# 使用 Oceanbase 创建一个向量数据库,存储分割后的文档及其嵌入向量
vectorstore = OceanbaseVectorStore.from_documents(documents=splits, embedding=self._embedding)
# 将向量存储转换为检索器,设置检索参数 k 为 6,即返回最相似的 6 个文档
self._retriever = vectorstore.as_retriever(search_kwargs={"k": 6})
@property
def retriever(self)-> VectorStoreRetriever:
self.build()
return self._retriever
4.3 分布式与并行处理
系统在多个层面实现了并行处理:
- 文档加载并行化:使用多线程加载文档
- 检索并行化:同时从多个知识源(知识库、知识图谱、互联网)检索信息
这种并行策略显著提高了系统的响应速度,特别是在处理大量文档和复杂查询时。
(五) 测试与验证
文本交流界面
病例识别界面
ppt&&word生成展示
知识图谱展示
联网搜索
语音对话界面
总结
项目在模型选择与训练方面采用了灵活而高效的策略。通过结合预训练大模型、RAG技术和多模态处理能力,系统能够为医疗健康领域提供专业、准确的信息,同时保持良好的可扩展性和适应性。系统的设计特别关注了医疗资源不平衡的问题,通过支持离线语音识别、多语言处理和个性化知识库等功能,使医疗信息更加普惠可及。