用JavaScript编写一个k-NN分类器

201 阅读1分钟

k-NN(k-最近邻)是一种简单的机器学习算法,可用于分类和回归问题。它可以根据数据集中最近的k个邻居来进行预测。在这个简单的k-NN分类器中,我们将使用欧几里得距离来衡量样本之间的相似性。

步骤一:导入数据集

首先需要将数据集导入代码中。可以使用CSVJSON格式的数据集。这里使用JSON格式的数据集。

const data = [
  { x: 1, y: 2, class: 'A' },
  { x: 2, y: 1, class: 'A' },
  { x: 4, y: 5, class: 'B' },
  { x: 5, y: 4, class: 'B' }
];

步骤二:计算距离

接下来需要计算测试样本与所有训练样本之间的距离。

function euclideanDistance(a, b) {
  return Math.sqrt(Math.pow(a.x - b.x, 2) + Math.pow(a.y - b.y, 2));
}

function getDistances(testSample, data) {
  const distances = [];
  for (let i = 0; i < data.length; i++) {
    distances.push({
      sample: data[i],
      distance: euclideanDistance(testSample, data[i])
    });
  }
  return distances;
}

步骤三:选择最近的邻居

排序距离数组,然后选择前k个最近的邻居。

function getNeighbors(k, distances) {
  return distances.sort((a, b) => a.distance - b.distance).slice(0, k);
}

步骤四:预测分类

使用投票法来预测测试样本的分类。

function predictClass(k, testSample, data) {
  const distances = getDistances(testSample, data);
  const neighbors = getNeighbors(k, distances);
  const classes = {};
  for (let i = 0; i < neighbors.length; i++) {
    const cls = neighbors[i].sample.class;
    if (!classes[cls]) {
      classes[cls] = 0;
    }
    classes[cls]++;
  }
  let result;
  let maxCount = -1;
  for (const cls in classes) {
    if (classes[cls] > maxCount) {
      result = cls;
      maxCount = classes[cls];
    }
  }
  return result;
}

步骤五:使用分类器

最后可以使用我们的分类器进行预测。

const testSample = { x: 3, y: 3 };
const k = 3;
const predictedClass = predictClass(k, testSample, data);
console.log(predictedClass); // 输出:'A'