djl加载模型

愚蠢的我以为加载pytorch模型还需要将模型结构在Java中写出来,因此各种写模型,还问gpt怎么把并行的结构用djl写出来,然后又问如何让djl加载模型,结果,我是个傻子。

人家早就把东西都弄好了,几行代码就搞定了。

 Criteria<NDList, NDList>criteria = Criteria.builder()
        .setTypes(NDList.class, NDList.class)
        .optModelPath(modelDir)
        .optEngine("PyTorch")
        .build();
    ZooModel<NDList, NDList> model = ModelZoo.loadModel(criteria);
    Predictor<NDList, NDList> predictor = model.newPredictor();

这里面,NDList是多个NDArray的封装,类似于封装了多个Tensor,optEngine是设置你的模型是从pytorch还是从TensorFlow来的。

optModelPath是指出pytorch生成的scriptTensor--》pt文件或者压缩包--》zip文件。

实际上,pt文件里面已经有模型的结构了,压根不需要自己写,直接调用OK了。

ModelZoo也是直接加载criteria就把模型整好了。

最后结果,我是个傻子。



posted @ 2024-04-01 02:04  KIKIcoo  阅读(194)  评论(0)    收藏  举报