Jax中关于Device ID的配置

问题背景

在不同的框架中对于Device ID的配置方法都略有不同,这里提两种Jax中配置Device ID的方法。

配置环境变量

这个方法是比较流行的,直接在环境变量里面配置:

export CUDA_VISIBLE_DEVICES=1

这样就使得当前shell下运行的程序只能识别到1号显卡,一般就是第二张显卡了。如果需要配置多张显卡,类似的可以指定:

export CUDA_VISIBLE_DEVICES=0,1

当然,如果是在Python程序中运行的话,也可以直接在Python脚本中配置环境变量:

import os
os.environ['CUDA_VISIBLE_DEVICES']='1'

当然,该语句最好在Jax初始化之前执行。

Jit-device参数配置

这种配置方法会更加具体一点,可以直接指定某个即时编译的函数所使用的device id,如下是一个使用的案例:

import os
# 禁用显存预分配
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE']='false'
import time
import numpy as np
np.random.seed(0)
import jax
from jax import numpy as jnp
# 创建CPU上的张量
N = 5000000
crd = np.random.random((N, 3))
# 生成显卡对应的对象
gpus = jax.devices()
# 分配对象到不同的显卡上
crd0 = jax.jit(jnp.array, device=gpus[0])(crd[:3000000])
crd1 = jax.jit(jnp.array, device=gpus[1])(crd[3000000:])
time.sleep(5)

在这个案例中,我们在CPU上初始化一个crd张量,然后通过jax.numpy.array函数将该张量拷贝到指定的GPU环境中。其中向显卡第一张显卡拷贝了3000000组数据,向第二张显卡拷贝了2000000组数据,这样的话如果在nvidia-smi中就可以看到两个不同的显存占用了。

总结概要

本文主要介绍了2个在Jax框架中配置显卡Device ID的方法。第一种方法可以使用环境变量进行配置,对于众多的深度学习框架都是可以兼容的。而第二种方案是在Jax即时编译的过程中通过Jax生成的Device对象来控制数据的传输和函数执行的Device ID。

版权声明

本文首发链接为:https://www.cnblogs.com/dechinphy/p/jax-device-id.html

作者ID:DechinPhy

更多原著文章:https://www.cnblogs.com/dechinphy/

请博主喝咖啡:https://www.cnblogs.com/dechinphy/gallery/image/379634.html

参考链接

  1. https://github.com/jax-ml/jax/discussions/15957
posted @ 2024-11-05 16:47  DECHIN  阅读(28)  评论(0编辑  收藏  举报