十、感知机原理实现

130 阅读1分钟

##载入相关模块

import numpy as np

import pandas as pd

from sklearn.datasets import load_iris

from sklearn.model_selection import train_test_split

from collections import Counter

import matplotlib.pyplot as plt

##载入数据

iris = load_iris()

df = pd.DataFrame(iris.data, columns=iris.feature_names)

df['label'] = iris.target

df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']

##提取特征和样品

#取前面100个数,第一列、第二列和最后一列

data = np.array(df.iloc[:100, [0, 1, -1]])

#最后一个特征作为标签,其他的作为特征

X, y = data[:,:-1], data[:,-1]

#取80%作为训练,20%作为测试

#X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

y = np.array([1 if i == 1 else -1 for i in y]) #把原本取值为0和1的y,调整成-1和1

class Model:

 #初始化
def __init__(self):
    #初始化w,b和学习
    self.w = np.ones(len(data[0]) - 1, dtype=np.float32)
    self.b = 0
    self.l_rate = 0.1  
    # self.data = data  

#定义线性函数
def lin(self, x, w, b):       
    y = np.dot(x, w) + b
    return y

# 随机梯度下降法
def fit(self, X_train, y_train):
    is_wrong = False
    while not is_wrong:
        wrong_count = 0
        for d in range(len(X_train)):
            X = X_train[d]
            y = y_train[d]
            if y * self.lin(X, self.w, self.b) <= 0:             #意思是分错了,然后就改呗(梯度下降公式)
                self.w = self.w + self.l_rate * (y * X)
                self.b = self.b + self.l_rate * y
                wrong_count += 1                                  #分错了得记下来
        if wrong_count == 0:
            is_wrong = True                                       #若全对了,就不进如while循环了            
    return 'Perceptron Model!'

def score(self):
    pass

##模型训练

perceptron = Model() #实例化感知机

perceptron.fit(X, y) #进行训练

参数估计结果

perceptron.w[0],perceptron.w[1],perceptron.b

绘图

x_points = np.linspace(4, 7, 10)

y_ = -(perceptron.w[0] * x_points + perceptron.b) / perceptron.w[1]

plt.plot(x_points, y_)

plt.plot(data[:50, 0], data[:50, 1], 'bo', color='blue', label='0')

plt.plot(data[50:100, 0], data[50:100, 1], 'bo', color='orange', label='1')

plt.xlabel('sepal length')

plt.ylabel('sepal width')

plt.legend()