机器学习-KNN

57 阅读1分钟

KNN模型的使用

1. 加载数据集

  • 使用sklearn 加载鸢尾花数据集
from sklearn.datasets import load_iris
load_iris(return_X_y=True)

-参数return_X_y=True只返回样本与标签, X=样本, y=标签(结果)

X, y = load_iris(return_X_y=True)

2. 拆分数据

-使用train_test_split函数切分数据

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)

-参数test_size代表测试集与训练集的比例

3. 创建模型

  • 实例化knn模型对象
from sklearn.neighbors import KNeighborsClassifier
knn = KNeighborsClassifier(n_neighbors=5)
  • 模型训练
knn.fit(X_train, y_train)
  • 模型预测
y_pred = knn.predict(X_test)

4. 模型评估

  • 分类问题,可以通过计算准确率来评估模型
(y_pred == y_test).mean()
  • 通过比较预测结果与验证集结果, 然后求平均值 mean

5. 模型的保存与加载

  • 保存dump
import joblib
joblib.dump(knn, "knn.model")
  • 加载load
knn = joblib.load("knn.model")
knn.predict(X_test)