设为一个待分类项,而每个a为x的一个特征属性。
有类别集合 。
计算。
如果,则
。
那么如何计算第3步中的各个条件概率呢?
1、找到一个已知分类的待分类项集合,这个集合叫做训练样本集。
2、统计得到在各类别下各个特征属性的条件概率。
3、如果各个特征属性是条件独立的,则根据贝叶斯定理有如下推导:
P(yi|x)=P(x|yi)P(i)/P(x)
因为分母对于所有类别为常数,因为我们只要将分子最大化皆可。又因为各特征属性是条件独立的,所以有:
当特征是离散的时候,使用多项式模型。多项式模型在计算先验概率和条件概率时,会做一些平滑处理。与多项式模型一样,伯努利模型适用于离散特征的情况,不同的是,伯努利模型中每个特征的取值只能是0或1。当特征是连续变量的时候,应采用高斯模型,高斯模型假设每一维特征都服从高斯分布(正态分布)。
高斯朴素贝叶斯:特征的可能性(即概率)假设为高斯分布:
多项分布朴素贝叶斯:
其中是第k个类别的第j维特征的第i个取值条件概率。
是训练样本中输出为第k类的样本个数。先验平滑因子
应用于在学习样本中没有出现的特征,以防在将来的计算中出现0概率输出。 把
被称为拉普拉斯平滑(Lapalce smoothing),而
被称为利德斯通(Lidstone smoothing)。
伯努利朴素贝叶斯:有多个特征,但每个特征都假设为一个二元变量。因此,这类算法要求样本以二元特征向量表示;如果样本含有其他类型的数据,一个BernoulliNB实例会将其二值化(取决于binarize参数)。伯努利贝叶斯的决策规则基于:
与多项分布朴素贝叶斯的规则不同,伯努利朴素贝叶斯明确地惩罚类y中没有出现作为预测因子的特征值,而多项分布朴素贝叶斯只是简单地忽略没出现的特征值。
用伯努利朴素贝叶斯训练Kaggle的titanic数据集:
import pandas as pd
from sklearn.naive_bayes import BernoulliNB
from sklearn.model_selection import train_test_split
#读取数据
def read_dataset(fname):
#指定第一列作为行索引
data=pd.read_csv(fname,index_col=0)
#丢弃无用数据
data.drop(['Name','Ticket','Cabin'],axis=1,inplace=True)
#处理性别数据
labels=data['Sex'].unique().tolist()
data['Sex']=[*map(lambda x:labels.index(x),data['Sex'])]
#处理登船港口数据
labels=data['Embarked'].unique().tolist()
data['Embarked']=data['Embarked'].apply(lambda n:labels.index(n))
#处理缺失数据填充0
data=data.fillna(0)
return data
train=read_dataset('E:/Data/titanic/train.csv')
test=read_dataset('E:/Data/titanic/test.csv')
#拆分数据集
# =============================================================================
# x_train=train.drop(['Survived'],axis=1).values
# y_train=train['Survived'].values
# x_test=test.values
# =============================================================================
x=train.drop(['Survived'],axis=1).values
y=train['Survived'].values
x_train,x_test,y_train,y_test=train_test_split(x,y,test_size=0.2)
clf=BernoulliNB()
clf.fit(x_train,y_train)
print(clf.score(x_test,y_test))
# =============================================================================
# predicted=clf.predict(x_test)
# test['Survived'] = predicted.astype(int)
# test.to_csv('submission.csv', sep=',', index=False)
# =============================================================================