决策树ID3算法示例

决策树代码如下:

#include "MyID3.h"
using namespace std;
void ReadData()        //读入数据
{
    ifstream fin("F:\\data.txt");
    for(int i=0;i<NUM;i++)
    {
      for(int j=0;j<6;j++)
        {
            fin>>DataTable[i][j];
            cout<<DataTable[i][j]<<"\t";
        }
      cout<<endl;
    }
    fin.close();
}

double ComputLog(double &p)                     //计算以2为底的log
{
    if(p==0||p==1)
    return 0;
    else
    {
        double result=log(p)/log(2);
        return result;
    }
}

double ComputInfo(double &p)                             //计算信息熵
{
    //cout<<"The value of p is: "<<p<<endl;
    double q=1-p;
    double m=1/p;
    double n=1/q;
    return (p*ComputLog(m)+q*ComputLog(n));
}

void CountInfoNP(int begin,int end,int &CountP,int &CountN)            //搜索的起始位置、终止位置、计数变量
{
    CountP=0;
    CountN=0;
    for(int i=begin;i<=end;i++)
        if(DataTable[i][5]=="Yes")
            CountP++;
        else
            CountN++;
}

bool CompareData(string &data,int &count,string &result)                           //判断该属性值是否出现过
{
    for(int k=0;k<count;k++)
        if(data==DataValueWeight[k].AttriValueName)                 //如果该值出现过,则将其出现次数加一
            {
                DataValueWeight[k].ValueWeight+=1;
                if(result=="Yes")
                    DataValueWeight[k].ValuePWeight+=1;
                else
                    DataValueWeight[k].ValueNWeight+=1;
                //cout<<"Exist Here"<<endl;
                return false;
            }
    return true;                                                    //如果该值没有出现过,则返回真值
}

int SearchData(const int &begin,const int &end,const int &k)        //对于第k列进行检索
{
    //cout<<"Enter SearchData()  "<<begin<<"  "<<end<<"  "<<k<<endl;
    int count=0;
    for(int i=0;i<VALUENUM;i++)
        {
            DataValueWeight[i].ValueWeight=0;
            DataValueWeight[i].ValueNWeight=0;
            DataValueWeight[i].ValuePWeight=0;
        }

    for(int i=begin;i<=end;i++)
        if(i==begin)
           {
             DataValueWeight[count].AttriValueName=DataTable[i][k];
             DataValueWeight[count].ValueWeight+=1;
             if(DataTable[i][5]=="Yes")
                DataValueWeight[count].ValuePWeight+=1;
             else
                DataValueWeight[count].ValueNWeight+=1;

             count++;
           }
        else
        {
            string data=DataTable[i][k];
            string result=DataTable[i][5];
            if(CompareData(data,count,result))                             //如果该值没有出现过
            {
                DataValueWeight[count].AttriValueName=data;
                DataValueWeight[count].ValueWeight+=1;



                if(DataTable[i][5]=="Yes")
                    DataValueWeight[count].ValuePWeight+=1;
                else
                    DataValueWeight[count].ValueNWeight+=1;
                count++;
            }
        }


     //for(int s=0;s<count;s++)
     //   cout<<"Hello: "<<DataValueWeight[s].AttriValueName<<"\t"<<DataValueWeight[s].ValueWeight<<
     //   "\t"<<DataValueWeight[s].ValuePWeight<<" \t"<<DataValueWeight[s].ValueNWeight<<endl;


    for(int i=0;i<count;i++)
    {
        if(DataValueWeight[i].ValueNWeight!=0)
            DataValueWeight[i].ValueNWeight=DataValueWeight[i].ValueWeight/DataValueWeight[i].ValueNWeight;
        else
            DataValueWeight[i].ValueNWeight=0;

        if(DataValueWeight[i].ValuePWeight!=0)
            DataValueWeight[i].ValuePWeight=DataValueWeight[i].ValueWeight/DataValueWeight[i].ValuePWeight;
        else
            DataValueWeight[i].ValuePWeight=0;
        //cout<<"N: "<<DataValueWeight[i].ValueNWeight<<"  P: "<<DataValueWeight[i].ValuePWeight<<endl;
    }
    return count;
}

int PickAttri()
{
    double max=0;
    int pos;

    for(int i=1;i<5;i++)
    if(InfoResult[i].AttriI>max)
    {
        pos=i;
        max=InfoResult[i].AttriI;
    }
    return pos;
}
int  SortByAttriValue(int &begin,int &end,int &temp,int *position)
{

    for(int i=begin;i<=end;i++)                                         //将相应的数据拷贝到另一个阵列
        for(int j=0;j<=5;j++)
        {
            int posy=i-begin;
            CopyDataTable[posy][j]=DataTable[i][j];
        }
//cout<<"have a look"<<endl;

    /*cout<<"*************         Show Result First        ****************"<<endl;
    cout<<InfoResult[temp].AttriName<<endl;
    for(int i=begin;i<=end;i++)
    {
        for(int j=0;j<=5;j++)
            cout<<DataTable[i][j]<<"\t";
        cout<<endl;
    }*/



    int low=0,high=end-begin;
    int count=0;
    int countpos=1;
    position[0]=begin;
    for(int i=0;i<InfoResult[temp].AttriKind;i++)
    {
        for(int j=low;j<=high;j++)
            if(CopyDataTable[j][temp]==DataValueWeight[i].AttriValueName)
               {
                    int pos=count+begin;

                    for(int k=0;k<6;k++)
                        DataTable[pos][k]=CopyDataTable[j][k];
                    count++;
               }
        position[countpos]=count+begin;
        countpos++;
    }

    /*cout<<"*************         Show Result Second        ****************"<<endl;
    cout<<InfoResult[temp].AttriName<<endl;
    for(int i=begin;i<=end;i++)
    {
        for(int j=0;j<=5;j++)
            cout<<DataTable[i][j]<<"\t";
        cout<<endl;
    }
    cout<<"\n\n\n";*/
    return countpos;
}

void BuildTree(int begin,int end,Node *parent)
{
    int CountP=0,CountN=0;
    CountInfoNP(begin,end,CountP,CountN);

    cout<<"************************ The data be sorted **************************"<<endl;
    for(int i=begin;i<=end;i++)
    {
        for(int j=0;j<=5;j++)
            cout<<DataTable[i][j]<<"\t";
        cout<<endl;
    }
    cout<<"\n\n\n";

    cout<<parent->AttriName<<" have a look: "<<CountP<<endl;
    if(CountP==0||CountN==0)               //该子集当中只包含Yes或者No时为叶子节点,返回调用处;
    {
        cout<<"creat leaf node"<<endl;
        Node* t=new Node();                                    //建立叶子节点
        if(CountP==0)
            t->AttriName="No";
        else
            t->AttriName="Yes";
        parent->Children.push_back(t);                             //插入孩子节点
        return;
    }
    else
    {
        double p=(double)CountP/(CountP+CountN);
        double InfoH=ComputInfo(p);                            //获得信息熵

        for(int k=1;k<5;k++)                                   //循环计算各个属性的条件信息熵,并计算出互信息
        {
            int KindOfValue=SearchData(begin,end,k);
            int sum=1+end-begin;
            for(int j=0;j<KindOfValue;j++)                     //计算出属性的每种取值的权重的倒数
                DataValueWeight[j].ValueWeight=DataValueWeight[j].ValueWeight/sum;

            double InfoGain=0;
            if(DataValueWeight[0].ValueNWeight!=0&&DataValueWeight[0].ValuePWeight!=0)
                InfoGain=DataValueWeight[0].ValueWeight*(ComputLog(DataValueWeight[0].ValueNWeight)/DataValueWeight[0].ValueNWeight+ComputLog(DataValueWeight[0].ValuePWeight)/DataValueWeight[0].ValuePWeight);

            for(int j=1;j<KindOfValue;j++)                     //计算条件信息
            if(DataValueWeight[j].ValueNWeight!=0&&DataValueWeight[j].ValuePWeight!=0)
                InfoGain+=DataValueWeight[j].ValueWeight*(ComputLog(DataValueWeight[j].ValueNWeight)/DataValueWeight[j].ValueNWeight+ComputLog(DataValueWeight[j].ValuePWeight)/DataValueWeight[j].ValuePWeight);

            InfoResult[k].AttriI=InfoH-InfoGain;               //计算互信息
            InfoResult[k].AttriKind=KindOfValue;
        }
        int temp=PickAttri();                                            //选出互信息最大的属性作为节点建树
        Node* t=new Node();
        t->AttriName=InfoResult[temp].AttriName;
        SearchData(begin,end,temp);
        for(int k=0;k<InfoResult[temp].AttriKind;k++)
        {
            string name=DataValueWeight[k].AttriValueName;
            t->AttriValue.push_back(name);
        }
        t->parent=parent;
        parent->Children.push_back(t);                                   //孩子节点压入vector当中
        int position[NUMOFPOS];

        cout<<"before SortByAttriValue Begin: "<<begin<<",END: "<<end<<endl;


        SortByAttriValue(begin,end,temp,position);                                     //将数据按照选定属性的取值不同进行划分
        int times=InfoResult[temp].AttriKind;
        for(int l=0;l<=times;l++)
            cout<<position[l]<<" ";
        cout<<endl;
        for(int k=0;k<times;k++)
            {
                int head,rear;
                head=position[k];
                int hire=k+1;
                rear=position[hire]-1;
                for(int l=0;l<=times;l++)
                cout<<position[l]<<" ";
                cout<<endl;
                cout<<"Head: "<<head<<" ,Rear: "<<rear<<endl;
                BuildTree(head,rear,t);
            }
    }
}

void ShowTree(Node *root)
{

    if(root->AttriName=="Yes"||root->AttriName=="No")
    {
        cout<<root->AttriName<<endl;
        return;
    }
    else
    {
        cout<<root->AttriName<<endl;
        for(vector<string>::iterator itvalue=root->AttriValue.begin();itvalue!=root->AttriValue.end();itvalue++)
        {
            string value=*itvalue;
            cout<<value<<" ";
        }
        cout<<endl;
        for(vector<Node*>::iterator itnode=root->Children.begin();itnode!=root->Children.end();itnode++)
        {
            Node *t=*itnode;
            ShowTree(t);
        }
    }
}
int main()
{
    InfoResult[1].AttriName="天气";
    InfoResult[2].AttriName="气温";
    InfoResult[3].AttriName="湿度";
    InfoResult[4].AttriName="";
    ReadData();
    Node *Root=new Node;
    BuildTree(0,NUM-1,Root);

    //vector<Node>::iterator it=Root.Children.begin();
    ShowTree(Root);
    /*Node t=*it;
    cout<<t.AttriName<<endl;
    for(vector<string>::iterator itvalue=t.AttriValue.begin();itvalue!=t.AttriValue.end();itvalue++)
    cout<<*itvalue<<endl;
    it=t.Children.begin();

    t=*it;
    cout<<t.AttriName<<endl;*/
    //ShowTree(t);
    //cout<<"Root: "<<t.AttriName<<" ,Value: "<<*(t.AttriValue.begin())<<endl;
    return 0;
}

 

posted @ 2014-05-06 11:49  再见,少年  Views(811)  Comments(1)    收藏  举报