0、数据
年龄:0青年,1中年,2老年;
有工作:0否,1是;
有自己的房子:0否,1是;
信贷情况:0一般,1好,2非常好;
类别(是否给贷款):no代表否,yes代表是。
# 数据集
def createDataSet():
dataSet = [[0, 0, 0, 0, 'no'], # 数据集
[0, 0, 0, 1, 'no'],
[0, 1, 0, 1, 'yes'],
[0, 1, 1, 0, 'yes'],
[0, 0, 0, 0, 'no'],
[1, 0, 0, 0, 'no'],
[1, 0, 0, 1, 'no'],
[1, 1, 1, 1, 'yes'],
[1, 0, 1, 2, 'yes'],
[1, 0, 1, 2, 'yes'],
[2, 0, 1, 2, 'yes'],
[2, 0, 1, 1, 'yes'],
[2, 1, 0, 1, 'yes'],
[2, 1, 0, 2, 'yes'],
[2, 0, 0, 0, 'no']]
labelSet = ['yes', 'no'] #标签集合
return dataSet, labelSet # 返回数据集和标签集合
1、计算信息熵
from math import log
import numpy as np
#求解信息熵,其实只需要最后一列labels数据即可
def calcShannonEnt(dataSet):
labels = np.array([row[-1] for row in dataSet]) #获取标签的数组 [no,no,yes.....,no]
numData = len(labels) #样本数
labelSet = set(labels) #创建set集合{},元素不可重复
shannonEnt = 0.0 #信息熵
for lbl in labelSet:
p = float((labels == lbl).sum()) / numData #某个标签的概率
shannonEnt -= p * log(p, 2)
return shannonEnt
if __name__ == '__main__':
dataSet, labelSet = createDataSet()
print(calcShannonEnt(dataSet, labelSet)) #0.9709505944546686
2、选择增益最大的特征
def splitDataSet(dataSet, axis, value):
retDataSet = [] #创建返回的数据集列表
for featVec in dataSet: #遍历数据集
if featVec[axis] == value:
reducedFeatVec = featVec[:axis] #去掉axis特征
reducedFeatVec.extend(featVec[axis+1:]) #将符合条件的添加到返回的数据集
retDataSet.append(reducedFeatVec)
return retDataSet #返回划分后的数据集
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1 #特征数量
baseEntropy = calcShannonEnt(dataSet) #计算数据集的香农熵
bestInfoGain = 0.0 #信息增益
bestFeature = -1 #最优特征的索引值
for i in range(numFeatures): #遍历所有特征
#获取dataSet的第i个所有特征
featList = [example[i] for example in dataSet]
uniqueVals = set(featList) #创建set集合{},元素不可重复
newEntropy = 0.0 #经验条件熵
for value in uniqueVals: #计算信息增益
subDataSet = splitDataSet(dataSet, i, value) #subDataSet划分后的子集
prob = len(subDataSet) / float(len(dataSet)) #计算子集的概率
newEntropy += prob * calcShannonEnt(subDataSet) #根据公式计算经验条件熵
infoGain = baseEntropy - newEntropy #信息增益
print("第%d个特征的增益为%.3f" % (i, infoGain)) #打印每个特征的信息增益
if (infoGain > bestInfoGain): #计算信息增益
bestInfoGain = infoGain #更新信息增益,找到最大的信息增益
bestFeature = i #记录信息增益最大的特征的索引值
return bestFeature #返回信息增益最大的特征的索引值
if __name__ == '__main__':
dataSet, labelSet = createDataSet()
print(calcShannonEnt(dataSet)) #0.9709505944546686
print("最优特征索引值:" + str(chooseBestFeatureToSplit(dataSet)))
3、dd
labelSet = ['yes', 'no']
return dataSet, labelSet