java实现gbdt

DATA类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import java.io.File;
import java.io.FileNotFoundException;
import java.util.ArrayList;
import java.util.Scanner;
 
public class Data {
    private ArrayList<ArrayList<String>> trainData=new ArrayList<ArrayList<String>>();
    public ArrayList<ArrayList<String>> getTrainData() {
        return this.trainData;
    }
 
    public Data() {
        String dataPath="D://javajavajava//dbdt//src//script//data//adult.data.csv";
        Scanner in;
        try {
            in = new Scanner(new File(dataPath));
            while (in.hasNext()) {
                String line=in.nextLine();
                String []strs=line.trim().split(",");
                ArrayList<String> tmp=new ArrayList<>();
                for(int i=0;i<strs.length;i++)
                {
                    tmp.add(strs[i]);  
                }
                this.trainData.add(tmp);
            }
        } catch (FileNotFoundException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
         
    }
 
    public static void main(String[] args) {
        // TODO Auto-generated method stub
        Data d =new Data();
         
    }
 
}

  TREE类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Random;
import java.util.spi.TimeZoneNameProvider;
 
public class Tree {
    private Tree leftTree=new Tree();
    private Tree rightTree=new Tree();
    private double loss=-1;
    private int attributeSplit=0;
    private String attributeSplitType="";
    boolean isLeaf;
    double leafValue;
    private ArrayList<Integer> leafNodeSet=new ArrayList<>();
     
    public ArrayList<String> getAttributeSet(ArrayList<ArrayList<String>> trainData,int idx)
    {
        HashSet<String> mySet=new HashSet<>();
        ArrayList<String> ans =new ArrayList<>();
        for(int i=0;i<trainData.size();i++)
        {
            mySet.add(trainData.get(i).get(idx));
        }
         
        Iterator<String> it=mySet.iterator();
         
        while(it.hasNext())
        {
            ans.add(it.next());
        }
         
        return ans;
    }
    public boolean myCmpLess(String str1,String str2)
    {
        if(Integer.parseInt(str1.trim())<=Integer.parseInt(str2.trim()))
            return true;
        else return false;
         
    }
    public double computeLoss(ArrayList<Double> values)
    {
        double loss=0;
        for(int i=0;i<values.size();i++)
        {
            loss+=values.get(i);
        }
        double mean=loss/values.size();
        loss=0;
        for(int i=0;i<values.size();i++)
        {
            loss+=Math.pow(values.get(i)-mean,2);
        }
        return Math.sqrt(loss);
    }
    public double getPredictValue(int K, ArrayList<Integer> subIdx,ArrayList<Double> target) {
        double ans=0;
        double sum=0,sum1=0;
        for(int i=0;i<subIdx.size();i++)
        {
            sum+=target.get(subIdx.get(i));
        }
        for(int i=0;i<subIdx.size();i++)
        {
            sum1+=target.get(subIdx.get(i))*(1-target.get(subIdx.get(i)));
        }
        ans=(K-1)/K*sum/sum1;
        return ans;
    }
    public double getPredictValue(Tree root)
    {
        return root.leafValue;
    }
    public double getPredictValue(Tree root,ArrayList<String> instance,Boolean isDigit[])
    {
         
        if(root.isLeaf)
            return root.leafValue;
        else if(isDigit[root.attributeSplit])
        {
            if(myCmpLess(instance.get(root.attributeSplit).trim(),root.attributeSplitType))
                return getPredictValue(root.leftTree, instance, isDigit);
            return getPredictValue(root.rightTree, instance, isDigit);
        }
        else
        {
            if(instance.get(root.attributeSplit).trim().equals(root.attributeSplitType))
                return getPredictValue(root.leftTree, instance, isDigit);
            return getPredictValue(root.rightTree, instance, isDigit);
        }
         
    }
    public Tree constructTree(ArrayList<ArrayList<Integer>> leafNodes,ArrayList<Double> leafValues,int K,int splitPoints, Boolean isDigit[],ArrayList<Integer> subIdx,ArrayList<ArrayList<String>> trainData,ArrayList<Double> target,int maxDepth[],int depth)
    {
         
        int n=trainData.size();
        int dim=trainData.get(0).size();
        ArrayList<Integer> leftTreeIdx=new ArrayList<>();
        ArrayList<Integer> rightTreeIdx=new ArrayList<>();
         
        if(depth<maxDepth[0])
        {
            /*
             * 从所有的attribute中选取最佳的attribute,并且attribute中最佳的分割点,对数据进行分割
             * */
            double loss=-1;
            ArrayList<Integer> leftNodes=new ArrayList<>();
            ArrayList<Integer> rightNodes=new ArrayList<>();
            int attributeSplit=0;
            String attributeSplitType="";
             
            for(int i=0;i<dim;i++)//遍历所有的attribute
            {
                //得到该attribute下所有的distinct的值
                ArrayList<String> myAttributeSet=new ArrayList<>();
                ArrayList<String> subDigitAttribute=new ArrayList<>();
                myAttributeSet=getAttributeSet(trainData, i);
                if(isDigit[i])//如果是数字,就从数组中随机选取splitpoints个节点,代表这个属性可以在这splitpoints下进行分割
                {
                    while(subDigitAttribute.size()<splitPoints)
                    {
                        Random r=new Random();
                        int tmp=r.nextInt(myAttributeSet.size());
                        subDigitAttribute.add(myAttributeSet.get(tmp));
                        myAttributeSet.clear();
                        myAttributeSet=subDigitAttribute;
                    }
                }
                for(int j=0;j<myAttributeSet.size();j++)
                {
                    for(int k=0;k<subIdx.size();k++)
                    {
                        if((!isDigit[i]&&trainData.get(subIdx.get(k)).get(i).trim().equals(myAttributeSet.get(j)))||(isDigit[i]&&myCmpLess(trainData.get(subIdx.get(k)).get(i),myAttributeSet.get(j))))
                        {
                            leftTreeIdx.add(subIdx.get(k));
                        }
                        else
                        {
                            rightTreeIdx.add(subIdx.get(k));
                        }
                    }
                    ArrayList<Double> leftTarget=new ArrayList<>();
                    ArrayList<Double> rightTarget=new ArrayList<>();
                    for(int k=0;k<leftTreeIdx.size();k++)
                        leftTarget.add(target.get(leftTreeIdx.get(k)));
                    for(int k=0;k<rightTreeIdx.size();k++)
                        rightTarget.add(target.get(rightTreeIdx.get(k)));
                    double lossTmp=computeLoss(leftTarget)+computeLoss(rightTarget);   
                    if(loss<0||loss<lossTmp)
                    {
                        leftNodes.clear();
                        rightNodes.clear();
                        for(int k=0;k<leftTreeIdx.size();k++)
                            leftNodes.add(leftTreeIdx.get(k));
                        for(int k=0;k<rightTreeIdx.size();k++)
                            rightNodes.add(rightTreeIdx.get(k));
                        attributeSplit=i;
                        attributeSplitType=myAttributeSet.get(j);
                    }
                     
                }
                         
            }
             
            Tree tmpTree=new Tree();
            tmpTree.attributeSplit=attributeSplit;
            tmpTree.attributeSplitType=attributeSplitType;
            tmpTree.loss=loss;
            tmpTree.isLeaf=false;
            tmpTree.leftTree=constructTree(leafNodes,leafValues,K,splitPoints, isDigit, leftNodes, trainData, target, maxDepth, depth+1);
            tmpTree.leftTree=constructTree(leafNodes,leafValues,K,splitPoints, isDigit, rightNodes, trainData, target, maxDepth, depth+1);
            return tmpTree;
             
        }
        else
        {
            Tree tmpTree=new Tree();
            tmpTree.isLeaf=true;
            tmpTree.leafValue=getPredictValue(K, subIdx, target);
            for(int i=0;i<subIdx.size();i++)
                tmpTree.leafNodeSet.add(subIdx.get(i));
            leafNodes.add(subIdx);
            leafValues.add(tmpTree.leafValue);
            return tmpTree;
        }
    }
     
    public static void main(String[] args) {
        // TODO Auto-generated method stub
        Tree aTree=new Tree();
    }
 
}

  

GBDT类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import java.rmi.server.SkeletonNotFoundException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Random;
import java.util.Set;
 
 
public class GBDT {
     
    private ArrayList<ArrayList<String>> datas=new ArrayList<ArrayList<String>>();
    private ArrayList<String> labelSets=new ArrayList<>();
    private ArrayList<ArrayList<Double>> F=new ArrayList<ArrayList<Double>>();
    private ArrayList<ArrayList<Double>> residual=new ArrayList<ArrayList<Double>>();
    private ArrayList<ArrayList<String>> trainData=new ArrayList<ArrayList<String>>();
    private ArrayList<Integer> labelTrainData=new ArrayList<Integer>();
    private int K;
    private Boolean isDigit[];
    private int dim;
    private int n;
    private double learningRate;
     
    private ArrayList<ArrayList<Tree>> trees=new ArrayList<ArrayList<Tree>>(); //存放所有的树
     
    private int max_iter;
    private double sampleRate;
    private int maxDepth;
    private int splitPoints;
 
    public void computeResidual(ArrayList<Integer> subId)
    {
        for(int i=0;i<subId.size();i++)
        {
            int idx=subId.get(i);
            int y=0;
            if(this.labelTrainData.get(idx)==-1) y=0;
            else y=1;
            double sum=Math.exp(this.F.get(idx).get(0))+Math.exp(this.F.get(idx).get(1));
            double p1=Math.exp(this.F.get(idx).get(0))/sum,p2=Math.exp(this.F.get(idx).get(1))/sum;
            this.residual.get(idx).set(0, y-p1);
            this.residual.get(idx).set(1, y-p2);
        }
    }
    public ArrayList<Integer> myrandom(int maxNum,int num)
    {
        ArrayList<Integer> ans=new ArrayList<>();
        Set<Integer> mySet=new HashSet<>();
        while(mySet.size()<num)
        {
            Random r=new Random();
            int tmp=r.nextInt(maxNum);
            mySet.add(tmp);
        }
        Iterator<Integer> it=mySet.iterator();
        while(it.hasNext())
        {
            ans.add(it.next());
        }
        return ans;
    }
     
    public GBDT()
    {
        this.max_iter=50;
        this.sampleRate=0.8;
        this.K=2;//2分类问题
        this.maxDepth=6;
        this.splitPoints=3;
        this.learningRate=0.01;
        getData();
    }
     
    public void train()
    {
        for(int i=0;i<max_iter;i++)
        {
            ArrayList<Integer> subSet=new ArrayList<>();
            int numSubset=(int)(this.n*this.sampleRate);
            subSet=myrandom(this.n,numSubset);
            computeResidual(subSet);
            ArrayList<Double> target=new ArrayList<>();
            ArrayList<Tree> tmpTree=new ArrayList<>();
            int maxdepths[]={this.maxDepth};
            for(int j=0;j<this.K;j++)
            {
                target.clear();
                for(int k=0;k<subSet.size();k++)
                {
                    target.add(residual.get(subSet.get(k)).get(j));
                }
                ArrayList<ArrayList<Integer>> leafNodes=new ArrayList<ArrayList<Integer>>();
                ArrayList<Double> leafValues=new ArrayList<>();
                Tree treeSub=new Tree();
                Tree iterTree=treeSub.constructTree(leafNodes,leafValues,K,splitPoints, isDigit, subSet, trainData, target,maxdepths,0);
                tmpTree.add(iterTree);
                updateFvalue(isDigit, subSet,leafNodes,leafValues,j,iterTree);
            }
             
            trees.add(tmpTree);
        }
    }
     
    public void updateFvalue(Boolean isDigit[], ArrayList<Integer> subIdx,ArrayList<ArrayList<Integer>> leafNodes,ArrayList<Double> leafValues,int label,Tree root)
    {
        ArrayList<Integer> remainIdx=new ArrayList<>();
        int arr[]=new int[this.n];
        for(int i=0;i<this.n;i++)
            arr[i]=i;
        for(int i=0;i<subIdx.size();i++)
        {
            arr[subIdx.get(i)]=-1;
        }
        //求出不是用来训练树的余下集合
        for(int i=0;i<this.n;i++)
        {
            if(arr[i]!=-1)
                remainIdx.add(i);
        }
        for(int i=0;i<leafNodes.size();i++)
        {
            for(int j=0;j<leafNodes.get(i).size();j++)
            {
                this.F.get(leafNodes.get(i).get(j)).set(label, this.F.get(leafNodes.get(i).get(j)).get(label)+this.learningRate*root.getPredictValue(root));
            }
        }
        for(int i=0;i<remainIdx.size();i++)
        {
            double leafV=root.getPredictValue(root,this.trainData.get(remainIdx.get(i)),isDigit);
            this.F.get(remainIdx.get(i)).set(label, this.F.get(remainIdx.get(i)).get(label)+this.learningRate*leafV);
        }
         
         
    }
     
    public boolean checkDigit(String str) {
        for(int i=0;i<str.length();i++)
        {
            if(!(str.charAt(i)>='0'&&str.charAt(i)<='9'))
            {
                return false;
            }
        }
        return true;
    }
     
    public void getData() {
        Data d =new Data();
        this.datas=d.getTrainData();
        this.dim=this.datas.get(0).size()-1;
        this.isDigit=new Boolean[this.dim];
        //遍历所有样本,去掉中间含有不是正常的数据
        for(int i=0;i<this.datas.get(0).size()-1;i++)
            labelSets.add(this.datas.get(0).get(i));
        //保证数据的第一行是正确的,来判断,特征哪些纬度是数字,哪些纬度是字符串
        for(int i=0;i<this.dim;i++)
        {
            if(checkDigit(this.datas.get(0).get(i)))
                this.isDigit[i]=true;
            else this.isDigit[i]=false;
        }
        //如果字符串==?说明是异常数据,这里做数据的清理
        for(int i=1;i<this.datas.size();i++)
        {
            ArrayList<String> tmp=new ArrayList<>();
            boolean flag=true;
            for(int j=0;j<this.dim;j++)
            {
                if(datas.get(i).get(j).trim().equals("?"))
                {
                    flag=false;
                    break;
                }
            }
            if(!flag) continue;
            if(datas.get(i).get(this.dim).trim().equals("?")) continue;
            trainData.add(tmp);
            if(datas.get(i).get(this.dim).trim().equals("<=50K"))
                labelTrainData.add(-1);
            else
                labelTrainData.add(1);
             
        }
        this.n=this.labelTrainData.size();
         
        for(int i=0;i<this.datas.get(0).size()-1;i++)
            labelSets.add(this.datas.get(0).get(i));
         
        //初始化F矩阵为全0,F矩阵是n*2,是2分类问题,如果要多分类,改下这里就可以了
        for(int i=0;i<this.n;i++)
        {
            ArrayList<Double> arrTmp=new ArrayList<Double>();
            for(int j=0;j<2;j++)
            {
                arrTmp.add(0.0);
            }
            this.F.add(arrTmp);
            this.residual.add(arrTmp);
        }
         
                             
    }
     
    public static void main(String[] args) {
        GBDT dGbdt=new GBDT();
        dGbdt.getData();
        System.err.println(dGbdt.n);
         
    }
}

  

 

posted @   simple_wxl  阅读(2307)  评论(0编辑  收藏  举报
点击右上角即可分享
微信分享提示