Rust实现朴素贝叶斯算法识别垃圾短信

551 阅读3分钟

本文已参与「新人创作礼」活动,一起开启掘金创作之路。

Image description

训练集下载地址: guaik.github.io/2021/01/29/…

简介

朴素贝叶斯法(Naive Bayes model)是基于贝叶斯定理与特征条件独立假设的分类方法。主要应用在文本分类,垃圾邮件识别。

Cargo.toml

[package]
name = "bayes"
version = "0.1.0"
authors = ["Rick.Gu <rick@guaik.io>"]
edition = "2018"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
ndarray = "0.14.0"
jieba-rs = "0.6"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"

# [lib]
# name = "bayes"
# crate-type = ["staticlib"]

bayes.rs

//! 朴素贝叶斯算法

use jieba_rs::Jieba;
use ndarray::{self, arr1, Array1, Array2};
use std::collections::HashSet;
use std::fs;
use std::io::Write;

use serde::{Deserialize, Serialize};
use serde_json;

use crate::utils;

#[derive(Deserialize, Serialize)]
pub struct BayesModel {
    init_val: f64, // 所有词的初始化出现数
    denom: f64,    // 默认分母
    p0: Vec<f64>,  // 正确概率数组
    p1: Vec<f64>,  // 错误概率数组
    p0_denom: f64,
    p1_denom: f64,
    pabusive: f64, // 错误率
    vocabs: Vec<String>,
}

pub struct Bayes {
    init_val: f64,   // 所有词的初始化出现数
    denom: f64,      // 默认分母
    p0: Array1<f64>, // 正确概率数组
    p1: Array1<f64>, // 错误概率数组
    p0_denom: f64,
    p1_denom: f64,
    pabusive: f64, // 错误率
    vocabs: HashSet<String>,
    jieba: Jieba,
}

impl Bayes {
    /// 创建朴素贝叶斯实例
    /// # 参数
    /// ## init_val: f64
    /// 向量的默认值,如果小于0.0,将被设置为1.0
    /// ## denom: f64
    /// 默认分母,如果小于等于0.0,将被设置为2.0
    ///
    /// # 返回值:Bayes
    /// 返回素朴贝叶斯实例
    pub fn new(init_val: f64, denom: f64) -> Bayes {
        let mut init_val = init_val;
        let mut denom = denom;

        if init_val < 0.0 {
            init_val = 1.0;
        }
        if denom <= 0.0 {
            denom = 2.0;
        }
        Bayes {
            init_val,
            denom,
            p0: arr1(&[]),
            p1: arr1(&[]),
            p0_denom: 0.0,
            p1_denom: 0.0,
            pabusive: 0.0,
            vocabs: HashSet::new(),
            jieba: Jieba::new(),
        }
    }

    /// 训练模型
    /// # 参数
    /// ## data_set: &str
    /// 该参数为文件路径,带训练的数据集,该数据集需要进行预处理将其中的空格全部去除,然后使用空格将标签与内容进行分隔,如下所示:
    ///
    /// > 0 这是一条正常短信
    /// > 1 这是一条垃圾短信
    ///
    /// # 返回值:Result
    /// 成功时返回空元组,失败时返回错误提示信息
    pub fn train(&mut self, data_set: &str) -> Result<(), &'static str> {
        let mut classifys: Vec<f64> = Vec::new(); // 分类
        let mut texts: Vec<Vec<String>> = Vec::new();

        // 加载数据
        if let Ok(lines) = utils::read_lines(data_set) {
            for line in lines {
                if let Ok(v) = line {
                    let vs: Vec<&str> = v.split(' ').collect();
                    // 获取分类
                    if let Some(c) = vs.get(0) {
                        if let Ok(c1) = c.parse() {
                            classifys.push(c1);
                        } else {
                            return Err("classify format failed");
                        }
                    } else {
                        return Err("get classify failed");
                    }
                    // 获取内容
                    if let Some(t) = vs.get(1) {
                        let t1: Vec<String> = self
                            .jieba
                            .cut(t, false)
                            .into_iter()
                            .map(String::from)
                            .collect();
                        texts.push(t1);
                    } else {
                        return Err("get text failed");
                    }
                }
            }
        } else {
            return Err("read data_set failed");
        }
        // -- 加载数据

        // 获取唯一词汇
        for doc in &texts {
            // let doc: HashSet<&String> = doc.iter().collect();
            self.vocabs = self
                .vocabs
                .union(&doc.iter().map(String::from).collect())
                .into_iter()
                .map(|x| x.to_string())
                .collect();
        }
        // println!("{:?}", self.vocabs);
        // 将文本数据转为对应存在的数字向量
        // 根据数据内容初始化2D数组
        let mut vocab_of_arr = Array2::from_elem((texts.len(), self.vocabs.len()), 0.0);
        for (line, doc) in (&texts).iter().enumerate() {
            for w in doc {
                if let Some(col) = self.vocabs.iter().position(|x| x == w) {
                    vocab_of_arr.row_mut(line)[col] = 1.0;
                }
            }
        }
        // println!("{:?}", vocab_of_arr);
        // 将分类数据存储array
        let classifys = match Array1::from_shape_vec((&classifys).len(), classifys) {
            Ok(vec) => vec,
            Err(_) => return Err("gen classifys array failed"),
        };
        // println!("{:?}", classifys);
        // 计算错误占比
        self.pabusive = classifys.sum() as f64 / classifys.len() as f64;

        self.p0 = Array1::from_elem(self.vocabs.len(), self.init_val);
        self.p1 = Array1::from_elem(self.vocabs.len(), self.init_val);
        self.p0_denom = self.denom;
        self.p1_denom = self.denom;

        for (i, classify) in classifys.iter().enumerate() {
            if *classify == 1.0 {
                self.p1 += &vocab_of_arr.row(i);
                self.p1_denom += &vocab_of_arr.row(i).sum();
            } else {
                self.p0 += &vocab_of_arr.row(i);
                self.p0_denom += &vocab_of_arr.row(i).sum();
            }
        }
        let calc = |z: &Array1<f64>, d: f64| -> Array1<f64> {
            z.iter().map(|x| (x / d).log(std::f64::consts::E)).collect()
        };
        self.p0 = calc(&self.p0, self.p0_denom);
        self.p1 = calc(&self.p1, self.p1_denom);
        Ok(())
    }

    /// 分类器
    /// # 参数
    /// ## text: &str
    /// 待分类的文本字符串
    ///
    /// # 返回值:Result
    /// 如果分类成功,则返回Ok变体。如果分类失败,则返回Err变体。
    /// Ok变体中如果p1 > p0则为true,否则为false。
    pub fn classify(&self, text: &str) -> Result<bool, &'static str> {
        if text == "" {
            return Err("text ie empty");
        }
        let text_vec = self.jieba.cut(text, false);
        let mut train_vec = Array1::from_elem(self.vocabs.len(), 0.0);
        for w in &text_vec {
            if let Some(col) = self.vocabs.iter().position(|x| x == *w) {
                train_vec[col] = 1.0;
            }
        }
        let p0: f64 =
            (&train_vec * &self.p0).sum() + (1.0 - self.pabusive.log(std::f64::consts::E));
        let p1: f64 = (&train_vec * &self.p1).sum() + (self.pabusive.log(std::f64::consts::E));
        Ok(p1 > p0)
    }

    /// 保存训练模型
    pub fn save_model(&self, model_name: &str) -> Result<(), &'static str> {
        // 转换数据类型
        let vocabs: Vec<String> = self.vocabs.iter().map(|x| x.to_string()).collect();
        let p0: Vec<f64> = self.p0.to_vec();
        let p1: Vec<f64> = self.p1.to_vec();

        // json对象
        let model = BayesModel {
            init_val: self.init_val,
            denom: self.denom,
            p0,
            p1,
            p0_denom: self.p0_denom,
            p1_denom: self.p1_denom,
            pabusive: self.pabusive,
            vocabs,
        };
        let json_str = match serde_json::to_string(&model) {
            Ok(v) => v,
            Err(_) => return Err("model to json string failed"),
        };
        let mut file = match fs::File::create(model_name) {
            Ok(f) => f,
            Err(_) => return Err("create model file failed"),
        };
        file.write_all(json_str.as_bytes())
            .map_err(|_| "write model failed")?;
        Ok(())
    }

    /// 加载模型
    pub fn load_model(&mut self, model_path: &str) -> Result<(), &'static str> {
        let model_str = match fs::read_to_string(model_path) {
            Ok(v) => v,
            Err(_) => return Err("read model failed"),
        };
        let model: BayesModel = match serde_json::from_str(&model_str) {
            Ok(v) => v,
            Err(_) => return Err("json to model failed"),
        };

        self.p0 = match Array1::from_shape_vec(model.p0.len(), model.p0) {
            Ok(v) => v,
            Err(_) => return Err("set p0 failed"),
        };

        self.p1 = match Array1::from_shape_vec(model.p1.len(), model.p1) {
            Ok(v) => v,
            Err(_) => return Err("set p1 failed"),
        };

        self.vocabs = model.vocabs.iter().map(|x| x.to_string()).collect();

        self.pabusive = model.pabusive;

        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_train_and_save() {
        let mut b = Bayes::new(1.0, 2.0);
        b.train("sms_data.txt").unwrap();
        b.save_model("bayes_model.json").unwrap();
    }

    #[test]
    fn test_load_and_classify() {
        let mut b = Bayes::new(1.0, 2.0);
        b.load_model("bayes_model.json").unwrap();
        println!("star");
        println!("{}", b.classify("验证码417520,用于注册/登录,10分钟内有效。验证码提供给他人可能导致账号被盗,请勿泄漏,谨防被骗。 ").unwrap());
        println!("end");
    }
}