djl训练简单模板
这是一个djl训练的简单模板
import java.io.IOException;
import ai.djl.Model;
import ai.djl.basicdataset.cv.classification.FashionMnist;
import ai.djl.metric.Metrics;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.Blocks;
import ai.djl.nn.Parameter;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.core.Linear;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.dataset.Batch;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.initializer.NormalInitializer;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.optimizer.Optimizer;
import ai.djl.training.tracker.Tracker;
import ai.djl.translate.TranslateException;
public class MulClassMain {
public static void main(String[] args) throws IOException, TranslateException {
// TODO 1. 定义模型结构
SequentialBlock net = new SequentialBlock();
net.add(Blocks.batchFlattenBlock(784));
net.add(Linear.builder().setUnits(256).build());
net.add(Activation::relu);
net.add(Linear.builder().setUnits(10).build());
net.setInitializer(new NormalInitializer(), Parameter.Type.WEIGHT);
// 训练过程
int batchSize = 256;
int numEpochs = Integer.getInteger("MAX_EPOCH", 20);
FashionMnist trainIter = FashionMnist.builder()
.optUsage(Dataset.Usage.TRAIN)
.setSampling(batchSize, true)
.optLimit(Long.getLong("DATASET_LIMIT", Long.MAX_VALUE))
.build();
FashionMnist testIter = FashionMnist.builder()
.optUsage(Dataset.Usage.TEST)
.setSampling(batchSize, true)
.optLimit(Long.getLong("DATASET_LIMIT", Long.MAX_VALUE))
.build();
trainIter.prepare();
testIter.prepare();
// 定义优化算法
Tracker lrt = Tracker.fixed(0.5f);
Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();
// 定义损失函数
Loss loss = Loss.softmaxCrossEntropyLoss();
DefaultTrainingConfig config = new DefaultTrainingConfig(loss)
.optOptimizer(sgd)
.addEvaluator(new Accuracy())
.addTrainingListeners(TrainingListener.Defaults.logging());
try(NDManager nm = NDManager.newBaseManager()){
try(Model model = Model.newInstance("mlp")){
model.setBlock(net);
try(Trainer trainer = model.newTrainer(config)){
trainer.initialize(new Shape(1, 784));
trainer.setMetrics(new Metrics());
for(int epoch=0;epoch<numEpochs;++epoch){
System.out.printf("Epoch %d \n", epoch);
for(Batch batch: trainIter.getData(nm)){
EasyTrain.trainBatch(trainer, batch);
// 更新参数
trainer.step();
batch.close();
}
trainer.notifyListeners(l->l.onEpoch(trainer));
}
}
}
}
}
public NDArray relu(NDArray X){
return X.maximum(0.0f);
}
}