Tensoflow API笔记(N) 设备指定
tf.device是tf.Graph.device()的一个包装,是一个用于指定新创建的操作(operation)的默认设备的环境管理器。参数为device_name_or_function,可以传入一个设备字符串或者环境操作函数,如tf.DeviceSpec。
- 不过,如果传入的是一个设备名称字符串,那么在此环境中构造的所有操作都将被分配给带有该名称的设备,除非被其他嵌套的设备环境(其他的tf.device)所覆盖。
- 如果传入的是一个函数,它将被当作一个从操作对象到设备名称字符串的函数,并在每次创建新操作时调用它。操作将被分配给带有返回名称的设备。
- 如果是None,所有的来自代码段上下文的设备调用将被忽略。
1 with g.device('/device:GPU:0'): 2 # All operations constructed in this context will be placed 3 # on GPU 0. 4 with g.device(None): 5 # All operations constructed in this context will have no 6 # assigned device. 7 8 # Defines a function from `Operation` to device string. 9 def matmul_on_gpu(n): 10 if n.type == "MatMul": 11 return "/device:GPU:0" 12 else: 13 return "/cpu:0" 14 15 with g.device(matmul_on_gpu): 16 # All operations of type "MatMul" constructed in this context 17 # will be placed on GPU 0; all other operations will be placed 18 # on CPU 0.
另外,API文档也警告说:
设备范围可能被op包装器或其他库代码覆盖。例如,变量赋值操作v .assign()必须与tf.Variable变量v一起使用,如果变量v和不兼容的设备嵌套将被忽略。
tf.DeviceSpec返回的是部分或者全部的设备指定,在整个graph中来描述状态存储和计算发生的位置,并且允许解析设备规范的字符串,以验证它们的有效性、合并它们或以编码方式组合它们。
1 # Place the operations on device "GPU:0" in the "ps" job. 2 device_spec = DeviceSpec(job="ps", device_type="GPU", device_index=0) 3 with tf.device(device_spec): 4 # Both my_var and squared_var will be placed on /job:ps/device:GPU:0. 5 my_var = tf.Variable(..., name="my_variable") 6 squared_var = tf.square(my_var) 7 如果一个DeviceSpec被部分指定,将根据定义的范围与其他DeviceSpecs合并,在内部内定义的DeviceSpec组件优先于在外层内定义的组件。 8 with tf.device(DeviceSpec(job="train", )): 9 with tf.device(DeviceSpec(job="ps", device_type="GPU", device_index=0): 10 # Nodes created here will be assigned to /job:ps/device:GPU:0. 11 with tf.device(DeviceSpec(device_type="GPU", device_index=1): 12 # Nodes created here will be assigned to /job:train/device:GPU:1.
参数:
- Job: 作业名称
- Replica: 用于复制job的索引.
- Task: 任务索引.
- Device type: 设备类型 ("CPU" or "GPU").
- Device index: 设备索引,如果未指定,则可以使用任意的设备。
方法:
from_string:
接收一个如下形式的字符串:
/job:/replica:/task:/device:CPU
/job:/replica:/task:/device:GPU
其中CPU和GPU是互相排斥的,用于
merge_from:
将另外一个“DeviceSpec”的属性合并到当前DeviceSpec中。
parse_from_string:
将DeviceSpec名称(字符串)解析为组件。
to_string:
返回DeviceSpec的字符串。