本文将简单介绍下LangChain4j框架对文本分类能力实现与支持!
什么是文本分类
文本分类是自然语言处理的一个基本任务,将文本分类为一个或多个不同类别以组织、构造和过滤成任何参数的过程。例如,将新闻文章分为“政治”、“经济”、“科技”等类别,或将电子邮件分为“垃圾邮件”,“非垃圾邮件”。
应用场景
- 垃圾邮件过滤:将邮件分为“垃圾邮件”和“非垃圾邮件”。对垃圾邮件进行拦截并处理。
- 新闻推荐:根据用户阅读历史,将新闻分为“关注” 和 “非关注”。
- 情感分析/词性标注:将文本划分为“积极”、“消极” 和 “中性”。
- 识别仇恨言论:将文本划分为 “反动言论”、“非反动言论”,一般在论坛场景中,提示发言违规。
- 营销与广告:同新闻推荐一样的思路!比如商品分为“好评”、“中性”和“差评”,从而对商品进行不同程度的营销策略等。
- ....
其实还有很多应用场景,在遇到相应问题时,在脑海中会浮现,哟!还有这么一种文本分类的解决方案。
LangChain4j 框架对分本分类的支持
TextClassifier 接口
首先:需要定义一个分类的枚举,接口泛型E必须是枚举
public interface TextClassifier<E extends Enum<E>> {
/**
* 对文本进行分类
*
* @param text 需要分类的文本
* @return 文本所属分类列表,比如某个文本既属于A、又属于 B、C、D.
*/
List<E> classify(String text);
/**
* 对文本片段进行分类,文本片段可能是一个句子、段落或者有意义的文本单位
* TextSegment:包含文本和一些元数据(在构建TextSegment开发者放入,比如用户ID)
*
* @param textSegment 需要分类的文本片段.
* @return 文本所属分类列表。
*/
default List<E> classify(TextSegment textSegment) {
return classify(textSegment.text());
}
/**
* 对文档进行分类,一般代表一个文件(txt、doc、pdf、ppt、html)中的内容。
*
* @param document 要分类的文档。
* @return 文本所属分类列表。
*/
default List<E> classify(Document document) {
return classify(document.text());
}
}
EmbeddingModelTextClassifier 实现类
EmbeddingModelTextClassifier 底层实现依然依赖于嵌入模型,再加上一些预定义示例(文本与分类标签的关系)。
注意:每个标签的示例数量增多,分类质量就会提高。但是耗时、算力都会增加!!!
public class EmbeddingModelTextClassifier<E extends Enum<E>> implements TextClassifier<E> {
// 使用的嵌入模型
private final EmbeddingModel embeddingModel;
// 枚举与嵌入对象的对应关系
private final Map<E, List<Embedding>> exampleEmbeddingsByLabel;
// 返回的结果数
private final int maxResults;
// 控制返回的所有数据评分必须大于 minScore
private final double minScore;
// 这是一个在[0~1]之间的值。
private final double meanToMaxScoreRatio;
@Override
public List<E> classify(String text) {
// 向量化
Embedding textEmbedding = embeddingModel.embed(text).content();
List<LabelWithScore> labelsWithScores = new ArrayList<>();
exampleEmbeddingsByLabel.forEach((label, exampleEmbeddings) -> {
// 计算每个标签,在使用余弦相似度进行计算最高分和平均分
double meanScore = 0;
double maxScore = 0;
for (Embedding exampleEmbedding : exampleEmbeddings) {
// 余弦相似度计算相似度得分
double cosineSimilarity = CosineSimilarity.between(textEmbedding, exampleEmbedding);
// ????
double score = RelevanceScore.fromCosineSimilarity(cosineSimilarity);
// 计算得分总和
meanScore += score;
// 计算最大分
maxScore = Math.max(score, maxScore);
}
// 计算平均分
meanScore /= exampleEmbeddings.size();
labelsWithScores.add(new LabelWithScore(label, aggregatedScore(meanScore, maxScore)));
});
return labelsWithScores.stream()
.filter(it -> it.score >= minScore)
// sorting in descending order to return highest score first
.sorted(comparingDouble(labelWithScore -> 1 - labelWithScore.score))
.limit(maxResults)
.map(it -> it.label)
.collect(toList());
}
// 计算最终的得分,其中 meanScore平均分, maxScore:最大分
private double aggregatedScore(double meanScore, double maxScore) {
return (meanToMaxScoreRatio * meanScore) + ((1 - meanToMaxScoreRatio) * maxScore);
}
}
解释下每个字段的含义;
- maxResults 与 minScore,如下图在返回数据时使用,作用已经相当明显了。
- meanToMaxScoreRatio:在计算最终得分时使用平均分和最大分的比率。比如:当
meanToMaxScoreRatio=0时,则使用最大分作为最终得分,当meanToMaxScoreRatio=1时,则使用平均分作为最后的得分。
代码实现
实现一个情感分析
依赖
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-ollama-spring-boot-starter</artifactId>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j</artifactId>
</dependency>
application.yml
langchain4j:
ollama:
embedding-model:
base-url: http://localhost:11434
model-name: qwen:7b
枚举定义
package org.ivy.classification.enums;
import lombok.AllArgsConstructor;
import lombok.Getter;
@Getter
@AllArgsConstructor
public enum Sentiment {
POSITIVE, NEUTRAL, NEGATIVE
}
示例数据定义
package org.ivy.classification.example;
import org.ivy.classification.enums.Sentiment;
import java.util.List;
import java.util.Map;
public class SentimentExamples {
public static final Map<Sentiment, List<String>> examples = Map.of(
Sentiment.POSITIVE, List.of("This is great!", "Wow, awesome!"),
Sentiment.NEUTRAL, List.of("Well, it's fine", "It's ok"),
Sentiment.NEGATIVE, List.of("It is pretty bad", "Worst experience ever!")
);
}
Service实现
package org.ivy.classification.service;
import dev.langchain4j.classification.EmbeddingModelTextClassifier;
import dev.langchain4j.model.ollama.OllamaEmbeddingModel;
import lombok.RequiredArgsConstructor;
import org.ivy.classification.enums.CustomerServiceCategory;
import org.ivy.classification.enums.Sentiment;
import org.ivy.classification.example.ClassifyExamples;
import org.ivy.classification.example.SentimentExamples;
import org.springframework.stereotype.Service;
import java.util.List;
@Service
@RequiredArgsConstructor
public class ClassifyService {
private final OllamaEmbeddingModel ollamaEmbeddingModel;
/**
* 情感分析
*
* @param text 文本
* @return 分类列表
*/
public List<Sentiment> classifySentiment(String text) {
EmbeddingModelTextClassifier<Sentiment> classifier =
new EmbeddingModelTextClassifier<>(ollamaEmbeddingModel, SentimentExamples.examples);
return classifier.classify(text);
}
}
Controller 实现
package org.ivy.classification.controller;
import lombok.RequiredArgsConstructor;
import org.ivy.classification.enums.CustomerServiceCategory;
import org.ivy.classification.enums.Sentiment;
import org.ivy.classification.service.ClassifyService;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import java.util.List;
@RestController
@RequestMapping("/classify/")
@RequiredArgsConstructor
public class ClassificationController {
private final ClassifyService classificationService;
@GetMapping("/sentiment")
public List<Sentiment> sentiment(String text) {
return classificationService.classifySentiment(text);
}
}
测试结果 「http/rest.http」
GET /classify/sentiment?text=Awesome! HTTP/1.1
Host: localhost:8080
Content-Type: application/json
###
GET /classify/sentiment?text=I love it HTTP/1.1
Host: localhost:8080
Content-Type: application/json
###
GET /classify/sentiment?text=Worst website ever! HTTP/1.1
Host: localhost:8080
Content-Type: application/json
代码示例与总结
本文主要介绍了文本分类的一些应用场景,并详细介绍了LangChain4j框架接口和实现类,并对源码进行了分析。在示例代码中提供了情感分析示例代码和客户服务分类两个场景的实现。由于篇幅没有将后者的实现代码在本文中体现。
客户服务分类 示例由于提供的分类数据过多,并且使用本地大模型,可能无法运行出结果!如果无法获取结果,大家就将就看看代码吧!!!