Jax Jit模式下的Array输出
问题背景
在Python编程中,最简单的Debug方法就是print了,直接print出来,有什么问题一目了然。虽然也有log相关的库和类似于Spyder等IDE有专业的Debug模式,但是print还是最方便的,对于小规模应用场景。
Jax的Jit即时编译模式,使得函数在被调用的时候才会执行编译,很大程度上可以方便快捷的优化程序调试性能。但也有个问题就是,如果在Jit模式下,我们用print
打算去输出一个Array时,打印的结果会变成:
Traced<ShapedArray(float64[32,3])>with<DynamicJaxprTrace(level=0/2)>
jax.debug.print
如参考链接1中的内容所述,Jax自带了一个print方法,可以在Jit模式下也能够正常的打印Array数组内容。使用方式为:
from jax.debug import print as jprint
jprint("{}", jax_array)
需要注意的是,这里jprint
必须传入两个参数,一个str和一个Array数组,跟内置的print
函数在使用方式上还是有一定的差异的,不能直接替换。
总结概要
本文仅介绍一个可以在Jax的Jit即时编译模式下,也能够正常通过print打印函数来输出Jax Array内容的方法。
版权声明
本文首发链接为:https://www.cnblogs.com/dechinphy/p/jprint.html
作者ID:DechinPhy
更多原著文章:https://www.cnblogs.com/dechinphy/
请博主喝咖啡:https://www.cnblogs.com/dechinphy/gallery/image/379634.html