朴素贝叶斯(Naive Bayes)是一类基于贝叶斯定理的简单而有效的分类算法。它假设特征之间是相互独立的,即在给定目标变量的情况下,每个特征都不依赖于其他特征。尽管这个假设在实际中很难成立,朴素贝叶斯在许多场景下仍表现得非常好,特别是对于文本分类等高维数据的应用。
1. 贝叶斯定理
贝叶斯定理表明给定一个事件发生的条件下另一个事件发生的概率:
-
P(A∣B) 是在已知B的情况下A发生的概率(后验概率)。
-
P(B∣A)是在已知A的情况下B发生的概率(条件概率)。
-
P(A)是A发生的先验概率。
-
P(B)是 B发生的总概率。
在分类问题中,给定一个特征向量,目标是计算每个类别 的后验概率,选取后验概率最大的类别作为预测结果。根据贝叶斯定理:
并且由于分母都是相等的,所以只需要比较分子 来决定分类结果
2. 朴素假设
朴素贝叶斯的“朴素”假设是特征 xi 之间是条件独立的,因此:
最终,分类器的决策规则变为:
3. 朴素贝叶斯的三种主要类型
- 高斯朴素贝叶斯(Gaussian Naive Bayes) :适用于连续数据,并且假设数据分布呈现高斯分布
- 多项式朴素贝叶斯(Multinomial Naive Bayes) :主要用于文本分类,适用于离散特征,例如词频。
- 伯努利朴素贝叶斯(Bernoulli Naive Bayes) :处理二元特征(即每个特征值为0或1的情况)。
4. 贝叶斯估计
在朴素贝叶斯分类中,贝叶斯估计用于避免由于零概率问题而导致模型失效的情况。通过引入伪计数(通常称为拉普拉斯平滑或贝叶斯平滑),可以确保所有的概率都不会为零。
5. 拉普拉斯平滑
拉普拉斯平滑通过在每个类别的词频上加上一个常数(通常为 1)来防止零概率问题。拉普拉斯平滑后的条件概率公式为:
-
是在类别 中,特征 出现的次数
-
是类别中所有特征出现的总次数。
-
是平滑参数,通常取值为 1(拉普拉斯平滑),但也可以取其他值
-
是特征的总数
平滑后,所有的特征都会有一个非零的概率,即使某些特征在某类别中没有出现过
6. 贝叶斯估计公式
在引入平滑后,分类器的决策规则变为:
通过调整 的值,可以控制平滑的强度。在文本分类中,通常会使用较小的平滑参数来防止过度平滑。
sklearn中的朴素贝叶斯实现
在sklearn
中,提供了多个朴素贝叶斯分类器的实现,包括 GaussianNB
、MultinomialNB
和 BernoulliNB
。下面是一个简单的使用MultinomialNB
处理文本分类问题的示例。
- 导入库和加载数据
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import MultinomialNB
from sklearn.metrics import accuracy_score
# 加载新闻组数据
newsgroups = fetch_20newsgroups(subset='train')
X = newsgroups.data
y = newsgroups.target
# 将文本数据转换为词频矩阵
vectorizer = CountVectorizer()
X_counts = vectorizer.fit_transform(X)
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X_counts, y, test_size=0.25, random_state=42)
- 模型训练
# 初始化多项式朴素贝叶斯分类器
nb = MultinomialNB()
# 训练模型
nb.fit(X_train, y_train)
- 模型预测和评估
# 预测
y_pred = nb.predict(X_test)
# 评估模型
accuracy = accuracy_score(y_test, y_pred)
print(f'Accuracy: {accuracy:.2f}')
4. 进一步优化
通过调整超参数(如 alpha
)可以提高模型的性能。例如,在文本分类中,通常会使用拉普拉斯平滑来处理零概率问题。
# 使用拉普拉斯平滑(alpha)
nb = MultinomialNB(alpha=0.5)
nb.fit(X_train, y_train)
# 预测并评估
y_pred = nb.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f'Accuracy with alpha=0.5: {accuracy:.2f}')