用JS实现一个支持向量机

291 阅读2分钟

支持向量机(Support Vector Machine,SVM)是一种基于统计学习理论的分类算法。SVM 的目的是寻找一个超平面,将训练数据划分为不同的类别,同时最大化两个类别之间的间隔。换句话说,它是一种最大间隔分类器。

数据准备

在实现支持向量机之前,需要准备数据。这里使用一个简单的二维数据集作为示例,数据集中有两类点:蓝色点和红色点。通过这个数据集训练一个支持向量机,以便它能够对新的点进行分类。

const data = [
  { x: 1, y: 2, label: -1 },
  { x: 2, y: 3, label: -1 },
  { x: 3, y: 1, label: -1 },
  { x: 6, y: 5, label: 1 },
  { x: 7, y: 7, label: 1 },
  { x: 8, y: 6, label: 1 },
];

其中x和y表示点的坐标,label表示点的类别。

训练模型

接下来需要使用训练数据来训练支持向量机。训练过程包括以下几个步骤:

  1. 初始化模型参数,包括权重w和偏置b。
  2. 对训练数据进行处理,将数据转换为可以用于训练的格式。
  3. 定义损失函数,用于评估模型的效果。
  4. 使用梯度下降算法来更新模型参数,以最小化损失函数。
class SVM {
  constructor() {
    this.w = [0, 0];
    this.b = 0;
  }

  train(data) {
    // 将数据转换为可以用于训练的格式
    const formattedData = data.map((d) => [d.x, d.y, d.label]);

    // 定义损失函数
    const loss = (w, b, x, y) => {
      const wx = w[0] * x[0] + w[1] * x[1] + b;
      return Math.max(0, 1 - y * wx);
    };

    // 使用梯度下降算法来更新模型参数
    const learningRate = 0.1;
    const iterations = 1000;
    for (let i = 0; i < iterations; i++) {
      let gradientW = [0, 0];
      let gradientB = 0;
      for (let j = 0; j < formattedData.length; j++) {
        const lossValue = loss(this.w, this.b, formattedData[j], formattedData[j][2]);
        if (lossValue > 0) {
          gradientW[0] += -formattedData[j][2] * formattedData[j][0];
          gradientW[1] += -formattedData[j][2] * formattedData[j][1];
          gradientB += -formattedData[j][2];
        }
      }
      this.w[0] -= learningRate * this.w[0] + gradientW[0];
      this.w[1] -= learningRate * this.w[1] + gradientW[1];
      this.b -= learningRate * gradientB;
    }
  }
}

预测

训练完成后,可以使用模型来对新的点进行预测。预测过程包括以下几个步骤:

  1. 对新的点进行处理,将点转换为可以用于预测的格式。
  2. 使用模型计算出预测值。
  3. 根据预测值判断点的类别。
const svm = new SVM();
svm.train(data);

const predict = (svm, point) => {
  const wx = svm.w[0] * point[0] + svm.w[1] * point[1] + svm.b;
  return wx > 0 ? 1 : -1;
};

const newPoint = [4, 4];
const predictedLabel = predict(svm, newPoint);
console.log(`(${newPoint[0]}, ${newPoint[1]}) is predicted as ${predictedLabel}`);