机器学习:用Javascript实现k-均值聚类分析算法

991 阅读13分钟

译者:非主流童话

原文链接

机器学习帮助我们分类和处理大量的数据。我们可以对数据提出各种各样的问题,并且希望机器学习能够回答它们:这个数据点最相似的地方是什么?数据是否有模式?根据过去的趋势,我们能否预测未来会发生什么?这些问题适用于所有的研究领域。

这篇文章是JavaScript机器学习系列的一部分本系列介绍了一些基本的机器学习算法,并假设读者具备一定的背景知识。如果你想了解新文章,页面底部会有一个邮件列表;你还可以在Twitter上关注我:@bkanber

或者你只是在寻找一些实例代码?那来吧!

介绍与动机

今天,我们将研究如何找到一组数据点。假设你在一家医疗成像设备公司工作。想象一下,你已经有了一种从图像扫描中识别恶性细胞的方法,但是如果还能自动识别细胞群的中心,那将是一件很好的事情。如果这样,一个机器人就可以进行手术精确地去除问题!

我们要找的就是这一种聚类方法,今天我们就来具体讲讲K-均值算法。

聚类

一般来说,聚类算法会发现相似的数据组。如果你经营一个在线商店,你可能会使用集群算法来识别不同的购物者类型。你可能会发现,有一种访客,他们只浏览三到五个产品就离开。另一个种访客,可能会通过浏览15页左右的产品和评论来做出细致的购买决策,最后只会做出一个高价值的购买。你也可以识别出冲动的买家,他们在没有浏览太多的情况下做了大量的小买卖。一旦确定了网上购物者的统计数据,你就能更好地优化你的网站以增加销售量。你可以发布一些吸引你的冲动买家的功能,因为现在你知道你确实有冲动买家。

虽然这只是k-均值的一个实际例子,但你会发现这个算法在多个领域都有使用。有时它只是二维图像处理,有时候它处理的是跨越数十个维度和参数的巨大数据。就像我们的k-近邻算法一样,k-均值是通用的,易于理解和实现,并且具有非常强大的功能。

k-均值

就像k-近邻算法一样,k-均值里的“k”表示将会有一些数字,我们将需要为算法提供一些数据。具体来说,“k”是我们将在数据中找到的数据集的数量。不幸的是,在解决问题之前很少有可能知道数据集的数量,所以k-均值通常是先通过另一种算法来首先会帮助你找到k的最佳值,然后提供给k-均值算法。

问题在于:k-均值算法将把数据划分为“k”不同的数据集,但它并没有告诉您这是否是正确的数据集数量。数据本身也许会有5个不同的数据集,但是如果你给k-均值算法带入3,那么将会得到3个数据集。比起真实的5个数据集的结果,每个数据集更大,结合得更松散,并且有些失真。

一言以蔽之:为了使用k-均值,你需要知道一开始你需要多少个数据集,或者你需要使用第二个算法来猜测数据集的数量。k-均值只是将你的观点组织成数据集;你需要做一些其他的事情来计算出正确的数据集数量。

就今天来讲,我们将设法从一开始就使用三个数据集。下一次(在k-均值第2部分)中,我们将介绍一种可以自动猜测“k”值的技术。最常见的情况是,这些算法依赖于某种错误分析和k-均值算法的多次传递来优化针对最小误差值的解决方案。

算法步骤

k-均值算法很简单,但如果你在一个有多个维度的数据集上使用它,它就会变得非常强大。今天我们要在二维空间中工作,下次我们将做一些更复杂的事情。下面是算法的主要步骤:

  1. 绘制你的数据点。

  2. 创建“k”额外的点,把它们随机地放在你的图形上。这些点是“数据集中心”——或者是数据集中心的候选对象。

  3. 重复下面步骤:

  4. 将每个数据点分配到离它最近的数据集中心位置 。

  5. 将重心移到属于它的所有数据点的平均位置。

  6. 如果任何一个中心在最后一步移动,重复上述步骤。如果没有任何移动,退出。

就这么简单!如你所见,这是一个迭代过程。可能需要2到3次或数十次迭代,但最终数据集应该收敛于它们的解决方案,停止移动。然后可以对作业进行最后的统计然后你就有了你的数据集。

机器委员会

就像我们在这个系列中会遇到的很多情况一样,这个算法,很容易受到局部优化的影响。如果您在下面几次运行这个示例,您将看到数据集可以以几种不同的配置之一结束。这些是不同的局部最优解被困在其中。从某种随机的种子状态(比如GAs或k-均值)开始的算法特别容易受到局部优化的影响,因为你永远不知道算法是如何开始的,最终的解决方案将会是怎样的。这种种子状态是否会导致局部或全局的最佳状态?还不得而知!

就像遗传算法一样,摆脱局部最优的方法之一就是给这个解决方案带来一点突变。在我们的k-均值例子中,我们可以添加一条规则,如果它在迭代之后没有移动,那么它就会向一个随机方向移动。它可能会回溯到上一个停止点,或者找到了一个新的解决方案。这个微调不应大到重头来过,而是当恰巧陷入某个坑不能自拔时,微转身形,轻盈避过。

我们可以使用的另一种技术称为“机器委员会”,如果您运行的算法运行速度非常快,或者具有并行计算能力,则可以运行良好。很简单:我们运行k-均值算法3或5或51或1万次,然后选择最常返回的解决方案。术语“机器委员会”打比方说实际就有一些人选择在不同的硬件上运行并行算法并得出解决方案,我们的“机器委员会”则对这些解决方案进行表决。

代码

就让我们一探究竟吧。与目前的其他示例不同,我将放弃面向对象的实现,只进行直接的过程。是的,哥有一百种方式搞定这事。我喜欢OOP,但重要的是要不断跳出舒适区。

此外,在本例中,我们只处理二维数据,我希望编写这个算法来处理任意数量的维度(除了画布绘制函数之外)。

让我们看一下我们正在使用的数据——一个简单的“点”数组,每个点由两个元素标识(X和Y值)。

var data = [  
    [1, 2],
    [2, 1],
    [2, 4], 
    [1, 3],
    [2, 2],
    [3, 1],
    [1, 1],

    [7, 3],
    [8, 2],
    [6, 4],
    [7, 4],
    [8, 1],
    [9, 2],

    [10, 8],
    [9, 10],
    [7, 8],
    [7, 9],
    [8, 11],
    [9, 9],
];

接下来,我们定义两个对我们有用的函数,但不是必需的。给定一个点的列表,我想知道每个维度的最大值和最小值是什么,以及每个维度的范围是什么。我想知道“X的范围从1到10,Y的范围从1到11”。(译者注:原文和上述代码不符,修正之)了解这些数据有助于我们在画布上绘制图表,并在初始化随机集群中心时帮助我们(当我们启动时,我们希望它们在数据点范围内)。

请记住,我们正在编写一个关于它们可以处理的维数的通用方法:

function getDataRanges(extremes) {  
    var ranges = [];

    for (var dimension in extremes)
    {
        ranges[dimension] = extremes[dimension].max - extremes[dimension].min;
    }

    return ranges;

}

function getDataExtremes(points) {

    var extremes = [];

    for (var i in data)
    {
        var point = data[i];

        for (var dimension in point)
        {
            if ( ! extremes[dimension] )
            {
                extremes[dimension] = {min: 1000, max: 0};
            }

            if (point[dimension] < extremes[dimension].min)
            {
                extremes[dimension].min = point[dimension];
            }

            if (point[dimension] > extremes[dimension].max)
            {
                extremes[dimension].max = point[dimension];
            }
        }
    }

    return extremes;
}

getDataExtremes()方法循环遍历每个点的所有点和每个维度,找到最小值和最大值(注意这里有一个硬编码的“1000”,如果使用的是大数字,请读者自行更改)。getDataRanges()函数只是一个辅助器,它获取输出并返回每个维度的范围(最大值减去最小值)。

接下来,我们定义一个初始化k个随机集群中心的函数

function initMeans(k) {

    if ( ! k )
    {
        k = 3;
    }

    while (k--)
    {
        var mean = [];

        for (var dimension in dataExtremes)
        {
            mean[dimension] = dataExtremes[dimension].min + ( Math.random() * dataRange[dimension] );
        }

        means.push(mean);
    }

    return means;

};

我们只是在数据集的范围和范围内创建了新的点。

一旦我们有了随机的种子中心,我们就需要进入我们的k-均值循环。提醒一下,循环包括首先将所有的数据点分配给最接近它的中心id,然后将中心id移动到分配给它的所有数据点的平均位置。我们重复这个动作直到中心停止移动。

function makeAssignments() {

    for (var i in data)
    {
        var point = data[i];
        var distances = [];

        for (var j in means)
        {
            var mean = means[j];
            var sum = 0;

            for (var dimension in point)
            {
                var difference = point[dimension] - mean[dimension];
                difference *= difference;
                sum += difference;
            }

            distances[j] = Math.sqrt(sum);
        }

        assignments[i] = distances.indexOf( Math.min.apply(null, distances) );
    }

}

上面的函数由我们的“循环”函数调用,并计算每个点和集群中心之间的欧氏距离

注意,上面的算法循环遍历每个点,然后循环遍历每个集群中心点,从而使之成为O(k*n)算法。这并不可怕,但是如果你有大量的数据点或者大量的集群或者两者都有的话,这可能是计算密集型的。有一些方法可以优化这一点,我们将在以后的文章中讨论这个问题。首先,我们可以尝试消除昂贵的Math.sqrt()调用;我们也可以尝试不用遍历每一个点。

一旦我们有了赋值列表——在本例中,只是一个point index => center index的关联数组——我们就可以继续更新方法的位置(数据集中心)。

function moveMeans() {

    makeAssignments();

    var sums = Array( means.length );
    var counts = Array( means.length );
    var moved = false;

    for (var j in means)
    {
        counts[j] = 0;
        sums[j] = Array( means[j].length );
        for (var dimension in means[j])
        {
            sums[j][dimension] = 0;
        }
    }

    for (var point_index in assignments)
    {
        var mean_index = assignments[point_index];
        var point = data[point_index];
        var mean = means[mean_index];

        counts[mean_index]++;

        for (var dimension in mean)
        {
            sums[mean_index][dimension] += point[dimension];
        }
    }

    for (var mean_index in sums)
    {
        console.log(counts[mean_index]);
        if ( 0 === counts[mean_index] ) 
        {
            sums[mean_index] = means[mean_index];
            console.log(""Mean with no points"");
            console.log(sums[mean_index]);

            for (var dimension in dataExtremes)
            {
                sums[mean_index][dimension] = dataExtremes[dimension].min + ( Math.random() * dataRange[dimension] );
            }
            continue;
        }

        for (var dimension in sums[mean_index])
        {
            sums[mean_index][dimension] /= counts[mean_index];
        }
    }

    if (means.toString() !== sums.toString())
    {
        moved = true;
    }

    means = sums;

    return moved;

}

moveMeans()从调用makeAssignments()函数开始。一旦我们有了赋值,我们就初始化两个数组:一个被称为“总数”,另一个称为“计数”。既然我们在计算算术平均值(或平均值),我们就需要知道点的维数和我们求平均值的点的个数。

然后我们命中了三个循环:

首先,我们循环使用我们的方法,并准备我们的总数和计数数组。我们的总数数组实际上是多维的,因为我们将每个维度的每个平均值的每个维度的总和存储在这个结构中 - 所以我们必须清除这个二维数组的第二个深度级别。

然后,我们循环遍历我们的赋值,并为每个数据集中心增加计数器,并在点的维度上循环,以填入总和数组。此时,我们已经拥有了计算数据集中心新位置所需的所有数据。

最后一个循环将循环遍历我们的结果,计算每个集群中心的平均位置,并移动它。最后一个循环还检查集群中心是否没有分配给它的点。如果没有分配给它的点,我们就给它一个新的随机位置。这只是我们试图将该数据集中心重新引入解决方案。

<font color=""#434343"">最后,我们通过检查来查看我们的集群中心是否已经移动,并返回true或false。

为了让这个算法启动,我们运行以下设置函数

function setup() {

    canvas = document.getElementById('canvas');
    ctx = canvas.getContext('2d');

    dataExtremes = getDataExtremes(data);
    dataRange = getDataRanges(dataExtremes);
    means = initMeans(3);

    makeAssignments();
    draw();

    setTimeout(run, drawDelay);
}

function run() {

    var moved = moveMeans();
    draw();

    if (moved)
    {
        setTimeout(run, drawDelay);
    }

}

setup()初始化我们需要的所有东西,然后我们的run()函数检查,看看算法是否已经停止,并基于计时器进行循环,这样我们就可以在合理的时间范围内观察算法的工作。

k-均值

k-均值算法的一个主要问题不是算法的缺陷,而是算术平均值的概念,或者说平均值。当你有外围数据时,平均值是一个相当糟糕的指标。

如果你在一家公司工作,这里5个人年薪5万美元,但一个人年薪100万美元,则工资中位数是5万美元(这个公司的工资很有代表性),但平均工资是20万美元(不代表这家公司的薪水)!

这种情况会发生在各种数据中,也会发生在k-均值算法中。如果有一个数据集倾向于离群值,你会发现k-均值在异常值上被“卡住”,结果会导致糟糕的结果。在这种情况下,切换到k-中位值!算法几乎是相同的;而不是计算集群中心的平均值,而是使用中位值。我相信——但我不确定——计算中位数也比均值有一个性能优势。

结果

正如下面的示例中可以看到的,k-均值对于我们良好的、整洁的数据非常有用。显然,与其他算法一样,它将更难处理凌乱的数据。

如果你多次运行这个示例(单击jsFiddle上的play按钮),你最终将看到它落入局部最佳状态。这也说明了“机器委员会”解决方法的有用性:当不时地出现一个坏的解决方案时,应该明确的是,机器委员会可靠地产生正确的解决方案。

最后,如果你喜欢这个系列,请登录下面的邮件列表并告诉你的朋友!我乐见讨论,所以请随意使用下面的评论工具。最后,一定要阅读其他的JS机器学习文章哦!