yangyang12138

导航

deeplearning4j

1.概述

  一个基于java实现的深度学习框架,用于深度学习神经网络的搭建和模型训练。

2.demo

  

public class Demo {

  public static void main(String[] args) throws Exception {
    int height = 16;
    int width = 16;
    int channels = 1; 
    int outputNum = 10; 
    int batchSize = 54;
    int nEpochs = 1;
    int iterations = 1;

    int seed = 1234;
    Random randNumGen = new Random(seed);

    File trainData = new File("./tf");
    FileSplit trainSplit = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);
    ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator(); 
    ImageRecordReader trainRR = new ImageRecordReader(height, width, channels, labelMaker);
    trainRR.initialize(trainSplit);
    DataSetIterator trainIter = new RecordReaderDataSetIterator(trainRR, batchSize, 1, outputNum);

    DataNormalization scaler = new ImagePreProcessingScaler(0, 1);
    scaler.fit(trainIter);
    trainIter.setPreProcessor(scaler);
    

    File testData = new File("./t");
    FileSplit testSplit = new FileSplit(testData, NativeImageLoader.ALLOWED_FORMATS, randNumGen);
    ImageRecordReader testRR = new ImageRecordReader(height, width, channels, labelMaker);
    testRR.initialize(testSplit);
    DataSetIterator testIter = new RecordReaderDataSetIterator(testRR, batchSize, 1, outputNum);
    testIter.setPreProcessor(scaler); 

    Map<Integer, Double> lrs = new HashMap<>();
    lrs.put(0, 0.06); 
    lrs.put(200, 0.05);
    lrs.put(600, 0.028);
    lrs.put(800, 0.0060);
    lrs.put(1000, 0.001);

    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
        .seed(seed)
        .iterations(iterations)
        .regularization(true).l2(0.0005)
        .learningRate(.01)
        .learningRateDecayPolicy(LearningRatePolicy.Schedule)
        .learningRateSchedule(lrSchedule) 
        .weightInit(WeightInit.XAVIER)
        .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
        .updater(Updater.NESTEROVS)
        .list()
        .layer(0, new ConvolutionLayer.Builder(5, 5)
            .nIn(channels)
            .stride(1, 1)
            .nOut(20)
            .activation(Activation.IDENTITY)
            .build())
        .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
            .kernelSize(2, 2)
            .stride(2, 2)
            .build())
        .layer(2, new ConvolutionLayer.Builder(5, 5)
            .stride(1, 1) 
            .nOut(50)
            .activation(Activation.IDENTITY)
            .build())
        .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
            .kernelSize(2, 2)
            .stride(2, 2)
            .build())
        .layer(4, new DenseLayer.Builder().activation(Activation.RELU)
            .nOut(500).build())
        .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
            .nOut(outputNum)
            .activation(Activation.SOFTMAX)
            .build())
        .setInputType(InputType.convolutionalFlat(28, 28, 1)) 
        .backprop(true).pretrain(false).build();

    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();
    net.setListeners(new ScoreIterationListener(10));
    log.debug("Total num of params: {}", net.numParams());

    for (int i = 0; i < nEpochs; i++) {
      net.fit(trainIter);
      Evaluation eval = net.evaluate(testIter);
      log.info(eval.stats());
      trainIter.reset();
      testIter.reset();
    }
    ModelSerializer.writeModel(net, new File(basePath + "/minist-model.zip"), true);
  }
}

 

posted on 2020-09-11 01:59  杨杨09265  阅读(357)  评论(0编辑  收藏  举报