k近邻算法
- k近邻法(k-nearest neighbor,k-NN):给定一个训练数据集,对新的输入实例,在训练数据集中找到与该实例最近邻的k个实例,这k个实例的多数属于某个类,就把该输入实例分为这个类。KNN使用的模型实际上对应于特征空间的划分,没有显式的训练过程。
k近邻模型
模型
- 该模型有三个基本要素: 距离度量,k值的选择,分类决策规则 .当这三个要素确定后,便能对于任何一个新的输入实例,给出唯一确定的分类.这里用图片说明更清楚,对于训练集中的每一个样本,距离该点比其他点更近的所有点组成一片区域,叫做单元.每个样本都拥有一个单元,所有样本的单元最终构成对整个特征空间的划分,且对每个样本而言,它的标签就是该单元内所有点的标记.这样每个单元的样本点的标签也就是唯一确定的.
三要素
-
距离度量
- 距离:特征空间中两个实例点的距离是两个实例点相优程度的反映。设输入实例 距离定义为:
- 欧氏距离:时的特殊情况。
- 曼哈顿距离:时的特殊情况。
- 切比雪夫距离:时的特殊情况。
-
k值的选择
- 较小的k值代表整体模型变得复杂,分类结果容易被噪声点影响,容易发生过拟合。 较大的k值代表整体模型变得简单,容易欠拟合。 在应用中,k值一般取一个比较小的数值,通常采用交叉验证法来选取最优的k值。
-
分类决策规则
- k近邻法中的分类决策规则往往是多数表决,即由输入实例的k个邻近的训练实例中的多数类决定输入实例的类。
- 多数表决规则
如果分类的损失函数为损失函数,分类函数为
那误分类的概率是
对给定的实例 其最近邻的个训练实例点构成集合 。如果涵盖 的区域类别是 ,那误分类率是
要使误分类率最小即经验风险最小,就要使最大,所以多数表决规则等价于经验风险最小化。
k近邻算法的实现:kd树
构造kd树
-
输入: k 维空间数据集: 其中, 输出:kd树
- 开始:构造根节点。 选取 为坐标轴,以训练集中的所有数据 坐标中的中位数作为切分点, 将超矩形区域切割成两个子区域, 将该切分点作为根结点。 由根结点生出深度为 1 的左右子结点,左节点对应坐标小于切分点,右结点对应坐标大于切分点。
- 重 复 对深度为 的结点, 选择 为切分坐标轴, , 以该结点区域中所有实例 坐标的中位数 作为切分点, 将区域分为两个子区域。 生成深度为 的左、右子结点。左节点对应坐标小于切分点,右结点对应坐标大于切分点。
- 直到两个子 区域没有实例时停止。
-
举例 输入:训练集: 输出: kd 树
开始:选择 为坐标轴,中位数为 , 即 为切分点, 切分整个区域
再次划分区域: 以 为坐标轴,选择中位数, 左边区域为 , 右边区域为 。故左边区域切分点为 , 右边区域切分点坐标为
划分左边区域: 以 为坐标轴,选择中位数,上边区域为 , 下边区域为 。故上边 区域切分点为 , 下边区域切分点坐标为
划分右边区域: 以 为坐标轴,选择中位数, 上边区域无实例点, 下边区域为 。故 下边区域切分点坐标为
最终划分结果
kd树
至此算法完成
搜索kd树
-
用kd树的最近邻搜索
- 寻找 “当前最近点” 寻找最近邻的子结点作为目标点的“当前最近点”。
- 回溯 以目标点和“当前最近点” 的距离沿树根部进行回溯和迭代。
- 详细描述
输入:已构造的 kd 树, 目标点 x
输出: x 的最近邻
- 寻找“当前最近点”
- 从根结点出发, 递归访问 kd 树, 找出包含 x 的叶结点; 以此叶结点为“当前最近点"”;
- 回溯
- 若该结点比 “当前最近点” 距离目标点更近, 更新“当前最近点”;
- 当前最近点一定存在于该结点一个子结点对应的区域, 检查子结点 的父结点的另一子结点对应的区域是否有更近的点。
- 当回退到根结点时, 搜索结束, 最后的“当前最近点” 即为 x 的最近邻点。
- 寻找“当前最近点”
-
举例 输入: kd 树, 目 标点 ; 输出:最近邻点
第一次回溯
第二次回溯,最近邻点:
如果实例点是随机分布的, kd 树搜索的平均计算复杂度是 , 这里 是 训练实例数。 kd 树更适用于训练实例数远大于空间维数时的 k 近邻搜索。当空间维委 接近训练实例数时,它的效率会迅速下降, 几乎接近线性扫描。
代码
import torch
import random
import matplotlib.pyplot as plt
class DrawTool():
"""画图类"""
# 画点[数据集,x点,离 x点 最近的点]
def drawPoint(self, points, x, nearestPoint):
XMax = max(points[:, 0]) # X 轴范围
YMax = max(points[:, 1]) # Y 轴范围
precision = max(XMax, YMax) // 10 + 1 # 坐标轴精度
#plt.rcParams['font.sans-serif'] = ['SimHei'] # 防止中文乱码
plt.scatter(points[:, 0], points[:, 1], label="data")
plt.scatter(x[0], x[1], c='c', marker='*', s=100, label="x(input)")
plt.scatter(nearestPoint[0], nearestPoint[1], c='r', label="nearest")
plt.xticks(torch.arange(0, XMax, precision)) # 设置 X 轴
plt.yticks(torch.arange(0, YMax, precision)) # 设置 Y 轴
plt.legend(loc='upper left')
plt.show()
class DataLoader():
""""数据加载类"""
# 初始化[creat:人造数据集,random:随机数据集]
def __init__(self, kind="creat"):
self.points = None
if (kind == "creat"):
self.x = [2, 5, 9, 4, 8, 7]
self.y = [3, 4, 6, 7, 1, 2]
elif kind == "random":
nums = random.randint(20, 40)
self.x = random.sample(range(0, 40), nums)
self.y = random.sample(range(0, 40), nums)
# 处理数据
def getData(self):
self.points = [[self.x[i], self.y[i]] for i in range(len(self.x))]
return self.points
# 得到一个与数据集不重复的点,作为 x 点
def getRandomPoint(self):
points = torch.tensor(self.points)
x, y, i = -1, -1, 0
while x == -1 or y == -1:
if x == -1 and i not in points[:, 0]:
x = i
if y == -1 and i not in points[:, 1]:
y = i
i += 2
return x, y
class KDNode():#二叉树
""""节点类"""
def __init__(self, point):
self.point = point
self.left = None
self.right = None
class KDTree():
"""KD树"""
def __init__(self):
self.root = None
self.nearestPoint = None
self.nearestDis = float('inf')
# 创造和搜索 KD树[数据集,x]
def creatAndSearch(self, points, x):
self.root = self.creatTree(points)
self.searchTree(self.root, x)
# 创造 KD树[数据集,维度]
def creatTree(self, points, col=0):
if len(points) == 0:
return None
points = sorted(points, key=lambda point: point[col])
mid = len(points) >> 1
node = KDNode(points[mid])
node.left = self.creatTree(points[0:mid], col ^ 1)
node.right = self.creatTree(points[mid + 1:len(points)], col ^ 1)
return node
# 搜索 KD树[KD树,x,维度]
def searchTree(self, tree, x, col=0):
if tree == None:
return
# 对应算法中第 1 步
if x[col] < tree.point[col]:
self.searchTree(tree.left, x, col ^ 1)
else:
self.searchTree(tree.right, x, col ^ 1)
disCurAndX = self.dis(tree.point, x)
if disCurAndX < self.nearestDis:
self.nearestDis = disCurAndX
self.nearestPoint = tree.point
# 判断目前最小圆是否与其他区域相交,即判断 |x(按轴读值)-节点(按轴读值)| < 最近的值(圆的半径)
# 对应算法中第 3 步中的 (b)
if abs(tree.point[col] - x[col]) < self.nearestDis:
if tree.point[col] < x[col]:
self.searchTree(tree.right, x, col ^ 1)
else:
self.searchTree(tree.left, x, col ^ 1)
# 两点间距离[a点,b点]
def dis(self, a, b):
return sum([(a[i] - b[i]) ** 2 for i in range(len(a))]) ** 0.5#欧氏距离
# 前序遍历 KD树(测试使用)[KD树]
def printTree(self, root):
if root != None:
print(root.point)
self.printTree(root.left)
self.printTree(root.right)
if __name__ == '__main__':
drawTool = DrawTool()
dataLoader = DataLoader("random")
kdTree = KDTree()
points = dataLoader.getData()
x = dataLoader.getRandomPoint()
kdTree.creatAndSearch(points, x)
drawTool.drawPoint(torch.tensor(points), x, kdTree.nearestPoint)