使用Cross-validation (CV) 调整Extreme learning Machine (ELM) 最优参数的实现(matlab)

ELM算法模型是最近几年得到广泛重视的模型,它不同于现在广为火热的DNN。 ELM使用传统的三层神经网络,只包含一个隐含层,但又不同于传统的神经网络。ELM是一种简单易用、有效的单隐层前馈神经网络SLFNs学习算法。2006年由南洋理工大学黄广斌副教授提出。传统的神经网络学习算法(如BP算法)需要人为设置大量的网络训练参数,并且很容易产生局部最优解。极限学习机只需要设置网络的隐层节点个数,在算法执行过程中不需要调整网络的输入权值以及隐元的偏置,并且产生唯一的最优解,因此具有学习速度快且泛化性能好的优点。但是隐含层节点个数的设置需要经过人工大量实验得到或者通过最常见的CV方法可以得到。 下面,matlab实现10-fold CV 寻找最优隐含层节点个数的。作为以前工作的一个小记录。ELM使用的是主页http://www.ntu.edu.sg/home/egbhuang/ 源码。源程序包含两个脚本文件cv_para.m  和 Data2txt.m

 

Description:

@cv_para.m 这个是主程序

结构体:function [best_para]=cv_para(data,para_set)。其中使用到两个参数,data代表我们的完整数据,也就是没有划分训练集和测试集的完整数据。para_set代表隐含层节点个数的一个数组,例如在[1:60]之间选择一个最优的隐含层节点个数。

 

@Data2txt.m 这个脚本文件是为了满足ELM算法训练将数据转化为源码ELM可以使用的文本文件。数据格式在ELM主页已经给出example。

 

  1 function [best_para]=cv_para(data,para_set)
  2 
  3 num_folds=10;     % 10-fold cross validation
  4 
  5 n=size(data,1);
  6 
  7 n_paras=length(para_set);
  8 
  9 idx=randperm(n);    % idx 代表n个数据中索引的任意排列
 10 
 11 n_test=floor(n/num_folds); % n_test: 测试集包含的数据集的个数
 12 
 13 test_idx=zeros(num_folds,n_test); % test_idx: 储存num_folds次测试集的索引
 14 
 15 train_idx=zeros(num_folds,n-n_test); %train_idx: 原理同test_idx
 16 
 17  
 18 
 19 % 下面程序操作的主要是索引,只要将训练集地址和测试集地址划分出来
 20 
 21 for i=1:num_folds
 22 
 23     test_idx(i,:)=idx((i-1)*n_test+1:i*n_test);  
 24 
 25     tmp=1:n;
 26 
 27     tmp(test_idx(i,:))=[];
 28 
 29     train_idx(i,:)=tmp;
 30 
 31 end
 32 
 33 best_accs=inf;
 34 
 35 best_para=1; % 保存最优的隐含节点个数
 36 
 37 for i=1:n_paras
 38 
 39         one_accs=0;
 40 
 41         for j=1:num_folds  
 42 
 43             % 这里就是将数据集转化为文本文件形式,以满足elm源码的需求
 44 
 45              train_data=data(train_idx(j,:),:);
 46 
 47              test_data=data(test_idx(j,:),:);
 48 
 49              Data2txt(train_data,'trainfile');
 50 
 51              Data2txt(test_data,'testfile');
 52 
 53             
 54 
 55              [TrainingTime, TrainingAccuracy] = elm_train('trainfile', 0, para_set(1,i), 'sig');
 56 
 57              [TestingTime, acc] = elm_predict('testfile');
 58 
 59             one_accs=one_accs+acc;
 60 
 61            
 62 
 63             delete('trainfile');
 64 
 65             delete('testfile');
 66 
 67         end
 68 
 69         if(best_accs>one_accs)
 70 
 71             best_para=para_set(1,i);
 72 
 73             best_accs=one_accs;
 74 
 75         end
 76 
 77 end
 78 
 79 end
 80 
 81  
 82 
 83 @Data2txt 源码
 84 
 85 function[]=Data2txt(Data,file)
 86 
 87     fid=fopen(file,'w');%дÈëÎļþ·¾¶
 88 
 89     [m,n]=size(Data);
 90 
 91      for i=1:1:m
 92 
 93          for j=1:1:n
 94 
 95             if j==n
 96 
 97                 fprintf(fid,'%g\n',Data(i,j));
 98 
 99             else
100 
101                 fprintf(fid,'%g\t',Data(i,j));
102 
103          end
104 
105         end
106 
107      end
108 
109     fclose(fid);
110 
111 end

 

 

 

posted @ 2014-03-20 10:52  愚人_同乐  阅读(2225)  评论(0编辑  收藏  举报