简洁的BP及RBF神经网络代码
BP神经网络
function [W,err]=BPTrain(data,label,hiddenlayers,nodes,type) %Train the bp artial nueral net work %input data,label,layers,nodes,type %data:dim*n %label:1*n %layers:m:number of hidden layers %nodes:num_1;num_2...num_m %type==1:create and train %type==0:train %tanh / 双曲正切: tanh(x) = sinh(x) / cosh(x)=[e^x - e^(-x)] / [e^x + e^(-x)] %(tanh(x))'=sech^2(x) %sech / 双曲正割: sech(x) = 1 / cosh(x) = 2 / [e^x + e^(-x)] if type==1 %create the nureal network and train nodes=[size(data,1);nodes]; nodes=[nodes+1;size(label,1)]; %W{1}=random(,nodes(1)); layers=hiddenlayers+2; for i=1:layers-2 W{i}=rand(nodes(i),nodes(i+1)-1); end W{layers-1}=rand(nodes(layers-1),nodes(layers)); else %do nothing end %train the bp network %the termination condition %iteration.error iter=0; error=inf; maxiter=2000; lr=0.1; epision=0.1; tic while iter<maxiter&&error>epision iter=iter+1; error=0; for k=1:size(data,2) %forward process y{1}=[data(:,k)]; v{1}=y{1}; for i=1:layers-1 y{i}=[1;y{i}]; v{i+1}=W{i}'*y{i}; y{i+1}=tanh(v{i+1}); end %back process error=error+abs(label(k)-y{layers}); delta=(label(k)-y{layers}).*((sech(v{layers}).^2)); W{layers-1}=W{layers-1}+lr.*(y{layers-1}*delta); for i=layers-1:-1:2 delta=sech(v{i}).^2.*(W{i}(1:size(W{i},1)-1,:)*delta); W{i-1}=W{i-1}+lr.*(y{i-1}*delta'); end end err(iter)=error; error end toc測试代码
function res=BPTest(W,data) for k=1:size(data,2) y=data(:,k); for i=1:length(W)-1 y=[1;y]; y=tanh((W{i}'*y)); end res(k)=tanh(W{i+1}'*[1;y]); end
global rbf_sigma; global rbf_center; global rbf_weight; if strcmp(traintype,'data') traindist=pdist2(traindata,traindata); rbf_sigma=max(max(traindist))/(scale.^2);%/(2*sqrt(sqrt(length(traindata)))); rbf_center=traindata; Phi=exp(-traindist./rbf_sigma); rbf_weight=inv(Phi)*trainlabel; else if strcmp(traintype,'cluster') [Idx,C,sumD,D]=kmeans(traindata,K,'emptyaction','singleton'); traindist=pdist2(traindata,C); Cdist=pdist2(C,C); rbf_sigma=max(max(Cdist))/(scale.^2);%/(2*sqrt(sqrt(length(traindata)))); rbf_center=C; Phi=exp(-traindist./rbf_sigma); rbf_weight=inv(Phi'*Phi)*Phi'*trainlabel; else if strcmp(traintype,'descend') end end end測试 代码
function predcict=RBFTest(data) global rbf_sigma; global rbf_center; global rbf_weight; testdist=pdist2(data,rbf_center); predcict=exp(-testdist./(2*rbf_sigma))*rbf_weight;