Embedding,实现语义化搜索的神器

7 阅读6分钟

Embedding是什么

Embedding从字面意思上来说是嵌入,但在机器学习自然语言处理的上下文中,我们通常把它理解成“向量化”或“向量表示”的技术。这有助于更加准确地描述在这些领域中的作用和应用。

NLP中的Embedding

在NLP(自然语言处理)中,Embedding是指将文本中的词语转换成数值向量的技术。为的是捕捉词语中的语义信息和它们的关系。通过这种方法,机器能更地理解人类的语言,并执行各种任务。

在NLP中,有一种常见的Embedding技术,就是谷歌推出的WordsVec。这个技术包含两种模型——Continuous Bag Of Words(CBOW)和Skip-gram。这两个模型都是为了学习词的向量表示,其中CBOW是根据周围的词来预测中心词,而Skip-gram是根据一个词来预测周围的词。

机器学习中的Embedding

在机器学习中,尤其是需要处理非数值型数据(如文本、图像和类型变量等)时,Embedding是一种将高维稀疏数据映射到低维稠密向量空间的技术。它能有效地将复杂的输入数据转化成机器可以理解和处理的形式。

Embedding的原理

Embedding的核心原理就是将高维的离散数据(如文本中的单词)映射早一个地位的连续向量空间,使得在这个新空间中,类似的对象在与以上也彼此接近。这种方法有助于降低数据维度,还能捕捉对象之间的潜在关系。再来看下它的一些基本原理:

  1. 分布式表示

传统的基于one-hot编码方法会导致非常稀疏的高维向量,导致难以处理。而Embedding通过分布式表示的方法,将每一个对象都表示成一个低维稠密向量。这样能更好的捕捉对象之间的复杂关系,从而减少模型参数的数量。 2. 上下文信息

许多的Embedding方法(如Word2Vec、Glove、ELMo、BERT等)都是利用上下文信息来学习词或对象的向量表示。比如在Word2Vec中,就是通过预测目标周围的词和根据周围的词预测目标词来训练模型。使得具有相似上下文的词在空间向量中更加接近。 3. 相似性保持

Embedding的重要一个特性就是它能保持原始数据中的相似性结构。也就是说,两个对象下原始输入空间中是相似的,那么它们在嵌入后的向量空间中也是应该靠近在一起的。 4. 学习过程

Embedding通常是作为更大的机器学习模型的一部分进行学习的。这些模型会在训练过程中自动调整嵌入向量,以最小化特定任务的损失函数。

在对Embedding有了一定了解后,我们来写一个AI搜索来巩固一下:

import Koa from 'koa';
import cors from '@koa/cors';
import Router from 'koa-router';
import bodyParser from 'koa-bodyparser';
import { client } from './app.service.mjs';
import fs from 'fs/promises';
const inputFilePath = './data/posts.json';
const outputFilePath = './data/posts_with_embedding.json';

const data = await fs.readFile(inputFilePath, 'utf-8');
const posts = JSON.parse(data);

const app = new Koa();
const router = new Router();
const port = 3002;

function cosineSimilarity(a, b) {
    if (a.length !== b.length) {
        throw new Error('向量长度不匹配');
    }

    let dotProduct = 0;
    let normA = 0;
    let normB = 0;

    for (let i = 0; i < a.length; i++) {
        dotProduct += a[i] * b[i];
        normA += a[i] * a[i];
        normB += b[i] * b[i];
    }

    return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
}

app.use(cors());
app.use(bodyParser());

router.post('/search',async(ctx) => {
  const {keyword} = ctx.request.body;
  console.log(keyword);
  const response = await client.embeddings.create({
    model:'text-embedding-ada-002',
    input:keyword,
  })
  const {embedding} = response.data[0];
  const results = posts.map(item => ({
    ...item,
    similarity: cosineSimilarity(embedding,item.embedding)
  }))
  .sort((a,b) => a.similarity - b.similarity)
  .reverse()
  .slice(0,3)
  .map((item,index) => ({id:index,title:`${index + 1}.${item.title},${item.category}`}))

  ctx.body = {
    status:200,
    data:results
  }

})

app.use(router.routes()).use(router.allowedMethods());
app.listen(port,() => {
 console.log(`Server is running at http://localhost:${port}`); 
})

这段代码实现了一个基于 Koa.js 的简单服务器,用于接收用户的搜索请求,并通过计算文本的嵌入向量(Embedding)之间的相似性来返回最相关的帖子。我们来详细解读下:


1. 导入依赖模块

import Koa from 'koa';
import cors from '@koa/cors';
import Router from 'koa-router';
import bodyParser from 'koa-bodyparser';
import { client } from './app.service.mjs';
import fs from 'fs/promises';
  • Koa: 一个轻量级的 Node.js Web 框架,用于构建服务器。
  • @koa/cors: 用于处理跨域资源共享(CORS),允许客户端从不同域访问服务器资源。
  • koa-router: 定义路由规则,处理不同的 HTTP 请求路径。
  • koa-bodyparser: 解析 HTTP 请求体中的 JSON 数据。
  • client: 从 ./app.service.mjs 文件中导入的模块,可能是与外部 API(如 OpenAI Embedding API)交互的客户端。
  • fs/promises: 提供基于 Promise 的文件系统操作。

2. 加载和解析数据

const inputFilePath = './data/posts.json';
const outputFilePath = './data/posts_with_embedding.json';

const data = await fs.readFile(inputFilePath, 'utf-8');
const posts = JSON.parse(data);
  • inputFilePath: 输入文件路径,存储了原始帖子数据。
  • outputFilePath: 输出文件路径,可能用于存储带有嵌入向量的帖子数据(但在这段代码中未使用)。
  • fs.readFile: 异步读取 posts.json 文件的内容。
  • JSON.parse: 将文件内容解析为 JavaScript 对象数组,假设每个对象表示一个帖子,包含标题、类别和嵌入向量等字段。

3. 初始化 Koa 应用

const app = new Koa();
const router = new Router();
const port = 3002;
  • app: 创建一个 Koa 实例。
  • router: 创建一个 Koa 路由实例,用于定义 API 路径和处理逻辑。
  • port: 设置服务器监听的端口号为 3002

4. 定义余弦相似度函数

function cosineSimilarity(a, b) {
    if (a.length !== b.length) {
        throw new Error('向量长度不匹配');
    }

    let dotProduct = 0;
    let normA = 0;
    let normB = 0;

    for (let i = 0; i < a.length; i++) {
        dotProduct += a[i] * b[i];
        normA += a[i] * a[i];
        normB += b[i] * b[i];
    }

    return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
}
  • cosineSimilarity: 计算两个向量之间的余弦相似度,衡量它们的语义相似性。
  • dotProduct: 向量点积。
  • normA 和 normB: 分别是向量 ab 的模(即向量的长度)。

如果两个向量完全相同,余弦相似度为 1;如果完全无关,则为 0


5. 配置中间件

app.use(cors());
app.use(bodyParser());
  • cors(): 允许跨域请求。
  • bodyParser(): 解析 HTTP 请求体中的 JSON 数据,方便提取用户输入的关键词。

6. 定义 /search 路由

router.post('/search', async (ctx) => {
  const { keyword } = ctx.request.body;
  console.log(keyword);

  const response = await client.embeddings.create({
    model: 'text-embedding-ada-002',
    input: keyword,
  });

  const { embedding } = response.data[0];

  const results = posts.map(item => ({
    ...item,
    similarity: cosineSimilarity(embedding, item.embedding)
  }))
  .sort((a, b) => a.similarity - b.similarity)
  .reverse()
  .slice(0, 3)
  .map((item, index) => ({
    id: index,
    title: `${index + 1}. ${item.title}, ${item.category}`
  }));

  ctx.body = {
    status: 200,
    data: results
  };
});

步骤详解

  1. 提取关键词:

    const { keyword } = ctx.request.body;
    console.log(keyword);
    
    • 从请求体中提取用户输入的关键词 keyword
  2. 生成关键词的嵌入向量:

    const response = await client.embeddings.create({
      model: 'text-embedding-ada-002',
      input: keyword,
    });
    const { embedding } = response.data[0];
    
    • 使用 client 调用外部 API(如 OpenAI Embedding API),将关键词转换为嵌入向量 embedding
    • 假设 response.data[0].embedding 是一个浮点数数组。
  3. 计算相似度并排序:

    const results = posts.map(item => ({
      ...item,
      similarity: cosineSimilarity(embedding, item.embedding)
    }))
    .sort((a, b) => a.similarity - b.similarity)
    .reverse()
    .slice(0, 3)
    .map((item, index) => ({
      id: index,
      title: `${index + 1}. ${item.title}, ${item.category}`
    }));
    
    • 计算相似度: 对每个帖子的嵌入向量 item.embedding 和关键词的嵌入向量 embedding 计算余弦相似度。
    • 排序: 按相似度从低到高排序后反转,得到从高到低的顺序。
    • 取前三个结果: 只保留相似度最高的三个帖子。
    • 格式化输出: 为每个帖子生成一个简化的对象,包含 idtitle
  4. 返回结果:

    ctx.body = {
      status: 200,
      data: results
    };
    
    • 返回 JSON 响应,包含状态码 200 和相似度最高的三个帖子信息。

7. 启动服务器

app.use(router.routes()).use(router.allowedMethods());
app.listen(port, () => {
  console.log(`Server is running at http://localhost:${port}`);
});
  • router.routes(): 将定义的路由应用到 Koa 应用中。
  • router.allowedMethods(): 处理不支持的 HTTP 方法错误。
  • app.listen(): 启动服务器,监听指定端口(3002),并在控制台打印启动信息。