Chinese Spam Email Classification Based on Naive Bayes
May 6, 2020 · 897 words · 2 min · ML
Chinese Spam Email Classification Based on Naive Bayes
Training and Testing Data
This project primarily uses open-source data on GitHub.
Data Processing
First, we use regular expressions to filter the content of Chinese emails in the training set, removing all non-Chinese characters. The remaining content is then tokenized using jieba for word segmentation, and stopwords are filtered using a Chinese stopword list. The processed results for spam and normal emails are stored separately.
Two dictionaries, spam_voca
and normal_voca
, are used to store the word frequencies of different terms in different emails. The data processing is then complete.
Training and Prediction
The training and prediction process involves calculating the probability $P(Spam|word_1, word_2, \dots, word_n)$. When this probability exceeds a certain threshold, the email is classified as spam.
Based on the conditional independence assumption of Naive Bayes, and assuming the prior probability $P(s) = P(s’) = 0.5$, we have:
$P(s|w_1, w_2, \dots, w_n) = \frac{P(s, w_1, w_2, \dots, w_n)}{P(w_1, w_2, \dots, w_n)}$
$= \frac{P(w_1, w_2, \dots, w_n | s) P(s)}{P(w_1, w_2, \dots, w_n)} = \frac{P(w_1, w_2, \dots, w_n | s) P(s)}{P(w_1, w_2, \dots, w_n | s) \cdot p(s) + P(w_1, w_2, \dots, w_n | s’) \cdot p(s’)} $
Since $P(spam) = P(not\ spam)$, we have
$\frac{\prod\limits_{j=1}^n P(w_j | s)}{\prod\limits_{j=1}^n P(w_j | s) + \prod\limits_{j=1}^n P(w_j | s’)}$
Further, using Bayes’ theorem $P(w_j | s) = \frac{P(s | w_j) \cdot P(w_j)}{P(s)}$, the expression becomes
$\frac{\prod\limits_{j=1}^n P(s | w_j)}{\prod\limits_{j=1}^n P(s | w_j) + \prod\limits_{j=1}^n P(s’ | w_j)}$
Process details:
- For each email in the test set, perform the same processing, and calculate the top $n$ words with the highest $P(s|w)$. During calculation, if a word appears only in the spam dictionary, set $P(w | s’) = 0.01$; similarly, if a word appears only in the normal dictionary, set $P(w | s) = 0.01$. If the word appears in neither, set $P(s|w) = 0.4$. These assumptions are based on prior research.
- Use the 15 most important words for each email and calculate the probability using the above formulas. If the probability is greater than the threshold $\alpha$ (typically set to 0.9), classify it as spam; otherwise, classify it as a normal email.
You can refer to the code for further details.
Results
By adjusting the number of words used for prediction, the best result for this dataset is:
Selected 29 words: 0.9642857142857143
Project Structure
- data
中文停用词表.txt
(Chinese stopword list)normal
(folder for normal emails)spam
(folder for spam emails)test
(folder for test emails)
- main.py (main script)
- normal_voca.json (JSON file for normal email vocabulary)
- pycache (cache folder)
utils.cpython-36.pyc
- spam_voca.json (JSON file for spam email vocabulary)
- utils.py (utility functions)
Code
# utils.py
import jieba
import numpy
import re
import os
import json
from collections import defaultdict
spam_file_num = 7775
normal_file_num = 7063
# Load stopword list
def get_stopwords():
return [i.strip() for i in open('./data/中文停用词表.txt', encoding='gbk')]
# Read raw email content and process it
def get_raw_str_list(path):
stop_list = get_stopwords()
with open(path, encoding='gbk') as f:
raw_str = f.read()
pattern = '[^\u4E00-\u9FA5]' # Chinese unicode range
regex = re.compile(pattern)
handled_str = re.sub(pattern, '', raw_str)
str_list = [word for word in jieba.cut(handled_str) if word not in stop_list]
return str_list
# Build vocabulary
def get_voca(path, is_file_path=False):
if is_file_path:
return read_voca_from_file(path)
voca = defaultdict(int)
file_list = [file for file in os.listdir(path)]
for file in file_list:
raw_str_list = get_raw_str_list(path + str(file))
for raw_str in raw_str_list:
voca[raw_str] = voca[raw_str] + 1
return voca
# Save vocabulary to JSON file
def save_voca2json(voca, path, sort_by_value=False, indent_=4):
if sort_by_value:
sorted_by_value(voca)
with open(path, 'w+') as f:
f.write(json.dumps(voca, ensure_ascii=False, indent=indent_))
# Read vocabulary from JSON file
def read_voca_from_file(path):
with open(path) as f:
voca = json.load(f)
return voca
# Sort dictionary by value
def sorted_by_value(_dict):
_dict = dict(sorted(spam_voca.items(), key=lambda x: x[1], reverse=True))
# Calculate P(Spam|word)
def get_top_words_prob(path, spam_voca, normal_voca, words_size=30):
critical_words = []
for word in get_raw_str_list(path):
if word in spam_voca.keys() and word in normal_voca.keys():
p_w_s = spam_voca[word] / spam_file_num
p_w_n = normal_voca[word] / normal_file_num
p_s_w = p_w_s / (p_w_n + p_w_s)
elif word in spam_voca.keys() and word not in normal_voca.keys():
p_w_s = spam_voca[word] / spam_file_num
p_w_n = 0.01
p_s_w = p_w_s / (p_w_n + p_w_s)
elif word not in spam_voca.keys() and word in normal_voca.keys():
p_w_s = 0.01
p_w_n = normal_voca[word] / normal_file_num
p_s_w = p_w_s / (p_w_n + p_w_s)
else:
p_s_w = 0.4
critical_words.append([word, p_s_w])
return dict(sorted(critical_words[:words_size], key=lambda x: x[1], reverse=True))
# Calculate Bayesian probability
def caculate_bayes(words_prob, spam_voca, normal_voca):
p_s_w = 1
p_s_nw = 1
for word, prob in words_prob.items():
p_s_w *= prob
p_s_nw *= (1 - prob)
return p_s_w / (p_s_w + p_s_nw)
def predict(bayes, threshold=0.9):
return bayes >= threshold
# Get files and labels
def get_files_labels(dir_path, is_spam=True):
raw_files_list = os.listdir(dir_path)
files_list = [dir_path + file for file in raw_files_list]
labels = [is_spam for _ in range(len(files_list))]
return files_list, labels
# Predict and print results
def predict_result(file_list, y, spam_voca, normal_voca, word_size=30):
ret = []
right = 0
for file in file_list:
words_prob = get_top_words_prob(file, spam_voca, normal_voca, words_size=word_size)
bayes = caculate_bayes(words_prob, spam_voca, normal_voca)
ret.append(predict(bayes))
for i in range(len(ret)):
if ret[i] == y[i]:
right += 1
print(right / len(y))
# main.py
from utils import *
if __name__ == '__main__':
# Get vocabulary and save for future use
spam_voca = get_voca('./spam_voca.json', is_file_path=True)
normal_voca = get_voca('./normal_voca.json', is_file_path=True)
save