K-邻近

240 阅读1分钟
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from sklearn import neighbors
from sklearn import datasets

N_neighbors=15
Weights='distance'

def train(n_neighbors,weights):
    return neighbors.KNeighborsClassifier(n_neighbors,weights)
if __name__ == '__main__':
    iris = datasets.load_iris()
    # print(iris.data.shape,iris.target.shape,iris.data,iris.target)
    X = iris.data[:, :2]
    y = iris.target
    # print(X)
    camp_light =ListedColormap(['#FFAAAA','#AAFFAA','#AAAAFF'])
    camp_bold = ListedColormap(['#FF0000','#00FF00','#0000FF'])
    n_neighbour=N_neighbors
    weights=Weights
    clf = train(n_neighbour,weights)
    clf.fit(X,y)
    x_min,x_max =X[:,0].min()-1,X[:,0].max()+1
    y_min,y_max =X[:,1].min()-1,X[:,1].max()+1
    xx,yy = np.meshgrid(np.arange(x_min,x_max,.02),np.arange(y_min,y_max,.02))
    Z = clf.predict(np.c_[xx.ravel(),yy.ravel()])
    Z =Z.reshape(xx.shape)
    #拿出画纸准备画图
    plt.figure()
    # plt.pcolormesh画出分类边界
    plt.pcolormesh(xx,yy,Z,cmap=camp_light)
    
    #plt.scatter 用来绘图
    plt.scatter(X[:,0],X[:,1],c=y,cmap=camp_bold)
    #plt.xlim(a,b) 显示区域[a,b]
    plt.xlim(xx.min(),xx.max())
    plt.ylim(yy.min(),yy.max())
    plt.title("3-Class classification (k=%i,weights='%s')"%(n_neighbour,weights))

    plt.show()