Spring Boot 中使用预训练模型进行意图识别

513 阅读2分钟

在 Spring Boot 中使用预训练模型进行意图识别,通常需要结合自然语言处理(NLP)模型和深度学习框架。以下是实现方法和常用预训练模型的详细指南:


一、常用的意图识别预训练模型

以下模型均支持 微调(Fine-tuning)  以适应特定意图分类任务:

模型名称特点适用场景框架支持
BERT双向Transformer,通用性强,支持多语言通用意图分类、复杂语义理解PyTorch/TensorFlow
RoBERTaBERT的优化版,训练更充分,效果更佳高精度要求的意图识别PyTorch
DistilBERTBERT的轻量版,速度更快,性能接近原版资源有限的边缘部署PyTorch
ALBERT参数共享技术,模型体积小,适合移动端移动端应用或低延迟场景TensorFlow/PyTorch
ERNIE(百度)中文优化,融入实体知识,中文意图识别效果更优中文对话系统、客服场景PaddlePaddle
ELECTRA训练效率高,小模型也能达到BERT性能快速迭代和低成本场景TensorFlow/PyTorch

二、Spring Boot 集成预训练模型的两种方案

方案1:本地加载模型(通过 ONNX/DJL)

适用于直接部署模型到 Java 环境。

步骤1:将模型转换为 ONNX 格式(以 BERT 为例)
# 使用 Python 转换 PyTorch 模型到 ONNX
from transformers import BertForSequenceClassification, BertTokenizer
import torch

model = BertForSequenceClassification.from_pretrained("bert-base-uncased")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

dummy_input = tokenizer("test", return_tensors="pt")
torch.onnx.export(
    model, 
    (dummy_input["input_ids"], dummy_input["attention_mask"]),
    "bert_intent.onnx",
    input_names=["input_ids", "attention_mask"],
    output_names=["logits"],
    dynamic_axes={"input_ids": {0: "batch"}, "attention_mask": {0: "batch"}}
)
步骤2:Spring Boot 中加载 ONNX 模型

使用 Deep Java Library (DJL)

<!-- pom.xml 添加依赖 -->
<dependency>
    <groupId>ai.djl</groupId>
    <artifactId>api</artifactId>
    <version>0.23.0</version>
</dependency>
<dependency>
    <groupId>ai.djl.onnxruntime</groupId>
    <artifactId>onnxruntime-engine</artifactId>
    <version>0.23.0</version>
</dependency>
// 意图分类服务类
import ai.djl.Model;
import ai.djl.inference.Predictor;
import ai.djl.modality.nlp.bert.BertTokenizer;
import ai.djl.translate.TranslateException;

@Service
public class BertIntentService {

    private Predictor<String, String> predictor;
    private BertTokenizer tokenizer;

    @PostConstruct
    public void init() throws ModelNotFoundException, IOException {
        Model model = Model.newInstance("bert_intent");
        model.load(Paths.get("path/to/bert_intent.onnx"));
        
        // 初始化分词器和预测器
        tokenizer = new BertTokenizer();
        predictor = model.newPredictor(new SimpleTranslator(tokenizer));
    }

    public String predictIntent(String text) throws TranslateException {
        return predictor.predict(text);
    }
}

方案2:通过 REST API 调用 Python 服务

适用于模型部署在 Python 环境,Spring Boot 通过 HTTP 调用。

Python 端(Flask 服务)
from flask import Flask, request, jsonify
from transformers import pipeline

app = Flask(__name__)
classifier = pipeline("text-classification", model="bert-base-uncased")

@app.route("/predict", methods=["POST"])
def predict():
    text = request.json["text"]
    result = classifier(text)
    return jsonify({"intent": result[0]["label"], "score": result[0]["score"]})

if __name__ == "__main__":
    app.run(host="0.0.0.0", port=5000)
Spring Boot 调用代码
@Service
public class PythonModelService {

    private final WebClient webClient;

    public PythonModelService(WebClient.Builder webClientBuilder) {
        this.webClient = webClientBuilder.baseUrl("http://localhost:5000").build();
    }

    public String predictIntent(String text) {
        return webClient.post()
                .uri("/predict")
                .bodyValue(Map.of("text", text))
                .retrieve()
                .bodyToMono(JsonNode.class)
                .map(response -> response.get("intent").asText())
                .block();
    }
}

三、完整 Spring Boot 意图识别流程

  1. 预处理文本

    • 分词、去除停用词、转换为小写。
    public String preprocess(String text) {
        return text.toLowerCase().replaceAll("[^a-zA-Z0-9\s]", "");
    }
    
  2. 调用模型推理

    @RestController
    public class IntentController {
        @Autowired
        private BertIntentService bertService;
    
        @PostMapping("/detect-intent")
        public String detectIntent(@RequestBody String userInput) {
            String cleanedText = preprocess(userInput);
            return bertService.predictIntent(cleanedText);
        }
    }
    
  3. 返回意图结果

    • 输出示例:{"intent": "book_flight", "confidence": 0.92}