Loading

djl训练简单模板

这是一个djl训练的简单模板



import java.io.IOException;

import ai.djl.Model;
import ai.djl.basicdataset.cv.classification.FashionMnist;
import ai.djl.metric.Metrics;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.Blocks;
import ai.djl.nn.Parameter;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.core.Linear;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.dataset.Batch;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.initializer.NormalInitializer;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.optimizer.Optimizer;
import ai.djl.training.tracker.Tracker;
import ai.djl.translate.TranslateException;

public class MulClassMain {
    public static void main(String[] args) throws IOException, TranslateException {
        // TODO 1. 定义模型结构
        SequentialBlock net = new SequentialBlock();
        net.add(Blocks.batchFlattenBlock(784));
        net.add(Linear.builder().setUnits(256).build());
        net.add(Activation::relu);
        net.add(Linear.builder().setUnits(10).build());
        net.setInitializer(new NormalInitializer(), Parameter.Type.WEIGHT);

        // 训练过程
        int batchSize = 256;
        int numEpochs = Integer.getInteger("MAX_EPOCH", 20);

        FashionMnist trainIter = FashionMnist.builder()
        .optUsage(Dataset.Usage.TRAIN)
        .setSampling(batchSize, true)
        .optLimit(Long.getLong("DATASET_LIMIT", Long.MAX_VALUE))
        .build();
        FashionMnist testIter = FashionMnist.builder()
        .optUsage(Dataset.Usage.TEST)
        .setSampling(batchSize, true)
        .optLimit(Long.getLong("DATASET_LIMIT", Long.MAX_VALUE))
        .build();

        trainIter.prepare();
        testIter.prepare();
        
        
        // 定义优化算法
        Tracker lrt = Tracker.fixed(0.5f);
        Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();

        // 定义损失函数
        Loss loss = Loss.softmaxCrossEntropyLoss();

        DefaultTrainingConfig config = new DefaultTrainingConfig(loss)
        .optOptimizer(sgd)
        .addEvaluator(new Accuracy())
        .addTrainingListeners(TrainingListener.Defaults.logging());

        try(NDManager nm = NDManager.newBaseManager()){
            try(Model model = Model.newInstance("mlp")){
                model.setBlock(net);

                try(Trainer trainer = model.newTrainer(config)){
                    trainer.initialize(new Shape(1, 784));
                    trainer.setMetrics(new Metrics());

                    for(int epoch=0;epoch<numEpochs;++epoch){
                        System.out.printf("Epoch %d \n", epoch);
                        for(Batch batch: trainIter.getData(nm)){
                            EasyTrain.trainBatch(trainer, batch);

                            // 更新参数
                            trainer.step();
                            batch.close();
                        }
                        trainer.notifyListeners(l->l.onEpoch(trainer));
                    }


                }
            }
        }
    }

    public NDArray relu(NDArray X){
        return X.maximum(0.0f);
    }
}


posted @ 2024-06-27 10:54  青山新雨  阅读(5)  评论(0编辑  收藏  举报