[源码解析] TensorFlow 分布式之 ClusterCoordinator
[源码解析] TensorFlow 分布式之 ClusterCoordinator
本文我们主要来看看ParameterServerStrategy如何分发计算,也就是ClusterCoordinator如何运作。这是TF分布式的最后一篇。
安利两个github,都是非常好的学习资料,推荐。
https://github.com/yuhuiaws/ML-study
https://github.com/Jack47/hack-SysML
另外推荐西门宇少的最新大作让Pipeline在Transformer LM上沿着Token level并行起来——TeraPipe。
本系列其他文章是:
[翻译] TensorFlow 分布式之论文篇 "Implementation of Control Flow in TensorFlow"
[源码解析] TensorFlow 分布式环境(1) --- 总体架构
[源码解析] TensorFlow 分布式环境(2)---Master 静态逻辑
[源码解析] TensorFlow 分布式环境(3)--- Worker 静态逻辑
[源码解析] TensorFlow 分布式环境(4) --- WorkerCache
[源码解析] TensorFlow 分布式环境(5) --- Session
[源码解析] TensorFlow 分布式环境(7) --- Worker 动态逻辑
[源码解析] TensorFlow 分布式环境(8) --- 通信机制
[源码解析] TensorFlow 分布式 DistributedStrategy 之基础篇
[源码解析] TensorFlow 分布式之 MirroredStrategy
[源码解析] TensorFlow 分布式之 MirroredStrategy 分发计算
[源码解析] TensorFlow 分布式之 ParameterServerStrategy V1
[源码解析] TensorFlow 分布式之 ParameterServerStrategy V2
1. 思路
TensorFlow 2 推荐使用一种基于中央协调的架构来进行参数服务器训练。每个工作者和参数服务器都运行一个 tf.distribution.Server,在此基础上,一个协调者任务负责在工作者和参数服务器上创建资源,调度功能,并协调训练。协调器使用 tf.distribution.experimental.coordinator.ClusterCoordinator 来协调集群,使用 tf.distribution.experimental.ParameterServerStrategy 来定义参数服务器上的变量和工作者的计算。
ClusterCoordinator 是一个用于安排和协调远程函数执行的对象。该类用于创建容错(fault-tolerant)资源和调度函数到远程 TensorFlow 服务器。目前该类不支持独立使用,它应该与旨在与之合作的 tf.distribution 策略一起使用。ClusterCoordinator 类目前只适用于和 tf.distribution.experimental.ParameterServerStrategy 一起工作。
1.1 使用
在使用 ParameterServerStrategy 定义所有的计算后,用户可以使用 tf.distribution.experimental.coordinator.ClusterCoordinator 类来创建资源并将训练步骤分配给远程工作者。
首先,我们来创建一个 ClusterCoordinator 对象并传入策略对象。
strategy = tf.distribute.experimental.ParameterServerStrategy(cluster_resolver=...)
coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(strategy)
其次,创一个属于每个工作者(per-worker)的数据集和一个迭代器。在下面代码的 per_worker_dataset_fn 中,建议将 dataset_fn 包裹到 strategy.distribution_datasets_from_function 中,以允许无缝高效的把数据预取(prefetching )到 GPU。
@tf.function
def per_worker_dataset_fn():
return strategy.distribute_datasets_from_function(dataset_fn)
per_worker_dataset = coordinator.create_per_worker_dataset(per_worker_dataset_fn)
per_worker_iterator = iter(per_worker_dataset)
最后一步是使用 ClusterCoordinator.schedule 将计算分配给远程工作者。
- schedule 方法把一个 tf.function 插入队列,并立即返回一个 future-like 的 RemoteValue 。队列之中的函数将被派发给后台线程中的远程工作者,RemoteValue 将被异步填充结果。
- 用户可以使用 join 方法( ClusterCoordinator.join )来等待所有被规划(scheduled)的函数执行。
@tf.function
def step_fn(iterator):
return next(iterator)
num_epoches = 4
steps_per_epoch = 5
for i in range(num_epoches):
accuracy.reset_states()
for _ in range(steps_per_epoch):
coordinator.schedule(step_fn, args=(per_worker_iterator,))
# Wait at epoch boundaries.
coordinator.join()
print ("Finished epoch %d, accuracy is %f." % (i, accuracy.result().numpy()))
下面是如何得到 RemoteValue 的结果。
loss = coordinator.schedule(step_fn, args=(per_worker_iterator,))
print ("Final loss is %f" % loss.fetch())
用户也可以启动所有的步骤(steps),并在等待完成时做一些事情。
for _ in range(total_steps):
coordinator.schedule(step_fn, args=(per_worker_iterator,))
while not coordinator.done():
time.sleep(10)
# Do something like logging metrics or writing checkpoints.
1.2 问题点
依据前面的代码,我们总结出来问题点如下:
- Worker 如何知道使用哪些设备?
- 如何具体执行用户函数?
- 如何获取数据?
接下来我们就尝试通过分析代码来回答这些问题。
2. 定义
ClusterCoordinator 的主要思路如下。
- 协调者不是训练工作者之一,相反,它负责创建资源,如变量和数据集,调度 "tf.function",保存检查点等等。
- 为了使训练工作顺利进行,协调者派遣 "tf.function" 在远程工作者上执行。
- 在收到协调者的请求后,工作者通过从参数服务器读取变量、执行操作和更新参数服务器上的变量来执行 "tf.function"。
- 每个工作者只处理来自协调者的请求,并与参数服务器进行通信。而不与集群中的其他工作者直接互动。
ClusterCoordinator 定义具体如下,我们可以看到,其主要是配置了 _strategy 成员变量,生成了 _cluster 成员变量。
@tf_export("distribute.experimental.coordinator.ClusterCoordinator", v1=[])
class ClusterCoordinator(object):
def __new__(cls, strategy):
# ClusterCoordinator is kept as a single instance to a given Strategy .
if strategy._cluster_coordinator is None:
strategy._cluster_coordinator = super(
ClusterCoordinator, cls).__new__(cls)
return strategy._cluster_coordinator
def __init__(self, strategy):
"""Initialization of a ClusterCoordinator instance.
Args:
strategy: a supported tf.distribute.Strategy object. Currently, only
tf.distribute.experimental.ParameterServerStrategy is supported.
Raises:
ValueError: if the strategy being used is not supported.
"""
if not getattr(self, "_has_initialized", False):
if not isinstance(strategy,
parameter_server_strategy_v2.ParameterServerStrategyV2):
raise ValueError(
"Only tf.distribute.experimental.ParameterServerStrategy "
"is supported to work with "
" tf.distribute.experimental.coordinator.ClusterCoordinator "
"currently.")
self._strategy = strategy
self.strategy.extended._used_with_coordinator = True
self._cluster = Cluster(strategy)
self._has_initialized = True
def __del__(self):
self._cluster.stop()
@property
def strategy(self):
"""Returns the Strategy associated with the ClusterCoordinator ."""
return self._strategy
2.1 Schedule
由 ClusterCoordinator 对象提供的最重要的 API 是 schedule,其会分派 tf.function 到一个工作者,以便异步执行,具体如下:
- 该方法是非阻塞的,因为它把 fn 插入队列,并立即返回 tf.distribution.experimental.coordinator.RemoteValue 对象。fn 排队等待稍后执行。
- 在队列之中排队的函数将被派发给后台线程中的远程工作者来异步执行,他们的 RemoteValue 将被异步赋值。
- 由于 schedule 不需要分配一个工作者,传递进来的 tf.function 可以在任何可用的工作者上执行。
- 可以调用 fetch 来等待函数执行完成,并从远程工作者那里获取其输出。另一方面,也可以调用 tf.distribution.experimental.coordinator.ClusterCoordinator.join 来等待所有预定的函数完成。
失败和容错的策略如下:
- 由于工作者在执行函数的任何时候都可能失败,所以函数有可能被部分执行,但是 tf.distribution.experimental.coordinator.ClusterCoordinator 保证在这些事件中,函数最终将在任何可用的工作者上执行。
- schedule 保证 fn 至少在工作者上执行一次;如果其对应的工作者在执行过程中失败,由于函数的执行不是原子性的,所以一个函数可能被执行多次。
- 如果被执行的工作者在结束之前变得不可用,该函数将在另一个可用的工作者上重试。
- 如果任何先前安排的函数出现错误,schedule 将抛出其中任何一个错误,并清除到目前为止收集的错误。用户可以在返回的 tf.distribution.experimental.coordinator.RemoteValue 上调用 fetch 来检查它们是否已经执行、失败或取消,如果需要,可以重新安排相应的函数。当 schedule 引发异常时,它保证没有任何函数仍在执行。
Schedule 的具体定义如下,数据迭代器作为参数之一会和 fn 一起被传入。
def schedule(self, fn, args=None, kwargs=None):
"""Schedules fn to be dispatched to a worker for asynchronous execution.
This method is non-blocking in that it queues the fn which will be
executed later and returns a
tf.distribute.experimental.coordinator.RemoteValue object immediately.
fetch can be called on it to wait for the function execution to finish
and retrieve its output from a remote worker. On the other hand, call
tf.distribute.experimental.coordinator.ClusterCoordinator.join to wait for
all scheduled functions to finish.
schedule guarantees that fn will be executed on a worker at least once;
it could be more than once if its corresponding worker fails in the middle
of its execution. Note that since worker can fail at any point when
executing the function, it is possible that the function is partially
executed, but tf.distribute.experimental.coordinator.ClusterCoordinator
guarantees that in those events, the function will eventually be executed on
any worker that is available.
If any previously scheduled function raises an error, schedule will raise
any one of those errors, and clear the errors collected so far. What happens
here, some of the previously scheduled functions may have not been executed.
User can call fetch on the returned
tf.distribute.experimental.coordinator.RemoteValue to inspect if they have
executed, failed, or cancelled, and reschedule the corresponding function if
needed.
When schedule raises, it guarantees that there is no function that is
still being executed.
At this time, there is no support of worker assignment for function
execution, or priority of the workers.
args and kwargs are the arguments passed into fn , when fn is
executed on a worker. They can be
tf.distribute.experimental.coordinator.PerWorkerValues and in this case,
the argument will be substituted with the corresponding component on the
target worker. Arguments that are not
tf.distribute.experimental.coordinator.PerWorkerValues will be passed into
fn as-is. Currently, tf.distribute.experimental.coordinator.RemoteValue
is not supported to be input args or kwargs .
Args:
fn: A tf.function ; the function to be dispatched to a worker for
execution asynchronously. Regular python funtion is not supported to be
scheduled.
args: Positional arguments for fn .
kwargs: Keyword arguments for fn .
Returns:
A tf.distribute.experimental.coordinator.RemoteValue object that
represents the output of the function scheduled.
Raises:
Exception: one of the exceptions caught by the coordinator from any
previously scheduled function, since the last time an error was thrown
or since the beginning of the program.
"""
if not isinstance(fn,
(def_function.Function, tf_function.ConcreteFunction)):
raise TypeError(
" tf.distribute.experimental.coordinator.ClusterCoordinator.schedule "
" only accepts a tf.function or a concrete function.")
# Slot variables are usually created during function tracing time; thus
# schedule needs to be called within the strategy.scope() .
with self.strategy.scope():
self.strategy.extended._being_scheduled = True
remote_value = self._cluster.schedule(fn, args=args, kwargs=kwargs)
self.strategy.extended._being_scheduled = False
return remote_value
2.2 Join
Join 方法的作用是阻塞直到所有预定的函数都执行完毕,其具体特点如下:
- 如果任何先前安排的函数产生错误,join 将因为抛出一个错误而失败,并清除到目前为止收集的错误。如果发生这种情况,一些先前安排的函数可能没有被执行。
- 用户可以对返回的 tf.distribution.experimental.coordinator.RemoteValue 调用 fetch 来检查它们是否已经执行、失败或取消了。
- 如果一些已经取消的函数需要重新安排,用户应该再次调用 schedule 。
- 当 join 返回或抛出异常时,它保证没有任何函数仍在执行。
def join(self):
"""Blocks until all the scheduled functions have finished execution.
If any previously scheduled function raises an error, join will fail by
raising any one of those errors, and clear the errors collected so far. If
this happens, some of the previously scheduled functions may have not been
executed. Users can call fetch on the returned
tf.distribute.experimental.coordinator.RemoteValue to inspect if they have
executed, failed, or cancelled. If some that have been cancelled need to be
rescheduled, users should call schedule with the function again.
When join returns or raises, it guarantees that there is no function that
is still being executed.
Raises:
Exception: one of the exceptions caught by the coordinator by any
previously scheduled function since the last time an error was thrown or
since the beginning of the program.
"""
self._cluster.join()
2.3 Done
Done 方法返回所有分发的函数是否已经执行完毕。如果任何先前分发的函数引发错误,done'将会失败。
def done(self):
"""Returns whether all the scheduled functions have finished execution.
If any previously scheduled function raises an error, done will fail by
raising any one of those errors.
When done returns True or raises, it guarantees that there is no function
that is still being executed.
Returns:
Whether all the scheduled functions have finished execution.
Raises:
Exception: one of the exceptions caught by the coordinator by any
previously scheduled function since the last time an error was thrown or
since the beginning of the program.
"""
return self._cluster.done()
2.4 Fetch
Fetch 会获取 remote values 的结果。
def fetch(self, val):
"""Blocking call to fetch results from the remote values.
This is a wrapper around
tf.distribute.experimental.coordinator.RemoteValue.fetch for a
RemoteValue structure; it returns the execution results of
RemoteValue s. If not ready, wait for them while blocking the caller.
Example:
```python
strategy = ...
coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
strategy)
def dataset_fn():
return tf.data.Dataset.from_tensor_slices([1, 1, 1])
with strategy.scope():
v = tf.Variable(initial_value=0)
@tf.function
def worker_fn(iterator):
def replica_fn(x):
v.assign_add(x)
return v.read_value()
return strategy.run(replica_fn, args=(next(iterator),))
distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn)
distributed_iterator = iter(distributed_dataset)
result = coordinator.schedule(worker_fn, args=(distributed_iterator,))
assert coordinator.fetch(result) == 1
```
Args:
val: The value to fetch the results from. If this is structure of
tf.distribute.experimental.coordinator.RemoteValue , fetch() will be
called on the individual
tf.distribute.experimental.coordinator.RemoteValue to get the result.
Returns:
If val is a tf.distribute.experimental.coordinator.RemoteValue or a
structure of tf.distribute.experimental.coordinator.RemoteValue s,
return the fetched tf.distribute.experimental.coordinator.RemoteValue
values immediately if they are available, or block the call until they are
available, and return the fetched
tf.distribute.experimental.coordinator.RemoteValue values with the same
structure. If val is other types, return it as-is.
"""
def _maybe_fetch(val):
if isinstance(val, RemoteValue):
return val.fetch()
else:
return val
return nest.map_structure(_maybe_fetch, val)
3. 数据
除了调度远程函数,ClusterCoordinator 还帮助在所有工作者上创建数据集,并当一个工作者从失败中恢复时重建这些数据集。用户可以通过调用 dataset_fn 来在worker设备上创建数据集。使用例子如下:
strategy = tf.distribute.experimental.ParameterServerStrategy(
cluster_resolver=...)
coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
strategy=strategy)
@tf.function
def worker_fn(iterator):
return next(iterator)
def per_worker_dataset_fn():
return strategy.distribute_datasets_from_function(
lambda x: tf.data.Dataset.from_tensor_slices([3] * 3))
per_worker_dataset = coordinator.create_per_worker_dataset(
per_worker_dataset_fn)
per_worker_iter = iter(per_worker_dataset)
remote_value = coordinator.schedule(worker_fn, args=(per_worker_iter,))
assert remote_value.fetch() == 3
3.1 建立数据集
上面代码使用了 create_per_worker_dataset 在worker上创建数据集,这些数据集由 dataset_fn 生成,并返回一个代表这些数据集的集合。在这样的数据集集合上调用 iter 会返回一个 tf.distribution.experimental.coordinator.PerWorkerValues,它是一个迭代器的集合,其中的迭代器已经被放置在各个工作者上。
需要注意,不支持在迭代器的 "PerWorkerValues"上直接调用 "next"。该迭代器应该是作为一个参数传递给 tf.distribution.experimental.coordinator.ClusterCoordinator.schedule 。当计划的函数即将被工作者执行时,该函数将收到与该工作者相对应的单个迭代器。该函数可以对该迭代器调用 next 方法。
目前,schedule 方法假定工作者都是相同的,因此假设不同工作者上的数据集是一样的,除非它们包含 dataset.shuffle 操作,并且没有设置随机种子,在这种情况下,它们的洗牌方式会不同。正因为如此,建议将数据集无限地重复,并安排有限的步骤,而不是依赖于数据集的 OutOfRangeError 来结束。
def create_per_worker_dataset(self, dataset_fn):
"""Create dataset on workers by calling dataset_fn on worker devices.
This creates the given dataset generated by dataset_fn on workers
and returns an object that represents the collection of those individual
datasets. Calling iter on such collection of datasets returns a
tf.distribute.experimental.coordinator.PerWorkerValues , which is a
collection of iterators, where the iterators have been placed on respective
workers.
Calling next on a PerWorkerValues of iterator is unsupported. The
iterator is meant to be passed as an argument into
tf.distribute.experimental.coordinator.ClusterCoordinator.schedule . When
the scheduled function is about to be executed by a worker, the
function will receive the individual iterator that corresponds to the
worker. The next method can be called on an iterator inside a
scheduled function when the iterator is an input of the function.
Currently the schedule method assumes workers are all the same and thus
assumes the datasets on different workers are the same, except they may be
shuffled differently if they contain a dataset.shuffle operation and a
random seed is not set. Because of this, we also recommend the datasets to
be repeated indefinitely and schedule a finite number of steps instead of
relying on the OutOfRangeError from a dataset.
Args:
dataset_fn: The dataset function that returns a dataset. This is to be
executed on the workers.
Returns:
An object that represents the collection of those individual
datasets. iter is expected to be called on this object that returns
a tf.distribute.experimental.coordinator.PerWorkerValues of the
iterators (that are on the workers).
"""
return values_lib.get_per_worker_dataset(dataset_fn, self)
get_per_worker_dataset 则返回 PerWorkerDatasetFromDataset 或者 PerWorkerDatasetFromDatasetFunction。
def get_per_worker_dataset(dataset_or_dataset_fn, coordinator):
if callable(dataset_or_dataset_fn):
return PerWorkerDatasetFromDatasetFunction(dataset_or_dataset_fn,
coordinator)
else:
return PerWorkerDatasetFromDataset(dataset_or_dataset_fn, coordinator)
3.2 PerWorkerDistributedDataset
PerWorkerDistributedDataset 代表了从一个数据集建立的工作者使用的分布式数据集。
class PerWorkerDatasetFromDataset(PerWorkerDatasetFromDatasetFunction):
"""Represents worker-distributed datasets created from a dataset."""
def __init__(self, dataset, coordinator):
"""Makes an iterable from datasets created by the given dataset.
It creates a dataset_fn which deserializes a dataset from a graph under the
hood.
Args:
dataset: A tf.data.Dataset, a DistributedDataset or a
DistributedDatasetsFromFunction
coordinator: a ClusterCoordinator object, used to create dataset
resources.
"""
if isinstance(dataset, input_lib.DistributedDataset):
original_dataset = dataset._original_dataset
serialized = serialize_dataset_to_graph(original_dataset)
def dataset_fn():
deserialized = deserialize_dataset_from_graph(
serialized, original_dataset.element_spec)
dataset.build(dataset_to_replace=deserialized)
return dataset
elif isinstance(dataset, input_lib.DistributedDatasetsFromFunction):
def dataset_fn():
dataset.build()
return dataset
elif isinstance(dataset, dataset_ops.Dataset):
serialized = serialize_dataset_to_graph(dataset)
def dataset_fn():
return deserialize_dataset_from_graph(serialized, dataset.element_spec)
else:
raise ValueError("Unexpected dataset type!")
super(PerWorkerDatasetFromDataset, self).__init__(dataset_fn, coordinator)
3.3 PerWorkerDatasetFromDatasetFunction
PerWorkerDistributedDataset 代表了从一个数据集方法建立的工作者使用的分布式数据集。
在 iter 之中有:
-
调用 _create_per_worker_iterator 得到一个 iter(dataset)。
-
调用 self._coordinator._create_per_worker_resources 为每工作者生成一个 iterator。
-
最后返回一个 PerWorkerDistributedIterator。
class PerWorkerDatasetFromDatasetFunction(object):
"""Represents worker-distributed datasets created from dataset function."""
def __init__(self, dataset_fn, coordinator):
"""Makes an iterable from datasets created by the given function.
Args:
dataset_fn: A function that returns a Dataset .
coordinator: a ClusterCoordinator object, used to create dataset
resources.
"""
def disallow_variable_creation(next_creator, **kwargs):
raise ValueError("Creating variables in dataset_fn is not allowed.")
if isinstance(dataset_fn, def_function.Function):
with variable_scope.variable_creator_scope(disallow_variable_creation):
dataset_fn = dataset_fn.get_concrete_function()
elif not isinstance(dataset_fn, tf_function.ConcreteFunction):
with variable_scope.variable_creator_scope(disallow_variable_creation):
dataset_fn = def_function.function(dataset_fn).get_concrete_function()
self._dataset_fn = dataset_fn
self._coordinator = coordinator
self._element_spec = None
def __iter__(self):
# We would like users to create iterators outside tf.function s so that we
# can track them.
if (not context.executing_eagerly() or
ops.get_default_graph().building_function):
raise RuntimeError(
"__iter__() is not supported inside of tf.function or in graph mode.")
def _create_per_worker_iterator():
dataset = self._dataset_fn()
return iter(dataset)
# If PerWorkerDatasetFromDatasetFunction.__iter__ is called multiple
# times, for the same object it should only create and register resource
# once. Using object id to distinguish different iterator resources.
per_worker_iterator = self._coordinator._create_per_worker_resources(
_create_per_worker_iterator)
# Setting type_spec of each RemoteValue so that functions taking these
# RemoteValues as inputs can be traced.
for iterator_remote_value in per_worker_iterator._values:
iterator_remote_value._type_spec = (
input_lib.get_iterator_spec_from_dataset(
self._coordinator.strategy, self._dataset_fn.structured_outputs))
return PerWorkerDistributedIterator(per_worker_iterator._values)
@property
def element_spec(self):
"""The type specification of an element of this dataset.
This property is subject to change without notice.
"""
return self._dataset_fn.structured_outputs.element_spec
3.4 _create_per_worker_resources
_create_per_worker_resources 会调用各个工作者的方法来让每个工作者得到数据。
def _create_per_worker_resources(self, fn, args=None, kwargs=None):
"""Synchronously create resources on the workers.
The resources are represented by
tf.distribute.experimental.coordinator.RemoteValue s.
Args:
fn: The function to be dispatched to all workers for execution
asynchronously.
args: Positional arguments for fn .
kwargs: Keyword arguments for fn .
Returns:
A tf.distribute.experimental.coordinator.PerWorkerValues object, which
wraps a tuple of tf.distribute.experimental.coordinator.RemoteValue
objects.
"""
results = []
for w in self._cluster.workers:
results.append(w.create_resource(fn, args=args, kwargs=kwargs))
return PerWorkerValues(tuple(results))
3.5 PerWorkerValues
PerWorkerValues 是一个容纳 value 列表的容器,每个工作者对应一个 value。Tf.distribution.experimental.coordinator.PerWorkerValues 包含一个值的集合,其中每个值都位于其相应的工作者上,当被用作 tf.distribution.experimental.coordinator.ClusterCoordinator.schedule() 的 args 或 kwargs 时,某一个工作者的特定值将被传递到该工作者上执行的函数中。
创建 tf.distribution.experimental.coordinator.PerWorkerValues 对象的唯一路径是通过在 ClusterCoordinator.create_per_worker_dataset 返回的分布式数据集实例上调用 iter 。目前还不支持创建自定义 tf.distribution.experimental.coordinator.PerWorkerValues 的机制。
@tf_export("distribute.experimental.coordinator.PerWorkerValues", v1=[])
class PerWorkerValues(composite_tensor.CompositeTensor):
"""A container that holds a list of values, one value per worker.
tf.distribute.experimental.coordinator.PerWorkerValues contains a collection
of values, where each of the values is located on its corresponding worker,
and upon being used as one of the args or kwargs of
tf.distribute.experimental.coordinator.ClusterCoordinator.schedule() , the
value specific to a worker will be passed into the function being executed at
that corresponding worker.
Currently, the only supported path to create an object of
tf.distribute.experimental.coordinator.PerWorkerValues is through calling
iter on a ClusterCoordinator.create_per_worker_dataset -returned
distributed dataset instance. The mechanism to create a custom
tf.distribute.experimental.coordinator.PerWorkerValues is not yet supported.
"""
def __init__(self, values):
for v in values:
if not isinstance(v, RemoteValue):
raise AssertionError(
" PerWorkerValues should only take RemoteValue s.")
self._values = tuple(values)
@property
def _type_spec(self):
return PerWorkerValuesTypeSpec(
self._values[0]._type_spec,
type(self))
获取数据的逻辑如下:
4. Cluster
Cluster 才是业务执行者。
4.1 定义
Cluster 是一个工作者集群。在初始化方法之中,会做如下处理:
- 设置如何忽略参数服务器暂时错误。
- 设定工作者的设备名字。
- 生成一系列工作者。
这里要注意的是如何忽略因为工作者瞬时连接错误而报告的故障。
- 工作者和参数服务器之间的瞬时连接问题会由工作者转达给协调者,这将导致协调者认为存在参数服务器故障。
- 瞬时与永久的参数服务器故障之间的区别是工作者报告的数量。当这个环境变量设置为正整数 K 时,协调器忽略最多 K 个失败报告,也就是说,只有超过 K 个执行错误,并且这些错误是因为同一个参数服务器实例导致的,我们才认为参数服务器实例遇到了失败。
class Cluster(object):
"""A cluster with workers.
We assume all function errors are fatal and based on this assumption our
error reporting logic is:
1) Both schedule and join can raise a non-retryable error which is the
first error seen by the coordinator from any previously scheduled functions.
2) When an error is raised, there is no guarantee on how many previously
scheduled functions have been executed; functions that have not been executed
will be thrown away and marked as cancelled.
3) After an error is raised, the internal state of error will be cleared.
I.e. functions can continue to be scheduled and subsequent calls of schedule
or join will not raise the same error again.
Attributes:
failure_handler: The failure handler used to handler worker preemption
failure.
workers: a list of Worker objects in the cluster.
"""
def __init__(self, strategy):
"""Initializes the cluster instance."""
self._num_workers = strategy._num_workers
self._num_ps = strategy._num_ps
# 如何忽略参数服务器暂时错误
self._transient_ps_failures_threshold = int(
os.environ.get("TF_COORDINATOR_IGNORE_TRANSIENT_PS_FAILURES", 3))
self._potential_ps_failures_lock = threading.Lock()
self._potential_ps_failures_count = [0] * self._num_ps
self._closure_queue = _CoordinatedClosureQueue()
self.failure_handler = WorkerPreemptionHandler(context.get_server_def(),
self)
# 设定 worker 的设备名字
worker_device_strings = [
"/job:worker/replica:0/task:%d" % i for i in range(self._num_workers)
]
# 生成 Workers
self.workers = [
Worker(i, w, self) for i, w in enumerate(worker_device_strings)
]
4.2 Schedule
这个类提供的最重要的API是 "schedule"/"join" 这对函数。"schedule" API是非阻塞的,它把一个 "tf.function "插入队列,并立即返回一个 "RemoteValue"。
def schedule(self, function, args, kwargs):
"""Schedules function to be dispatched to a worker for execution.
Args:
function: The function to be dispatched to a worker for execution
asynchronously.
args: Positional arguments for fn .
kwargs: Keyword arguments for fn .
Returns:
A RemoteValue object.
"""
closure = Closure(
function,
self._closure_queue._cancellation_mgr,
args=args,
kwargs=kwargs)
self._closure_queue.put(closure)
return closure.output_remote_value
def join(self):
"""Blocks until all scheduled functions are executed."""
self._closure_queue.wait()
具体逻辑如下,虚线表示数据集被传入,这里的 Queue 是 from six.moves import queue 引入的 queue.Queue,我们接下来在_CoordinatedClosureQueue之中会见到。
或者我们从官方文档图来看,目前完成的是左边圆圈部分。
4.3 停止
停止代码如下,具体是调用队列的处理方法。
def stop(self):
"""Stop worker, worker preemption threads, and the closure queue."""
self.failure_handler.stop()
for worker in self.workers:
worker.stop()
self._closure_queue.stop()
def done(self):
"""Returns true if all scheduled functions are executed."""
return self._closure_queue.done()
5. 任务 Closure
Closure 的作用是把任务封装起来,并且提供了其他功能。
class Closure(object):
"""Hold a function to be scheduled and its arguments."""
def __init__(self, function, cancellation_mgr, args=None, kwargs=None):
if not callable(function):
raise ValueError("Function passed to ClusterCoordinator.schedule must "
"be a callable object.")
self._args = args or ()
self._kwargs = kwargs or {}
_disallow_remote_value_as_input(self._args)
_disallow_remote_value_as_input(self._kwargs)
if isinstance(function, def_function.Function):
replica_args = _select_worker_slice(0, self._args)
replica_kwargs = _select_worker_slice(0, self._kwargs)
# Note: no need to handle function registration failure since this kind of
# failure will not raise exceptions as designed in the runtime. The
# coordinator has to rely on subsequent operations that raise to catch
# function registration failure.
# Record the function tracing overhead. Note that we pass in the tracing
# count of the def_function.Function as a state tracker, so that metrics
# will only record the time for actual function tracing (i.e., excluding
# function cache lookups).
with metric_utils.monitored_timer(
"function_tracing", state_tracker=function._get_tracing_count):
self._concrete_function = function.get_concrete_function(
*nest.map_structure(_maybe_as_type_spec, replica_args),
**nest.map_structure(_maybe_as_type_spec, replica_kwargs))
elif isinstance(function, tf_function.ConcreteFunction):
self._concrete_function = function
if hasattr(self, "_concrete_function"):
# If we have a concrete function, we get to retrieve the output type spec
# via the structured_output.
output_type_spec = func_graph.convert_structure_to_signature(
self._concrete_function.structured_outputs)
self._function = cancellation_mgr.get_cancelable_function(
self._concrete_function)
else:
# Otherwise (i.e. what is passed in is a regular python function), we have
# no such information.
output_type_spec = None
self._function = function
self.output_remote_value = RemoteValueImpl(self, output_type_spec)
5.1 执行
Closure 的 execute_on 负责运行,具体是在指定的设备上执行 self._function,就是用户自定义的 function。需要注意的是,with context.executor_scope(worker.executor) 使用了 context。
def execute_on(self, worker):
"""Executes the closure on the given worker.
Args:
worker: a Worker object.
"""
replica_args = _select_worker_slice(worker.worker_index, self._args)
replica_kwargs = _select_worker_slice(worker.worker_index, self._kwargs)
e = (
_maybe_rebuild_remote_values(worker, replica_args) or
_maybe_rebuild_remote_values(worker, replica_kwargs))
if e:
if not isinstance(e, InputError):
e = InputError(e)
self.output_remote_value._set_error(e)
return
with ops.device(worker.device_name): # 在指定设备上
with context.executor_scope(worker.executor): # 通过上下文
with metric_utils.monitored_timer("closure_execution"):
output_values = self._function( # 运行用户的参数
*nest.map_structure(_maybe_get_remote_value, replica_args),
**nest.map_structure(_maybe_get_remote_value, replica_kwargs))
self.output_remote_value._set_values(output_values)
Self._function 是用户自定义的 function,我们再给出一个方法示例,可以看出来可以使用 strategy.run 把训练方法分发到远端工作者进行训练。
@tf.function
def worker_fn(iterator):
def replica_fn(inputs):
batch_data, labels = inputs
# calculate gradient, applying gradient, metrics update etc.
strategy.run(replica_fn, args=(next(iterator),))
5.2 取消
用户可以设置取消 Closure,就是在返回值之中做下设置。
def mark_cancelled(self):
self.output_remote_value._set_error(
errors.CancelledError(
None, None, "The corresponding function is "
"cancelled. Please reschedule the function."))
5.3 ResourceClosure
ResourceClosure 是派生类,把 Closure 用 RemoteValue 包装起来。实际上使用的都是 ResourceClosure。
class ResourceClosure(Closure):
def build_output_remote_value(self):
if self._output_remote_value_ref is None:
# We need to remember the Closure object in the RemoteValue here.
ret = RemoteValueImpl(self, self._output_type_spec)
self._output_remote_value_ref = weakref.ref(ret)
return ret
else:
return self._output_remote_value_ref()
6. 队列
_CoordinatedClosureQueue 是任务所在的队列。
6.1 定义
from six.moves import queue
class _CoordinatedClosureQueue(object):
"""Manage a queue of closures, inflight count and errors from execution.
This class is thread-safe.
"""
def __init__(self):
# self._inflight_closure_count only tracks the number of inflight closures
# that are "in generation". Once an error occurs, error generation is
# incremented and all subsequent arriving closures (from inflight) are
# considered "out of generation".
self._inflight_closure_count = 0
self._queue_lock = threading.Lock()
# Condition indicating that all pending closures (either queued or inflight)
# have been processed, failed, or cancelled.
self._stop_waiting_condition = threading.Condition(self._queue_lock)
# Condition indicating that an item becomes available in queue (not empty).
self._closures_queued_condition = threading.Condition(self._queue_lock)
self._should_process_closures = True
# Condition indicating that a queue slot becomes available (not full).
# Note that even with "infinite" queue size, there is still a "practical"
# size limit for the queue depending on host memory capacity, and thus the
# queue will eventually become full with a lot of enqueued closures.
self._queue_free_slot_condition = threading.Condition(self._queue_lock)
# Condition indicating there is no inflight closures.
self._no_inflight_closure_condition = threading.Condition(self._queue_lock)
# Use to cancel in-flight closures.
self._cancellation_mgr = cancellation.CancellationManager()
self._queue = queue.Queue(maxsize=_CLOSURE_QUEUE_MAX_SIZE)
self._error = None
# The following is a lock to make sure when wait is called and before it
# returns no put can be executed during this period. It is because wait
# won't know what to do with newly put closures. This lock adds an cutoff
# for wait so that closures put into the queue while waiting would not be
# taken responsible by this wait .
#
# We cannot reuse the self._queue_lock since when wait waits for a
# condition, the self._queue_lock will be released.
#
# We don't use a reader/writer's lock on purpose to reduce the complexity
# of the code.
self._put_wait_lock = threading.Lock()
6.2 插入取出
Put 和 get 方法分别负责插入和取出。
def put(self, closure):
"""Put a closure into the queue for later execution.
If mark_failed was called before put , the error from the first
invocation of mark_failed will be raised.
Args:
closure: The Closure to put into the queue.
"""
with self._put_wait_lock, self._queue_lock:
self._queue_free_slot_condition.wait_for(lambda: not self._queue.full())
self._queue.put(closure, block=False)
self._raise_if_error()
self._closures_queued_condition.notify()
def get(self, timeout=None):
"""Return a closure from the queue to be executed."""
with self._queue_lock:
while self._queue.empty() and self._should_process_closures:
if not self._closures_queued_condition.wait(timeout=timeout):
return None
if not self._should_process_closures:
return None
closure = self._queue.get(block=False)
self._queue_free_slot_condition.notify()
self._inflight_closure_count += 1
return closure
Put_back 则负责把 closure 重新放回queue。
def put_back(self, closure):
"""Put the closure back into the queue as it was not properly executed."""
with self._queue_lock:
if self._inflight_closure_count < 1:
raise AssertionError("There is no inflight closures to put_back.")
if self._error:
closure.mark_cancelled()
else:
self._queue_free_slot_condition.wait_for(lambda: not self._queue.full())
self._queue.put(closure, block=False)
self._closures_queued_condition.notify()
self._inflight_closure_count -= 1
if self._inflight_closure_count == 0:
self._no_inflight_closure_condition.notifyAll()
6.3 等待
方法 wait 会等待所有 closures 结束。
def wait(self, timeout=None):
"""Wait for all closures to be finished before returning.
If mark_failed was called before or during wait , the error from the
first invocation of mark_failed will be raised.
Args:
timeout: A float specifying a timeout for the wait in seconds.
Returns:
True unless the given timeout expired, in which case it returns False.
"""
with self._put_wait_lock, self._queue_lock:
while (not self._error and
(not self._queue.empty() or self._inflight_closure_count > 0)):
if not self._stop_waiting_condition.wait(timeout=timeout):
return False
self._raise_if_error()
return True
6.4 异常&结束
Mark_failed 和 done 则是处理结束和异常的一套组合。
def mark_failed(self, e):
"""Sets error and unblocks any wait() call."""
with self._queue_lock:
# TODO(yuefengz): maybe record all failure and give users more
# information?
if self._inflight_closure_count < 1:
raise AssertionError("There is no inflight closures to mark_failed.")
if self._error is None:
self._error = e
self._inflight_closure_count -= 1
if self._inflight_closure_count == 0:
self._no_inflight_closure_condition.notifyAll()
self._stop_waiting_condition.notifyAll()
def done(self):
"""Returns true if the queue is empty and there is no inflight closure.
If mark_failed was called before done , the error from the first
invocation of mark_failed will be raised.
"""
with self._queue_lock:
self._raise_if_error()
return self._queue.empty() and self._inflight_closure_count == 0
6.5 停止
Stop 和 _cancel_all_closures 负责暂停 closures。
def stop(self):
with self._queue_lock:
self._should_process_closures = False
self._closures_queued_condition.notifyAll()
def _cancel_all_closures(self):
"""Clears the queue and sets remaining closures cancelled error.
This method expects self._queue_lock to be held prior to entry.
"""
self._cancellation_mgr.start_cancel()
while self._inflight_closure_count > 0:
self._no_inflight_closure_condition.wait()
while True:
try:
closure = self._queue.get(block=False)
self._queue_free_slot_condition.notify()
closure.mark_cancelled()
except queue.Empty:
break
# The cancellation manager cannot be reused once cancelled. After all
# closures (queued or inflight) are cleaned up, recreate the cancellation
# manager with clean state.
# Note on thread-safety: this is triggered when one of theses
# ClusterCoordinator APIs are called: schedule , wait , and done . At the
# same time, no new closures can be constructed (which reads the
# _cancellation_mgr to get cancellable functions).
self._cancellation_mgr = cancellation.CancellationManager()
def _raise_if_error(self):
"""Raises the error if one exists.
If an error exists, cancel the closures in queue, raises it, and clear
the error.
This method expects self._queue_lock to be held prior to entry.
"""
if self._error:
logging.error("Start cancelling closures due to error %r: %s",
self._error, self._error)
self._cancel_all_closures()
try:
raise self._error
finally:
self._error = None
7.4 Worker
Worker 是函数的执行者。
7.1 定义
Worker 的定义如下,其启动了一个线程来运行 _process_queue。
class Worker(object):
"""A worker in a cluster.
Attributes:
worker_index: The index of the worker in the cluster.
device_name: The device string of the worker, e.g. "/job:worker/task:1".
executor: The worker's executor for remote function execution.
failure_handler: The failure handler used to handler worker preemption
failure.
"""
def __init__(self, worker_index, device_name, cluster):
self.worker_index = worker_index
self.device_name = device_name
# 这里会有一个executor
self.executor = executor.new_executor(enable_async=False)
self.failure_handler = cluster.failure_handler
self._cluster = cluster
self._resource_remote_value_refs = []
self._should_worker_thread_run = True
# Worker threads need to start after Worker 's initialization.
threading.Thread(target=self._process_queue,
name="WorkerClosureProcessingLoop-%d" % self.worker_index,
daemon=True).start()
New_executor 会调用 TFE_NewExecutor。
def new_executor(enable_async):
handle = pywrap_tfe.TFE_NewExecutor(enable_async)
return Executor(handle)
TFE_NewExecutor 定义在 tensorflow/c/eager/c_api_experimental.cc,其生成了 TFE_Executor。
TFE_Executor* TFE_NewExecutor(bool is_async) {
return new TFE_Executor(is_async);
}
TFE_Executor 定义如下,Executor类是会话执行器的抽象,在 TF2 之中,也有 EagerExecutor。
struct TFE_Executor {
explicit TFE_Executor(bool async)
: owned_executor(new tensorflow::EagerExecutor(async)) {}
explicit TFE_Executor(tensorflow::EagerExecutor* executor)
: owned_executor(nullptr), unowned_executor(executor) {}
tensorflow::EagerExecutor* executor() {
return owned_executor == nullptr ? unowned_executor : owned_executor.get();
}
std::unique_ptr<tensorflow::EagerExecutor> owned_executor;
tensorflow::EagerExecutor* unowned_executor;
};
7.2 处理
_process_queue 方法会从 queue 之中取出 Closure,然后运行任务。
def _process_queue(self):
"""Function running in a worker thread to process closure queues."""
self._maybe_delay()
while self._should_worker_thread_run:
closure = self._cluster._closure_queue.get()
if not self._should_worker_thread_run or closure is None:
return
self._process_closure(closure)
# To properly stop the worker and preemption threads, it is important that
# ClusterCoordinator object is not held onto so its __del__ can be
# called. By removing the reference to the closure that has already been
# processed, we ensure that the closure object is released, while
# getting the next closure at above self._cluster._closure_queue.get()
# call.
del closure
7.2.1 等待
_process_queue 之中首先会调用 _maybe_delay 等待环境变量配置。
def _maybe_delay(self):
"""Delay if corresponding env vars are set."""
# If the following two env vars variables are set. Scheduling for workers
# will start in a staggered manner. Worker i will wait for
# TF_COORDINATOR_SCHEDULE_START_DELAY * i seconds, not exceeding
# TF_COORDINATOR_SCHEDULE_START_DELAY_MAX .
delay_secs = int(os.environ.get("TF_COORDINATOR_SCHEDULE_START_DELAY", "0"))
delay_cap = int(
os.environ.get("TF_COORDINATOR_SCHEDULE_START_DELAY_MAX", "0"))
if delay_cap:
delay_secs = min(delay_secs * self.worker_index, delay_cap)
if delay_secs > 0:
logging.info("Worker %d sleeping for %d seconds before running function",
self.worker_index, delay_secs)
time.sleep(delay_secs)
7.2.2 处理任务
_process_queue 之中接着会调用 _process_closure 来运行 closure。
def _process_closure(self, closure):
"""Runs a closure with preemption handling."""
try:
with self._cluster.failure_handler.wait_on_failure(
on_failure_fn=lambda: self._cluster._closure_queue.put_back(closure),
on_recovery_fn=self._set_resources_aborted,
worker_device_name=self.device_name):
closure.execute_on(self)
with metric_utils.monitored_timer("remote_value_fetch"):
# Copy the remote tensor to local (the coordinator) in case worker
# becomes unavailable at a later time.
closure.output_remote_value.get()
self._cluster._closure_queue.mark_finished()
except Exception as e:
# Avoid logging the derived cancellation error
if not isinstance(e, errors.CancelledError):
logging.error(
"/job:worker/task:%d encountered the following error when "
"processing closure: %r:%s", self.worker_index, e, e)
closure.output_remote_value._set_error(e)
self._cluster._closure_queue.mark_failed(e)
7.3 数据
我们接下来看看如何把数据读取放到工作者上运行。前面提到了,在 _create_per_worker_resources 会调用 create_resource,为每一个工作者建立其自己的资源。
def create_resource(self, function, args=None, kwargs=None):
"""Synchronously creates a per-worker resource represented by a RemoteValue .
Args:
function: the resource function to be run remotely. It should be a
tf.function , a concrete function or a Python function.
args: positional arguments to be passed to the function.
kwargs: keyword arguments to be passed to the function.
Returns:
one or several RemoteValue objects depending on the function return
values.
"""
# Some notes about the concurrency: currently all the activities related to
# the same worker such as creating resources, setting resources' aborted
# status, and executing closures happen on the same thread. This allows us
# to have simpler logic of concurrency.
closure = ResourceClosure(
function,
self._cluster.closure_queue._cancellation_mgr,
args=args,
kwargs=kwargs)
resource_remote_value = closure.build_output_remote_value()
self._register_resource(resource_remote_value)
# The following is a short-term solution to lazily create resources in
# parallel.
resource_remote_value._set_aborted()
return resource_remote_value
_register_resource 则会把每个 Worker 的资源注册到 Worker 之上。
def _register_resource(self, resource_remote_value):
if not isinstance(resource_remote_value, RemoteValue):
raise ValueError("Resource being registered is not of type "
" tf.distribute.experimental.coordinator.RemoteValue .")
self._resource_remote_value_refs.append(weakref.ref(resource_remote_value))
逻辑如下,虚线表述数据流。用户通过 put 方法向队列之中放入 Closure,Worker 通过 put 方法从队列获取 Closure 执行。
7.4 停止
Stop 等一系列方法负责停止。
def stop(self):
"""Ensure the worker thread is closed."""
self._should_worker_thread_run = False
def _set_resources_aborted(self):
for weakref_resource in self._resource_remote_value_refs:
resource = weakref_resource()
if resource:
resource._set_aborted() # pylint: disable=protected-access
def _set_dead(self):
raise NotImplementedError("_set_dead is not implemented.")
7.5 与 Strategy 联系
至此,我们其实还没有正式和 Strategy 联系起来,我们再用一个例子来看看,这里会发现,传递给 coordinator 的方法之中,会调用 strategy.run(replica_fn, args=(next(iterator),)),这样就和 strategy 联系起来了。
strategy = ...
coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
strategy)
def dataset_fn():
return tf.data.Dataset.from_tensor_slices([1, 1, 1])
with strategy.scope():
v = tf.Variable(initial_value=0)
@tf.function
def worker_fn(iterator):
def replica_fn(x):
v.assign_add(x)
return v.read_value()
return strategy.run(replica_fn, args=(next(iterator),)) # 这里正式联系起来
distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn)
distributed_iterator = iter(distributed_dataset)
result = coordinator.schedule(worker_fn, args=(distributed_iterator,))
assert coordinator.fetch(result) == 1
8. Failover
8.1 策略
应对失败的总体策略大致如下:
-
当发现一个工作者失败了,Coordinator 把 function 再次放入队列,然后发给另一个工作者执行,同时启动一个后台线程等待恢复,如果恢复了,则用资源来重建这个工作者,继续分配工作。
-
因此,一些工作者的失败并不妨碍集群继续工作,这使得集群之中的实例可以偶尔不可用(例如,可抢占或spot 实例)。但是协调者和参数服务器必须始终可用,这样集群才能取得进展。
8.2 工作者失败
当发生工作者失败(failure)时候,具体逻辑如下:
- ClusterCoordinator 类与 tf.distribution.experimental.ParameterServerStrategy 一起使用时,具有内置的工作者故障容错功能。也就是说,当一些工作者由于任何原因,协调器无法联系上它们,这些工作者的训练进度将继续由其余工作者完成。
- 在工作者恢复时,之前提供的数据集函数(对于自定义训练循环,可以是 ClusterCoordinator.create_per_worker_dataset,或者是 tf.keras.utils.experimental.DatasetCreator 用于 Model.fit )将被调用到工作者身上,以重新创建数据集。
- 当一个失败的工作者恢复之后,在使用通过 create_per_worker_dataset 创建的数据被重新建立后,它将被添加到函数执行中。
8.3 参数服务器或者协调器故障
当参数服务器失败时,schedule,join 或 done 会引发 tf.errors.UnavailableError。在这种情况下,除了重置失败的参数服务器外,用户还应该重新启动协调器,使其重新连接到工作者和参数服务器,重新创建变量,并加载检查点。如果协调器发生故障,在用户把它重置回来之后,程序会自动连接到工作者和参数服务器,并从检查点继续前进。因为协调器本身也可能变得不可用。因此建议使用某些工具以便不丢失训练进度:
- 因此,在用户的程序中,必须定期保存检查点文件,并在程序开始时恢复。如果 "tf.keras.optimizers.Optimizer" 被应用 checkpoint,在从检查点恢复后,其 "iterations" 属性会大致显示已经进行的步骤数。这可以用来决定在训练完成前还需要多少个 epochs 和步骤(steps)。
- 对于 Model.fit,你应该使用 BackupAndRestore 回调,它可以自动处理进度的保存和恢复。
- 对于一个自定义的训练循环,你应该定期检查模型变量,并在训练开始前从检查点(如果有的话)加载模型变量。如果优化器有检查点,训练进度可以从 optimizer.iterations 中大致推断出来。
checkpoint_manager = tf.train.CheckpointManager(
tf.train.Checkpoint(model=model, optimizer=optimizer),
checkpoint_dir,
max_to_keep=3)
if checkpoint_manager.latest_checkpoint:
checkpoint = checkpoint_manager.checkpoint
checkpoint.restore(
checkpoint_manager.latest_checkpoint).assert_existing_objects_matched()
global_steps = int(optimizer.iterations.numpy())
starting_epoch = global_steps // steps_per_epoch
for _ in range(starting_epoch, num_epoches):
for _ in range(steps_per_epoch):
coordinator.schedule(step_fn, args=(per_worker_iterator,))
coordinator.join()
checkpoint_manager.save()
8.4 返回 RemoteValue
如果一个函数被成功执行,就可以成功获取到 RemoteValue。这是因为目前在执行完一个函数后,返回值会立即被复制到协调器。如果在复制过程中出现任何工作者故障,该函数将在另一个可用的工作者上重试。因此,如果你想优化性能,你可以安排(schedule)一个没有返回值的函数。
8.5 错误报告
一旦协调器发现一个错误,如来自参数服务器的 UnavailableError 或其他应用错误,如来自 tf.debugging.check_numerics 的 InvalidArgument,它将在引发错误之前取消所有 pending 和排队(queued)的函数。获取它们相应的 RemoteValue 将引发一个 CancelledError 。
在引发错误后,协调器将不会引发相同的错误或任何引发一个来自已取消函数的错误。
ClusterCoordinator 假设所有的函数错误都是致命的,基于这个假设,其的错误报告逻辑是:
- Schedule 和 join 都可以引发一个不可重试的错误,这是协调者从任何先前安排的函数中看到的第一个错误。
- 当一个错误被抛出时,不保证有多少先前安排的功能被执行;没有被执行的功能将被丢弃并被标记为取消。
- 在一个错误被抛出后,错误的内部状态将被清除。
8.6 WorkerPreemptionHandler
WorkerPreemptionHandler 是处理失败的主要模块,其定义如下:
class WorkerPreemptionHandler(object):
"""Handles worker preemptions."""
def __init__(self, server_def, cluster):
self._server_def = server_def
self._cluster = cluster
self._cluster_update_lock = threading.Lock()
self._cluster_due_for_update_or_finish = threading.Event()
self._worker_up_cond = threading.Condition(self._cluster_update_lock)
self._error_from_recovery = None
self._should_preemption_thread_run = True
self._preemption_handler_thread = threading.Thread(
target=self._preemption_handler,
name="WorkerPreemptionHandler",
daemon=True)
self._preemption_handler_thread.start()
8.6.1 配置
在 Cluster 生成时,会把 WorkerPreemptionHandler 配置进来。
self.failure_handler = WorkerPreemptionHandler(context.get_server_def(), self)
8.6.2 等待
在处理 closure 时,会用 wait_on_failure 包裹一层用来处理错误。
def _process_closure(self, closure):
"""Runs a closure with preemption handling."""
assert closure is not None
try:
with self._cluster.failure_handler.wait_on_failure(
on_failure_fn=lambda: self._cluster._closure_queue.put_back(closure),
on_recovery_fn=self._set_resources_aborted,
worker_device_name=self.device_name):
closure.execute_on(self)
WorkerPreemptionHandler 的 wait_on_failure 方法如下:
@contextlib.contextmanager
def wait_on_failure(self,
on_failure_fn=None,
on_transient_failure_fn=None,
on_recovery_fn=None,
worker_device_name="(unknown)"):
"""Catches worker preemption error and wait until failed workers are back.
Args:
on_failure_fn: an optional function to run if preemption happens.
on_transient_failure_fn: an optional function to run if transient failure
happens.
on_recovery_fn: an optional function to run when a worker is recovered
from preemption.
worker_device_name: the device name of the worker instance that is passing
through the failure.
Yields:
None.
"""
try:
yield
except (errors.OpError, InputError) as e:
# If the error is due to temporary connectivity issues between worker and
# ps, put back closure, ignore error and do not mark worker as failure.
if self._cluster._record_and_ignore_transient_ps_failure(e):
if on_transient_failure_fn:
on_transient_failure_fn()
return
# Ignoring derived CancelledErrors to tolerate transient failures in
# PS-worker communication, which initially exposed as an UnavailableError
# and then lead to sub-function cancellation, subsequently getting
# reported from worker to chief as CancelledError.
# We do not mark either worker or PS as failed due to only CancelledError.
# If there are real (non-transient) failures, they must also be reported
# as other errors (UnavailableError most likely) in closure executions.
if isinstance(e, errors.CancelledError) and "/job:" in str(e):
if on_transient_failure_fn:
on_transient_failure_fn()
return
# This reraises the error, if it's not considered recoverable; otherwise,
# the following failure recovery logic run. At this time, only worker
# unavailability is recoverable. PS unavailability as well as other
# errors in the user function is not recoverable.
self._validate_preemption_failure(e)
if on_failure_fn:
on_failure_fn()
with self._cluster_update_lock:
self._cluster_due_for_update_or_finish.set()
self._worker_up_cond.wait(_WORKER_MAXIMUM_RECOVERY_SEC)
if self._error_from_recovery:
try:
raise self._error_from_recovery
finally:
self._error_from_recovery = None
if on_recovery_fn:
with self.wait_on_failure(
on_recovery_fn=on_recovery_fn,
on_transient_failure_fn=on_transient_failure_fn,
worker_device_name=worker_device_name):
on_recovery_fn()
_validate_preemption_failure 定义如下:
def _validate_preemption_failure(self, e):
"""Validates that the given exception represents worker preemption."""
# Only categorize the failure as a worker preemption if the cancellation
# manager did not attempt to cancel the blocking operations.
if _is_worker_failure(e) and (
not self._cluster._closure_queue._cancellation_mgr.is_cancelled):
return
raise e
8.6.3 handler
WorkerPreemptionHandler 有一个后台线程 _preemption_handler_thread。
self._preemption_handler_thread = threading.Thread(
target=self._preemption_handler,
name="WorkerPreemptionHandler",
daemon=True)
self._preemption_handler_thread.start()
_preemption_handler 会进行必要的错误处理。
def _preemption_handler(self):
"""A loop that handles preemption.
This loop waits for signal of worker preemption and upon worker preemption,
it waits until all workers are back and updates the cluster about the
restarted workers.
"""
assert self._should_preemption_thread_run
while True:
self._cluster_due_for_update_or_finish.wait()
if not self._should_preemption_thread_run:
break
with self._cluster_update_lock:
try:
context.context().update_server_def(self._server_def)
# Cluster updated successfully, clear the update signal, and notify
# all workers that they are recovered from failure.
self._worker_up_cond.notify_all()
# The check for _should_preemption_thread_run is necessary since the
# stop may have already set _cluster_due_for_update_or_finish.
if self._should_preemption_thread_run:
self._cluster_due_for_update_or_finish.clear()
except Exception as e:
try:
self._validate_preemption_failure(e)
except Exception as ps_e:
# In this case, a parameter server fails. So we raise this error to
# the caller of wait_on_failure .
self._error_from_recovery = ps_e
self._worker_up_cond.notify_all()
if self._should_preemption_thread_run:
self._cluster_due_for_update_or_finish.clear()
# NOTE: Since the first RPC (GetStatus) of update_server_def is
# currently blocking by default, error should only happen if:
# (1) More workers failed while waiting for the previous workers to
# come back;
# (2) Worker failed when exchanging subsequent RPCs after the first
# RPC returns.
# Consider adding backoff retry logic if we see the error logged
# too frequently.
9. 总结
依据前面的代码,我们总结出来问题点如下:
-
Worker 如何知道使用哪些设备?答案是:在集群建立工作者时候,会给每一个工作者设定一个设备。
-
如何具体执行用户函数?答案是:在工作者运行 Closure 时候,会在指定运行在本工作者设备上,然后运行指定的方法(Self._function)。Self._function 是用户自定义的 function,其中可以使用 strategy.run 把训练方法分发到远端工作者进行训练。
-
如何获取数据?答案是:为每个工作者建立一个 PerWorkerValues,PerWorkerValues 是一个容纳 value 列表的容器,每个工作者从对应 PerWorkerValues 之中获取数据。