使用Cross-validation (CV) 调整Extreme learning Machine (ELM) 最优参数的实现(matlab)
ELM算法模型是最近几年得到广泛重视的模型,它不同于现在广为火热的DNN。 ELM使用传统的三层神经网络,只包含一个隐含层,但又不同于传统的神经网络。ELM是一种简单易用、有效的单隐层前馈神经网络SLFNs学习算法。2006年由南洋理工大学黄广斌副教授提出。传统的神经网络学习算法(如BP算法)需要人为设置大量的网络训练参数,并且很容易产生局部最优解。极限学习机只需要设置网络的隐层节点个数,在算法执行过程中不需要调整网络的输入权值以及隐元的偏置,并且产生唯一的最优解,因此具有学习速度快且泛化性能好的优点。但是隐含层节点个数的设置需要经过人工大量实验得到或者通过最常见的CV方法可以得到。 下面,matlab实现10-fold CV 寻找最优隐含层节点个数的。作为以前工作的一个小记录。ELM使用的是主页http://www.ntu.edu.sg/home/egbhuang/ 源码。源程序包含两个脚本文件cv_para.m 和 Data2txt.m
Description:
@cv_para.m 这个是主程序
结构体:function [best_para]=cv_para(data,para_set)。其中使用到两个参数,data代表我们的完整数据,也就是没有划分训练集和测试集的完整数据。para_set代表隐含层节点个数的一个数组,例如在[1:60]之间选择一个最优的隐含层节点个数。
@Data2txt.m 这个脚本文件是为了满足ELM算法训练将数据转化为源码ELM可以使用的文本文件。数据格式在ELM主页已经给出example。
1 function [best_para]=cv_para(data,para_set) 2 3 num_folds=10; % 10-fold cross validation 4 5 n=size(data,1); 6 7 n_paras=length(para_set); 8 9 idx=randperm(n); % idx 代表n个数据中索引的任意排列 10 11 n_test=floor(n/num_folds); % n_test: 测试集包含的数据集的个数 12 13 test_idx=zeros(num_folds,n_test); % test_idx: 储存num_folds次测试集的索引 14 15 train_idx=zeros(num_folds,n-n_test); %train_idx: 原理同test_idx 16 17 18 19 % 下面程序操作的主要是索引,只要将训练集地址和测试集地址划分出来 20 21 for i=1:num_folds 22 23 test_idx(i,:)=idx((i-1)*n_test+1:i*n_test); 24 25 tmp=1:n; 26 27 tmp(test_idx(i,:))=[]; 28 29 train_idx(i,:)=tmp; 30 31 end 32 33 best_accs=inf; 34 35 best_para=1; % 保存最优的隐含节点个数 36 37 for i=1:n_paras 38 39 one_accs=0; 40 41 for j=1:num_folds 42 43 % 这里就是将数据集转化为文本文件形式,以满足elm源码的需求 44 45 train_data=data(train_idx(j,:),:); 46 47 test_data=data(test_idx(j,:),:); 48 49 Data2txt(train_data,'trainfile'); 50 51 Data2txt(test_data,'testfile'); 52 53 54 55 [TrainingTime, TrainingAccuracy] = elm_train('trainfile', 0, para_set(1,i), 'sig'); 56 57 [TestingTime, acc] = elm_predict('testfile'); 58 59 one_accs=one_accs+acc; 60 61 62 63 delete('trainfile'); 64 65 delete('testfile'); 66 67 end 68 69 if(best_accs>one_accs) 70 71 best_para=para_set(1,i); 72 73 best_accs=one_accs; 74 75 end 76 77 end 78 79 end 80 81 82 83 @Data2txt 源码 84 85 function[]=Data2txt(Data,file) 86 87 fid=fopen(file,'w');%дÈëÎļþ·¾¶ 88 89 [m,n]=size(Data); 90 91 for i=1:1:m 92 93 for j=1:1:n 94 95 if j==n 96 97 fprintf(fid,'%g\n',Data(i,j)); 98 99 else 100 101 fprintf(fid,'%g\t',Data(i,j)); 102 103 end 104 105 end 106 107 end 108 109 fclose(fid); 110 111 end