Adaline网络识别印刷体数字0到9-java实现

 

    本篇只给出实现的代码,下一篇将讲一讲实现的原理,及其Adline网络中的LMS算法原理。

 

 

包含两个类:

 

   

package com.cgjr.com;

import java.security.DigestInputStream;
import java.util.Arrays;

import org.neuroph.core.data.DataSet;
import org.neuroph.core.data.DataSetRow;
import org.neuroph.core.events.LearningEvent;
import org.neuroph.core.events.LearningEventListener;
import org.neuroph.util.TransferFunctionType;

public class AdalineDemo implements LearningEventListener {
    public final static int CHAR_WIDTH = 5;
    public final static int CHAR_HEIGHT = 7;
    public static String[][] DIGITS = { 
            { 
              " 000 ", 
              "0   0", 
              "0   0", 
              "0   0", 
              "0   0",
              "0   0", 
              " 000 "
             },
            { 
                  "  0  ", 
                  " 00  ", 
                  "0 0  ", 
                  "  0  ", 
                  "  0  ",
                  "  0  ", 
                  "  0  " 
                 

            }, {
                
                  " 000 ", 
                  "0   0", 
                  "    0", 
                  "   0 ", 
                  "  0  ",
                  " 0   ", 
                  "00000"         
            }, {
                
                  " 000 ", 
                  "0   0", 
                  "    0", 
                  " 000 ", 
                  "    0",
                  "0   0", 
                  " 000 "     

            }, {
                
                  "   0 ", 
                  "  00 ", 
                  " 0 0 ", 
                  "0  0 ", 
                  "00000",
                  "   0 ", 
                  "   0 "     
                
                

            }, {
                  "00000", 
                  "0    ", 
                  "0    ", 
                  "0000 ", 
                  "    0",
                  "0   0", 
                  " 000 "     

            }, {
                  " 000 ", 
                  "0   0", 
                  "0    ", 
                  "0000 ", 
                  "0   0",
                  "0   0", 
                  " 000 " 

            }, {
                  "00000", 
                  "    0", 
                  "    0", 
                  "   0 ", 
                  "  0  ",
                  " 0   ", 
                  "0    " 
                
                

            }, {
                
                  " 000 ", 
                  "0   0", 
                  "0   0", 
                  " 000 ", 
                  "0   0",
                  "0   0", 
                  " 000 " 

            }, {
                
                  " 000 ", 
                  "0   0", 
                  "0   0", 
                  " 0000", 
                  "    0",
                  "0   0", 
                  " 000 " 
                
                

            }
    };

    public static void main(String[] args) {
        Adaline ada = new Adaline(CHAR_WIDTH * CHAR_HEIGHT,DIGITS.length,0.01d,TransferFunctionType.LINEAR);
        DataSet ds = new DataSet(CHAR_WIDTH * CHAR_HEIGHT, DIGITS.length);
        for (int i = 0; i < DIGITS.length; i++) {
            //一个数字符号就是一个训练的数据,第0个数字的的期望输出为0,第一个数字的期望输出为1等等。
            ds.addRow(createTrainDataRow(DIGITS[i],i));            
        }        
        //ada.getLearningRule().addListener(new AdalineDemo());        
        ada.learn(ds);        
        for (int i = 0; i < DIGITS.length; i++) {
            ada.setInput(image2data(DIGITS[i]));
            ada.calculate();
            printDIGITS(DIGITS[i]);        
            System.out.println(maxIndex(ada.getOutput()));
            System.out.println(Arrays.toString(ada.getOutput()));
            System.out.println();            
        }
    }

    private static int maxIndex(double[] output) {
        //这其实就是选出最接近一的那个
        double maxData=output[0];
        int maxIndex=0;
        for (int i = 0; i < output.length; i++) {
            if(maxData<output[i]){
                maxData=output[i];
                maxIndex=i;            
            }
        }        
        return maxIndex;
    }

    private static void printDIGITS(String[] image) {

        for (int i = 0; i < image.length; i++) {
            System.out.println(image[i]);            
        }
        System.out.println("\n");
    }

    private static DataSetRow createTrainDataRow(String[] image, int idealValue) {
        //设置所有的为输出为负一,只有当那个等于
        double[] output=new double[DIGITS.length];
        for (int i = 0; i < output.length; i++) {
            output[i]=-1;
        }        
        double[] input=image2data(image);        
        output[idealValue]=1;
        DataSetRow dsr=new DataSetRow(input,output);                
        return dsr;
    }
    
    //将图像转换为数字,空格的地方为-1,不空格的地方为1

    private static double[] image2data(String[] image) {
        double[] input=new double[CHAR_WIDTH*CHAR_HEIGHT];
        //行的长度,即为字符的长度,为整个字体的高度
        for (int row = 0; row < CHAR_HEIGHT; row++) {
            //有多少个列
            for (int col = 0; col < CHAR_WIDTH; col++) {
                int index=(row*CHAR_WIDTH)+col;
                char ch=image[row].charAt(col);
                input[index]=ch=='0'?1:-1;                
            }
        }
        
        return input;
    }

    @Override
    public void handleLearningEvent(LearningEvent event) {
        // TODO Auto-generated method stub
        
    }

}

 

网络类:

 

package com.cgjr.com;



import org.neuroph.core.Layer;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.nnet.comp.neuron.BiasNeuron;
import org.neuroph.nnet.learning.LMS;
import org.neuroph.util.ConnectionFactory;
import org.neuroph.util.LayerFactory;
import org.neuroph.util.NeuralNetworkFactory;
import org.neuroph.util.NeuralNetworkType;
import org.neuroph.util.NeuronProperties;
import org.neuroph.util.TransferFunctionType;

public class Adaline extends NeuralNetwork {

    /**
     * The class fingerprint that is set to indicate serialization compatibility
     * with a previous version of the class.
     */
    private static final long serialVersionUID = 1L;

    /**
     * Creates new Adaline network with specified number of neurons in input
     * layer
     * 
     * @param inputNeuronsCount
     *            number of neurons in input layer
     */
    public Adaline(int inputNeuronsCount, int outputNeuronsCount, double learnRate, TransferFunctionType transferFunction) {
        this.createNetwork(inputNeuronsCount, outputNeuronsCount, learnRate,transferFunction);
    }

    /**
     * Creates adaline network architecture with specified number of input
     * neurons
     * 
     * @param inputNeuronsCount
     *            number of neurons in input layer
     */
    private void createNetwork(int inputNeuronsCount, int outputNeuronsCount, double learnRate,
            TransferFunctionType transferFunction) {
        // set network type code
        this.setNetworkType(NeuralNetworkType.ADALINE);

        // create input layer neuron settings for this network
        NeuronProperties inNeuronProperties = new NeuronProperties();
        inNeuronProperties.setProperty("transferFunction", TransferFunctionType.LINEAR);

        // createLayer input layer with specified number of neurons
        Layer inputLayer = LayerFactory.createLayer(inputNeuronsCount, inNeuronProperties);
        inputLayer.addNeuron(new BiasNeuron()); // add bias neuron (always 1,
                                                // and it will act as bias input
                                                // for output neuron)
        this.addLayer(inputLayer);

        // create output layer neuron settings for this network
        NeuronProperties outNeuronProperties = new NeuronProperties();
        if (transferFunction == TransferFunctionType.LINEAR) {
            outNeuronProperties.setProperty("transferFunction", TransferFunctionType.LINEAR);
        } else {
            outNeuronProperties.setProperty("transferFunction", TransferFunctionType.RAMP);
            outNeuronProperties.setProperty("transferFunction.slope", new Double(1));
            outNeuronProperties.setProperty("transferFunction.yHigh", new Double(1));
            outNeuronProperties.setProperty("transferFunction.xHigh", new Double(1));
            outNeuronProperties.setProperty("transferFunction.yLow", new Double(-1));
            outNeuronProperties.setProperty("transferFunction.xLow", new Double(-1));
        }
        // createLayer output layer (only one neuron)
        Layer outputLayer = LayerFactory.createLayer(outputNeuronsCount, outNeuronProperties);
        this.addLayer(outputLayer);

        // createLayer full conectivity between input and output layer
        ConnectionFactory.fullConnect(inputLayer, outputLayer);

        // set input and output cells for network
        NeuralNetworkFactory.setDefaultIO(this);

        // set LMS learning rule for this network
        LMS l = new LMS();
        l.setLearningRate(learnRate);
        this.setLearningRule(l);
    }

}

 

 

运行的结果截图:

 

 

 

posted @ 2016-05-12 11:17  北宫风晨  阅读(1820)  评论(0编辑  收藏  举报