用JS实现K-NN算法

102 阅读2分钟

K-NN(K近邻)是一种经典的机器学习算法,它可以用于分类和回归问题。在这个任务中,我们将使用JavaScript实现K-NN算法,而不依赖于任何第三方库。

什么是K-NN算法?

K-NN算法是一种基于实例的学习方法。给定一个新的输入向量,该算法会在训练数据集中找到与其最接近的K个实例,然后将这K个实例中最常见的标签作为该向量的标签。因此,K-NN算法的核心是计算向量之间的距离。

实现K-NN算法

数据集

在这个任务中,我们将使用一个简单的数据集,它包含两个属性(x和y)和两个标签(0和1)。以下是数据集的JavaScript表示:

const dataset = [
  { x: 1, y: 2, label: 0 },
  { x: 2, y: 1, label: 0 },
  { x: 2, y: 3, label: 1 },
  { x: 3, y: 2, label: 1 },
  { x: 4, y: 3, label: 1 },
  { x: 3, y: 4, label: 0 },
  { x: 4, y: 4, label: 1 },
  { x: 5, y: 5, label: 0 },
];

计算距离

在K-NN算法中,我们需要计算向量之间的距离。在这个任务中,我们将使用欧几里得距离来计算距离。以下是计算距离的JavaScript函数:

function distance(a, b) {
  return Math.sqrt((a.x - b.x) ** 2 + (a.y - b.y) ** 2);
}

找到最近的K个实例

现在,我们可以使用distance函数来找到训练数据集中与给定向量最近的K个实例。以下是找到最近的K个实例的JavaScript函数:

function nearestNeighbors(k, dataset, input) {
  const distances = dataset.map((d) => ({
    distance: distance(d, input),
    label: d.label,
  }));
  distances.sort((a, b) => a.distance - b.distance);
  return distances.slice(0, k);
}

预测标签

最后,我们可以使用nearestNeighbors函数来预测给定向量的标签。以下是预测标签的JavaScript函数:

function predict(k, dataset, input) {
  const neighbors = nearestNeighbors(k, dataset, input);
  const labels = neighbors.map((n) => n.label);
  const counts = labels.reduce(
    (acc, label) => ((acc[label] = (acc[label] || 0) + 1), acc),
    {}
  );
  return Object.keys(counts).reduce(
    (a, b) => (counts[a] > counts[b] ? a : b)
  );
}

现在,我们已经实现了K-NN算法。以下是使用K-NN算法来预测新向量的标签的JavaScript代码:

const input = { x: 2, y: 2 };
const k = 3;
const label = predict(k, dataset, input);
console.log(`Input (${input.x}, ${input.y}) is classified as ${label}`);

输出:

Input (2, 2) is classified as 0