简述朴素贝叶斯估计

本文深入探讨了使用贝叶斯方法进行垃圾邮件分类的技术细节,包括特征选择、模型训练与预测过程。通过构建统计词典并应用朴素贝叶斯分类算法,实现对大量邮件的有效分类。实例展示了如何从文本数据中提取关键信息,计算概率以区分垃圾邮件与非垃圾邮件,最终达到高准确率的分类效果。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

第一部分

贝叶斯公式的基本形式为:



第二部分

朴素贝叶斯法的过程:

(1)确定特征属性,架设每个属性之间是相互独立的。

(2)分类器训练阶段:A对每个类别计算P(Yi)

     B对每个特征属性计算所有划分的条件概率P(X|Yi);

(3)分类器训练阶段:以P(Yi)* P(X|Yi)  最大分类项作为X所属的类别。

简要的来说:对于给出的待分类项,求解此项出现的条件下,各个类别出现的条件概率,哪一个最大,就认为此待分类项属于哪个类别。


下面搜集了一个贝叶斯分类方法在垃圾邮件分类上的应用例子,首先谢谢此博主的分享,作为参考,日后如果有用到,再来详细学习。


原理分析


总体思想

利用Naive Bayes(后验概率)计算特征所属空间的概率,取其最大者为判定结果。
如下,其中P表示概率,w表示所属类别。


对于Prior,可用如下公式进行计算:

对于Likelihood中独立同分布的各项概率,可用如下公式计算:

训练

输入为上万封电子邮件内容,包含垃圾邮件/非垃圾邮件。提取邮件内单词,改写为小写单词输入字典,过滤出现1次的单词,过滤长度只有1的单词,过滤出现总次数超过1万次的单词,最后形成我们的统计词典以及垃圾/非垃圾邮件词典。

预测

定义probSPAM = probHAM = 1
输入一封邮件,抽取其单词,对上述词典中的每一个单词进行处理:
1)在垃圾邮件词典中,若单词A 出现在当前邮件中,那么probSPAM *= (垃圾邮件中该单词出现次数)/(垃圾邮件数量);
2)在垃圾邮件词典中,若单词A 没有出现在当前邮件中,那么probSPAM *= [1-(垃圾邮件中该单词出现次数)/(垃圾邮件数量)];
3)在非垃圾邮件词典中,若单词A 出现在当前邮件中,那么probHAM *= (正常邮件中该单词出现次数)/(正常邮件数量);
4)在非垃圾邮件词典中,若单词A 没有出现在当前邮件中,那么probHAM *= [1-(正常邮件中该单词出现次数)/(正常邮件数量)];

完成统计后,两个prob变量分别乘以(对应类别的邮件数)/(所有邮件总数),即 Prior。
比较probSPAM以及probHAM,哪个相对较大就判定为对应空间。

实现

数据以及代码下载地址:

用Python代码实现
其中,计算概率使用了log函数简化乘除法,
计算单个特征点概率时使用了smoothing:

代码如下:
[python]  view plain copy
  1. # -*- coding: utf-8 -*-  
  2. """ 
  3. Created on Sun May 19 18:54:36 2013 
  4.  
  5. @author: rk 
  6. """  
  7. import nltk  
  8. import os  
  9. import math  
  10.   
  11. train_data = "./hw1_data/train/"  
  12. test_data = "./hw1_data/test/"  
  13. ham = "ham/"  
  14. spam = "spam/"  
  15. MAX_NUM = 10000   
  16. K = 2  
  17.   
  18. def sort_by_value(d):   
  19.     return sorted(d.items(), lambda x, y: cmp(x[1], y[1]), reverse = True)  
  20.   
  21. def word_process(word):  
  22. #     lemmatizer = nltk.WordNetLemmatizer()      
  23.     #stop words  
  24.     #lower characters  
  25.     word_low = word.strip().lower()  
  26.     #lemmatize word      
  27.     #word_final = lemmatizer.lemmatize(word_low)  
  28.     word_final = word_low  
  29.     return word_final  
  30.   
  31. def add_to_dict(word, dict_name):  
  32.     if(word in dict_name):  
  33.         num = dict_name[word]  
  34.         num += 1  
  35.         dict_name[word] = num  
  36.     else:  
  37.         dict_name[word] = 1  
  38.   
  39. def negative_dict_maker(dictionary):  
  40.     d = dict()  
  41.     for (key, value) in dictionary.items():  
  42.         if(value >= MAX_NUM or value <= 1):  
  43.             d[key] = 1  
  44.     return d  
  45.   
  46. def text_reader(file_name, dict_name):  
  47.     tokenizer = nltk.RegexpTokenizer("[\w']{2,}")   #leave the word with length > 1  
  48.     f = open(file_name, 'r')  
  49.     for line in f:  
  50.         words = tokenizer.tokenize(line)  
  51.         for word in words:  
  52.             word = word_process(word)  
  53.             add_to_dict(word, dict_name)  
  54.     f.close()  
  55.   
  56. def save_dict(dict_name, file_path, all_flag):  
  57.     f = open(file_path, 'w')#"dict_file.data", 'w')  
  58.     word_max = ""  
  59.     value_max = 0;  
  60.     for (key, value) in dict_name.items():  
  61.         if(not all_flag):  
  62.             if value > 1 and value < MAX_NUM:  
  63.                 f.writelines(key+" "+str(value)+"\n")  
  64.             if value > value_max:  
  65.                 word_max = key  
  66.                 value_max = value  
  67.         else:  
  68.             f.writelines(key+" "+str(value)+"\n")  
  69.             if value > value_max:  
  70.                 word_max = key  
  71.                 value_max = value  
  72.     f.close()  
  73.     print("Save_dict-----> Max_key:"+word_max+", Max_value:"+str(value_max))  
  74.           
  75. def load_dict(file_path):  
  76.     dict_loaded = dict()  
  77.     f = open(file_path, 'r')  
  78.     while 1:  
  79.         line = f.readline()  
  80.         if not line:  
  81.             break  
  82.         words = line.split()  
  83.         dict_loaded[words[0]] = int(words[1])  
  84.     f.close()  
  85.     return dict_loaded  
  86.   
  87. def save_file_number(ham, spam, total):  
  88.     f = open("file_number.data"'w')  
  89.     f.writelines(str(ham)+"\n")  
  90.     f.writelines(str(spam)+"\n")  
  91.     f.writelines(str(total)+"\n")  
  92.     f.close()  
  93.   
  94. #make the master dictionary and calculate the number of ham or spam  
  95. def traverse_dictionary_maker(file_path):#the path is the ham/spam's parent  
  96.     dictionary = dict()  
  97.     ham_path = file_path+ham  
  98.     spam_path = file_path+spam  
  99.     path = {ham_path, spam_path}  
  100.     path_order = 0  
  101.     num_ham = 0  
  102.     num_spam = 0  
  103.     for i in path:  
  104.         folders = os.listdir(i)  
  105.         for file_name in folders:  
  106.             if os.path.isfile(i+file_name):  
  107.                 text_reader(i+file_name, dictionary)  
  108.                 if(path_order == 0):  
  109.                     num_ham += 1  
  110.                 else:  
  111.                     num_spam += 1  
  112.         path_order += 1  
  113.     #initialize  
  114.     save_file_number(num_ham, num_spam, num_ham + num_spam)  
  115.     return dictionary  
  116.   
  117. #create the ham/spam email dictionary  
  118. def dict_creator(file_path, negative_dict):  
  119.     dictionary = load_dict("dict_file.data")  
  120.       
  121.     #initialize the dictionary item value  
  122.     for key in dictionary:  
  123.         dictionary[key] = 0  
  124.     if(not os.path.isfile(file_path)): #if file_path is a folder  
  125.         folders = os.listdir(file_path)  
  126.         for file_name in folders:  
  127.             single_dict = dict()  
  128.             if os.path.isfile(file_path+file_name):  
  129.                 text_reader(file_path+file_name, single_dict)  
  130.             for key in single_dict:  
  131.                 if key not in negative_dict:  
  132.                     num = dictionary[key]  
  133.                     num += 1  
  134.                     dictionary[key] = num  
  135.     else:  
  136.         single_dict = dict()  
  137.         text_reader(file_path, single_dict)  
  138.         for key in single_dict:  
  139. #             if key not in negative_dict:  
  140.             if(key in dictionary):  
  141.                 num = dictionary[key]  
  142.                 num += 1  
  143.                 dictionary[key] = num  
  144.     print("Dict_creator")  
  145.     return dictionary  
  146.       
  147.     #output the file data after process  
  148.   
  149. def save_vector(dictionary, target_path):  
  150.     f = open(target_path, 'w')  
  151.     for (key, value) in dictionary.items():  
  152.         if (value != 0) :  
  153.             f.writelines(key+" "+str(value)+"\n")  
  154.     print("test_dict: "+target_path + " written!")  
  155.     f.close()  
  156.   
  157. def vector_creator(file_path, negative_dict, target_path):  
  158.     dictionary = dict_creator(file_path, negative_dict)  
  159.     save_vector(dictionary, target_path)  
  160.   
  161. def vector_loader(target_path):  
  162.     dictionary = load_dict(target_path)  
  163.     return dictionary  
  164.   
  165.   
  166. def read_w_number():  
  167.     f = open("file_number.data"'r')  
  168.     lines = f.readlines()  
  169.     w_num = [int(lines[0]), int(lines[1]), int(lines[2])]  
  170.     f.close()  
  171.     return w_num  
  172.   
  173. def print_top_twenty(list_name):  
  174.     index = 0  
  175.     while(index < 20):  
  176.         print(list_name[index])  
  177.         index += 1  
  178.   
  179. #to calculate the probability of P(xi|w) with smoothing log function  
  180. def calculate_log_p_xi_w(word, dict_name, n_w, exist_flag):  
  181.     if(exist_flag):  
  182.         result = math.log(dict_name[word]+1)  
  183.     else:  
  184.         result = math.log(n_w+K-dict_name[word]-1)  
  185.     return result  
  186.   
  187. #to calculate the probability of P(x|w) with log function  
  188. def calculate_log_p_x_w(vector, dict_name, n_w, n_t, denominator_all):  
  189.     result = 0.0  
  190.     for (key, value) in dict_name.items():  
  191.         exist_flag = (key in vector)  
  192.         result += calculate_log_p_xi_w(key, dict_name, n_w, exist_flag)  
  193.     result -= denominator_all  
  194.     result += math.log(n_w)  
  195.     result -= math.log(n_t)  
  196.     return result  
  197.   
  198. def predict(file_path, w_num, ham_dict, spam_dict, ham_denominator_all, spam_denominator_all):  
  199.     vector = vector_loader(file_path)      
  200.     prob_ham = calculate_log_p_x_w(vector, ham_dict, w_num[0], w_num[2], ham_denominator_all)  
  201.     prob_spam = calculate_log_p_x_w(vector, spam_dict, w_num[1], w_num[2], spam_denominator_all)  
  202.     if prob_ham > prob_spam :  
  203.         return 0  
  204.     else :  
  205.         return 1  
  206.       
  207.   
  208. file_path = train_data  
  209. dictionary = traverse_dictionary_maker(file_path)  
  210. negative_dict = negative_dict_maker(dictionary) #filter the word with number >= MAX_NUM negative[key]=1  
  211. # print ("negative: "+str(len(negative_dict))) #39624  
  212. save_dict(dictionary, "dict_file.data"False)   #include the filtering of number of words  
  213. dictionary = load_dict("dict_file.data")  
  214. w_num = read_w_number() #the number of ham emails and spam emails  
  215. #==================================================================================#  
  216. #print the number of each email set  
  217. print(w_num)  
  218.   
  219. #traverse ham emails, create ham dictionary  
  220. ham_dict = dict_creator(train_data + ham, negative_dict)  
  221. save_dict(ham_dict, "ham_dict.data"True)  
  222. #traverse spam emails, create spam dictionary  
  223. spam_dict = dict_creator(train_data + spam, negative_dict)  
  224. save_dict(spam_dict, "spam_dict.data"True)  
  225.   
  226. #train process  
  227. list_ham = sort_by_value(ham_dict)  
  228. print_top_twenty(list_ham)  
  229. print("-------------------------------------------")  
  230. list_spam = sort_by_value(spam_dict)  
  231. print_top_twenty(list_spam)  
  232.   
  233. ham_length = len(ham_dict)  
  234. # print ham_length #46328  
  235. ham_denominator_all = math.log(w_num[0]+K) * len(dictionary)# w_num[2]  
  236. spam_length = len(spam_dict)  
  237. # print spam_length #46328  
  238. # print len(dictionary) #46328  
  239. spam_denominator_all = math.log(w_num[1]+K) * len(dictionary)# w_num[2]  
  240.   
  241. #initial process for test set  
  242. def test_process(data_set):  
  243.     test_set = [ham, spam]  
  244.     for w in test_set:  
  245.         folder_path = data_set + w  
  246.         files = os.listdir(folder_path)  
  247.     #     total_num = 0  
  248.     #     correct_num = 0  
  249.         for file in files:  
  250.         #     dict_temp = dict()  
  251.             if os.path.isfile(folder_path+file):  
  252.                 vector_creator(folder_path+file, negative_dict, folder_path+"dict/"+file)  
  253.   
  254. def test_prob(data_set):  
  255.     test_set = [ham, spam]  
  256.     for w in test_set:  
  257.         folder_path = data_set + w + "dict/"  
  258.         print folder_path  
  259.         files = os.listdir(folder_path)  
  260.         total_num = 0  
  261.         ham_predict_num = 0  
  262.         spam_predict_num = 0  
  263.         print len(files)  
  264.         f = open("file.data"'w')  
  265.         f.writelines(folder_path+"\n")  
  266.         for file in files:  
  267.         #     dict_temp = dict()  
  268.             if os.path.isfile(folder_path+file):  
  269.                 f.writelines(folder_path+file+"\n")  
  270.                 result = predict(folder_path+file, w_num, ham_dict, spam_dict, ham_denominator_all, spam_denominator_all)  
  271. #                 print result  
  272.                 total_num += 1  
  273.                 if(w == ham and result == 0):  
  274.                     ham_predict_num += 1  
  275.                 elif (w == spam and result == 1):  
  276.                     spam_predict_num += 1  
  277.         print "Total test number: " + str(total_num)  
  278.         print "Ham predicted number: " + str(ham_predict_num)  
  279.         print "Spam predicted number: " + str(spam_predict_num)  
  280. # test_process(test_data) #This function is to pre-process the test file  
  281. # test_process(train_data) #This function is to pre-process the train file  
  282. test_prob(test_data) #This function is to calculate the probability of test_data  
  283. test_prob(train_data) ##This function is to calculate the probability of train_data  
  284.   
  285. def get_top_ten_word_ratio():  
  286.     dict_temp = dict()  
  287.     for key in dictionary:  
  288.         log2 = calculate_log_p_xi_w(key, ham_dict, w_num[0], True) - math.log(w_num[0]+K)  
  289.         log1 = calculate_log_p_xi_w(key, spam_dict, w_num[1], True) - math.log(w_num[1]+K)  
  290.         result = log1-log2  
  291.         dict_temp[key] = result  
  292.     list_temp = sort_by_value(dict_temp)  
  293.     print("--------------TOP 20 spam words-----------------")  
  294.     print_top_twenty(list_temp)  
  295.   
  296. get_top_ten_word_ratio()  
  297.   
  298. print("-------------output ends------------------")  

测试结果

在data set.zip中有训练集+测试集。




准确率均达98%以上。

    

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值