机器学习——使用python进行朴素贝叶斯训练

816 阅读2分钟

本文已参与「新人创作礼」活动,一起开启掘金创作之路。


朴素贝叶斯算法

它是一个生成模型,通过给定的样本数据,得到在每个样本中每个特征出现的概率,由此来计算联合概率,判断在哪个类别中出现的可能性大,它可用于多分类问题

朴素贝叶斯原理

它是基于条件概率得到的

image.png

训练数据和测试数据

训练数据和测试数据分别保存在TextData.txt和ceshiData.txt文件中。

image.png

TextData.txt

C、J表示标签

image.png

ceshiData.txt

image.png

求解过程

带有C标签的这类中有8个单词,其中包含Chinese5个,Beijing1个,Shanghai1个,Macao1个

带有J标签的这类中有3个单词,其中包含Chinese1个,Tokyo1个,Japan1个

所以整个数据中所有类别的特征是:Chinese、Beijing、Shanghai、Macao、Tokyo、Japan,总共6个特征。

接下来统计C标签、J标签以及6个特征在整个文档中出现的次数

之后采用拉普拉斯平滑计算每一特征在每个类别下出现的概率

最后进行预测

代码展示

导入数据函数:

def loadDataSetx(fileName):
	dataMat=[]
	fr=open(fileName)
	# 一行一行地读取数据
	for line in fr.readlines():
		# 将当前行的数据转化为列表
		curLine=line.strip().split(' ')
		# 将数据提取出来保存到dataMat中
		dataMat.append(curLine)
	return dataMat

训练数据函数:

def train(dataMat):
	# 得到每组数据所属的类型
	allType_yuan=[]
	for i in range(len(dataMat)):
		allType_yuan.append(dataMat[i][-1])

	# 去除类型中重复的
	allType=list(set(allType_yuan))

	allType_pinlv=[]
	for i in range(len(allType)):
		allType_pinlv.append(float(allType_yuan.count(allType[i])/len(allType_yuan)))
	# print(allType_pinlv)

	# 样例中所有的属性
	attribute=[]
	for i in range(len(dataMat)):
		for j in range(1,len(dataMat[i])-1):
			attribute.append(dataMat[i][j])

	attribute=list(set(attribute))

	# 按照类型对原始数据进行划分
	type_dataMat=[]
	for i in range(len(allType)):
		data=[]
		data.append(allType[i])
		for j in range(len(dataMat)):
			if dataMat[j][-1]==allType[i]:
				for k in range(1,len(dataMat[j])-1):
					data.append(dataMat[j][k])
		type_dataMat.append(data)
	# print(type_dataMat)

	# 在每个类别中每个单词出现的频率,采用了拉普拉斯平滑
	attr_pinlv=[]
	for i in range(len(allType)):
		attr_pinlv_1={}
		attr_pinlv_1['Type']=allType[i]
		for j in range(len(attribute)):
			attr_pinlv_1[attribute[j]]=float((type_dataMat[i].count(attribute[j])+1)/(len(attribute)+len(type_dataMat[i])-1))
		attr_pinlv.append(attr_pinlv_1)

	return allType_pinlv,attr_pinlv

测试数据函数:

# 测试数据
def ceshi(fileName,allType_pinlv,attr_pinlv):
	ceshi=loadDataSetx(fileName)
	ceshi_result=[]
	# 第i条测试数据
	for i in range(len(ceshi)):
		ceshi_result_1={}
		for j in range(len(allType_pinlv)):
			pinlv=1
			# print(allType_pinlv[j])
			pinlv*=allType_pinlv[j]
			# print(pinlv)
			yangben_data=attr_pinlv[j]
			# 测试数据中的第j个字符串
			for k in range(1,len(ceshi[i])):
				pinlv*=yangben_data[ceshi[i][k]]
			ceshi_result_1[yangben_data['Type']]=pinlv
		ceshi_result.append(ceshi_result_1)
	print(ceshi_result)
	for i in range(len(ceshi_result)):
		maxPinlv=max(ceshi_result[i].values())
		for key,value in ceshi_result[i].items():
			if value==maxPinlv:
				print('第'+str(i+1)+'组数据预测类别为:'+key)
				break


函数调用:

dataMat=loadDataSetx('TextData.txt')
allType_pinlv,attr_pinlv=train(dataMat)
print(allType_pinlv)
print(attr_pinlv)
ceshi('ceshiData.txt',allType_pinlv,attr_pinlv)

训练结果

image.png