朴素贝叶斯 - 预测你能找到女朋友吗?

 

朴素贝叶斯

是一种对待分类项进行分类的算法
主要有以下几个步骤
1.人工给出一些训练数据(每个数据有多个属性F1,F2...Fn),并对每条数据标注人工看法下的分类属于哪类(C1,C2,C3...)
2.输入第一步的数据,根据朴素贝叶斯算法,求得分类器(实际上就是求得每个类别之下,各属性所占的概率,比如在分类C1下,属性F1为true 的频率是多少,再求属性为false的频率有多少,这里的F1属性,我们举例的限定是只有true和false2种情况可选)
3.需要对新的待分类数据进行分类时,根据第二步得到的分类器,求得该数据项发生的前提下在各个分类下发生的概率(P1,P2..Pn),找到概率最大的一个,将该数据分为该类


1.中指出的数据中被抽象出来的属性维度,需要自己衡量所需的属性有哪些


接下来说说公式的推导和理解

A,B为独立事件
发生B的概率为 P(B)
发生A的概率为 P(A)
在B发生时,再发生A的概率 P(A|B)
在A发生时,再发生B的概率 P(B|A)
A,B同时发生的概率 P(AB) = P(A∩B)

图:

 

 

 

例如 ,A为 0.8 ,B为 0.6 , A∩B = 0.2

A,B为独立事件
发生B的概率为 P(B) = 0.6
发生A的概率为 P(A) = 0.8
A,B同时发生的概率 P(AB) = P(A∩B) = 0.2
在B发生时,再发生A的概率 P(A|B) = P(A∩B)/P(B) = 0.2 / 0.8 = 0.25
在A发生时,再发生B的概率 P(B|A) = P(A∩B)/P(A) = 0.2 / 0.6 = 0.33


看这2个式子
P(A|B) = P(A∩B)/P(B)
P(B|A) = P(A∩B)/P(A)
我们可以导出朴素贝叶斯的公式
(1). P(A∩B) = P(A|B) * P(B)
(2). P(A∩B) = P(B|A) * P(A)

(1) = (2)
P(A|B) * P(B) = P(B|A) * P(A)   
P(A|B) = P(B|A) * P(A) / P(B)


在含义上也等于
P(A|B) = P(A∩B) / P(B)     //A∩B的部分占B面积的百分比



类似的

P(A|B) = P(A|B) * P(B) / P(A)

举例:
在1984年中的某一天
会下雨的地区为 80%
忘记带伞的人占 20%
在忘记带伞的人中处在下雨地区的人占 10%
问:处于会下雨的地区的人中忘记带伞的人占百分之多少?


P(B) = 0.8
P(A) = 0.2
P(B|A) = 0.1
P(A|B) = ?
根据公式 = P(B|A) * P(A) / P(B) = 0.1 * 0.8 / 0.2 = 0.08 / 0.2 = 0.4




如果事件A有多种可能性 A = {A1,A2}

这时
P(A|B) = P(B|A) * P(A) / P(B)

P(A1|B) = P(B|A1) * P(A1) / P(B)
P(A1|B) = P(B|A1) * P(A1) / P(A1∩B) + P(A2∩B)
P(A1|B) = P(B|A1) * P(A1) / (P(B|A1)*P(A1) + P(B|A2)*P(A2))


在含义上也等于
P(A1|B) = P(A1∩B) / P(A1∩B) + P(A2∩B)

A = {A1,A2...Aj}

P(A1|B) = P(B|A1) * P(A1) / Σnj=1  P(B|Aj) * P(Aj)

 

P(A1|B) = P(A1∩B) / P(A1∩B) + P(A2∩B) + ...P(An ∩ B)

 

 

 



朴素贝叶斯分类器

实际情况下,一个事件往往受多个特征影响, 影响 B 的因素有 n 个,分别是 b1,b2,…,bn。

P(A|B) = P(B|A) * P(A) / P(B)

中的B(先验的)有特征属性 b1,b2,b3。。。bn (对应的 A 也有这些属性 )
P(A|B) = P(A|b1,b2,b3...bn) = P(b1,b2,b3...bn|A) * P(A) / P(b1,b2,b3...Fn)


其中
P(b1,b2,b3...Fn) = P(B)
P(b1,b2,b3...Fn|A) = P(B|A) 


那么 P(b1,b2,b3...bn|A) 应该如何计算呢?
在每个属性 b1,b2...bn 没有关联性的前提下

 

 


P(
b1,b2,b3...bn|A) = P(b1|A) * P(b2|A) ... * P(bn|A)

所以
P(A|b1,b2,b3...bn) = P(b1|A) * P(b2|A) ... * P(bn|A) * P(A) / P(b1,b2...bn)

n
= \prod i=1 P(bi|A) * P(A) /P(b1,b2...bn)






代码:

一个测试你是否会被妹子喜欢的例子:

//朴素贝叶斯
//是一种对待分类项进行分类的算法
//主要有以下几个步骤
//1.人工给出一些训练数据(每个数据有多个属性F1,F2...Fn),并对每条数据标注人工看法下的分类属于哪类(C1,C2,C3...)
//2.输入第一步的数据,根据朴素贝叶斯算法,求得分类器(实际上就是求得每个类别之下,各属性所占的概率,比如在分类C1下,属性F1为true 的频率是多少,再求属性为false的频率有多少,这里的F1属性,我们举例的限定是只有true和false2种情况可选)
//3.需要对新的待分类数据进行分类时,根据第二步得到的分类器,求得该数据项发生的前提下在各个分类下发生的概率(P1,P2..Pn),找到概率最大的一个,将该数据分为该类

//1.数据中被抽象出来的属性维度,需要自己衡量所需的属性有哪些

import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;

//离散数据的 朴素贝叶斯
public class NaiveBayesDCT {
    int fieldNum; //包含人工分类的一列维度
    List<List<Integer>> datas;
    List<Integer>[] fieldValueOptions;

    //映射,方便找全局索引(一个类别下的一个属性被分配一个索引号)(索引号从0到+)
    //<fieldCountIndex :<fieldValue :fieldIndex>>
    Map<Integer, Map<Integer, Integer>> fci2fv2fi;
    //反向映射,全局索引找fieldCountIndex(类别索引)
    //fieldIndex 2 fieldCountIndex
    Map<Integer, Integer> fi2fci;
    int fieldIndexCount;

    //P(C) ,每个人工分类类别占训练样本的百分比
    Map<Integer, Float> clazz2Prob;

    //按类别分类的数据
    //<类别:数据[]>
    Map<Integer, List<List<Integer>>> classifyData;
    //<类别:<全局属性索引号:概率统计值>>
    Map<Integer, Map<Integer, Float>> trainData;

    //一个维度一个List,每个list的integer表示该维度下允许的可选值是哪些,比如对野外采集的农产品有重量维度:轻,中,重 ,对应1,2,3
    //最后一列是人工标注的分类
    //f1,f2,...fn,clazzTypes
    public NaiveBayesDCT(List<Integer>... fieldValueOptions) {
        this.fieldNum = fieldValueOptions.length;
        this.fieldValueOptions = fieldValueOptions;
        datas = new LinkedList<>();
        trainData = new HashMap<>();
    }

    public void addData(List<Integer> data) {
        if (data.size() != fieldNum) //check 输入的数据维度要一致
            return;
        datas.add(data);
    }

    private void initTrainData() {
        fci2fv2fi = new HashMap<>();
        fi2fci = new HashMap<>();
        for (int filedCount = 0; filedCount < fieldValueOptions.length - 1; filedCount++) { //每类属性 ,抛开人工分类的属性
            fci2fv2fi.put(filedCount, new HashMap<Integer, Integer>());
            for (int fv : fieldValueOptions[filedCount]) {            //这类属性的每个可能出现的属性值
                fci2fv2fi.get(filedCount).put(fv, fieldIndexCount);   //属性的每一个可能出现的值,被分配一个全局索引
                fi2fci.put(fieldIndexCount, filedCount);
                fieldIndexCount++;
            }
        }

        //初始化人工标注类别属性的 每个类别的默认概率为0
        for (int clazzType : fieldValueOptions[fieldNum - 1]) {
            trainData.put(clazzType, new HashMap<Integer, Float>());
            Map<Integer, Float> field2Prob = trainData.get(clazzType);
            for (int tmpfi = 0; tmpfi < fieldIndexCount; tmpfi++) {
                field2Prob.put(tmpfi, 0f); //默认是 0%
            }
        }
    }

    public void train() {
        trainData = new HashMap<>();
        //先按人工标注的类别分类,再计算该分类下,每个维度属性的可选值的统计概率
        spliteByClass();
        initTrainData();

        //<人工分类类型:[属性索引]=统计次数>
        Map<Integer, int[]> fieldIndexSavedCountsMap = new HashMap<>();

        //按分类遍历数据集合
        for (Map.Entry<Integer, List<List<Integer>>> e : classifyData.entrySet()) {
            int clazzType = e.getKey();
            int[] fieldIndexSavedCounts = new int[fieldNum];
            fieldIndexSavedCountsMap.put(clazzType, fieldIndexSavedCounts);

            Map<Integer, Float> field2Prob = trainData.get(clazzType);

            //遍历该分类下的所有数据
            List<List<Integer>> datas = e.getValue();
            for (int i = 0; i < datas.size(); i++) {
                List<Integer> data = datas.get(i);
                //对该分类下的一个样本的每个属性统计概率(先求出该属性现次数的总和)
                for (int fci = 0; fci < data.size() - 1; fci++) {
                    //属性索引 - 属性值 - 属性全局索引
                    int fieldIndex = -1;
                    //该数据类别的全局属性索引号
                    try {
                        fieldIndex = fci2fv2fi.get(fci).get(data.get(fci));
                    } catch (NullPointerException ex) {
                        //数据中给出了未再构造函数中声明的属性值
                        ex.printStackTrace();
                        throw new RuntimeException("fieldIndex :" + fci + " value : " + data.get(fci) + " not statement in constructor params");
                    }
                    field2Prob.put(fieldIndex, field2Prob.get(fieldIndex) + 1); //统计值+1
                    fieldIndexSavedCounts[fci]++;
                }
            }
        }

        if (datas.size() <= 0)
            return;

        //再算概率 = (统计值 / 总数量),最后几个是人工分类类别(数量=分类类别中的类别个数)
        for (Map.Entry<Integer, Map<Integer, Float>> e : trainData.entrySet()) {
            int clazzType = e.getKey();
            int[] fieldIndexSavedCounts = fieldIndexSavedCountsMap.get(clazzType);
            for (Map.Entry<Integer, Float> e2 : e.getValue().entrySet()) {
                int globalFieldIndex = e2.getKey();
                float sumCount = trainData.get(clazzType).get(globalFieldIndex);
                float prob = sumCount / fieldIndexSavedCounts[fi2fci.get(e2.getKey())]; //prob
                trainData.get(e.getKey()).put(e2.getKey(), prob);
            }
        }
        clazz2Prob = new HashMap<>();
        for (Map.Entry<Integer, List<List<Integer>>> e2 : classifyData.entrySet()) {
            clazz2Prob.put(e2.getKey(), ((float) e2.getValue().size() / datas.size()));
        }


        System.out.println("==== trainData ====");
        System.out.println(trainData);  //分类1={全局属性索引1=概率,全局属性索引2=概率....} , 分类2={全局属性索引1=概率,全局属性索引2=概率....} ....
        System.out.println("==== clazz2Prob ====");
        System.out.println(clazz2Prob);
    }

    private void spliteByClass() {
        classifyData = new HashMap<>();
        for (int clazzType : fieldValueOptions[fieldNum - 1]) {
            classifyData.put(clazzType, new LinkedList<List<Integer>>());
        }

        for (int i = 0; i < datas.size(); i++) {
            List<Integer> data = datas.get(i);
            int clazzType = data.get(fieldNum - 1);
            classifyData.get(clazzType).add(data);
        }
    }

    //对数据进行分类
    public int caculClass(List<Integer> data) {
        //C = CLASS TYPE {c1,c2,c3...cn}
        //B = DATA
        //P(C|B) = P(B|C) * P(C) / P(B)
        //      ∝ P(b1|c1) * P(b2|c2) * ..... * P(C) * P(B)
        //      ∝ P(b1|c1) * P(b2|c2) * ..... * P(C)

        if (data.size() != fieldNum - 1)
            return -1;  //数据的属性维度和定义的不一致

        float maxDataClassProb = 0; //保存各分类中,概率最大的概率值
        int maxDataClassProbClassType = -1; //概率最大的概率值的分类类型
        Map<Integer, Float> dataClazz2Prob = new HashMap<>();   //key:分类类型 c1,c2...cn, value:该分类在data发生情况下的 概率值 P(C|B)
        for (Map.Entry<Integer, Map<Integer, Float>> trainClazzData : trainData.entrySet()) {
            int clazzType = trainClazzData.getKey();
            Map<Integer, Float> filedProbMap = trainClazzData.getValue();
            float probProduct = clazz2Prob.get(clazzType);  //P(C) 的值
            for (int fieldCountIndex = 0; fieldCountIndex < data.size(); fieldCountIndex++) {
                int val = data.get(fieldCountIndex);
                if(!fci2fv2fi.get(fieldCountIndex).containsKey(val))
                    return -1; //没有该属性存在定义中
                int globalFieldIndex = fci2fv2fi.get(fieldCountIndex).get(val);
                probProduct *= filedProbMap.get(globalFieldIndex);
            }
            if (maxDataClassProb < probProduct) {
                maxDataClassProb = probProduct;
                maxDataClassProbClassType = clazzType;
            }
            dataClazz2Prob.put(clazzType, probProduct);
        }
        System.out.println(dataClazz2Prob);

        return maxDataClassProbClassType;
    }

    public static void main(String[] a) {
        //属性1
        List<Integer> face = new LinkedList<Integer>();
        face.add(10);            //好看的脸
        face.add(0xdeadface);    //不好看的脸
        //属性2
        List<Integer> FightCapacity = new LinkedList<Integer>();
        FightCapacity.add(999); //999战斗力
        FightCapacity.add(5);   //5战斗力
        //人工分类属性
        List<Integer> beLiked = new LinkedList<>();
        beLiked.add(1);         //被人喜欢
        beLiked.add(0);         //不被人喜欢

        NaiveBayesDCT nbdct = new NaiveBayesDCT(face, FightCapacity, beLiked);
        //初始化训练数据
        List<Integer> data = new LinkedList<>();
        data.add(10);  //f1
        data.add(999); //f2
        data.add(1);   //clazzType ,人工标注
        nbdct.addData(data);
        data = new LinkedList<>();
        data.add(0xdeadface);
        data.add(5);
        data.add(0);
        nbdct.addData(data);
        //训练
        nbdct.train();  //该妹子几乎只喜欢脸好看,战斗力999的

        System.out.println();
        //进行一次 分类测试
        List<Integer> people1 = new LinkedList<>();
        people1.add(10);
        people1.add(999);
        System.out.println("people1 " + people1);
        System.out.println("是否被喜欢:" + nbdct.caculClass(people1)); //理想输出:1

        //博主:
        List<Integer> me = new LinkedList<>();
        me.add(0xdeadface);
        me.add(5);
        System.out.println("me " + me);
        System.out.println("博主 是否被喜欢:" + nbdct.caculClass(me)); //理想输出:0

        List<Integer> people3 = new LinkedList<>();
        people3.add(0xdeadface);
        people3.add(999);
        System.out.println("people3 " + people3);
        System.out.println("是否被喜欢: " + nbdct.caculClass(people3)); //理想输出:-1 ,不知道

        List<Integer> people4 = new LinkedList<>();
        people4.add(5);
        people4.add(999);
        System.out.println("people4 " + people4);
        System.out.println("是否被喜欢: " + nbdct.caculClass(people4)); //理想输出:-1 ,不知道
    }
}

 

 

输出:

==== trainData ====
{0={0=0.0, 1=1.0, 2=0.0, 3=1.0}, 1={0=1.0, 1=0.0, 2=1.0, 3=0.0}}
==== clazz2Prob ====
{0=0.5, 1=0.5}

people1 [10, 999]
{0=0.0, 1=0.5}
是否被喜欢:1

me [-559023410, 5]
{0=0.5, 1=0.0}
博主 是否被喜欢:0

people3 [-559023410, 999]
{0=0.0, 1=0.0}
是否被喜欢: -1

people4 [5, 999]
是否被喜欢: -1

 




 

参考:

https://blog.csdn.net/u013850277/article/details/83996358

 

posted on 2019-12-21 07:48  jald  阅读(520)  评论(0编辑  收藏  举报

导航