理解Embedding的本质、掌握向量化模型选型、实现语义搜索和相似度计算
时间:45分钟 | 难度:⭐⭐⭐ | Week 3 Day 16
官方Example信息
- GitHub链接:EmbeddingExample.java
- 相关Example:EmbeddingModelExample、EmbeddingStoreExample
- 所在路径:src/main/java/dev/langchain4j/examples/
- 代码行数:约50-100行
- 难度:中级 ⭐⭐⭐
学习目标
- 理解Embedding的概念和工作原理
- 掌握不同Embedding模型的选型
- 学会计算向量相似度
- 使用LangChain4J的EmbeddingModel接口
- 实现基于Embedding的语义搜索
- 掌握Embedding缓存和优化策略
🚀 快速入门:什么是Embedding?
Embedding的本质
Embedding = 把文本转换为数字向量
"今天天气真好" → [0.23, -0.15, 0.87, 0.42, ..., 0.11] (1536维向量)
"天气很不错" → [0.21, -0.13, 0.85, 0.40, ..., 0.09] (相似!)
"我要写代码" → [-0.55, 0.72, -0.12, 0.33, ..., 0.68] (不同!)
核心思想:
- 语义相近的文本 → 向量距离近
- 语义不同的文本 → 向量距离远
为什么需要Embedding?
问题:计算机不理解"语义"
- "苹果公司" vs "苹果水果" → 文字相同,含义不同
- "快乐" vs "开心" → 文字不同,含义相同
解决:Embedding把语义编码为数学向量
- "快乐" → [0.8, 0.6, 0.2, ...]
- "开心" → [0.79, 0.61, 0.19, ...] ← 距离很近!
- "悲伤" → [-0.7, -0.5, 0.1, ...] ← 距离很远!
Embedding工作原理
┌──────────────────────────────────────┐
│ Embedding 工作原理 │
│ │
│ 文本输入 │
│ │ │
│ ▼ │
│ ┌──────────────┐ │
│ │ Tokenization │ 文本 → Token序列 │
│ └──────┬───────┘ │
│ │ │
│ ▼ │
│ ┌──────────────┐ │
│ │ 神经网络模型 │ Token → 隐层表示 │
│ │ (Transformer)│ │
│ └──────┬───────┘ │
│ │ │
│ ▼ │
│ ┌──────────────┐ │
│ │ 池化层 │ 多个向量 → 1个向量 │
│ │ (Pooling) │ │
│ └──────┬───────┘ │
│ │ │
│ ▼ │
│ [0.23, -0.15, 0.87, ..., 0.11] │
│ (固定维度的浮点数向量) │
└──────────────────────────────────────┘
深度讲解
1️⃣ Embedding模型对比
| 模型 | 维度 | 性能 | 成本 | 语言 | 适用场景 |
|---|---|---|---|---|---|
| text-embedding-3-small | 1536 | ⭐⭐⭐⭐ | $0.00002/1K | 多语言 | 通用推荐 |
| text-embedding-3-large | 3072 | ⭐⭐⭐⭐⭐ | $0.00013/1K | 多语言 | 高精度 |
| all-MiniLM-L6-v2 | 384 | ⭐⭐⭐ | 免费(本地) | 英文 | 低延迟 |
| bge-large-zh | 1024 | ⭐⭐⭐⭐ | 免费(本地) | 中文 | 中文专用 |
| Ollama embedding | 768-4096 | ⭐⭐⭐ | 免费(本地) | 多语言 | 离线场景 |
2️⃣ 在LangChain4J中使用EmbeddingModel
OpenAI Embedding
// 1. 配置OpenAI Embedding模型
@Bean
public EmbeddingModel embeddingModel() {
return OpenAiEmbeddingModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.modelName("text-embedding-3-small") // 1536维
.build();
}
// 2. 基本用法:单文本向量化
@Service
public class EmbeddingService {
@Autowired
private EmbeddingModel embeddingModel;
public Embedding embed(String text) {
Response<Embedding> response = embeddingModel.embed(text);
Embedding embedding = response.content();
System.out.println("维度:" + embedding.vector().length); // 1536
System.out.println("前5个值:" + Arrays.toString(
Arrays.copyOf(embedding.vector(), 5)));
return embedding;
}
}
批量Embedding
// 批量处理多个文本
@Service
public class BatchEmbeddingService {
@Autowired
private EmbeddingModel embeddingModel;
public List<Embedding> embedBatch(List<String> texts) {
// 将文本转为TextSegment
List<TextSegment> segments = texts.stream()
.map(TextSegment::from)
.collect(Collectors.toList());
// 批量Embedding(一次API调用)
Response<List<Embedding>> response = embeddingModel.embedAll(segments);
return response.content();
}
// 使用示例
public void demo() {
List<String> texts = List.of(
"Java是一种编程语言",
"Python适合数据分析",
"今天天气不错"
);
List<Embedding> embeddings = embedBatch(texts);
System.out.println("生成了 " + embeddings.size() + " 个向量");
}
}
本地Embedding模型
// 使用ONNX本地模型(无需API调用)
@Bean
public EmbeddingModel localEmbeddingModel() {
return new AllMiniLmL6V2EmbeddingModel();
// 384维,英文模型,完全本地运行
}
// 使用Ollama本地模型
@Bean
public EmbeddingModel ollamaEmbeddingModel() {
return OllamaEmbeddingModel.builder()
.baseUrl("http://localhost:11434")
.modelName("nomic-embed-text")
.build();
}
3️⃣ 向量相似度计算
三种相似度算法对比
| 算法 | 公式 | 范围 | 特点 | 推荐场景 |
|---|---|---|---|---|
| 余弦相似度 | cos(A,B) | [-1, 1] | 只关注方向 | 文本相似度(最常用) |
| 欧几里得距离 | ||A-B|| | [0, ∞) | 关注距离 | 聚类分析 |
| 点积 | A·B | (-∞, +∞) | 关注方向+大小 | 推荐系统 |
手动实现余弦相似度
/**
* 余弦相似度:衡量两个向量的方向相似程度
* 值越接近1 = 越相似
* 值越接近0 = 无关
* 值越接近-1 = 完全相反
*/
public class CosineSimilarity {
public static double calculate(float[] vectorA, float[] vectorB) {
if (vectorA.length != vectorB.length) {
throw new IllegalArgumentException("向量维度不同");
}
double dotProduct = 0.0; // 点积
double normA = 0.0; // A的模
double normB = 0.0; // B的模
for (int i = 0; i < vectorA.length; i++) {
dotProduct += vectorA[i] * vectorB[i];
normA += vectorA[i] * vectorA[i];
normB += vectorB[i] * vectorB[i];
}
if (normA == 0 || normB == 0) return 0;
return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
}
// 测试
public static void main(String[] args) {
float[] v1 = {1.0f, 2.0f, 3.0f};
float[] v2 = {1.0f, 2.0f, 3.0f};
float[] v3 = {-1.0f, -2.0f, -3.0f};
System.out.println("完全相同:" + calculate(v1, v2)); // 1.0
System.out.println("完全相反:" + calculate(v1, v3)); // -1.0
}
}
使用LangChain4J内置相似度
// LangChain4J提供了内置的相似度计算
@Service
public class SimilarityService {
@Autowired
private EmbeddingModel embeddingModel;
public double calculateSimilarity(String text1, String text2) {
Embedding e1 = embeddingModel.embed(text1).content();
Embedding e2 = embeddingModel.embed(text2).content();
// 使用内置的余弦相似度
return CosineSimilarity.between(e1, e2);
}
// 比较多个文本的相似度
public void comparePairs() {
String[] texts = {
"Java是一种面向对象编程语言",
"Java是一种OOP语言",
"今天天气很好",
"Python是一种脚本语言"
};
for (int i = 0; i < texts.length; i++) {
for (int j = i + 1; j < texts.length; j++) {
double sim = calculateSimilarity(texts[i], texts[j]);
System.out.printf("[%.3f] '%s' vs '%s'%n", sim, texts[i], texts[j]);
}
}
}
}
// 预期输出:
// [0.923] 'Java是一种面向对象编程语言' vs 'Java是一种OOP语言' ← 高相似
// [0.412] 'Java是一种面向对象编程语言' vs '今天天气很好' ← 低相似
// [0.756] 'Java是一种面向对象编程语言' vs 'Python是一种脚本语言' ← 中等相似
4️⃣ 语义搜索实现
语义搜索流程:
┌─────────┐ ┌──────────┐ ┌──────────┐
│ 用户查询 │───>│ 向量化 │───>│ 查询向量 │
└─────────┘ └──────────┘ └────┬─────┘
│
比较相似度 │
▼
┌─────────┐ ┌──────────┐ ┌──────────┐
│ 最相似 │<───│ 排序 │<───│ 文档向量库│
│ 的文档 │ └──────────┘ └──────────┘
└─────────┘
/**
* 简单的语义搜索实现
*/
@Service
public class SemanticSearchService {
@Autowired
private EmbeddingModel embeddingModel;
@Autowired
private EmbeddingStore<TextSegment> embeddingStore;
/**
* 索引文档(存入向量库)
*/
public void indexDocuments(List<String> documents) {
for (String doc : documents) {
TextSegment segment = TextSegment.from(doc);
Embedding embedding = embeddingModel.embed(segment).content();
embeddingStore.add(embedding, segment);
}
System.out.println("已索引 " + documents.size() + " 个文档");
}
/**
* 语义搜索(按相似度排序)
*/
public List<SearchResult> search(String query, int topK) {
// 1. 把查询文本向量化
Embedding queryEmbedding = embeddingModel.embed(query).content();
// 2. 在向量库中搜索最相似的文档
List<EmbeddingMatch<TextSegment>> matches =
embeddingStore.findRelevant(queryEmbedding, topK);
// 3. 返回结果
return matches.stream()
.map(match -> new SearchResult(
match.embedded().text(),
match.score()))
.collect(Collectors.toList());
}
// 使用示例
public void demo() {
// 索引
indexDocuments(List.of(
"LangChain4J是Java的LLM框架",
"Spring Boot是Java Web框架",
"React是JavaScript前端框架",
"向量数据库用于存储Embedding"
));
// 搜索
List<SearchResult> results = search("Java框架", 3);
for (SearchResult r : results) {
System.out.printf("[%.3f] %s%n", r.score(), r.text());
}
// 预期:LangChain4J和Spring Boot排在前面
}
}
record SearchResult(String text, double score) {}
5️⃣ Embedding与EmbeddingStore集成
/**
* 使用InMemoryEmbeddingStore(适合开发和小规模场景)
*/
@Configuration
public class EmbeddingStoreConfig {
@Bean
public EmbeddingStore<TextSegment> embeddingStore() {
return new InMemoryEmbeddingStore<>();
}
@Bean
public EmbeddingModel embeddingModel() {
return OpenAiEmbeddingModel.builder()
.apiKey(System.getenv("OPENAI_API_KEY"))
.modelName("text-embedding-3-small")
.build();
}
/**
* ContentRetriever = EmbeddingStore + EmbeddingModel
* 用于RAG系统的内容检索
*/
@Bean
public ContentRetriever contentRetriever(
EmbeddingStore<TextSegment> store,
EmbeddingModel model) {
return EmbeddingStoreContentRetriever.builder()
.embeddingStore(store)
.embeddingModel(model)
.maxResults(5) // 最多返回5个结果
.minScore(0.7) // 最低相似度0.7
.build();
}
}
6️⃣ Embedding缓存策略
/**
* 缓存Embedding避免重复计算
* 同一文本不需要每次都调用API
*/
@Service
public class CachedEmbeddingService {
@Autowired
private EmbeddingModel embeddingModel;
// 使用LRU缓存
private final Map<String, Embedding> cache =
Collections.synchronizedMap(new LinkedHashMap<>(1000, 0.75f, true) {
@Override
protected boolean removeEldestEntry(Map.Entry eldest) {
return size() > 10000; // 最多缓存10000个
}
});
public Embedding embed(String text) {
// 先查缓存
String key = text.hashCode() + "";
Embedding cached = cache.get(key);
if (cached != null) {
return cached;
}
// 缓存未命中,调用API
Embedding embedding = embeddingModel.embed(text).content();
cache.put(key, embedding);
return embedding;
}
}
// 使用Spring Cache(更优雅)
@Service
public class SpringCachedEmbedding {
@Autowired
private EmbeddingModel embeddingModel;
@Cacheable(value = "embeddings", key = "#text.hashCode()")
public Embedding embed(String text) {
return embeddingModel.embed(text).content();
}
}
💻 实战:完整的文本相似度比较系统
@Service
public class TextSimilaritySystem {
@Autowired
private EmbeddingModel embeddingModel;
/**
* 文本分类:找出最匹配的类别
*/
public String classify(String text, Map<String, String> categories) {
Embedding textEmbedding = embeddingModel.embed(text).content();
String bestCategory = null;
double bestScore = -1;
for (Map.Entry<String, String> entry : categories.entrySet()) {
Embedding catEmbedding = embeddingModel.embed(entry.getValue()).content();
double score = CosineSimilarity.between(textEmbedding, catEmbedding);
if (score > bestScore) {
bestScore = score;
bestCategory = entry.getKey();
}
}
return bestCategory + " (置信度: " + String.format("%.2f", bestScore) + ")";
}
/**
* 去重检测:找出相似度过高的文本
*/
public List<String[]> findDuplicates(List<String> texts, double threshold) {
List<Embedding> embeddings = texts.stream()
.map(t -> embeddingModel.embed(t).content())
.collect(Collectors.toList());
List<String[]> duplicates = new ArrayList<>();
for (int i = 0; i < texts.size(); i++) {
for (int j = i + 1; j < texts.size(); j++) {
double sim = CosineSimilarity.between(embeddings.get(i), embeddings.get(j));
if (sim >= threshold) {
duplicates.add(new String[]{texts.get(i), texts.get(j),
String.format("%.3f", sim)});
}
}
}
return duplicates;
}
// 使用示例
public void demo() {
// 文本分类
Map<String, String> categories = Map.of(
"技术", "编程、代码、软件、开发、框架",
"天气", "晴天、下雨、温度、气候",
"美食", "餐厅、烹饪、食材、菜品"
);
System.out.println(classify("Spring Boot如何配置", categories));
// 输出:技术 (置信度: 0.82)
// 去重检测
List<String[]> dups = findDuplicates(List.of(
"Java是编程语言", "Java是一种编程语言", "今天天气好"
), 0.85);
dups.forEach(d -> System.out.printf("重复: '%s' ≈ '%s' (%.3s)%n", d[0], d[1], d[2]));
}
}
🔧 最佳实践
✅ 好的做法
// 1. 选择合适的模型维度
// 小规模项目:384维(all-MiniLM-L6-v2)
// 生产项目:1536维(text-embedding-3-small)
// 高精度需求:3072维(text-embedding-3-large)
// 2. 批量处理减少API调用
List<TextSegment> segments = texts.stream()
.map(TextSegment::from).toList();
embeddingModel.embedAll(segments); // 一次调用
// 3. 缓存常用的Embedding
@Cacheable("embeddings")
public Embedding embed(String text) { ... }
// 4. 设置合理的相似度阈值
ContentRetriever retriever = EmbeddingStoreContentRetriever.builder()
.minScore(0.7) // 过滤低质量结果
.build();
❌ 坏的做法
// 1. 逐条处理(浪费API调用)
for (String text : texts) {
embeddingModel.embed(text); // N次API调用
}
// 2. 不缓存(重复计算浪费钱)
// 每次查询都重新计算Embedding
// 3. 模型选型不当
// 简单场景用3072维模型(浪费)
// 高精度场景用384维模型(不够)
学习成果检查:
- 能解释Embedding的工作原理
- 能选择合适的Embedding模型
- 能计算向量相似度
- 能实现语义搜索
- 能优化Embedding的性能和成本
下一步:学习向量数据库,掌握大规模向量存储和高效检索。