python实现朴素贝叶斯分类器
本文用的是sciki-learn库的iris数据集进行测试。用的模型也是最简单的,就是用贝叶斯定理P(A|B)=P(B|A)*P(A)/P(B),计算每个类别在样本中概率(代码中是pLabel变量)
以及每个类下每个特征的概率(代码中是pNum变量)。
写得比较粗糙,对于某个类下没有此特征的情况采用p=1/样本数量。
有什么错误有人发现麻烦提出,谢谢。
[python]viewplaincopy
#-*-coding:utf-8-*-
fromnumpyimport*
fromsklearnimportdatasets
importnumpyasnp
classNaiveBayesClassifier(object):
def__init__(self):
self.dataMat=list()
self.labelMat=list()
self.pLabel={}
self.pNum={}
defloadDataSet(self):
iris=datasets.load_iris()
self.dataMat=iris.data
self.labelMat=iris.target
labelSet=set(iris.target)
labelList=[iforiinlabelSet]
labelNum=len(labelList)
foriinrange(labelNum):
self.pLabel.setdefault(labelList[i])
self.pLabel[labelList[i]]=np.sum(self.labelMat==labelList[i])/float(len(self.labelMat))
defseperateByClass(self):
seperated={}
foriinrange(len(self.dataMat)):
vector=self.dataMat[i]
ifself.labelMat[i]notinseperated:
seperated[self.labelMat[i]]=[]
seperated[self.labelMat[i]].append(vector)
returnseperated
#通过numpyarray二维数组来获取每一维每种数的概率
defgetProbByArray(self,data):
prob={}
foriinrange(len(data[0])):
ifinotinprob:
prob[i]={}
dataSetList=list(set(data[:,i]))
forjindataSetList:
ifjnotinprob[i]:
prob[i][j]=0
prob[i][j]=np.sum(data[:,i]==j)/float(len(data[:,i]))
prob[0]=[1/float(len(data[:,0]))]#防止feature不存在的情况
returnprob
deftrain(self):
featureNum=len(self.dataMat[0])
seperated=self.seperateByClass()
t_pNum={}#存储每个类别下每个特征每种情况出现的概率
forlabel,datainseperated.iteritems():
iflabelnotint_pNum:
t_pNum[label]={}
t_pNum[label]=self.getProbByArray(np.array(data))
self.pNum=t_pNum
defclassify(self,data):
label=0
pTest=np.ones(3)
foriinself.pLabel:
forjinself.pNum[i]:
ifdata[j]notinself.pNum[i][j]:
pTest[i]*=self.pNum[i][0][0]
else:
pTest[i]*=self.pNum[i][j][data[j]]
pMax=np.max(pTest)
ind=np.where(pTest==pMax)
returnind[0][0]
deftest(self):
self.loadDataSet()
self.train()
pred=[]
right=0
fordinself.dataMat:
pred.append(self.classify(d))
foriinrange(len(self.labelMat)):
ifpred[i]==self.labelMat[i]:
right+=1
printright/float(len(self.labelMat))
if__name__=='__main__':
NB=NaiveBayesClassifier()
NB.test()
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持毛票票。