1、介绍
Sklearn (全称 Scikit-Learn) 是基于 Python 语言的机器学习工具。它建立在 NumPy, SciPy, Pandas 和 Matplotlib 之上,里面的 API 的设计非常好,所有对象的接口简单,很适合新手上路。
在 Sklearn 里面有六大任务模块:分别是分类、回归、聚类、降维、模型选择和预处理,如下图从其官网的截屏
1.1 常用模块
sklearn中常用的模块有分类、回归、聚类、降维、模型选择、预处理。
分类:识别某个对象属于哪个类别,常用的算法有:SVM(支持向量机)、nearest neighbors(最近邻)、random forest(随机森林),常见的应用有:垃圾邮件识别、图像识别。
回归:预测与对象相关联的连续值属性,常见的算法有:SVR(支持向量机)、 ridge regression(岭回归)、Lasso,常见的应用有:药物反应,预测股价。
聚类:将相似对象自动分组,常用的算法有:k-Means、 spectral clustering、mean-shift,常见的应用有:客户细分,分组实验结果。
降维:减少要考虑的随机变量的数量,常见的算法有:PCA(主成分分析)、feature selection(特征选择)、non-negative matrix factorization(非负矩阵分解),常见的应用有:可视化,提高效率。
模型选择:比较,验证,选择参数和模型,常用的模块有:grid search(网格搜索)、cross validation(交叉验证)、 metrics(度量)。它的目标是通过参数调整提高精度。
预处理:特征提取和归一化,常用的模块有:preprocessing,feature extraction,常见的应用有:把输入数据(如文本)转换为机器学习算法可用的数据。 、
2、基本使用
1、实例化
对于sklearn这个库来说,是用面向对象的思想来使用的。 就是每种算法都是一个对象,要使用某种算法解决问题,可以将其import进来,此时它只是个对象,只有在实例化之后才可以对数据学习和预测
from sklearn.neighbors import KNeighborsClassifier
obj = KNeighborsClassifier()
2、获取数据
from sklearn import datasets
iris = datasets.load_iris()
x = iris.data
y = iris.target
导入的数据是个训练集,并不能直接传递给实例对其学习,因为实例需要特征和标签才能开始学习(这个实例是这个情况的,无监督学习现在还不那么清楚是个什么样的情况)不过特征和标签都存储在data和target标签中,可以轻松的获得。
3、开始训练和预测
这个时候,实例有了,数据也有了,可以开始对数据训练和学习了
obj.fit(x, y)
y_pred = obj.predict(x[:4, :]
print(y - y_pred)
4、划分训练集
一般把训练集划分为训练集和测试集,这样来验证算法的准确性比较好
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3)
其中参数test_size=0.3代表将把训练集的30%数据当做测试集使用
现在有了训练集和测试集,重复之前的步骤,对训练集训练:
5、参数获取
对于每个学习到的模型,也就是一个假设函数,都有一些参数
还有预测的时候也是有参数设置的
对于这两个参数的获取,可以通过实例化后的实例的coef_,intercept_属性和get_params()获得
由于分类好像没有实例的那几个属性,换个线性回归的的模型
from sklearn import datasets
from sklearn.linear_model import LinearRegression
#加载数据
loaded_data = datasets.load_boston()
data_X = loaded_data.data
data_y = loaded_data.target
# 实例化
model = LinearRegression()
# 开始训练
model.fit(data_X, data_y)
# 预测
y_pred = model.predict(data_X[:4, :])
# 打印参数
print(model.coef_)
print(model.intercept_)
6、评估模型
最经常用的评估方法,越接近于1的效果越好:
print(model.score(data_X, data_y))
# 0.7406077428649427