本文已参与「新人创作礼」活动,一起开启掘金创作之路。
训练集下载地址: guaik.github.io/2021/01/29/…
简介
朴素贝叶斯法(Naive Bayes model)是基于贝叶斯定理与特征条件独立假设的分类方法。主要应用在文本分类,垃圾邮件识别。
Cargo.toml
[package]
name = "bayes"
version = "0.1.0"
authors = ["Rick.Gu <rick@guaik.io>"]
edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
ndarray = "0.14.0"
jieba-rs = "0.6"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
# [lib]
# name = "bayes"
# crate-type = ["staticlib"]
bayes.rs
//! 朴素贝叶斯算法
use jieba_rs::Jieba;
use ndarray::{self, arr1, Array1, Array2};
use std::collections::HashSet;
use std::fs;
use std::io::Write;
use serde::{Deserialize, Serialize};
use serde_json;
use crate::utils;
#[derive(Deserialize, Serialize)]
pub struct BayesModel {
init_val: f64, // 所有词的初始化出现数
denom: f64, // 默认分母
p0: Vec<f64>, // 正确概率数组
p1: Vec<f64>, // 错误概率数组
p0_denom: f64,
p1_denom: f64,
pabusive: f64, // 错误率
vocabs: Vec<String>,
}
pub struct Bayes {
init_val: f64, // 所有词的初始化出现数
denom: f64, // 默认分母
p0: Array1<f64>, // 正确概率数组
p1: Array1<f64>, // 错误概率数组
p0_denom: f64,
p1_denom: f64,
pabusive: f64, // 错误率
vocabs: HashSet<String>,
jieba: Jieba,
}
impl Bayes {
/// 创建朴素贝叶斯实例
/// # 参数
/// ## init_val: f64
/// 向量的默认值,如果小于0.0,将被设置为1.0
/// ## denom: f64
/// 默认分母,如果小于等于0.0,将被设置为2.0
///
/// # 返回值:Bayes
/// 返回素朴贝叶斯实例
pub fn new(init_val: f64, denom: f64) -> Bayes {
let mut init_val = init_val;
let mut denom = denom;
if init_val < 0.0 {
init_val = 1.0;
}
if denom <= 0.0 {
denom = 2.0;
}
Bayes {
init_val,
denom,
p0: arr1(&[]),
p1: arr1(&[]),
p0_denom: 0.0,
p1_denom: 0.0,
pabusive: 0.0,
vocabs: HashSet::new(),
jieba: Jieba::new(),
}
}
/// 训练模型
/// # 参数
/// ## data_set: &str
/// 该参数为文件路径,带训练的数据集,该数据集需要进行预处理将其中的空格全部去除,然后使用空格将标签与内容进行分隔,如下所示:
///
/// > 0 这是一条正常短信
/// > 1 这是一条垃圾短信
///
/// # 返回值:Result
/// 成功时返回空元组,失败时返回错误提示信息
pub fn train(&mut self, data_set: &str) -> Result<(), &'static str> {
let mut classifys: Vec<f64> = Vec::new(); // 分类
let mut texts: Vec<Vec<String>> = Vec::new();
// 加载数据
if let Ok(lines) = utils::read_lines(data_set) {
for line in lines {
if let Ok(v) = line {
let vs: Vec<&str> = v.split(' ').collect();
// 获取分类
if let Some(c) = vs.get(0) {
if let Ok(c1) = c.parse() {
classifys.push(c1);
} else {
return Err("classify format failed");
}
} else {
return Err("get classify failed");
}
// 获取内容
if let Some(t) = vs.get(1) {
let t1: Vec<String> = self
.jieba
.cut(t, false)
.into_iter()
.map(String::from)
.collect();
texts.push(t1);
} else {
return Err("get text failed");
}
}
}
} else {
return Err("read data_set failed");
}
// -- 加载数据
// 获取唯一词汇
for doc in &texts {
// let doc: HashSet<&String> = doc.iter().collect();
self.vocabs = self
.vocabs
.union(&doc.iter().map(String::from).collect())
.into_iter()
.map(|x| x.to_string())
.collect();
}
// println!("{:?}", self.vocabs);
// 将文本数据转为对应存在的数字向量
// 根据数据内容初始化2D数组
let mut vocab_of_arr = Array2::from_elem((texts.len(), self.vocabs.len()), 0.0);
for (line, doc) in (&texts).iter().enumerate() {
for w in doc {
if let Some(col) = self.vocabs.iter().position(|x| x == w) {
vocab_of_arr.row_mut(line)[col] = 1.0;
}
}
}
// println!("{:?}", vocab_of_arr);
// 将分类数据存储array
let classifys = match Array1::from_shape_vec((&classifys).len(), classifys) {
Ok(vec) => vec,
Err(_) => return Err("gen classifys array failed"),
};
// println!("{:?}", classifys);
// 计算错误占比
self.pabusive = classifys.sum() as f64 / classifys.len() as f64;
self.p0 = Array1::from_elem(self.vocabs.len(), self.init_val);
self.p1 = Array1::from_elem(self.vocabs.len(), self.init_val);
self.p0_denom = self.denom;
self.p1_denom = self.denom;
for (i, classify) in classifys.iter().enumerate() {
if *classify == 1.0 {
self.p1 += &vocab_of_arr.row(i);
self.p1_denom += &vocab_of_arr.row(i).sum();
} else {
self.p0 += &vocab_of_arr.row(i);
self.p0_denom += &vocab_of_arr.row(i).sum();
}
}
let calc = |z: &Array1<f64>, d: f64| -> Array1<f64> {
z.iter().map(|x| (x / d).log(std::f64::consts::E)).collect()
};
self.p0 = calc(&self.p0, self.p0_denom);
self.p1 = calc(&self.p1, self.p1_denom);
Ok(())
}
/// 分类器
/// # 参数
/// ## text: &str
/// 待分类的文本字符串
///
/// # 返回值:Result
/// 如果分类成功,则返回Ok变体。如果分类失败,则返回Err变体。
/// Ok变体中如果p1 > p0则为true,否则为false。
pub fn classify(&self, text: &str) -> Result<bool, &'static str> {
if text == "" {
return Err("text ie empty");
}
let text_vec = self.jieba.cut(text, false);
let mut train_vec = Array1::from_elem(self.vocabs.len(), 0.0);
for w in &text_vec {
if let Some(col) = self.vocabs.iter().position(|x| x == *w) {
train_vec[col] = 1.0;
}
}
let p0: f64 =
(&train_vec * &self.p0).sum() + (1.0 - self.pabusive.log(std::f64::consts::E));
let p1: f64 = (&train_vec * &self.p1).sum() + (self.pabusive.log(std::f64::consts::E));
Ok(p1 > p0)
}
/// 保存训练模型
pub fn save_model(&self, model_name: &str) -> Result<(), &'static str> {
// 转换数据类型
let vocabs: Vec<String> = self.vocabs.iter().map(|x| x.to_string()).collect();
let p0: Vec<f64> = self.p0.to_vec();
let p1: Vec<f64> = self.p1.to_vec();
// json对象
let model = BayesModel {
init_val: self.init_val,
denom: self.denom,
p0,
p1,
p0_denom: self.p0_denom,
p1_denom: self.p1_denom,
pabusive: self.pabusive,
vocabs,
};
let json_str = match serde_json::to_string(&model) {
Ok(v) => v,
Err(_) => return Err("model to json string failed"),
};
let mut file = match fs::File::create(model_name) {
Ok(f) => f,
Err(_) => return Err("create model file failed"),
};
file.write_all(json_str.as_bytes())
.map_err(|_| "write model failed")?;
Ok(())
}
/// 加载模型
pub fn load_model(&mut self, model_path: &str) -> Result<(), &'static str> {
let model_str = match fs::read_to_string(model_path) {
Ok(v) => v,
Err(_) => return Err("read model failed"),
};
let model: BayesModel = match serde_json::from_str(&model_str) {
Ok(v) => v,
Err(_) => return Err("json to model failed"),
};
self.p0 = match Array1::from_shape_vec(model.p0.len(), model.p0) {
Ok(v) => v,
Err(_) => return Err("set p0 failed"),
};
self.p1 = match Array1::from_shape_vec(model.p1.len(), model.p1) {
Ok(v) => v,
Err(_) => return Err("set p1 failed"),
};
self.vocabs = model.vocabs.iter().map(|x| x.to_string()).collect();
self.pabusive = model.pabusive;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_train_and_save() {
let mut b = Bayes::new(1.0, 2.0);
b.train("sms_data.txt").unwrap();
b.save_model("bayes_model.json").unwrap();
}
#[test]
fn test_load_and_classify() {
let mut b = Bayes::new(1.0, 2.0);
b.load_model("bayes_model.json").unwrap();
println!("star");
println!("{}", b.classify("验证码417520,用于注册/登录,10分钟内有效。验证码提供给他人可能导致账号被盗,请勿泄漏,谨防被骗。 ").unwrap());
println!("end");
}
}