朴素贝叶斯方法(二分类)[机器学习实战]

数据链接

垃圾短信分类

解析

设一个点(x,y),对(x,y)进行分类(1,2),我们可以设每个点分别属于两个类别的概率:

如果p1(x,y) > p2(x,y),那么类别为1
如果p1(x,y) < p2(x,y),那么类别为2

由贝叶斯概率我们有

\[p(c|x,y) = \frac {p(x,y|c)p(c)}{p(x,y)}\dots(1) \]

对于二分类可见
$$p1 \rightarrow p(1|x,y)$$
$$p2 \rightarrow p(2|x,y)$$

观察公式一右边.
根据大数定理,当数据集具有一定规模时,我们可以以频率逼近概率。
右边的概率可以由统计而得

因此朴素贝叶斯法则主要在于对数据的统计,步骤如下:

  1. 分词,生成词向量空间(英文文本无需如此,中文文本可以使用jieba分词工具)
  2. 对于每个向量,计算其向量空间坐标(每个特征词出现次数,即词袋)
  1)计算出p(c),即每个类别出现概率
  2)对于p(x,y),可以统计出所有的单变量,再使用乘法原理即可
  3)对于p(x,y|c)可以统计类别c下的所有(x,y)的出现次数
  1. 计算对于给定词向量的p1,p2,答案为其中值较大者

《机器学习实战中》给出了一个优化:

考虑概率很小以及一些为0的值会导致乘完出现0,所以使用对数代替p(由于对数函数是单调递增函数,因此同样很好度量)

from numpy import *

def textParse1(vec):    
    return 1 if vec[0] == 'spam' else 0,vec[1:];
    
def textParse2(vec): 
    return vec[0],vec[1:];
    
def bagOfWords2VecMN(vocabList, inputSet):
    returnVec = [0]*len(vocabList)
    for word in inputSet:
        if word in vocabList:
            returnVec[vocabList.index(word)] += 1
    return returnVec

def setOfWords2VecMN(vocabList, inputSet):
    returnVec = [0]*len(vocabList)
    for word in inputSet:
        if word in vocabList:
            returnVec[vocabList.index(word)] = 1
    return returnVec

def createVocabList(dataSet):
    vocabSet = set([])  
    for document in dataSet:
        vocabSet = vocabSet | set(document) 
    return list(vocabSet)

def tfIdf(trainMatrix,setMatrix):
    n = len(trainMatrix)
    m = len(trainMatrix[0])
    d = [n]*n;
    tb = sum(trainMatrix,axis=1)
    tc = sum(setMatrix,axis=0)
    b = array(tb,dtype='float')
    c = array(tc,dtype='float')
    weight = []
    for i in range(m):
        a = trainMatrix[:,i]
        tf = a/b
        weight.append(tf * log(d/(c[i])))
    returnVec = array(weight).transpose()
    return returnVec
    

def trainNB0(trainMatrix,trainCategory,weight):
    numTrainDocs = len(trainMatrix)
    numWords = len(trainMatrix[0])
    pAbusive = sum(trainCategory)/float(numTrainDocs)
    p0Num = ones(numWords); p1Num = ones(numWords)     
    p0Denom = 2.0; p1Denom = 2.0      
    a = 0;b = 0
    a += trainMatrix[0];b += sum(trainMatrix[0])
    for i in range(numTrainDocs):
        if trainCategory[i] == 1:
            p1Num += trainMatrix[i]*weight[i]
            p1Denom += sum(trainMatrix[i]*weight[i])
        else:
            p0Num += trainMatrix[i]*weight[i]
            p0Denom += sum(trainMatrix[i]*weight[i])
    p1Vect = log(p1Num/p1Denom)     
    p0Vect = log(p0Num/p0Denom)    
    return p0Vect,p1Vect,pAbusive

def classifyNB(vec2Classify, p0Vec, p1Vec, pClass1):
    p1 = sum(vec2Classify * p1Vec) + log(pClass1)    
    p0 = sum(vec2Classify * p0Vec) + log(1.0 - pClass1)
    if p1 > p0:
        return 1
    else: 
        return 0

def spamTest():
    trainFile = './train.csv'
    testFile = './test.csv'
    import csv
    docList=[]; classList = []; fullText =[]
    in1 = open(trainFile);in1.readline()
    fr1 = csv.reader(in1)
    trainData = [row for row in fr1]
    
   #prepare trainData
    
    n = 0
    for i in trainData:
        label,wordList = textParse1(i)
        docList.append(wordList)
        classList.append(label)
        n += 1
        
    vocabList = createVocabList(docList)
    
    trainMat = [];trainClasses = []
    setMat = []
    
    for docIndex in range(n):
        trainMat.append(bagOfWords2VecMN(vocabList,docList[docIndex]))
        setMat.append(setOfWords2VecMN(vocabList,docList[docIndex]))
        trainClasses.append(classList[docIndex])
    
    #traiing by bayes
    weight = tfIdf(array(trainMat),array(setMat))
    
    p0V,p1V,pSpam = trainNB0(array(trainMat),array(trainClasses),array(weight))
    
    
    #prepare testData
    
    in2 = open(testFile);in2.readline()
    fr2 = csv.reader(in2)
    fw = csv.writer(open('predict.csv', 'w'))
    name = ['SmsId','Label']
    fw.writerow(name)
    testData = [row for row in fr2]
    
    #predict testData
    for i in testData:
        id,wordList = textParse2(i)
        wordVector = bagOfWords2VecMN(vocabList, wordList)
        fw.writerow([id,'spam' if classifyNB(array(wordVector),p0V,p1V,pSpam) else 'ham'])
    
    #print 'fianl point'
posted @ 2018-03-16 10:46  江南何采莲  阅读(2432)  评论(0编辑  收藏  举报