torch.jit.ScriptModule导出的ptl文件在Android加载.getDataAsFloatArray出错

环境:

 Implementation 'org.pytorch:pytorch_android_lite:1.10.0'
 Implementation 'org.pytorch:pytorch_android_torchvision:1.9.0'

代码:
【模型加载处】

Publicvoid loadModule(AppCompatActivityapp) {
  try {
    module = LiteModuleLoader.load(assetFilePath(app, model_fname+".ptl"));
  } catch (IOExceptione) {
    Log.e("Pytorch", "Errorreadingassets", e);
    app.finish();
  }
}

【模型使用处】

//running the model
final Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();

//getting tensor content as java array of floats
finalfloat [] scores = outputTensor.getDataAsFloatArray();

错误描述:
模型加载时未出现任何错误,但使用处.getDataAsFloatArray提示输出为IntegerArray类型,且结果恒为[0]
原因及修订方式:
ptl文件名不允许有下划线("_"),加载的模型名称为"test_model.ptl",修改为"testNC.ptl"即恢复正常

posted @ 2021-12-11 16:11  若水茗心  阅读(430)  评论(0编辑  收藏  举报