Jax Jit模式下的Array输出
问题背景
在Python编程中,最简单的Debug方法就是print了,直接print出来,有什么问题一目了然。虽然也有log相关的库和类似于Spyder等IDE有专业的Debug模式,但是print还是最方便的,对于小规模应用场景。
Jax的Jit即时编译模式,使得函数在被调用的时候才会执行编译,很大程度上可以方便快捷的优化程序调试性能。但也有个问题就是,如果在Jit模式下,我们用print
打算去输出一个Array时,打印的结果会变成:
Traced<ShapedArray(float64[32,3])>with<DynamicJaxprTrace(level=0/2)>
jax.debug.print
如参考链接1中的内容所述,Jax自带了一个print方法,可以在Jit模式下也能够正常的打印Array数组内容。使用方式为:
from jax.debug import print as jprint
jprint("{}", jax_array)
需要注意的是,这里jprint
必须传入两个参数,一个str和一个Array数组,跟内置的print
函数在使用方式上还是有一定的差异的,不能直接替换。
总结概要
本文仅介绍一个可以在Jax的Jit即时编译模式下,也能够正常通过print打印函数来输出Jax Array内容的方法。
版权声明
本文首发链接为:https://www.cnblogs.com/dechinphy/p/jprint.html
作者ID:DechinPhy
更多原著文章:https://www.cnblogs.com/dechinphy/
请博主喝咖啡:https://www.cnblogs.com/dechinphy/gallery/image/379634.html
参考链接
本文作者:Dechin的博客
本文链接:https://www.cnblogs.com/dechinphy/p/18570507/jprint
版权声明:本作品采用CC BY-NC-SA 4.0许可协议进行许可。
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】凌霞软件回馈社区,博客园 & 1Panel & Halo 联合会员上线
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步