djl加载模型
愚蠢的我以为加载pytorch模型还需要将模型结构在Java中写出来,因此各种写模型,还问gpt怎么把并行的结构用djl写出来,然后又问如何让djl加载模型,结果,我是个傻子。
人家早就把东西都弄好了,几行代码就搞定了。
Criteria<NDList, NDList>criteria = Criteria.builder() .setTypes(NDList.class, NDList.class) .optModelPath(modelDir) .optEngine("PyTorch") .build(); ZooModel<NDList, NDList> model = ModelZoo.loadModel(criteria); Predictor<NDList, NDList> predictor = model.newPredictor();
这里面,NDList是多个NDArray的封装,类似于封装了多个Tensor,optEngine是设置你的模型是从pytorch还是从TensorFlow来的。
optModelPath是指出pytorch生成的scriptTensor--》pt文件或者压缩包--》zip文件。
实际上,pt文件里面已经有模型的结构了,压根不需要自己写,直接调用OK了。
ModelZoo也是直接加载criteria就把模型整好了。
最后结果,我是个傻子。