基于NMF与决策树的鸢尾花分类

361 阅读1分钟

NMF降维

NMF对鸢尾花降维 此处不再赘述NMF.

# NMF预处理
nmf = decomposition.NMF(n_components=2)

dataset = load_iris()
X = dataset.data
y = dataset.target

reduced_dataset = nmf.fit_transform(X)  # 降到二维后, reduce_X应该是二维的, 第一个维度大小等于样本个数, 第二个维度存储对应的第n个样本的降维后的两个特征
reduced_X = [[row[0], row[1]] for row in reduced_dataset]  # 使用列表生成器, 将降维后的特征打包
reduced_X = np.array(reduced_X)

决策树

具体原理不再赘述. 直接上代码:

X_train, X_test, y_train, y_test = train_test_split(reduced_X, y, test_size=0.3)

clf = tree.DecisionTreeClassifier(criterion="entropy")
clf = clf.fit(X_train, y_train.astype("int"))
score = clf.score(X_test, y_test)

print(score)  # 模型性能

feature_name = ['First Feature', 'Second Feature']
print([*zip(feature_name, clf.feature_importances_)])  # 不同特征对模型的重要(贡献)程度

可视化

使用Graphviz进行可视化. 具体安装参考.

dot_data = tree.export_graphviz(clf,
                                feature_names=feature_name,
                                class_names=['Setosa', 'Versicolour', 'Virginica'],
                                filled=True,
                                rounded=True
                                )

graph = graphviz.Source(dot_data)
graph.render("Visible")

总结

在鸢尾花数据集上, NMF降维后的分类, 与不降维直接用决策树分类, 效果差不多(相差不大或相等), 所用时间开销NMF<原始数据. 所以推荐降维后进行分类训练.