使用 Python 实现 KNN 分类红酒数据集的核心算法代码

193 阅读1分钟

KNN 分类红酒数据集

概述

使用 KNN(K 近邻)算法,对红酒(wine)数据集进行分类。

代码

import

导入必要的组件,包括 numpy、pandas、sklearn 等库

from collections import Counter
from typing import Optional

import numpy as np
from pandas import DataFrame
from sklearn.datasets import load_wine
from sklearn.metrics import accuracy_score
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler

数据预处理

  • 数据集划分
  • 归一化(极大提升准确率)
# 导入 sklearn 提供的 wine 数据集
wine = load_wine(return_X_y=True, as_frame=True)

data: DataFrame = wine[0]
target: DataFrame = wine[1]

# 划分训练集与测试集(2:1)
X_train, X_test, y_train, y_test = train_test_split(data, target, test_size=0.33)

# 对各个属性进行“归一化”处理
scalar = MinMaxScaler()
X_train = scalar.fit_transform(X_train)
X_test = scalar.fit_transform(X_test)

y_train = y_train.to_numpy()
y_test = y_test.to_numpy()

定义模型

class KNN:
    def __init__(self, k: int):
        self.k: int = k
        self.data: Optional[np.ndarray] = None
        self.target: Optional[np.ndarray] = None

    def fit(self, data_: np.ndarray, target_: np.ndarray) -> None:
        self.data = data_
        self.target = target_

    def predict(self, data_: np.ndarray) -> int:
        # 最小的 k 个
        distances = np.array([self.euclidean(row, data_) for row in self.data]).argsort()[:self.k]

        counter = Counter([self.target[index] for _, index in enumerate(distances)])

        return counter.most_common(1)[0][0]

    # 欧式距离
    @staticmethod
    def euclidean(x1: np.ndarray, x2: np.ndarray) -> float:
        return euclidean_distances(
            np.reshape(x1, (1, -1,)),
            np.reshape(x2, (1, -1,)),
        )[0, 0]

创建 KNN 模型实例

K: int = 4
knn: KNN = KNN(k=K)

训练

knn.fit(X_train, y_train)

测试

使用 Accuracy 作为精度评判标准,已达到的最高正确率:0.9830508474576272。

predicts: list = []
targets: list = []
for idx, (X, y) in enumerate(zip(X_test, y_test)):
    predicts.append(knn.predict(X))
    targets.append(y)

accuracy_score(predicts, targets)