deeplearning4j训练MNIST数据集以及验证
训练模型官方示例
MNIST数据下载地址: http://github.com/myleott/mnist_png/raw/master/mnist_png.tar.gz
GitHub示例地址: https://github.com/deeplearning4j/deeplearning4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/quickstart/modeling/convolution/LeNetMNISTReLu.java
/******************************************************************************* * * * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * See the NOTICE file distributed with this work for additional * information regarding copyright ownership. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ package org.deeplearning4j.examples.quickstart.modeling.convolution; import org.datavec.api.io.labels.ParentPathLabelGenerator; import org.datavec.api.split.FileSplit; import org.datavec.image.loader.NativeImageLoader; import org.datavec.image.recordreader.ImageRecordReader; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.SubsamplingLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.deeplearning4j.examples.utils.DataUtilities; import org.deeplearning4j.util.ModelSerializer; import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization; import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler; import org.nd4j.linalg.learning.config.Nesterovs; import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.schedule.MapSchedule; import org.nd4j.linalg.schedule.ScheduleType; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.File; import java.util.HashMap; import java.util.Map; import java.util.Random; /** * Implementation of LeNet-5 for handwritten digits image classification on MNIST dataset (99% accuracy) * <a href="http://yann.lecun.com/exdb/publis/pdf/lecun-01a.pdf">[LeCun et al., 1998. Gradient based learning applied to document recognition]</a> * Some minor changes are made to the architecture like using ReLU and identity activation instead of * sigmoid/tanh, max pooling instead of avg pooling and softmax output layer. * <p> * This example will download 15 Mb of data on the first run. * * @author hanlon * @author agibsonccc * @author fvaleri * @author dariuszzbyrad */ public class LeNetMNISTReLu { private static final Logger LOGGER = LoggerFactory.getLogger(LeNetMNISTReLu.class); // private static final String BASE_PATH = System.getProperty("java.io.tmpdir") + "/mnist"; private static final String BASE_PATH = "D:\\Documents\\Downloads\\mnist_png"; private static final String DATA_URL = "http://github.com/myleott/mnist_png/raw/master/mnist_png.tar.gz"; public static void main(String[] args) throws Exception { // 图片高度 int height = 28; // height of the picture in px // 图片宽度 int width = 28; // width of the picture in px // 通道 1 表示 黑白 int channels = 1; // single channel for grayscale images // 可能出现的结果数量 0-9 10个数字 int outputNum = 10; // 10 digits classification // 批处理数量 int batchSize = 54; // number of samples that will be propagated through the network in each iteration // 迭代次数 int nEpochs = 1; // number of training epochs // 随机数生成器 int seed = 1234; // number used to initialize a pseudorandom number generator. Random randNumGen = new Random(seed); LOGGER.info("Data load..."); if (!new File(BASE_PATH + "/mnist_png").exists()) { LOGGER.debug("Data downloaded from {}", DATA_URL); String localFilePath = BASE_PATH + "/mnist_png.tar.gz"; if (DataUtilities.downloadFile(DATA_URL, localFilePath)) { DataUtilities.extractTarGz(localFilePath, BASE_PATH); } } LOGGER.info("Data vectorization..."); // vectorization of train data File trainData = new File(BASE_PATH + "/mnist_png/training"); FileSplit trainSplit = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, randNumGen); ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator(); // use parent directory name as the image label ImageRecordReader trainRR = new ImageRecordReader(height, width, channels, labelMaker); trainRR.initialize(trainSplit); // MNIST中的数据 DataSetIterator trainIter = new RecordReaderDataSetIterator(trainRR, batchSize, 1, outputNum); // pixel values from 0-255 to 0-1 (min-max scaling) DataNormalization imageScaler = new ImagePreProcessingScaler(); imageScaler.fit(trainIter); trainIter.setPreProcessor(imageScaler); // vectorization of test data File testData = new File(BASE_PATH + "/mnist_png/testing"); FileSplit testSplit = new FileSplit(testData, NativeImageLoader.ALLOWED_FORMATS, randNumGen); ImageRecordReader testRR = new ImageRecordReader(height, width, channels, labelMaker); testRR.initialize(testSplit); DataSetIterator testIter = new RecordReaderDataSetIterator(testRR, batchSize, 1, outputNum); testIter.setPreProcessor(imageScaler); // same normalization for better results LOGGER.info("Network configuration and training..."); // reduce the learning rate as the number of training epochs increases // iteration #, learning rate Map<Integer, Double> learningRateSchedule = new HashMap<>(); learningRateSchedule.put(0, 0.06); learningRateSchedule.put(200, 0.05); learningRateSchedule.put(600, 0.028); learningRateSchedule.put(800, 0.0060); learningRateSchedule.put(1000, 0.001); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(seed) .l2(0.0005) // ridge regression value .updater(new Nesterovs(new MapSchedule(ScheduleType.ITERATION, learningRateSchedule))) .weightInit(WeightInit.XAVIER) .list() .layer(new ConvolutionLayer.Builder(5, 5) .nIn(channels) .stride(1, 1) .nOut(20) .activation(Activation.IDENTITY) .build()) .layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) .kernelSize(2, 2) .stride(2, 2) .build()) .layer(new ConvolutionLayer.Builder(5, 5) .stride(1, 1) // nIn need not specified in later layers .nOut(50) .activation(Activation.IDENTITY) .build()) .layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) .kernelSize(2, 2) .stride(2, 2) .build()) .layer(new DenseLayer.Builder().activation(Activation.RELU) .nOut(500) .build()) .layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .nOut(outputNum) .activation(Activation.SOFTMAX) .build()) .setInputType(InputType.convolutionalFlat(height, width, channels)) // InputType.convolutional for normal image .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); net.setListeners(new ScoreIterationListener(10)); LOGGER.info("Total num of params: {}", net.numParams()); // evaluation while training (the score should go down) for (int i = 0; i < nEpochs; i++) { net.fit(trainIter); LOGGER.info("Completed epoch {}", i); Evaluation eval = net.evaluate(testIter); LOGGER.info(eval.stats()); trainIter.reset(); testIter.reset(); } File ministModelPath = new File(BASE_PATH + "/minist-model.zip"); ModelSerializer.writeModel(net, ministModelPath, true); LOGGER.info("The MINIST model has been saved in {}", ministModelPath.getPath()); } }
验证模型
package org.deeplearning4j.examples.quickstart.modeling.convolution; import org.datavec.image.loader.NativeImageLoader; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization; import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler; import java.io.File; import java.io.IOException; /** * @description: * @author: Mr.Fang * @create: 2023-07-14 15:06 **/ public class VerifyMNSIT { public static void main(String[] args) throws IOException { // 加载训练好的模型 File modelFile = new File("D:\\Documents\\Downloads\\mnist_png\\minist-model.zip"); MultiLayerNetwork model = MultiLayerNetwork.load(modelFile, true); // 加载待验证的图像 File imageFile = new File("D:\\Documents\\Downloads\\mnist_png\\mnist_png\\testing\\8\\1717.png"); NativeImageLoader loader = new NativeImageLoader(28, 28, 1); INDArray image = loader.asMatrix(imageFile); DataNormalization scaler = new ImagePreProcessingScaler(0, 1); scaler.transform(image); // 对图像进行预测 INDArray output = model.output(image); int predictedLabel = output.argMax().getInt(); // 在这行代码中,`output.argMax()`用于找到`output`中具有最大值的索引。`output`是一个包含模型的输出概率的NDArray对象。对于MNIST模型,输出是一个长度为10的向量,表示数字0到9的概率分布。 // //`.argMax()`方法返回具有最大值的索引。例如,如果`output`的值为[0.1, 0.3, 0.2, 0.05, 0.25, 0.05, 0.05, 0.1, 0.05, 0.05],则`.argMax()`将返回索引1,因为在位置1处的值0.3是最大的。 // //最后,`.getInt()`方法将获取`.argMax()`的结果并将其转换为一个整数,表示预测的标签。在这个例子中,`predictedLabel`将包含模型预测的数字标签。 // //简而言之,这行代码的作用是找到输出中概率最高的数字标签,以进行预测。 System.out.println("Predicted label: " + predictedLabel); } }
输出结果
o.n.l.f.Nd4jBackend - Loaded [CpuBackend] backend o.n.n.NativeOpsHolder - Number of threads used for linear algebra: 6 o.n.l.c.n.CpuNDArrayFactory - Binary level Generic x86 optimization level AVX/AVX2 o.n.n.Nd4jBlas - Number of threads used for OpenMP BLAS: 6 o.n.l.a.o.e.DefaultOpExecutioner - Backend used: [CPU]; OS: [Windows 10] o.n.l.a.o.e.DefaultOpExecutioner - Cores: [12]; Memory: [4.0GB]; o.n.l.a.o.e.DefaultOpExecutioner - Blas vendor: [OPENBLAS] o.n.l.c.n.CpuBackend - Backend build information: GCC: "12.1.0" STD version: 201103L DEFAULT_ENGINE: samediff::ENGINE_CPU HAVE_FLATBUFFERS HAVE_OPENBLAS o.d.n.m.MultiLayerNetwork - Starting MultiLayerNetwork with WorkspaceModes set to [training: ENABLED; inference: ENABLED], cacheMode set to [NONE] Predicted label: 8
哇!又赚了一天人民币
分类:
java
, deeplearning4j
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】凌霞软件回馈社区,博客园 & 1Panel & Halo 联合会员上线
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】博客园社区专享云产品让利特惠,阿里云新客6.5折上折
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 本地部署 DeepSeek:小白也能轻松搞定!
· 如何给本地部署的DeepSeek投喂数据,让他更懂你
· 在缓慢中沉淀,在挑战中重生!2024个人总结!
· 大人,时代变了! 赶快把自有业务的本地AI“模型”训练起来!
· 从 Windows Forms 到微服务的经验教训