KNN:从数学推导到维度灾难,手写一个完整的K近邻分类器

34 阅读2分钟

摘要

KNNKNN是最直观的非参数分类/回归算法,但其性能高度依赖距离度量与K值选择策略,KNN(K Nearest Neighbor) 直译为K近邻算法,即通过K个最近邻居的类别来判定待测样本的归属。实现了通过简单的"找邻居"方法即可实现未知样本的分类和回归,在样本量n→∞且K以适当的速率增长时,其误差率可无限逼近贝叶斯最优误差率。

引言

假设你刚到一个新城市,想判断一个陌生街区是否安全,最朴素的想法就是看离它最近的三条街的治安如何——这就是KNN的直觉原型。与逻辑回归需要假设数据分布不同,KNN无需任何训练阶段;与决策树容易过拟合分支不同,KNN天然平滑且可解释。

在分类问题中,统计学家拥有的知识存在两个极端:

  • 完全知晓,观测值与真实类别之间的底层联合分布(此时通过标准的贝叶斯分析可获得最优决策和最小错误率);
  • 完全无知,仅能从样本中推断。

如果样本是独立同分布的情况下,一种最直观的分类决策就是:距离相近的样本可能与当前样本具有相同的分类。最简单的非参数决策程序就是最近邻(NN)规则

参考自Nearest Neighbor Pattern Classification1967 年发表的KNN开山之作。

核心原理

KNN算法最核心的不是其准确率,而是其内部使用简单直接的算法同时其预测准确率表现也极佳,可以说是高性价比的算法。

K Nearest Neighbor K个最近的邻居,实现过程也是直接按照其名字执行,通过在数据集中寻找K个最近的样本,根据这K个样本进行分类和回归,如果是分类问题则投票决定类别回归问题则计算K个样本标签均值作为预测结果

根据你所住小区最近K个邻居的工资平均值来估计你的工资,根据电影出现不同镜头的数量来分类本电影的类别,这些问题都是在日常可以见到的。

那如何计算样本之间的距离、K值的选择则是核心的问题。

距离度量

最近邻规则是在一个度量空间(Metric Space)中定义的。这意味着算法的运行基础是必须存在一种能够衡量两点之间“接近程度”的数学方法。

  • 只要在样本空间 XX 上定义了符合数学定义的距离函数(Metric) ρ\rho,该理论KNN结论就成立。

  • 距离的数学要求: 只要该距离度量满足非负性、同一性、对称性和三角不等式,最近邻算法的收敛性分析依然适用。

  • 量纲问题: 如果两个特征的单位相差较大会导致其特征权重不同,如(身高/m,体重/kg)两个样本身高之间的差距会远小于体重差距,最终就会导致体重特征的权重远高于身高。

欧氏距离

欧式距离也称欧几里得距离,是最常见的距离度量,计算的是空间中两个点之间的直线距离。

对于n维空间中的两点 A=(p1,p2,,pn)A=(p_1,p_2,…,p_n)和 B=(q1,q2,,qn)B=(q_1,q_2,…,q_n),欧氏距离定义为:

d(A,B)=i=1n(piqi)2d(A,B)=\sqrt{\sum_{i=1}^{n}(p_i−q_i)^2}

二维平面中,两点 A(x1,y1)A(x_1,y_1) 与 B(x2,y2)B(x_2,y_2) 的欧氏距离就变成勾股定理:

d(A,B)=(x1x2)2+(y1y2)2d(A,B)=\sqrt{(x_1−x_2)^2 + (y_1−y_2)^2}
欧氏距离
图1:欧氏距离二维示意图
  • 优点:最接近人类感知的物理距离,且不会根据坐标轴的改变而改变。

  • 缺点

    • 量纲敏感: 如果不同特征的单位(如米和毫米)不同,大数值特征会主导距离计算。通常需要先进行归一化。
    • 高维灾难: 在超高维空间中,点与点之间的距离趋于相等,欧氏距离的区分度会大幅下降。

曼哈顿距离

曼哈顿距离也称城市街区距离,欧氏距离计算的是两点之间的直线距离,曼哈顿距离则更贴近现实,学校里从宿舍到饭堂是经过大街小巷而非横穿建筑。

对于n维空间中的两点 A=(p1,p2,,pn)A=(p_1,p_2,…,p_n)和 B=(q1,q2,,qn)B=(q_1,q_2,…,q_n),曼哈顿距离定义为:

d(A,B)=i=1npiqid(A,B)=\sum_{i=1}^{n}|p_i−q_i|

二维平面中,两点 A(x1,y1)A(x_1,y_1) 与 B(x2,y2)B(x_2,y_2) 曼哈顿距离简化为::

d(A,B)=x1x2+y1y2d(A,B)=|x_1−x_2| + |y_1−y_2|
image.png
图2:曼哈顿距离二维示意图
image.png
图3:曼哈顿距离路径示意
  • 优点:

    • 计算速度快: 仅涉及加法和绝对值运算,无需平方和开方,对计算机硬件更友好。
  • 缺点:

    • 路径依赖: 无法表示斜向的最短路径。
    • 非旋转不变: 坐标系的旋转会改变曼哈顿距离的值。

切比雪夫距离

切比雪夫距离也称棋盘距离,参考了国际象棋中国王的移动方式(国王可以直行、横行、斜行),所以国王走一步可以移动到相邻8个方格中的任意一个。

可以看作欧氏距离和曼哈顿距离的结合,对于n维空间中的两点 A=(p1,p2,,pn)A=(p_1,p_2,…,p_n)和 B=(q1,q2,,qn)B=(q_1,q_2,…,q_n),切比雪夫距离定义为:

d(A,B)=maxi=1npiqid(A,B)=\max_{i=1}^{n}|p_i−q_i|

二维平面中,两点 A(x1,y1)A(x_1,y_1) 与 B(x2,y2)B(x_2,y_2) 的切比雪夫距离:

d(A,B)=max(x1x2,y1y2)d(A,B)=max(∣x1​−x2​∣,∣y1​−y2​∣)
  • 优点:

    • 特定场景极优: 特别适用于可以同时在多个维度移动的场景,例如国际象棋中王(King)的走法。
    • 关键维度识别: 能突出显示导致两个对象差异最显著的那个单一特征。
  • 缺点:

    • 忽略细节: 它完全忽略了除了最大差异维度之外的所有其他维度信息,可能会导致信息丢失。

闵式距离

闵式距离全称闵可夫斯基距离,并不是一种新的距离度量方式,而是作为上面多种距离度量方式的统一范式。

对于n维空间中的两点 A=(x1,x2,,xn)A=(x_1,x_2,…,x_n)和 B=(y1,y2,,yn)B=(y_1,y_2,…,y_n)

d(A,B)=(i=1nxiyip)1/pd(A,B)=\left(\sum_{i=1}^{n}|x_i-y_i|^p\right)^{1/p}
  • 当 p=1,即曼哈顿距离
  • 当 p=2,即欧氏距离
  • 当 p→∞,即切比雪夫距离

K值选择

了解计算两个样本之间的距离后,就可以计算预测样本与所有训练集样本的距离( sklearn 库默认使用欧氏距离),然后选出TOP-K的样本,分类任务结果是TOP-K样本中占比较多的类别为结果,回归任务结果则是TOP-K样本的标签数据平均值。

那K值应该如何选择呢?

  • 是越大越好,K = n(训练集样本数量) ;
  • 还是越小越好,K = 1
  • 又或是K适中就好,K = n/2

下面将从模型拟合情况、决策边界、偏差和方差的角度去分析不同的K值。方差衡量的是在不同训练集上,模型预测结果的波动情况;偏差衡量的是模型预测值的平均结果与真实值之间的差距

K值过大

从直觉上来说,K值较大时模型应该会越复杂,最终可能会导致过拟合(训练集数据表现好,测试集数据表现差),但是"模型复杂度"往往取决于决策边界的灵活性,而不是参数的数量。

  • 平滑效应:K=NK=N (样本总数)时,无论输入什么,模型都会预测样本数量最多的那个类别。此时,模型的决策边界消失,模型会降维成一个多数表决器
  • 低方差: 方差代表的是模型的稳定性,此时即使训练集的个别样本发生变化,模型最终的预测结果并不会受到太大的影响,因为参与投票的邻居非常多。
  • 高偏差: 预测样本的最终结果完全由数据集的平衡决定,忽略了数据的特征和细微边界,导致模型无法捕捉到真实的逻辑,从而产生巨大的偏差。

综上当K值偏大时,模型的决策边界简单,低方差高偏差,模型欠拟合。

K值过小

直观上可能觉得 K=1K=1 很简单,但实际上 K=1K=1 时的决策边界极其扭曲且复杂

  • 决策边界的灵活性:K=1K=1 时,模型预测的结果完全取决于最近的邻居,如果数据中存在一个异常值(噪声点),决策边界会为了容纳这个点而发生剧烈的抖动,最终决策边界会很复杂。
  • 低偏差: 因为模型完美地拟合了训练数据(在 K=1K=1 的情况下,训练集误差为 0),它对训练样本的表达非常精确,偏差极低。
  • 高方差: 模型对每一个点都异常敏感,如果训练集发生了的轻微的改动(增加或减少一个噪声点),模型的决策边界都会发生大幅改动,最终导致预测结果不稳定。

综上当K值偏小时,模型的决策边界复杂,低偏差高方差,模型过拟合。

选择最优K值

过大过小结果模型都会出现问题,那K值取多大算合适,是否所有的KNN模型都可以使用同一个最优K值,还是最优K值的选择受到其他因素影响,寻找最优K值本质就是在寻找偏差和方差的平衡点,以获得最强的泛化能力

1. 交叉验证

交叉验证是目前工业方面寻找K值最常用的方法,核心思想是通过实现最优解,for K in range(1,n)遍历所有可能的K值去训练模型,综合效果最好的就是最优K值。

sklearn 提供了交叉验证+网格搜索的方式寻找最优超参数(eg:K值),其底层就是组合不同的参数训练不断的模型,全部结束后取结果最优的就是最佳超参数。

image.png
图4:交叉验证网格搜索示意图
sklearn.model_selection.GridSearchCV(模型, 参数字典, 折数)
2. 自适应 K 值

主张的是不应该在全局使用统一的 K 值,而是根据周围的数据密度决定。

  • 基于局部密度的方法: 根据样本点到邻居的距离变化率来决定K值,如果样本点区域数据稀疏,则增大K值,如果样本点区域数据密集,则减少K值。
3. 统计学视角:渐进性能研究

这类论文通常从理论层面探讨 K 随样本量 NN 增长时的收敛性。

  • 收敛性定理:经典研究证明,当 NN \to \infty 时,若满足 KK \to \inftyK/N0K/N \to 0,则 KNN 的误差率会趋于贝叶斯最优误差率。

  • Bias-Variance Tradeoff

    • 较小的 K:低偏差、高方差,模型容易过拟合,对噪声敏感。
    • 较大的 K:低方差、高偏差,模型更平滑,但可能模糊掉细微的类别边界。

邻居权重

找到了最优的K,又找到TOP-K的邻居,此时就可以计算预测结果了,如果K个邻居的‘话语权’一致的话,计算结果就同上面介绍一样,分类问题投票、回归问题平均值。

但是是不是越近的邻居其‘话语权’是否更大一些,结果受更近的邻居影响更大,远一点的邻居其影响就小一点,实现很简单使权重weight与distance成反比。

常见的加权方式包括反距离加权wi=1diw_i = \frac{1}{d_i}和高斯核wi=exp(di22σ2)w_i = \exp\left(-\frac{d_i^2}{2\sigma^2}\right),后者能更平滑地抑制远邻影响。

数学推导

时间复杂度

训练时间复杂度

KNNKNN非参数算法,拿到训练集后并不需要真正的去训练模型,只需要保存训练集和其他参数即可,所以训练模型的时间复杂度是 O(1)

预测时间复杂度

设训练集样本数量为 N ,特征维度为 b :

默认使用欧氏距离计算样本点距离,两个样本之间计算的次数为特征维度的数量 b ,训练集样本数量为 N , 则时间复杂度为 O(N·b)

维度灾难

维度灾难说的是在高维度特征的情况下,到最近邻居的距离和到最远邻居的距离之间的相对差异会趋近于零,最后导致距离计算后无法寻找TOP-K的邻居。

维度灾难的提出在论文《When Is “Nearest Neighbor” Meaningful?》中提出,

这篇论文通过严谨的数学证明指出,在某些分布条件下,随着空间维度 dd 趋向于无穷大:

  • 距离趋同:查询点到“最近邻”的距离与到“最远邻”的距离之间的相对差异会趋近于零。

  • 数学表达式

    limdE[distmaxdistmindistmin]=0\lim_{d \to \infty} E \left[ \frac{dist_{max} - dist_{min}}{dist_{min}} \right] = 0

  • 后果:这意味着在高维空间中,所有的点到查询点的距离几乎都是一样的。在这种情况下,所谓的“最近”已经失去了区分度,导致基于距离的算法 KNN 失效。

sklearn实现

sklearn 是机器学习最常用的库,其中实现了大多数机器学习的算法,如 KNN 、线性回归、逻辑回归、决策树、集成学习和聚类 等等算法。

可以选择在以下两个网站学习:

scikit-learn官方

scikit-learn中文社区

from sklearn.neighbors import KNeighborsClassifier

# 核心方法
estimator = KNeighborsClassifier()
estimator.fit(x_train, y_train)
y_predict = estimator.predict(x_test)

# 初始化KNN分类器对象可传入参数
def __init__(
    self,
    n_neighbors=5, # K值:5
    *,
    weights="uniform", # 权重方式:统一权重
    algorithm="auto", # 计算方式
    leaf_size=30,
    p=2,
    metric="minkowski", # 距离度量:闵式距离
    metric_params=None, # 闵式距离参数p
    n_jobs=None,
):

KNeighborsClassifier是基类,继承自 KNeighborsMixin, NeighborsBase, ClassifierMixin

class KNeighborsClassifier(KNeighborsMixin, ClassifierMixin, NeighborsBase):

执行流程

KNeighborsClassifier的核心方法有三个:

  • fit(X, y):构建搜索索引(如 KD-Tree 或 Ball Tree)

  • predict(X):预测类别

  • kneighbors(X):执行底层搜索,寻找最近的 KK 个邻居

fit(X, y)

fit 阶段,sklearn 不会进行复杂的数学计算,而是根据参数algorithm =('auto','kd_tree', 'ball_tree', 'brute')建立存储训练集的数据机构。

  • auto : 不是具体的数据结构,而是让 sklearn 根据数据集自动选择最优算法;
kd_tree :

通过垂直于坐标轴的超平面来分割空间的二叉树,听起来非常的抽象,在官方论文 kd_tree中找到两张图,通过在空间中不断画横线和竖线,把空间切成一个个矩形小格子。

  • 优点 :查找时间复杂度低,因为每一次都排除了一半的数据,当出现某一个节点选择错误时会使用回溯算法回到上一个节点继续往下搜索。
  • 缺点 :在高维空间中效率下降,会出现维度灾难问题,常适用于通常维度 < 20 情况。

平均搜索复杂度为O(log⁡n),最坏坍缩为O(n)。

image.png

image.png

图片源自论文kd_tree

ball_tree :

不再画直线,而是用一个个圆圈(球体)把点围起来,每一个圈子内部又可以继续画圈,如在操场中可以划分男生大圈女生大圈男生大圈里又分成了“打篮球的”和“踢足球的”小圈。

找到距离最近的圈子就会进入该圈从而实现筛选掉大部分数据,如距离男生大圈中心10m和女生大圈中心20m,此时就会进行男生大圈继续后续的计算直到找到K个最近邻居。

  • 优点:在高维空间中搜索效率高
  • 缺点:构建时间较长,内存消耗较大
image.png

图片源自文章kd-tree-and-ball-tree

brute :

暴力搜索,直接遍历样本与训练集的每一个样本的距离,使用距离度量方式计算并进行排序得到前K个样本就是最近的K个邻居。

  • 优点:简单直接,保证找到精确解
  • 缺点:时间复杂度 O(n),大数据集效率低

predict(X)kneighbors(X)

predict 是最外层的预测函数,传入测试样本即可得到预测类别,其底层实现是通过 kneighbors 寻找 K 个邻居,以下是 sklearn 的KNeighborsClassifier.predict()核心源码。

if self.weights == "uniform":
    # 特殊情况:brute + uniform 权重时的优化路径
    if self._fit_method == "brute" and ArgKminClassMode.is_usable_for(
        X, self._fit_X, self.metric
    ):
        probabilities = self.predict_proba(X)
        if self.outputs_2d_:
            return np.stack(
                [
                    self.classes_[idx][np.argmax(probas, axis=1)]
                    for idx, probas in enumerate(probabilities)
                ],
                axis=1,
            )
        return self.classes_[np.argmax(probabilities, axis=1)]
    # 寻找最近的 K 个邻居的索引neigh_ind和距离neigh_dist
    neigh_ind = self.kneighbors(X, return_distance=False)
    neigh_dist = None
else:
    # 其他情况:适用于 kd_tree、ball_tree 和 brute
    neigh_dist, neigh_ind = self.kneighbors(X)
# 根据邻居索引获取其类别
classes_ = self.classes_
... 加权处理
return y_pred

鸢尾花分类

sklearnKNeighborsClassifier 最经典的分类案例一定是鸢尾花分类,通过 KNN 算法实现鸢尾花种类的识别,下面将分部实现KNN鸢尾花分类KNN鸢尾花分类案例。

数据集

sklearn 所提供的经典数据集中就包含鸢尾花数据集,下面将加载鸢尾花数据集并展示其数据特征:

# 1.导入KNN相关库
from sklearn.datasets import load_iris # 鸢尾花数据集
from sklearn.neighbors import KNeighborsClassifier # KNN分类器
from sklearn.model_selection import train_test_split # 数据集划分
from sklearn.preprocessing import StandardScaler # 特征预处理
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

# 2.加载查看数据集
iris_data = load_iris()
print(f'iris数据集描述:\n{iris_data.DESCR}')

# 3.数据集可视化
# 3.1转化为df格式
iris_df=pd.DataFrame(iris_data.data,columns=iris_data.feature_names)
iris_df['label']=iris_data.target
# 3.2展示散点图
col1=iris_data.feature_names[0] # 使用花萼长做x轴
col2=iris_data.feature_names[1] # 使用花萼宽做y轴
sns.lmplot(data=iris_df,x=col1,y=col2,hue='label',fit_reg=True) # hue分组,fit_reg拟合线
plt.show()

image.png image.png

特征工程

看完数据集情况后,就开始进行数据集处理和特征工程了:

# ...接上文
# 4.数据集划分 参1:特征数据;参2:标签数据;参3:测试集占比;参4:随机种子
x_train, x_test,  y_train, y_test = train_test_split(iris_data.data, iris_data.target, 
                                                     test_size=0.2, random_state=1)
print('训练集样本数量:',len(x_train)) # 150*0.8=120
print('测试集样本数量:',len(x_test)) # 150*0.2=30
# 5.特征预处理--处理特征量纲问题,常用归一化和标准化
transfer = StandardScaler()
x_train = transfer.fit_transform(x_train)
x_test = transfer.transform(x_test)
image.png

模型处理

处理完数据集和特征后就可以处理模型了,训练模型->预测模型->评估模型:

# 7.模型训练
estimator = KNeighborsClassifier()
estimator.fit(x_train, y_train)
# 8.模型预测
y_predict = estimator.predict(x_test)
print('预测结果:\n',y_predict[:5],'\n真实结果:\n',y_test[:5])
# 9.模型评估
score = estimator.score(x_test, y_test)
print('准确率:\n',score)
image.png

不同K值下的决策曲线

  • KK越小其越容易受到异常点的影响,其决策曲线会很曲折;
  • KK越大其越容易收到数据平衡的影响从而变得平滑。
from mlxtend.plotting import plot_decision_regions
# 1. 准备数据
iris = load_iris()
# 仅提取前两个特征进行二维可视化
X = iris.data[:, :2]
y = iris.target

# 2. 定义不同的 K 值
k_values = [1, 5, 15]
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
plt.subplots_adjust(wspace=0.3)

# 3. 循环绘图
for i, k in enumerate(k_values):
    # 使用 sklearn 的 KNN(逻辑与你的自定义版本一致,但速度更快支持网格绘图)
    knn = KNeighborsClassifier(n_neighbors=k)
    knn.fit(X, y)

    # 使用 mlxtend 绘制决策区域
    plot_decision_regions(X, y, clf=knn, ax=axes[i], legend=2)

    # 设置标题和标签
    axes[i].set_title(f'K = {k}')
    axes[i].set_xlabel('Sepal length')
    axes[i].set_ylabel('Sepal width')

# 4. 整体展示
plt.suptitle('图5:不同K值下的决策边界对比', fontsize=16, y=1.05)
plt.show()
image.png
图5:不同K值下的决策边界对比

手写简易KNN分类器

了解完 KNN 的核心原理和实现流程,下面将从不同的角度实现多个KNN分类器。

暴力搜索

最直接简单的思路,训练集存储方式就是二维数组,计算距离使用numpynumpy

预测类别就是将测试集的每一条数据与训练集所有数据进行距离计算,根据距离排序,获取邻居索引,找到邻居类别,最后进行投票表决。

predict 函数中的核心字段

  • distance : 数组,存储内容为当前节点到每一个训练集节点的距离;
  • k_ind : 数组,存储内容训练集节点的索引,按照 distance 降序,取前 K 个;
  • k_lables : 数组,存储内容为前 K 个邻居的类别,用于投票决定结果;
  • label_list : 列表,存储内容为测试集样本的预测结果类别。
import time
import numpy as np


class MyKNeighborsClassifier:
    # 初始化
    def __init__(self, k=5, metric='euclidean'):
        self.y_train = None
        self.X_train = None
        self.k = k
        self.metric = metric

    # 计算距离
    def _distance(self, a, b):
        if self.metric == 'euclidean':
            return np.sqrt(np.sum((a - b) ** 2))
        elif self.metric == 'manhattan':
            return np.sum(np.abs(a - b))
        return None

    # 拟合函数
    def fit(self, x, y):
        self.X_train = x
        self.y_train = y

    # 预测函数
    def predict(self, X):
        start_time = time.time()
        label_list = []
        for x in X:
            # 计算两个样本的距离
            distance = np.array([self._distance(x, x_train) for x_train in self.X_train])
            # 按照距离进行排序,获取dist最大的前K个索引
            k_ind = np.argsort(distance)[:self.k]
            # 根据索引获取标签
            k_labels = self.y_train[k_ind]
            # 投票 bincount(k_labels)--统计每个类别的频率;argmax()获取频率最高的索引
            label_list.append(int(np.bincount(k_labels).argmax()))
        end_time = time.time()
        print(f'运行时间:{end_time - start_time:.4f}秒')
        return np.array(label_list)


from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler


def myknn_test():
    # 1.加载鸢尾花数据集
    iris_data = load_iris()
    # 2.数据处理
    x_train, x_test, y_train, y_test = train_test_split(iris_data.data, iris_data.target, 
                                                        test_size=0.2,random_state=23)
    # 3.特征预处理
    transfer = StandardScaler()
    x_train = transfer.fit_transform(x_train)
    x_test = transfer.transform(x_test)
    # 4.模型训练
    myKNN = MyKNeighborsClassifier(k=5, metric='manhattan')
    myKNN.fit(x_train, y_train)
    # 5.模型预测
    y_predict = myKNN.predict(x_test)
    print(f'预测结果:{y_predict}')
    print(f'真实结果:{y_test}')
    # 6.模型评估
    print(f'准确率:{score(y_test, y_predict)}')


def score(y, y_pre):
    accuracy = np.sum(y == y_pre) / len(y)
    return accuracy


if __name__ == '__main__':
    myknn_test()
image.png

kdtree和balltree优化

前面介绍fitfit阶段时, KNN 实现训练集的多种存储方式,有数组和二叉树(kbtree/balltree)(kbtree/balltree),使用二叉树可以明显提高邻居查找的时间效率。

下面代码基于暴力搜索的实现添加了algorithm字段实现数据存储数据结构选择,

  • brute: 暴力搜索,使用二维数组存储训练集;
  • kdtree: 分割树,使用二叉树存储训练集,每层按照一个特征的中位数进行分割;
  • balltree: 球树,使用二叉树存储训练集,训练集样本按照距离进行画圈分隔;
  • auto: 根据特征的维度进行自动选择,特征维度<20是kdtree否则balltree

直接使用sklearn.neighbors import KDTree, BallTree,核心方法:

  • 初始化: KDTree(训练集, 叶子节点样本最少数量)
  • 查询: dist, ind = self.tree.query(测试集, K值)直接返回前K个邻居的距离和索引。
import time
import numpy as np
from sklearn.neighbors import KDTree, BallTree


class MyKNeighborsClassifier:
    # 初始化
    def __init__(self, k=5, metric='euclidean', algorithm='auto', leaf_size=30):
        self.y_train = None
        self.X_train = None
        self.k = k
        self.metric = metric
        # kb-tree和ball-tree优化
        self.algorithm = algorithm
        self.leaf_size = leaf_size
        self.tree = None

    # 计算距离
    def _distance(self, a, b):
        if self.metric == 'euclidean':
            return np.sqrt(np.sum((a - b) ** 2))
        elif self.metric == 'manhattan':
            return np.sum(np.abs(a - b))
        return None

    # 拟合函数
    def fit(self, x, y):
        self.X_train = x
        self.y_train = y
        # 使用不同数据结构存储
        if self.algorithm == 'kdtree':
            self.tree = KDTree(x, leaf_size=self.leaf_size)
        elif self.algorithm == 'balltree':
            self.tree = BallTree(x, leaf_size=self.leaf_size)
        elif self.algorithm == 'auto':
            if x.shape[0] < 20:
                self.tree = KDTree(x, leaf_size=self.leaf_size)
            else:
                self.tree = BallTree(x, leaf_size=self.leaf_size)

    # 预测函数
    def predict(self, X):
        start_time = time.time()
        label_list = []
        # 暴力搜索
        if self.tree is None:
            for x in X:
                distance = np.array([self._distance(x, x_train) for x_train in self.X_train])
                k_ind = np.argsort(distance)[:self.k]
                k_labels = self.y_train[k_ind]
                label_list.append(int(np.bincount(k_labels).argmax()))
        # kdtree和balltree - 批量查询
        else:
            _, k_ind = self.tree.query(X, k=self.k)
            for i in range(len(X)):
                k_labels = self.y_train[k_ind[i]]
                label_list.append(int(np.bincount(k_labels).argmax()))

        end_time = time.time()
        print(f'{self.algorithm}预测运行时间:{end_time - start_time:.4f}秒')
        return np.array(label_list)

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler


def myknn_test():
    # 1.加载鸢尾花数据集
    iris_data = load_iris()
    # 2.数据处理
    x_train, x_test, y_train, y_test = train_test_split(iris_data.data, iris_data.target, test_size=0.2,
                                                        random_state=23)
    # 3.特征预处理
    transfer = StandardScaler()
    x_train = transfer.fit_transform(x_train)
    x_test = transfer.transform(x_test)
    # 4.模型训练
    myKNN = MyKNeighborsClassifier(k=5, metric='manhattan',algorithm='brute')
    myKNN.fit(x_train, y_train)
    # 5.模型预测
    y_predict = myKNN.predict(x_test)
    print(f'预测结果:{y_predict}')
    print(f'真实结果:{y_test}')
    # 6.模型评估
    print(f'准确率:{score(y_test, y_predict)}')


def score(y, y_pre):
    accuracy = np.sum(y == y_pre) / len(y)
    return accuracy


if __name__ == '__main__':
    myknn_test()

注意:sklearn 的 BallTree 仅支持欧氏距离等闵氏距离变体,如需使用曼哈顿距离,请选择算法 brute 或 kdtree

时间对比

对比暴力搜索和二叉树(KD-Tree/Ball-Tree)查找的计算时间:

设训练集样本数量为 N ,测试集样本数量为 n ,样本特征数量为 b

  • 暴力搜索计算单个样本的时间复杂度为O(N*b),计算全部样本为 O(Nbn)O(N*b*n)

  • 二叉树搜索会有额外的搭建二叉树时间为O(N*logN),单个样本为O(logN),计算全部样本为 O(nlogN)O(n*logN)

def myknn_test():
    # 1.加载鸢尾花数据集
    iris_data = load_iris()
    # 复制50倍,150*50=7500行数据
    X_repeated = np.tile(iris_data.data, (50, 1))  # 纵向复制50倍
    y_repeated = np.tile(iris_data.target, 50)
    # 2.数据处理
    x_train, x_test, y_train, y_test = train_test_split(X_repeated, y_repeated, test_size=0.2,
                                                        random_state=23)
    # 3.特征预处理
    transfer = StandardScaler()
    x_train = transfer.fit_transform(x_train)
    x_test = transfer.transform(x_test)
    # 4.模型训练
    myKNN = MyKNeighborsClassifier(k=5, metric='manhattan',algorithm='auto')
    myKNN.fit(x_train, y_train)
    # 5.模型预测
    y_predict = myKNN.predict(x_test)
    print(f'预测结果:{y_predict}')
    print(f'真实结果:{y_test}')
    # 6.模型评估
    print(f'准确率:{score(y_test, y_predict)}')

    # 4.模型训练
    myKNN = MyKNeighborsClassifier(k=5, metric='manhattan', algorithm='brute')
    myKNN.fit(x_train, y_train)
    # 5.模型预测
    y_predict = myKNN.predict(x_test)
    print(f'预测结果:{y_predict}')
    print(f'真实结果:{y_test}')
    # 6.模型评估
    print(f'准确率:{score(y_test, y_predict)}')
image.png

维度灾难

维度灾难被认为是 KNN 算法最核心、最致命的缺陷,虽然还有“计算开销大”、“对异常值敏感”等问题,但是维度灾难从数学本质上影响了 KNN 。

  • “近邻”概念的失效: 前面提到随着空间维度 d 趋向于无穷大,查询点到“最近邻”的距离与到“最远邻”的距离之间的相对差异会趋近于零,此时 KNN 找最近的邻居思想就失效了。
  • 树结构索引(KD Tree/Ball Tree)的崩塌: 在高维时(通常 d>20d > 20),几乎每一个超平面分割都需要被回溯检查,KD Tree 的查询速度会退化到 O(N)O(N),甚至比直接暴力搜索(Brute Force)还要慢,因为还多了维护树结构的开销。
def myknn_test():
    # 1.加载鸢尾花数据集
    iris_data = load_iris()
    # 维度灾难
    n_noise_features = 100000  # 添加100000个噪声特征,总维度=4+100000=100004
    np.random.seed(1)
    noise = np.random.randn(iris_data.data.shape[0], n_noise_features) # 生成呈正态分布的随机值 (150,100000)
    X_high_dim = np.hstack([iris_data.data, noise])  # 横向拼接 (150,100004)
    y = iris_data.target
    # 2.数据处理
    x_train, x_test, y_train, y_test = train_test_split(X_high_dim, y, test_size=0.2,
                                                        random_state=23)
    # 3.特征预处理
    transfer = StandardScaler()
    x_train = transfer.fit_transform(x_train)
    x_test = transfer.transform(x_test)
    # 4.模型训练
    myKNN = MyKNeighborsClassifier(k=5, metric='manhattan',algorithm='auto')
    myKNN.fit(x_train, y_train)
    # 5.模型预测
    y_predict = myKNN.predict(x_test)
    print(f'预测结果:{y_predict}')
    print(f'真实结果:{y_test}')
    # 6.模型评估
    print(f'准确率:{score(y_test, y_predict)}')

    # 4.模型训练
    myKNN = MyKNeighborsClassifier(k=5, metric='manhattan', algorithm='brute')
    myKNN.fit(x_train, y_train)
    # 5.模型预测
    y_predict = myKNN.predict(x_test)
    print(f'预测结果:{y_predict}')
    print(f'真实结果:{y_test}')
    # 6.模型评估
    print(f'准确率:{score(y_test, y_predict)}')

理论上使用暴力搜索和树的准确率是一样的,但是在面对超高维度的数据时存在差异:

  • 浮点数计算误差;
  • 距离趋同导致的“等距离”选择。
image.png

折线图展示

测试在不同的维度下 手写KNN模型 的三种数据结构的计算时间和准确率。

def test_dimension_disaster():
    """系统性测试维度灾难对KNN性能的影响"""
    iris_data = load_iris()

    # 扩展维度范围,从 4 维一直到 2000 维,这样性能差异才显著
    dimensions = [4, 20, 100, 500, 1000, 2000]

    results = []
    # 初始化存储结构
    time_accum = {algo: [] for algo in ['brute', 'kdtree', 'balltree']}
    acc_accum = {algo: [] for algo in ['brute', 'kdtree', 'balltree']}

    for dim in dimensions:
        print(f"\n测试维度: {dim}")

        # 1. 构造高维数据
        n_noise = dim - iris_data.data.shape[1]
        if n_noise > 0:
            np.random.seed(42)
            noise = np.random.randn(iris_data.data.shape[0], n_noise)
            X_extended = np.hstack([iris_data.data, noise])
        else:
            X_extended = iris_data.data

        # 2. 数据处理
        x_train, x_test, y_train, y_test = train_test_split(
            X_extended, iris_data.target, test_size=0.2, random_state=23
        )
        transfer = StandardScaler()
        x_train = transfer.fit_transform(x_train)
        x_test = transfer.transform(x_test)

        # 3. 运行三种算法
        for algo in ['brute', 'kdtree', 'balltree']:
            # 注意:此处实例化你自定义的类
            myKNN = MyKNeighborsClassifier(k=5, algorithm=algo, metric='euclidean')

            # 只计时预测过程,因为 fit 树构建时间在低样本量下极快
            myKNN.fit(x_train, y_train)

            start = time.time()
            y_predict = myKNN.predict(x_test)
            elapsed = time.time() - start

            acc = score(y_test, y_predict)

            acc_accum[algo].append(acc)
            time_accum[algo].append(elapsed)

    # 调用绘图函数
    plot_dimension_curse(dimensions, acc_accum, time_accum)

def plot_dimension_curse(dimensions, acc_accum, time_accum):
    """绘制维度灾难的双轴折线图"""

    # 解决中文显示问题
    plt.rcParams['font.sans-serif'] = ['SimHei', 'Arial Unicode MS', 'sans-serif']
    plt.rcParams['axes.unicode_minus'] = False

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

    algorithms = ['brute', 'kdtree', 'balltree']
    colors = {'brute': '#e74c3c', 'kdtree': '#2ecc71', 'balltree': '#3498db'}
    markers = {'brute': 'o', 'kdtree': 's', 'balltree': '^'}

    # 图1: 准确率下降曲线 (展示维度灾难对距离识别的破坏)
    for algo in algorithms:
        ax1.plot(dimensions, acc_accum[algo],
                 label=f'{algo.upper()}',
                 color=colors[algo],
                 marker=markers[algo],
                 linewidth=2)

    ax1.set_title('准确率随维度增加的衰减\n(Accuracy vs Dimension)', fontsize=14)
    ax1.set_xlabel('特征维度 (Dimension)', fontsize=12)
    ax1.set_ylabel('准确率 (Accuracy)', fontsize=12)
    ax1.set_ylim(0, 1.1)
    ax1.grid(True, linestyle='--', alpha=0.6)
    ax1.legend()

    # 图2: 耗时曲线 (展示维度灾难对索引效率的破坏)
    for algo in algorithms:
        ax2.plot(dimensions, time_accum[algo],
                 label=f'{algo.upper()}',
                 color=colors[algo],
                 marker=markers[algo],
                 linewidth=2)

    ax2.set_title('预测耗时随维度增加的变化\n(Time vs Dimension)', fontsize=14)
    ax2.set_xlabel('特征维度 (Dimension)', fontsize=12)
    ax2.set_ylabel('耗时 (Seconds)', fontsize=12)

    # 使用对数坐标轴,因为高维下树的搜索时间会指数级爆炸
    ax2.set_yscale('log')
    ax2.grid(True, which="both", linestyle='--', alpha=0.5)
    ax2.legend()
    plt.tight_layout()
    plt.show()
image.png
图6:手写KNN在不同维度下的准确率与耗时变化

为排除自定义实现可能存在的效率偏差,下面使用 sklearn 原生的 KNeighborsClassifier 重复相同实验,结果如下:

image.png
图7:sklearn KNN在不同维度下的准确率与耗时变化

结果明显,当维度不断提高时,KNN 模型的准确率会急剧下降,同时预测时间会提升, sklearn 的暴力搜索使用了 Cython 优化和向量化运算,避免了 Python 层的显式循环,因此在大数据量下效率显著高于纯 NumPy 逐元素计算,同时维度提高到一定程度时暴力搜索时间要优于树。

算法局限性总结

缺陷根本原因在本文实验中的体现
预测耗时长懒惰学习,每次需全量扫描暴力搜索在7500样本时耗时显著
维度灾难距离在高维空间趋同维度≥500时准确率降至随机水平
存储开销大需保留全部训练数据需存储完整训练集(数组/树结构)
对异常值敏感少数噪声样本会扭曲决策边界可通过交叉验证选择合适K值缓解
特征尺度依赖距离未归一化标准化前大值特征主导距离计算

前沿进展

DW-KNN(双重加权KNN)

传统的KNN 简单易用,但在面对类别不平衡、决策边界不稳定以及无法验证邻居样本可靠性(如存在噪声或标签错误)时表现较差。

DWKNNDW-KNN通过实现双重权重 (距离权重 + 有效性权重) 提高邻居样本的可靠性和模型的稳定性:

  1. 邻居有效性权重:
  • 背景: 传统的 KNN 算法会使用统一权重或距离权重对找的 K 个邻居加权预测结果,但是忽略了邻居是异常点的情况,如果距离近的邻居是异常点,那么就会对预测结果有巨大的影响。
  • 有效性得分: 在模型预测之前,对训练集所有的样本划分为"好邻居"和"坏邻居",通过一个将有效性得分 viv_i 来区分。
  • 具体计算: 算法会遍历训练集每一个样本点xix_i,找到其KvK_v个邻居,有效性得分viv_i取决于这 KvK_v 个邻居中,有多少比例的样本标签与 xix_i 自身的标签一致,如果一个样本被错误标注,或者它深陷于异类样本的包围中(即噪声点),其有效性得分 viv_i 会非常低。
  1. 池化:
  • 背景: 找到 K 个邻居并计算好权重后,传统的 KNN 算法直接进行投票预测,但是预测结果容易受到个别离群点的干扰,举例:样本点旁边有一个距离极近、有效性得分也极高的邻居,但如果同类的其他邻居都离得很远且不可信,结果还是会受到这个点干扰。
  • 距离池化: 将投票的最小单位由邻居修改为类别,算法将这 KK 个邻居按类别(如类别 A、类别 B)分开,距离池化计算所有邻居距离的平均值δc\delta_c,并转化为该类的距离权重 wc(d)w_c^{(d)}
  • 有效性池化: 计算该类所有邻居 viv_i平均值,得到该类的有效性权重 wc(v)w_c^{(v)}。公式为:wc(v)=1kci:yi=cviw_c^{(v)} = \frac{1}{k_c} \sum_{i:y_i=c} v_i
  • 好处: 池化后衡量的不再是单个邻居的素质,它衡量的是某个类别邻居的整体素质,这种机制大大增强了算法在面对高维噪声数据时的预测稳定性

DW-KNN: 双加权K近邻改进算法

优化算法+KNN

K 值、距离度量类型和特征/邻居权重等超参数是影响一个 KNN 模型效率的核心,传统的格搜索(GridSearchCV)效率低下,而遗传算法(GA)粒子群优化(PSO) 等元启发式算法可在此空间中高效搜索最优超参数组合,这正是智能优化计算在机器学习中的直接应用。

总结

KNN以“零训练”的极致简洁和逼近贝叶斯最优的理论保证,在低维小样本场景中依然是最强有力的基线算法之一。但其维度灾难、计算开销和类别不平衡敏感等问题提醒我们:任何算法都有适用边界。理解这些边界,比记住公式本身更为重要。

参考文献

  1. Cover T M, Hart P E. Nearest neighbor pattern classification[J]. IEEE Transactions on Information Theory, 1967, 13(1): 21-27.
  2. Beyer K, Goldstein J, Ramakrishnan R, et al. When is “nearest neighbor” meaningful?[C]. International Conference on Database Theory, 1999.
  3. DW-KNN: 双加权K近邻改进算法 
  4. scikit-learn官方文档: scikit-learn.org/stable/modu…