3. k近邻法

425 阅读6分钟

在这里插入图片描述

k近邻算法


  • k近邻法(k-nearest neighbor,k-NN):给定一个训练数据集,对新的输入实例,在训练数据集中找到与该实例最近邻的k个实例,这k个实例的多数属于某个类,就把该输入实例分为这个类。KNN使用的模型实际上对应于特征空间的划分,没有显式的训练过程。 在这里插入图片描述

k近邻模型


模型


  • 该模型有三个基本要素: 距离度量,k值的选择,分类决策规则 .当这三个要素确定后,便能对于任何一个新的输入实例,给出唯一确定的分类.这里用图片说明更清楚,对于训练集中的每一个样本,距离该点比其他点更近的所有点组成一片区域,叫做单元.每个样本都拥有一个单元,所有样本的单元最终构成对整个特征空间的划分,且对每个样本而言,它的标签就是该单元内所有点的标记.这样每个单元的样本点的标签也就是唯一确定的. 在这里插入图片描述

三要素


  1. 距离度量

    • pp距离:特征空间中两个实例点的距离是两个实例点相优程度的反映。设输入实例 xRn,xixjLpx \in \mathbb{R}^{n}, x_{i} 和 x_{j} 的 L_{p}距离定义为: ​Lp(xi,xj)=(l=1nxilxjlpp)(1p)L_{p}\left(x_{i}, x_{j}\right)=\left(\sum_{l=1}^{n}\left|x_{i}^{l}-x_{j}^{l}\right| p^{p}\right)^{\left(\frac{1}{p}\right)}\qquad\quad\quad\quad\quad\quad\quad\quad\quad\quad\quad\quad\quad\quad\quad
    • 欧氏距离:p=2p=2时的特殊情况。 ​L2(xi,xj)=(l=1nxi(l)xj(l)2)12L_{2}\left(x_{i}, x_{j}\right)=\left(\sum_{l=1}^{n}\left|x_{i}^{(l)}-x_{j}^{(l)}\right|^{2}\right)^{\frac{1}{2}}
    • 曼哈顿距离:p=1p=1时的特殊情况。 ​L1(xi,xj)=l=1nxi(l)xj(l)L_{1}\left(x_{i}, x_{j}\right)=\sum_{l=1}^{n}\left|x_{i}^{(l)}-x_{j}^{(l)}\right|
    • 切比雪夫距离:p=p={\infty}时的特殊情况。 ​L(xi,xj)=maxlxi(l)xj(l)L_{\infty}\left(x_{i}, x_{j}\right)=\max _{l}\left|x_{i}^{(l)}-x_{j}^{(l)}\right| 在这里插入图片描述
  2. k值的选择

    • 较小的k值代表整体模型变得复杂,分类结果容易被噪声点影响,容易发生过拟合。 较大的k值代表整体模型变得简单,容易欠拟合。 在应用中,k值一般取一个比较小的数值,通常采用交叉验证法来选取最优的k值。
  3. 分类决策规则

    • k近邻法中的分类决策规则往往是多数表决,即由输入实例的k个邻近的训练实例中的多数类决定输入实例的类。
    • 多数表决规则 如果分类的损失函数为010-1损失函数,分类函数为 ​f:Rnc1,c2,,ckf: \mathbf{R}^{n} \rightarrow c_{1}, c_{2}, \ldots, c_{k} 那误分类的概率是 ​P(Yf(X))=1P(Y=f(X))P(Y \neq f(X))=1-P(Y=f(X)) 对给定的实例 xχ1x \in \chi_{1} 其最近邻的kk个训练实例点构成集合 Nk(x)N_{k}(x) 。如果涵盖 Nk(x)N_{k}(x) 的区域类别是 cjc_{j} ,那误分类率是 ​1kxiNk(x)I(yicj)=1kxiNk(x)I(yi=cj)\frac{1}{k} \sum_{x_{i} \in N_{k}(x)} I\left(y_{i} \neq c_{j}\right)=1-k \sum_{x_{i} \in N_{k}(x)} I\left(y_{i}=c_{j}\right) 要使误分类率最小即经验风险最小,就要使xiNk(x)I(yi=cj)\sum_{x_{i} \in N_{k}(x)} I\left(y_{i}=c_{j}\right)最大,所以多数表决规则等价于经验风险最小化。

k近邻算法的实现:kd树


构造kd树


  • 输入: k 维空间数据集: T={x1,x2,,xN}T=\left\{x_{1}, x_{2}, \cdots, x_{N}\right\} 其中, xi=(xi(1),xi(2),,xi(k))Tx_{i}=\left(x_{i}^{(1)}, x_{i}^{(2)}, \cdots, x_{i}^{(k)}\right)^{T} 输出:kd树

    1. 开始:构造根节点。 选取 x(1)x^{(1)} 为坐标轴,以训练集中的所有数据 x(1)x^{(1)} 坐标中的中位数作为切分点, 将超矩形区域切割成两个子区域, 将该切分点作为根结点。 由根结点生出深度为 1 的左右子结点,左节点对应坐标小于切分点,右结点对应坐标大于切分点。
    2. 重 复 对深度为 jj 的结点, 选择 x(I)x^{(I)} 为切分坐标轴, I=j(modk)+1I=j(\bmod k)+1 , 以该结点区域中所有实例 x(l)x^{(l)} 坐标的中位数 作为切分点, 将区域分为两个子区域。 生成深度为 j+1j+1 的左、右子结点。左节点对应坐标小于切分点,右结点对应坐标大于切分点。
    3. 直到两个子 区域没有实例时停止。
  • 举例 输入:训练集: T={(2,3),(5,4),(9,6),(4,7),(8,1),(7,2)}T=\{(2,3),(5,4),(9,6),(4,7),(8,1),(7,2)\} 输出: kd 树 在这里插入图片描述

    x(1):2,4,5,7,8,9x^{(1)}: \quad 2,4,5,7,8,9 开始:选择 x(1)x^{(1)} 为坐标轴,中位数为 77, 即 (7,2)(7,2) 为切分点, 切分整个区域 在这里插入图片描述

    再次划分区域: 以 x(2)x^{(2)} 为坐标轴,选择中位数, 左边区域为 44, 右边区域为 66。故左边区域切分点为 (5,4)(5,4) , 右边区域切分点坐标为 (9,6)(9,6) 在这里插入图片描述

    划分左边区域: 以 x(1)x^{(1)} 为坐标轴,选择中位数,上边区域为 44, 下边区域为 22。故上边 区域切分点为 (4,7)(4,7) , 下边区域切分点坐标为 (2,3)(2,3) 在这里插入图片描述

    划分右边区域: 以 x(1)x^{(1)} 为坐标轴,选择中位数, 上边区域无实例点, 下边区域为 88。故 下边区域切分点坐标为 (8,1)(8,1) 在这里插入图片描述 最终划分结果 在这里插入图片描述

    kd树 在这里插入图片描述

    至此算法完成

搜索kd树


  • 用kd树的最近邻搜索

    • 寻找 “当前最近点” 寻找最近邻的子结点作为目标点的“当前最近点”。
    • 回溯 以目标点和“当前最近点” 的距离沿树根部进行回溯和迭代。
    • 详细描述 输入:已构造的 kd 树, 目标点 x 输出: x 的最近邻
      • 寻找“当前最近点”
        • 从根结点出发, 递归访问 kd 树, 找出包含 x 的叶结点; 以此叶结点为“当前最近点"”;
      • 回溯
        • 若该结点比 “当前最近点” 距离目标点更近, 更新“当前最近点”;
        • 当前最近点一定存在于该结点一个子结点对应的区域, 检查子结点 的父结点的另一子结点对应的区域是否有更近的点。
      • 当回退到根结点时, 搜索结束, 最后的“当前最近点” 即为 x 的最近邻点。
  • 举例 输入: kd 树, 目 标点 x=(2,4.5)x=(2,4.5) ; 输出:最近邻点 在这里插入图片描述

    第一次回溯 在这里插入图片描述 第二次回溯,最近邻点:(2,3)(2,3) 在这里插入图片描述

    如果实例点是随机分布的, kd 树搜索的平均计算复杂度是 O(logN)O(\log N) , 这里 NN 是 训练实例数。 kd 树更适用于训练实例数远大于空间维数时的 k 近邻搜索。当空间维委 接近训练实例数时,它的效率会迅速下降, 几乎接近线性扫描。

代码


import torch
import random
import matplotlib.pyplot as plt


class DrawTool():
    """画图类"""

    # 画点[数据集,x点,离 x点 最近的点]
    def drawPoint(self, points, x, nearestPoint):
        XMax = max(points[:, 0])  # X 轴范围
        YMax = max(points[:, 1])  # Y 轴范围
        precision = max(XMax, YMax) // 10 + 1  # 坐标轴精度
        #plt.rcParams['font.sans-serif'] = ['SimHei']  # 防止中文乱码
        plt.scatter(points[:, 0], points[:, 1], label="data")
        plt.scatter(x[0], x[1], c='c', marker='*', s=100, label="x(input)")
        plt.scatter(nearestPoint[0], nearestPoint[1], c='r', label="nearest")
        plt.xticks(torch.arange(0, XMax, precision))  # 设置 X 轴
        plt.yticks(torch.arange(0, YMax, precision))  # 设置 Y 轴
        plt.legend(loc='upper left')
        plt.show()


class DataLoader():
    """"数据加载类"""

    # 初始化[creat:人造数据集,random:随机数据集]
    def __init__(self, kind="creat"):
        self.points = None
        if (kind == "creat"):
            self.x = [2, 5, 9, 4, 8, 7]
            self.y = [3, 4, 6, 7, 1, 2]
        elif kind == "random":
            nums = random.randint(20, 40)
            self.x = random.sample(range(0, 40), nums)
            self.y = random.sample(range(0, 40), nums)

    # 处理数据
    def getData(self):
        self.points = [[self.x[i], self.y[i]] for i in range(len(self.x))]
        return self.points

    # 得到一个与数据集不重复的点,作为 x 点
    def getRandomPoint(self):
        points = torch.tensor(self.points)
        x, y, i = -1, -1, 0
        while x == -1 or y == -1:
            if x == -1 and i not in points[:, 0]:
                x = i
            if y == -1 and i not in points[:, 1]:
                y = i
            i += 2
        return x, y


class KDNode():#二叉树
    """"节点类"""

    def __init__(self, point):
        self.point = point
        self.left = None
        self.right = None


class KDTree():
    """KD树"""

    def __init__(self):
        self.root = None
        self.nearestPoint = None
        self.nearestDis = float('inf')

    # 创造和搜索 KD树[数据集,x]
    def creatAndSearch(self, points, x):
        self.root = self.creatTree(points)
        self.searchTree(self.root, x)

    # 创造 KD树[数据集,维度]
    def creatTree(self, points, col=0):
        if len(points) == 0:
            return None
        points = sorted(points, key=lambda point: point[col])
        mid = len(points) >> 1
        node = KDNode(points[mid])
        node.left = self.creatTree(points[0:mid], col ^ 1)
        node.right = self.creatTree(points[mid + 1:len(points)], col ^ 1)
        return node

    # 搜索 KD树[KD树,x,维度]
    def searchTree(self, tree, x, col=0):
        if tree == None:
            return

        # 对应算法中第 1 步
        if x[col] < tree.point[col]:
            self.searchTree(tree.left, x, col ^ 1)
        else:
            self.searchTree(tree.right, x, col ^ 1)

        disCurAndX = self.dis(tree.point, x)
        if disCurAndX < self.nearestDis:
            self.nearestDis = disCurAndX
            self.nearestPoint = tree.point

        # 判断目前最小圆是否与其他区域相交,即判断 |x(按轴读值)-节点(按轴读值)| < 最近的值(圆的半径)
        # 对应算法中第 3 步中的 (b)
        if abs(tree.point[col] - x[col]) < self.nearestDis:
            if tree.point[col] < x[col]:
                self.searchTree(tree.right, x, col ^ 1)
            else:
                self.searchTree(tree.left, x, col ^ 1)

    # 两点间距离[a点,b点]
    def dis(self, a, b):
        return sum([(a[i] - b[i]) ** 2 for i in range(len(a))]) ** 0.5#欧氏距离

    # 前序遍历 KD树(测试使用)[KD树]
    def printTree(self, root):
        if root != None:
            print(root.point)
            self.printTree(root.left)
            self.printTree(root.right)


if __name__ == '__main__':
    drawTool = DrawTool()
    dataLoader = DataLoader("random")
    kdTree = KDTree()

    points = dataLoader.getData()
    x = dataLoader.getRandomPoint()

    kdTree.creatAndSearch(points, x)
    drawTool.drawPoint(torch.tensor(points), x, kdTree.nearestPoint)

在这里插入图片描述