Hopfield神经网络实现污染字体的识别
这个网络的内部使用的是hebb学习规则
贴上两段代码:
package geym.nn.hopfiled; import java.util.Arrays; import org.neuroph.core.data.DataSet; import org.neuroph.core.data.DataSetRow; import org.neuroph.nnet.Hopfield; import org.neuroph.nnet.comp.neuron.InputOutputNeuron; import org.neuroph.nnet.learning.HopfieldLearning; import org.neuroph.util.NeuronProperties; import org.neuroph.util.TransferFunctionType; /** * 识别0 1 2 使用hopfield 全连接结构 * @author Administrator * */ public class HopfieldSample2 { public static double[] format(double[] data){ for(int i=0;i<data.length;i++){ if(data[i]==0)data[i]=-1; } return data; } public static void main(String args[]) { NeuronProperties neuronProperties = new NeuronProperties(); neuronProperties.setProperty("neuronType", InputOutputNeuron.class); neuronProperties.setProperty("bias", new Double(0.0D)); neuronProperties.setProperty("transferFunction", TransferFunctionType.STEP); neuronProperties.setProperty("transferFunction.yHigh", new Double(1.0D)); neuronProperties.setProperty("transferFunction.yLow", new Double(-1.0D)); // create training set (H and T letter in 3x3 grid) DataSet trainingSet = new DataSet(30); trainingSet.addRow(new DataSetRow(format(new double[] { 0,1,1,1,1,0, 1,0,0,0,0,1, 1,0,0,0,0,1, 1,0,0,0,0,1, 0,1,1,1,1,0}))); //0 trainingSet.addRow(new DataSetRow(format(new double[] { 0,0,0,0,0,0, 1,0,0,0,0,0, 1,1,1,1,1,1, 0,0,0,0,0,0, 0,0,0,0,0,0}))); //1 trainingSet.addRow(new DataSetRow(format(new double[] { 1,0,0,0,0,0, 1,0,0,1,1,1, 1,0,0,1,0,1, 1,0,0,1,0,1, 0,1,1,0,0,1}))); //2 // create hopfield network Hopfield myHopfield = new Hopfield(30, neuronProperties); myHopfield.setLearningRule(new StandHopfieldLearning()); // learn the training set myHopfield.learn(trainingSet); // test hopfield network System.out.println("Testing network"); // add one more 'incomplete' H pattern for testing - it will be // recognized as H // DataSetRow h=new DataSetRow(new double[] { 1, 0, 0, 1, 0, 1, 1, 0, 1 // }); // DataSetRow h=new DataSetRow(new double[] { 1, 0, 0, 1, 0, 1, 1, 0, 1 // }); DataSetRow h = new DataSetRow(format(new double[] { 1,0,0,0,0,0, 1,0,0,1,1,1, 1,0,0,1,0,1, 1,0,0,1,0,0, 0,1,1,0,0,1})); // 2 bad trainingSet.addRow(h); myHopfield.setInput(h.getInput()); double[] networkOutput = null; double[] preNetworkOutput = null; while (true) { myHopfield.calculate(); networkOutput = myHopfield.getOutput(); if (preNetworkOutput == null) { preNetworkOutput = networkOutput; continue; } if (Arrays.equals(networkOutput, preNetworkOutput)) { break; } preNetworkOutput = networkOutput; } System.out.print("Input: " + Arrays.toString(h.getInput())); System.out.println(" Output: " + Arrays.toString(networkOutput)); System.out.println(Arrays.equals(format(new double[] { 1,0,0,0,0,0, 1,0,0,1,1,1, 1,0,0,1,0,1, 1,0,0,1,0,1, 0,1,1,0,0,1}), networkOutput)); } }
下面就是StandHopfieldLearning类的实现,里面标红的地方就是hebb学习规则,权重为输入和输出的乘积:
package com.cgjr.com.hopfield; import org.neuroph.core.Connection; import org.neuroph.core.Layer; import org.neuroph.core.Neuron; import org.neuroph.core.data.DataSet; import org.neuroph.core.data.DataSetRow; import org.neuroph.core.learning.LearningRule; /** * Learning algorithm for the Hopfield neural network. * * @author Zoran Sevarac <sevarac@gmail.com> */ public class StandHopfieldLearning extends LearningRule { /** * The class fingerprint that is set to indicate serialization * compatibility with a previous version of the class. */ private static final long serialVersionUID = 1L; /** * Creates new HopfieldLearning */ public StandHopfieldLearning() { super(); } /** * Calculates weights for the hopfield net to learn the specified training * set * * @param trainingSet * training set to learn */ public void learn(DataSet trainingSet) { int M = trainingSet.size(); int N = neuralNetwork.getLayerAt(0).getNeuronsCount(); Layer hopfieldLayer = neuralNetwork.getLayerAt(0); for (int i = 0; i < N; i++) { for (int j = 0; j < N; j++) { if (j == i) continue; Neuron ni = hopfieldLayer.getNeuronAt(i); Neuron nj = hopfieldLayer.getNeuronAt(j); Connection cij = nj.getConnectionFrom(ni); Connection cji = ni.getConnectionFrom(nj); double wij=0; for(int k = 0;k < M;k++){ DataSetRow row=trainingSet.getRowAt(k); double[] inputs=row.getInput(); wij+=inputs[i]*inputs[j];//Hebb学习规则 } cij.getWeight().setValue(wij); cji.getWeight().setValue(wij); }// j } // i } }
大道至简,逻辑起点,记忆关联,直观抽象。。。