基于朴素贝叶斯的中文垃圾电子邮件分类

训练数据和测试数据

本次主要使用了github上的开源数据

数据处理

首先使用正则表达式对训练集中的中文邮件的内容进行过滤,去除其中中文之外的其他字符,再将剩下的内容使用jieba进行分词,过滤掉中文停用字表中的内容。将垃圾邮件以及正常邮件的处理结果分别存放。

用两个字典spam_voca和normal_voca用于保存不同邮件中不同词语的词频,数据处理完成。

训练与预测

训练和预测的过程主要使计算$ P(Spam|word_1,word_2,…,word_n) $的概率,当这个概率大于某个阈值的时候,这个邮件被分类为垃圾邮件。

根据朴素贝叶斯的条件独立假设,并设先验概率$P(s)=P(s^{’})=0.5$,则有: P(s|w_1,w_2,…,w_n)=$\frac {P(s,w_1,w_2,…,w_n)}{P(w_1,w_2,…,w_n)}$

=$\frac {P(w_1,w_2,…,w_n|s)P(s)}{P(w_1,w_2,…,w_n)}$=$\frac {P(w_1,w_2,…,w_n|s)P(s)}{P(w_1,w_2,…,w_n|s)\cdot p(s)+P(w_1,w_2,…,w_n|s^{’})\cdot p(s^{’})}\qquad\qquad$

因为P(spam)=P(not spam),则有

$\frac {\prod\limits_{j=1}^nP(w_j|s)}{\prod\limits_{j=1}^nP(w_j|s)+\prod\limits_{j=1}^nP(w_j|s^{’})}$ 再利用贝叶斯$P(w_j|s)=\frac{P(s|w_j)\cdot P(w_j)}{P(s)}$,式子化为 $\frac {\prod\limits_{j=1}^nP(s|w_j)}{\prod\limits_{j=1}^nP(s|w_j)+\prod\limits_{j=1}^nP(s^{’}|w_j)}$

具体流程

  • 对测试集中的每一封邮件做同样的处理,并计算得到$P(s|w)$最高的n个词,在计算过程中,若该词只出现在垃圾邮件的词典中,则令$P(w|s^{’})=0.01$,反之亦然;若都未出现,则令$P(s|w)=0.4$。PS.这里做的几个假设基于前人做的一些研究工作得出的。
  • 对得到的每封邮件中重要的15个词利用式2计算概率,若概率$>$阈值$\alpha(一般设为0.9)$,则判为垃圾邮件,否则判为正常邮件。

具体的代码细节可以参见附加的代码。

结果

通过调整预测是使用的词汇的数量,针对本次数据集,最佳的结果是

1
选择29个词汇 : 0.9642857142857143

附加

  • 项目结构 ├── data │ ├── 中文停用词表.txt │ ├── normal │ ├── spam │ └── test ├── main.py ├── normal_voca.json ├── pycache │ └── utils.cpython-36.pyc ├── spam_voca.json └── utils.py

  • 具体代码

      1
      2
      3
      4
      5
      6
      7
      8
      9
     10
     11
     12
     13
     14
     15
     16
     17
     18
     19
     20
     21
     22
     23
     24
     25
     26
     27
     28
     29
     30
     31
     32
     33
     34
     35
     36
     37
     38
     39
     40
     41
     42
     43
     44
     45
     46
     47
     48
     49
     50
     51
     52
     53
     54
     55
     56
     57
     58
     59
     60
     61
     62
     63
     64
     65
     66
     67
     68
     69
     70
     71
     72
     73
     74
     75
     76
     77
     78
     79
     80
     81
     82
     83
     84
     85
     86
     87
     88
     89
     90
     91
     92
     93
     94
     95
     96
     97
     98
     99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    173
    174
    175
    176
    177
    178
    179
    180
    181
    182
    183
    184
    185
    186
    187
    188
    189
    190
    191
    192
    
    # utils.py
    import jieba
    import numpy
    import re
    import os
    import json
    from collections import defaultdict
    
    spam_file_num=7775 
    normal_file_num=7063
    
    #todo 原始邮件数据的预处理
    '''
    获取中文的停用词表
    【停用词】(https://zh.wikipedia.org/wiki/%E5%81%9C%E7%94%A8%E8%AF%8D)
    '''
    
    def get_stopwords():
        return [
            i.strip() for i in open('./data/中文停用词表.txt', encoding='gbk')
        ]
    
    
    '''
        读取原始文件,进行简单的处理和分词,返回分词列表
    
    '''
    
    
    def get_raw_str_list(path):
        stop_list = get_stopwords()
        with open(path, encoding='gbk') as f:
            raw_str = f.read()
        pattern = '[^\u4E00-\u9FA5]'  #中文的unicode编码的范围
        regex = re.compile(pattern)
        ret = regex.findall(raw_str)
        handled_str = re.sub(pattern, '', raw_str)
        str_list = [
            word for word in jieba.cut(handled_str) if word not in stop_list
        ]
        return str_list
    
    #分词以及统计词频,得到词汇表
    '''
    返回一个字典,记录了分词后的词汇表
    path:k可以使dir或者file路径,默认为存放文本的dir
    is_file_path:表示输入路径是不是文件
    
    '''
    
    
    def get_voca(path, is_file_path=False):
        if is_file_path:
            return read_voca_from_file(path)
    
        voca = defaultdict(int)
        # 获取垃圾邮件列表
        file_list = [file for file in os.listdir(path)]
        #获得词汇表
        for file in file_list:  #测试用
            raw_str_list = get_raw_str_list(path + str(file))
            for raw_str in raw_str_list:
                voca[raw_str] = voca[raw_str] + 1
    
        return voca
    
    
    
    '''
        将得到的词汇表保存到json文件中,以便下次读取
        voca:词汇表字典
        path:文件的保存路径
        sort_by_value:是否按value值排序
    
    '''
    
    
    def save_voca2json(voca, path, sort_by_value=False,indent_=4):
        if sort_by_value:
            sorted_by_value(voca)
    
        with open(path, 'w+') as f:
            f.write(json.dumps(voca, ensure_ascii=False, indent=indent_))
    
    
    '''
    从json文件中读取voca
    '''
    
    def read_voca_from_file(path):
        voca = None
        with open(path) as f:
            voca = json.load(f)
        return voca
    
    
    '''
    将字典基于value排序
    '''
    
    def sorted_by_value(_dict):
        _dict = dict(sorted(spam_voca.items(), key=lambda x: x[1], reverse=True))
    
    
    #计算 P(Spam|word)
    
    '''
        计算邮件和邮件分类最相关词语及其 P(spam|word)
        words_size:最终使用词语的数量,用于最后的预测
    
    '''
    
    def get_top_words_prob(path,spam_voca,normal_voca,words_size=30):
        critical_words=[]
        for word in get_raw_str_list(path):
            if word in spam_voca.keys() and word in normal_voca.keys():
                # 如果word在两边都出现
                p_w_s=spam_voca[word]/spam_file_num
                p_w_n=normal_voca[word]/normal_file_num
                p_s_w=p_w_s/(p_w_n+p_w_s)
    
            elif word in spam_voca.keys() and word not in normal_voca.keys():
                # 如果word只在spam出现
                p_w_s=spam_voca[word]/spam_file_num
                p_w_n=0.01
                p_s_w=p_w_s/(p_w_n+p_w_s)
    
            elif word not in spam_voca.keys() and word in normal_voca.keys():
                # 如果word只在normal出现
                p_w_s=0.01
                p_w_n=normal_voca[word]/normal_file_num
                p_s_w=p_w_s/(p_w_n+p_w_s)
            else:
                #如果都不出现
                p_s_w=0.4
            #以上设定的数值均是基于前人的研究得到的数据
            critical_words.append([word,p_s_w])
    
        return dict(sorted(critical_words[:words_size],key=lambda x:x[1],reverse=True))
    
    '''
        计算贝叶斯概率
        words_prob:包含词汇以及其P(spam|word)概率的一个字典
        spam_voca:垃圾邮件词汇表
        normal_voca:正常邮件词汇表
    '''
    
    
    def caculate_bayes(words_prob,spam_voca,normal_voca):
        p_s_w=1
        p_s_nw=1
        for word,prob in words_prob.items():
            p_s_w*=prob
            p_s_nw*=(1-prob)
    
        return p_s_w/(p_s_w+p_s_nw)
    
    def predict(bayes,threshold=0.9):
        return bayes>=threshold    
    
    
    '''
        返回文件名和label
    '''
    
    def get_files_labels(dir_path,is_spam=True):
        raw_files_list=os.listdir(dir_path)
        files_list= [dir_path+file for file in raw_files_list]
        labels=[is_spam for i in range(len(files_list))]
        return files_list,labels
    
    
    '''
        预测测试集结果并打印
        file_list:测试集文件路径集合
        y:测试集结果标签
        word_size:预测的指标
    '''
    
    def predict_result(file_list,y,spam_voca,normal_voca,word_size=30):
        ret=[]
        right=0
        for file in file_list:
    #         raw_strs=get_raw_str_list(str(file))
            words_prob=get_top_words_prob(file,spam_voca,normal_voca,words_size=word_size)
            bayes=caculate_bayes(words_prob,spam_voca,normal_voca)
            ret.append(predict(bayes))
        for i in range(len(ret)):
            if ret[i]==y[i]:
                right+=1
    
        print(right/len(y))
    
     1
     2
     3
     4
     5
     6
     7
     8
     9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    
    #main.py
    from utils import *
    
    if __name__=='__main__':
        #获取词典并保存,一遍下次读取
        spam_voca=get_voca('./spam_voca.json',is_file_path=True)
        normal_voca=get_voca('./normal_voca.json',is_file_path=True)
        #保存方便下次读取
        save_voca2json(spam_voca,'./spam_voca.json')
        save_voca2json(normal_voca,'./snormal_voca.json')
        #预处理测试数据
        s_x,s_y=get_files_labels('./data/test/spam/')
        n_x,n_y=get_files_labels('./data/test/normal/',is_spam=False)
        x,y=list(s_x+n_x),s_y+n_y
        #预测结果
    
        for i in range(10,40):
            print(str(i)+' : ',end='')
            predict_result(x,y,spam_voca,normal_voca,word_size=i)
    

源码

github