Jax计算框架的NamedSharding的reshape —— namedsharding-gives-a-way-to-express-shardings-with-names

官方文档参考:
https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#namedsharding-gives-a-way-to-express-shardings-with-names



本篇post的主要讲解的是:
jax.device_put(x, mesh_sharding(P(('a', 'b'), None)))

jax.device_put(x, mesh_sharding(P(('b', 'a'), None)))
的不同:


主机的四个CPU情况:

代码:

import os
import functools
from typing import Optional
import numpy as np
import jax
import jax.numpy as jnp
from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding
# Create a Sharding object to distribute a value across devices:
sharding = PositionalSharding(mesh_utils.create_device_mesh((4,)))
# Create an array of random values:
x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))
# and use jax.device_put to distribute it across devices:
y = jax.device_put(x, sharding.reshape(2, 2))
jax.debug.visualize_array_sharding(y)

运行结果:

image




jax.device_put(x, mesh_sharding(P(('a', 'b'), None)))

代码:(行优先的方式展开GPU)

点击查看代码
from typing import Optional
import jax
from jax.sharding import Mesh
from jax.sharding import PartitionSpec
from jax.sharding import NamedSharding
from jax.experimental import mesh_utils
P = PartitionSpec
devices = mesh_utils.create_device_mesh((2, 2))
mesh = Mesh(devices, axis_names=('a', 'b'))
from jax.sharding import PositionalSharding
sharding = PositionalSharding(devices)
x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))
x = jax.device_put(x, sharding.reshape(4, 1))
devices = mesh_utils.create_device_mesh((2, 2))
default_mesh = Mesh(devices, axis_names=('a', 'b'))
def mesh_sharding(
pspec: PartitionSpec, mesh: Optional[Mesh] = None,
) -> NamedSharding:
if mesh is None:
mesh = default_mesh
return NamedSharding(mesh, pspec)
y = jax.device_put(x, mesh_sharding(P(('a', 'b'), None)))
jax.debug.visualize_array_sharding(y)

运行结果:
image



jax.device_put(x, mesh_sharding(P(('b', 'a'), None)))

代码:(列优先的方式展开GPU)

点击查看代码
from typing import Optional
import jax
from jax.sharding import Mesh
from jax.sharding import PartitionSpec
from jax.sharding import NamedSharding
from jax.experimental import mesh_utils
P = PartitionSpec
devices = mesh_utils.create_device_mesh((2, 2))
mesh = Mesh(devices, axis_names=('a', 'b'))
from jax.sharding import PositionalSharding
sharding = PositionalSharding(devices)
x = jax.random.normal(jax.random.PRNGKey(0), (8192, 8192))
x = jax.device_put(x, sharding.reshape(4, 1))
devices = mesh_utils.create_device_mesh((2, 2))
default_mesh = Mesh(devices, axis_names=('a', 'b'))
def mesh_sharding(
pspec: PartitionSpec, mesh: Optional[Mesh] = None,
) -> NamedSharding:
if mesh is None:
mesh = default_mesh
return NamedSharding(mesh, pspec)
y = jax.device_put(x, mesh_sharding(P(('b', 'a'), None)))
jax.debug.visualize_array_sharding(y)

运行结果:

image

posted on   Angry_Panda  阅读(64)  评论(0编辑  收藏  举报

相关博文:
阅读排行:
· 分享一个免费、快速、无限量使用的满血 DeepSeek R1 模型,支持深度思考和联网搜索!
· 使用C#创建一个MCP客户端
· ollama系列1:轻松3步本地部署deepseek,普通电脑可用
· 基于 Docker 搭建 FRP 内网穿透开源项目(很简单哒)
· 按钮权限的设计及实现
历史上的今天:
2023-01-07 项目的工作核心内容
2023-01-07 项目所需技术及问题
2019-01-07 [python]自问自答:python -m参数? ( python3.7 版本 )
2019-01-07 常用增强学习实验环境 II (ViZDoom, Roboschool, TensorFlow Agents, ELF, Coach等) (转载)
2019-01-07 常用增强学习实验环境 I (MuJoCo, OpenAI Gym, rllab, DeepMind Lab, TORCS, PySC2) (转载)
2017-01-07 遗传算法,实数编码的交叉操作之SBX(模拟二进制交叉)

导航

< 2025年3月 >
23 24 25 26 27 28 1
2 3 4 5 6 7 8
9 10 11 12 13 14 15
16 17 18 19 20 21 22
23 24 25 26 27 28 29
30 31 1 2 3 4 5

统计

点击右上角即可分享
微信分享提示