用java写bp神经网络(四)
接上篇。
在(一)和(二)中,程序的体系是Net,Propagation,Trainer,Learner,DataProvider。这篇重构这个体系。
Net
首先是Net,在上篇重新定义了激活函数和误差函数后,内容大致是这样的:
List<DoubleMatrix> weights = new ArrayList<DoubleMatrix>(); List<DoubleMatrix> bs = new ArrayList<>(); List<ActivationFunction> activations = new ArrayList<>(); CostFunction costFunc; CostFunction accuracyFunc; int[] nodesNum; int layersNum; public CompactDoubleMatrix getCompact(){ return new CompactDoubleMatrix(this.weights,this.bs); }
函数getCompact()生成对应的超矩阵。
DataProvider
DataProvider是数据的提供者。
public interface DataProvider { DoubleMatrix getInput(); DoubleMatrix getTarget(); }
如果输入为向量,还包含一个向量字典。
public interface DictDataProvider extends DataProvider { public DoubleMatrix getIndexs(); public DoubleMatrix getDict(); }
每一列为一个样本。getIndexs()返回输入向量在字典中的索引。
我写了一个有用的类BatchDataProviderFactory来对样本进行批量分割,分割成minibatch。
int batchSize; int dataLen; DataProvider originalProvider; List<Integer> endPositions; List<DataProvider> providers; public BatchDataProviderFactory(int batchSize, DataProvider originalProvider) { super(); this.batchSize = batchSize; this.originalProvider = originalProvider; this.dataLen = this.originalProvider.getTarget().columns; this.initEndPositions(); this.initProviders(); } public BatchDataProviderFactory(DataProvider originalProvider) { this(4, originalProvider); } public List<DataProvider> getProviders() { return providers; }
batchSize指明要分多少批,getProviders返回生成的minibatch,被分的原始数据为originalProvider。
Propagation
Propagation负责对神经网络的正向传播过程和反向传播过程。接口定义如下:
public interface Propagation { public PropagationResult propagate(Net net,DataProvider provider); }
传播函数propagate用指定数据对指定网络进行传播操作,返回执行结果。
BasePropagation实现了该接口,实现了简单的反向传播:
public class BasePropagation implements Propagation{ // 多个样本。 protected ForwardResult forward(Net net,DoubleMatrix input) { ForwardResult result = new ForwardResult(); result.input = input; DoubleMatrix currentResult = input; int index = -1; for (DoubleMatrix weight : net.weights) { index++; DoubleMatrix b = net.bs.get(index); final ActivationFunction activation = net.activations .get(index); currentResult = weight.mmul(currentResult).addColumnVector(b); result.netResult.add(currentResult); // 乘以导数 DoubleMatrix derivative = activation.derivativeAt(currentResult); result.derivativeResult.add(derivative); currentResult = activation.valueAt(currentResult); result.finalResult.add(currentResult); } result.netResult=null;// 不再需要。 return result; } // 多个样本梯度平均值。 protected BackwardResult backward(Net net,DoubleMatrix target, ForwardResult forwardResult) { BackwardResult result = new BackwardResult(); DoubleMatrix output = forwardResult.getOutput(); DoubleMatrix outputDerivative = forwardResult.getOutputDerivative(); result.cost = net.costFunc.valueAt(output, target); DoubleMatrix outputDelta = net.costFunc.derivativeAt(output, target).muli(outputDerivative); if (net.accuracyFunc != null) { result.accuracy=net.accuracyFunc.valueAt(output, target); } result.deltas.add(outputDelta); for (int i = net.layersNum - 1; i >= 0; i--) { DoubleMatrix pdelta = result.deltas.get(result.deltas.size() - 1); // 梯度计算,取所有样本平均 DoubleMatrix layerInput = i == 0 ? forwardResult.input : forwardResult.finalResult.get(i - 1); DoubleMatrix gradient = pdelta.mmul(layerInput.transpose()).div( target.columns); result.gradients.add(gradient); // 偏置梯度 result.biasGradients.add(pdelta.rowMeans()); // 计算前一层delta,若i=0,delta为输入层误差,即input调整梯度,不作平均处理。 DoubleMatrix delta = net.weights.get(i).transpose().mmul(pdelta); if (i > 0) delta = delta.muli(forwardResult.derivativeResult.get(i - 1)); result.deltas.add(delta); } Collections.reverse(result.gradients); Collections.reverse(result.biasGradients); //其它的delta都不需要。 DoubleMatrix inputDeltas=result.deltas.get(result.deltas.size()-1); result.deltas.clear(); result.deltas.add(inputDeltas); return result; } @Override public PropagationResult propagate(Net net, DataProvider provider) { ForwardResult forwardResult=this.forward(net, provider.getInput()); BackwardResult backwardResult=this.backward(net, provider.getTarget(), forwardResult); PropagationResult result=new PropagationResult(backwardResult); result.output=forwardResult.getOutput(); return result; }
我们定义的PropagationResult略为:
public class PropagationResult{ DoubleMatrix output;// 输出结果矩阵:outputLen*sampleLength DoubleMatrix cost;// 误差矩阵:1*sampleLength DoubleMatrix accuracy;// 准确度矩阵:1*sampleLength private List<DoubleMatrix> gradients;// 权重梯度矩阵 private List<DoubleMatrix> biasGradients;// 偏置梯度矩阵 DoubleMatrix inputDeltas;//输入层delta矩阵:inputLen*sampleLength public CompactDoubleMatrix getCompact(){ return new CompactDoubleMatrix(gradients,biasGradients); } }
另一个实现了该接口的类为MiniBatchPropagation。他在内部用并行方式对样本进行传播,然后对每个minipatch结果进行综合,内部用到了BatchDataProviderFactory类和BasePropagation类。
Trainer
Trainer接口定义为:
public interface Trainer { public void train(Net net,DataProvider provider); }
简单的实现类为:
public class CommonTrainer implements Trainer { int ecophs; Learner learner; Propagation propagation; List<Double> costs = new ArrayList<>(); List<Double> accuracys = new ArrayList<>(); public void trainOne(Net net, DataProvider provider) { PropagationResult propResult = this.propagation .propagate(net, provider); learner.learn(net, propResult, provider); Double cost = propResult.getMeanCost(); Double accuracy = propResult.getMeanAccuracy(); if (cost != null) costs.add(cost); if (accuracy != null) accuracys.add(accuracy); } @Override public void train(Net net, DataProvider provider) { for (int i = 0; i < this.ecophs; i++) { System.out.println("echops:"+i); this.trainOne(net, provider); } } }
简单的迭代echops此,没有智能停止功能,每次迭代用Learner调节权重。
Learner
Learner根据每次传播结果对网络权重进行调整,接口定义如下:
public interface Learner<N extends Net,P extends DataProvider> { public void learn(N net,PropagationResult propResult,P provider); }
一个简单的根据动量因子-自适应学习率进行调整的实现类为:
public class MomentAdaptLearner<N extends Net, P extends DataProvider> implements Learner<N, P> { double moment = 0.7; double lmd = 1.05; double preCost = 0; double eta = 0.01; double currentEta = eta; double currentMoment = moment; CompactDoubleMatrix preGradient; public MomentAdaptLearner(double moment, double eta) { super(); this.moment = moment; this.eta = eta; this.currentEta = eta; this.currentMoment = moment; } public MomentAdaptLearner() { } @Override public void learn(N net, PropagationResult propResult, P provider) { if (this.preGradient == null) init(net, propResult, provider); double cost = propResult.getMeanCost(); this.modifyParameter(cost); System.out.println("current eta:" + this.currentEta); System.out.println("current moment:" + this.currentMoment); this.updateGradient(net, propResult, provider); } public void updateGradient(N net, PropagationResult propResult, P provider) { CompactDoubleMatrix netCompact = this.getNetCompact(net, propResult, provider); CompactDoubleMatrix gradCompact = this.getGradientCompact(net, propResult, provider); gradCompact = gradCompact.mul(currentEta * (1 - currentMoment)).addi( preGradient.mul(currentMoment)); netCompact.subi(gradCompact); this.preGradient = gradCompact; } public CompactDoubleMatrix getNetCompact(N net, PropagationResult propResult, P provider) { return net.getCompact(); } public CompactDoubleMatrix getGradientCompact(N net, PropagationResult propResult, P provider) { return propResult.getCompact(); } public void modifyParameter(double cost) { if (this.currentEta > 10) { this.currentEta = 10; } else if (this.currentEta < 0.0001) { this.currentEta = 0.0001; } else if (cost < this.preCost) { this.currentEta *= 1.05; this.currentMoment = moment; } else if (cost < 1.04 * this.preCost) { this.currentEta *= 0.7; this.currentMoment *= 0.7; } else { this.currentEta = eta; this.currentMoment = 0.1; } this.preCost = cost; } public void init(Net net, PropagationResult propResult, P provider) { PropagationResult pResult = new PropagationResult(net); preGradient = pResult.getCompact().dup(); } }
在上面的代码中,我们可以看到CompactDoubleMatrix类对权重自变量的封装,使代码更加简洁,它在此表现出来的就是一个超矩阵,超向量,完全忽略了内部的结构。
同时,其子类实现了同步更新字典的功能,代码也很简洁,只是简单的把需要调整的矩阵append到超矩阵中去即可,在父类中会统一对其进行调整:
public class DictMomentLearner extends MomentAdaptLearner<Net, DictDataProvider> { public DictMomentLearner(double moment, double eta) { super(moment, eta); } public DictMomentLearner() { super(); } @Override public CompactDoubleMatrix getNetCompact(Net net, PropagationResult propResult, DictDataProvider provider) { CompactDoubleMatrix result = super.getNetCompact(net, propResult, provider); result.append(provider.getDict()); return result; } @Override public CompactDoubleMatrix getGradientCompact(Net net, PropagationResult propResult, DictDataProvider provider) { CompactDoubleMatrix result = super.getGradientCompact(net, propResult, provider); result.append(DictUtil.getDictGradient(provider, propResult)); return result; } @Override public void init(Net net, PropagationResult propResult, DictDataProvider provider) { DoubleMatrix preDictGradient = DoubleMatrix.zeros( provider.getDict().rows, provider.getDict().columns); super.init(net, propResult, provider); this.preGradient.append(preDictGradient); } }