随机森林分类器的实现
下面是我实现的简易版本
决策树ID3Tree.h
//#pragma once #ifndef ID3 #define ID3 #include<vector> #include<iostream> using namespace std; #define Epsilon 0.000000001 class Record{ public: std::vector < int >attri; int label; static int attributeNums;//有多少种属性 static int labelNums;//数据最终需要分成几类 static std::vector<int>each_attr_class;//每个属性可以取几种值 }; extern int value_int; extern std::vector<float> value_vector; class ID3Tree { struct ID3Node{ int attri;//the attribute that current node selected to split std::vector<ID3Node*>child; std::vector<int>remain_attri;//剩余待分裂属性 int label; ID3Node(){ label = -1; } }; private: ID3Node*root; int attributeNums;//有多少种属性 int labelNums;//数据最终需要分成几类 int sampleNums;//训练样本数量 std::vector<int>label_count;//dataset中属于每种label的个数 double threshold;//阈值 std::vector<int>each_attr_class;//每个属性可以取几种值 //std::vector<Record>*dataset;//样本集 double cal_entropy(std::vector<Record>&Dataset); std::vector<std::vector<Record>> splitDataset(std::vector<Record>&Dataset, const int k); int majority(std::vector<Record>&Dataset);//判断数据集中的majority label的比例是否达到停止分裂的标准 public: ID3Tree(); ~ID3Tree(); void build_tree(ID3Node*node, std::vector<Record>&Dataset); int classify(Record &rd); bool create_root(); bool create_root(std::vector<int>&aa); ID3Node*get_root(){ return root; } //void load_dataset(std::vector<Record>&Dataset){ this->dataset = &Dataset; }; void set_paras(int attributeNums, int labelNums, int sampleNums){ this->attributeNums = attributeNums; this->labelNums = labelNums; this->sampleNums = sampleNums; each_attr_class = Record::each_attr_class; } }; int ID3Tree::majority(std::vector<Record>&Dataset) { std::vector<int>::iterator it; label_count.clear(); label_count.resize(labelNums); label_count.assign(label_count.size(), 0); for (int i = 0; i < Dataset.size(); i++) label_count[Dataset[i].label]++; it = std::max_element(label_count.begin(), label_count.end()); //if (double(*it) / double(Dataset.size()) > threshold) return it - label_count.begin(); //return -1; } std::vector<std::vector<Record>>ID3Tree::splitDataset(std::vector<Record>&Dataset, const int k) { std::vector<std::vector<Record>>aa; aa.resize(each_attr_class[k]); for (int i = 0; i < Dataset.size(); i++) aa[Dataset[i].attri[k]].push_back(Dataset[i]); return aa; } void ID3Tree::build_tree(ID3Node*node, std::vector<Record>&Dataset) { int label = majority(Dataset); if (double(label_count[label]) / double(Dataset.size()) > threshold) { node->label = label; return; } if (node->remain_attri.size() == 1) { node->label = label; return; } if (Dataset.size() == 1) { node->label = Dataset[0].label; return; } double base_entropy = cal_entropy(Dataset); double maxgain = -10000; int selectAttri; std::vector<std::vector<Record>>bb; for (int i = 0; i < node->remain_attri.size(); i++) { double gain = base_entropy; std::vector<std::vector<Record>>aa = splitDataset(Dataset, node->remain_attri[i]); for (int j = 0; j < aa.size(); j++) { double entro = cal_entropy(aa[j]); double dd = double(aa[j].size()) / double(Dataset.size() + Epsilon)*entro; //std::cout << dd << endl; gain -= dd; } if (gain > maxgain) { maxgain = gain; selectAttri = i; bb = aa; } } _ASSERTE(selectAttri >= 0); std::vector<int>aa = node->remain_attri; node->attri = node->remain_attri[selectAttri]; aa.erase(aa.begin() + selectAttri); for (int i = 0; i < bb.size(); i++) { ID3Node*nn = new ID3Node; //nn->attri = node->remain_attri[selectAttri]; nn->remain_attri = aa; node->child.push_back(nn); build_tree(nn, bb[i]); } } double ID3Tree::cal_entropy(std::vector<Record>&Dataset) { int len = Dataset.size(); double entropy = 0; std::vector<int>count; count.resize(labelNums); for (int i = 0; i < Dataset.size(); i++) count[Dataset[i].label]++; for (int i = 0; i < labelNums; i++) entropy += -double(count[i] + Epsilon) / double(len + Epsilon)*log(double(count[i] + Epsilon) / double(len + Epsilon)) / log(double(2)); _ASSERTE(entropy >= 0.0); return entropy; } ID3Tree::ID3Tree() { root = NULL; attributeNums = -1; threshold = 0.99; } bool ID3Tree::create_root() { if (attributeNums < 0) return false; root = new ID3Node; std::vector<int>aa; for (int i = 0; i < attributeNums; i++) aa.push_back(i); root->remain_attri = aa; return true; } bool ID3Tree::create_root(std::vector<int>&aa) { attributeNums = aa.size(); root = new ID3Node; root->remain_attri = aa; return true; } ID3Tree::~ID3Tree() { if (root == NULL) return; std::vector<ID3Node*>aa, bb; aa.push_back(root); while (!aa.empty()) { ID3Node*nn = aa.back(); aa.pop_back(); bb.push_back(nn); while (!nn->child.empty()) { aa.push_back(nn->child.back()); nn->child.pop_back();//如果注释掉会不会出错 } } for (int i = 0; i < bb.size(); i++) delete bb[i]; } int ID3Tree::classify(Record &rd) { ID3Node*node = root; while (node->child.size() > 0) { node = node->child[rd.attri[node->attri]]; } rd.label = node->label; return node->label; } #endif
#ifndef RANDOMFOREST #define RANDOMFOREST #include<time.h> #include<cstdlib> class ID3Tree; class Record; class RandomForest { private: std::vector < ID3Tree* > forest; int treeNums; void boostrap(); std::vector<Record>wholeDataSet; std::vector<Record>subDataSet; int sizeofwholeDataSet; double ratioofsubDataset; int attributeNums;//有多少种属性 int sub_attriNums;//建立一个tree所需要选择的属性数目 int labelNums;//数据最终需要分成几类 std::vector<int>ranom_select_feature(); std::vector<int>vote(Record rd); public: void load_dataset(); void create_forest(); int classify(std::vector<int>query); void set_paras(); RandomForest(); ~RandomForest(); }; #endif
randomforest.cpp
#include "stdafx.h" #include"ID3Tree.h" #include "RandomForest.h" #include<string> using namespace std; RandomForest::RandomForest() { time_t t; srand((unsigned)time(&t)); } RandomForest::~RandomForest() { for (int i = 0; i < forest.size(); i++) delete forest[i]; } void RandomForest::set_paras() { ratioofsubDataset = 0.5; sizeofwholeDataSet = wholeDataSet.size(); attributeNums = Record::attributeNums; sub_attriNums = attributeNums - 1; treeNums = 100; labelNums = Record::labelNums; } int split(const std::string& str, std::vector<std::string>& ret_, std::string sep = ",") { if (str.empty()) { return 0; } std::string tmp; std::string::size_type pos_begin = str.find_first_not_of(sep); std::string::size_type comma_pos = 0; while (pos_begin != std::string::npos) { comma_pos = str.find(sep, pos_begin); if (comma_pos != std::string::npos) { tmp = str.substr(pos_begin, comma_pos - pos_begin); pos_begin = comma_pos + sep.length(); } else { tmp = str.substr(pos_begin); pos_begin = comma_pos; } if (!tmp.empty()) { ret_.push_back(tmp); tmp.clear(); } } return 0; } int Record::attributeNums = 4; int Record::labelNums = 2; int aa[4] = { 3, 3, 2, 2 }; vector<int>nums(aa, aa + 4); vector<int>Record::each_attr_class = nums; void RandomForest::load_dataset() { /*Rid Age Income Student CreditRating BuysComputer 1 Youth High No Fair No 2 Youth High No Excellent No 3 MiddleAged High No Fair Yes 4 Senior Medium No Fair Yes 5 Senior Low Yes Fair Yes 6 Senior Low Yes Excellent No 7 MiddleAged Low Yes Excellent Yes 8 Youth Medium No Fair No 9 Youth Low Yes Fair Yes 10 Senior Medium Yes Fair Yes 11 Youth Medium Yes Excellent Yes 12 MiddleAged Medium No Excellent Yes 13 MiddleAged High Yes Fair Yes 14 Senior Medium No Excellent No*/ FILE*fp = fopen("input.txt", "r"); _ASSERTE(fp != NULL); char ch; std::string str; ch = getc(fp); while (ch != EOF) { if (ch != EOF&&ch - '0' > 0 && ch - '0' <= 9) { str.clear(); while (ch - '0' >= 0 && ch - '0' <= 9 && ch != EOF) { ch = getc(fp); } if (ch == EOF) break; while (ch != EOF&&ch - '0' < 0 || ch - '0'>9) { putchar(ch); str += ch; ch = getc(fp); } std::vector<std::string>re; split(str, re, std::string(" ")); _ASSERTE(re.size() == 5); Record rd; if (re[0] == "Youth") rd.attri.push_back(0); else if (re[0] == "Senior") rd.attri.push_back(1); else if (re[0] == "MiddleAged") rd.attri.push_back(2); else _ASSERTE(1 < 0); if (re[1] == "Low") rd.attri.push_back(0); else if (re[1] == "Medium") rd.attri.push_back(1); else if (re[1] == "High") rd.attri.push_back(2); else _ASSERTE(1 < 0); if (re[2] == "No") rd.attri.push_back(0); else if (re[2] == "Yes") rd.attri.push_back(1); else _ASSERTE(1 < 0); if (re[3] == "Fair") rd.attri.push_back(0); else if (re[3] == "Excellent") rd.attri.push_back(1); else _ASSERTE(1 < 0); if (re[4] == "No\n") rd.label = 0; else if (re[4] == "Yes\n") rd.label = 1; else if (re[4] == "No") rd.label = 0; else if (re[4] == "Yes") rd.label = 1; else _ASSERTE(1 < 0); wholeDataSet.push_back(rd); } else ch = getc(fp); } fclose(fp);//关闭文件 fp = NULL;//需要指向空,否则会指向原打开文件地址 } void RandomForest::boostrap() { subDataSet.clear(); for (int i = 0; i < ratioofsubDataset* sizeofwholeDataSet; i++) subDataSet.push_back(wholeDataSet[sizeofwholeDataSet*rand() / (RAND_MAX + 1.0)]); } std::vector<int>RandomForest::ranom_select_feature() { std::vector<int>aa, bb; aa.resize(sub_attriNums); bb.resize(attributeNums); for (int i = 0; i < attributeNums; i++) { bb[i] = i; } int kk = attributeNums; for (int i = 0; i < sub_attriNums; i++) { int jj = kk*rand() / (RAND_MAX + 1.0); aa[i] = bb[jj]; bb.erase(bb.begin() + jj); kk--; } return aa; } void RandomForest::create_forest() { for (int i = 0; i < treeNums; i++) { boostrap(); ID3Tree* tree = new ID3Tree; tree->set_paras(attributeNums, labelNums, ratioofsubDataset* sizeofwholeDataSet); tree->create_root(ranom_select_feature()); tree->build_tree(tree->get_root(), subDataSet); forest.push_back(tree); } _ASSERTE(forest.size() == treeNums); } int RandomForest::classify(vector<int>query) { _ASSERTE(query.size() == Record::attributeNums); Record rd; rd.attri = query; std::vector<int>aa = vote(rd); std::vector<int>::iterator it; it = std::max_element(aa.begin(), aa.end()); return it - aa.begin(); } std::vector<int>RandomForest::vote(Record rd) { std::vector<int>aa; aa.resize(labelNums); aa.assign(aa.size(), 0); for (int i = 0; i < treeNums; i++) aa[forest[i]->classify(rd)]++; return aa; }
main.cpp
#include "stdafx.h" #include"RandomForest.h" using namespace std; int _tmain(int argc, _TCHAR* argv[]) { /*std::vector<int>aa; aa.resize(5); cout << aa.size() << endl; cout << aa[2];*/ RandomForest rf; rf.load_dataset(); rf.set_paras(); rf.create_forest(); /*Age=Youth,Income=Low,Student=No,CreditRating=Fair,预测的分类结果为BuysCompute:No */ int aa[4] = { 0, 0, 0, 0 }; vector<int>query(aa, aa + 4); int re = rf.classify(query); system("pause"); return 0; }
input.txt
Rid Age Income Student CreditRating BuysComputer 1 Youth High No Fair No 2 Youth High No Excellent No 3 MiddleAged High No Fair Yes 4 Senior Medium No Fair Yes 5 Senior Low Yes Fair Yes 6 Senior Low Yes Excellent No 7 MiddleAged Low Yes Excellent Yes 8 Youth Medium No Fair No 9 Youth Low Yes Fair Yes 10 Senior Medium Yes Fair Yes 11 Youth Medium Yes Excellent Yes 12 MiddleAged Medium No Excellent Yes 13 MiddleAged High Yes Fair Yes 14 Senior Medium No Excellent No
版权声明: