Jax Jit模式下的Array输出

问题背景

在Python编程中,最简单的Debug方法就是print了,直接print出来,有什么问题一目了然。虽然也有log相关的库和类似于Spyder等IDE有专业的Debug模式,但是print还是最方便的,对于小规模应用场景。

Jax的Jit即时编译模式,使得函数在被调用的时候才会执行编译,很大程度上可以方便快捷的优化程序调试性能。但也有个问题就是,如果在Jit模式下,我们用print打算去输出一个Array时,打印的结果会变成:

Traced<ShapedArray(float64[32,3])>with<DynamicJaxprTrace(level=0/2)>

jax.debug.print

如参考链接1中的内容所述,Jax自带了一个print方法,可以在Jit模式下也能够正常的打印Array数组内容。使用方式为:

from jax.debug import print as jprint
jprint("{}", jax_array)

需要注意的是,这里jprint必须传入两个参数,一个str和一个Array数组,跟内置的print函数在使用方式上还是有一定的差异的,不能直接替换。

总结概要

本文仅介绍一个可以在Jax的Jit即时编译模式下,也能够正常通过print打印函数来输出Jax Array内容的方法。

版权声明

本文首发链接为:https://www.cnblogs.com/dechinphy/p/jprint.html

作者ID:DechinPhy

更多原著文章:https://www.cnblogs.com/dechinphy/

请博主喝咖啡:https://www.cnblogs.com/dechinphy/gallery/image/379634.html

参考链接

  1. https://jax.ac.cn/en/latest/debugging/print_breakpoint.html
posted @ 2024-11-26 17:01  DECHIN  阅读(33)  评论(0编辑  收藏  举报