Tensorflow Federated(TFF)框架整理(下)
之前提到的方法,完全没有提供任何的反向传播/优化过程,都是tff.templates.IterativeProcess
帮我们处理好的,我们每次传入当前state
和训练集就可以得到新的state
和metrics
。为了更好的定制我们自己的优化方法,我们需要自己编写tff.template.IterativeProcess
方法,重写initialize
和next
方法,并且自己设定优化过程。
数据类型
Federated Core 提供了以下几种类型:
- 张量类型(
tff.TensorType
)。对象不仅限于在 TensorFlow 计算图中表示 TensorFlow 运算输出的 Python 的tf.Tensor
实例,而是也可能包括可产生的数据单位,例如,作为分布聚合协议的输出。张量类型的紧凑表示法为dtype
或dtype[shape]
。例如,int32
和int32[10]
分别是整数和整数向量的类型。 - 序列类型 (
tff.SequenceType
)。这些是 TFF 中等效于 TensorFlow 中tf.data.Dataset
的具体概念的抽象。用户可以按顺序使用序列的元素,并且可以包含复杂的类型。序列类型的紧凑表示法为T*
,其中T
是元素的类型。例如,int32*
表示整数序列。 - 命名元组类型 (
tff.StructType
)。这些是 TFF 使用指定类型构造具有预定义数量元素的元组或字典式结构(无论命名与否)的方式。重要的一点是,TFF 的命名元组概念包含等效于 Python 参数元组的抽象,即元组的元素集合中有一部分(并非全部)是命名元素,还有一部分是位置元素。命名元组的紧凑表示法为<n_1=T_1, ..., n_k=T_k>
,其中n_k
是可选元素名称,T_k
是元素类型。例如,<int32,int32>
是一对未命名整数的紧凑表示法,<X=float32,Y=float32>
是命名为X
和Y
(可能代表平面上的一个点)的一对浮点数的紧凑表示法。元组可以嵌套,也可以与其他类型混用,例如,<X=float32,Y=float32>*
可能是一系列点的紧凑表示法。 - 函数类型 (
tff.FunctionType
)。TFF 是一个函数式编程框架,其中函数被视为这些函数的紧凑表示法为(T -> U)
,其中T
为参数类型,U
为结果类型;或者,如果没有参数(虽然无参数函数是一个大部分情况下仅在 Python 级别存在的过时概念),则可以表示为( -> U)
。例如,(int32* -> int32)
表示一种将整数序列缩减为单个整数值的函数类型。第一类值。函数最多有一个参数,并且只有一个结果。
以下类型解决 TFF 计算的分布系统方面的问题:
- 布局类型。除了 2 个文字形式的
tff.SERVER
和tff.CLIENTS
(可将其视为这种类型的常量)外,这种类型还没有在公共 API 中公开。它仅供内部使用,但是,将在以后的公共 API 版本中引入。该类型的紧凑表示法为placement
。布局表示扮演特定角色的系统参与者的集合。最初的版本是为了解决客户端-服务器计算的问题,其中有 2 组参与者:客户端和服务器(可将后者视为单一实例组)。但是,在更复杂的架构中,可能还有其他角色,如多层系统中的中间聚合器。这种聚合器可能执行不同类型的聚合,或者使用不同类型的数据压缩/解压缩,而不是服务器或客户端使用的类型。定义布局概念的主要目的是作为定义联合类型的基础。- 联合类型 (
tff.FederatedType
)。联合类型的值是由特定布局(如tff.SERVER
或tff.CLIENTS
)定义的一组系统参与者托管的值。联合类型通过布局值(因此,它是一种依赖类型), 成员组成要素(每个参与者在本地托管的内容类型),以及指定所有参与者是否在本地托管同一项目的附加部分all_equal
进行定义。对于包含T
类型项目(成员组成)的值的联合类型,如果每个项目由组(布局)G
托管,则其紧凑表示法为T@G
或{T}@G
,分别设置或不设置all_equal
位。{int32}@CLIENTS
表示包含一组可能不同的整数;{<X=float32,Y=float32>*}@CLIENTS
表示一个联合数据集;<weights=float32[10,5],bias=float32[5]>@SERVER
表示服务器上的权重和偏差张量的命名元组。我们省略了花括号,这表示已设置all_equal
位。federated_float_on_clients = tff.FederatedType(tf.float32, tff.CLIENTS) # '{float32}@CLIENTS'
函数
Federated Core 的语言是一种 λ 演算,它提供了当前在公共 API 中公开的以下编程抽象:
- TensorFlow 计算 (
tff.tf_computation
)。TFF 中有一些使用tff.tf_computation
装饰器包装为可重用组件的 TensorFlow 代码部分。这些代码一般都是函数式类型,但是与 TensorFlow 中的函数不同,它们可以接受结构化参数或返回序列类型的结构化结果。
# tensor computation is constricted in tff.federated_computation
# should be completion by the following way:
@tff.tf_computation(tff.SequenceType(tf.int32))
def add_up_integers(x):
return x.reduce(np.int32(0), lambda x, y: x + y)
@tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))
def get_average_temperature(client_temperatures):
return tff.federated_mean(client_temperatures)
# '({float32}@CLIENTS -> float32@SERVER)'
-
内联函数(
tff.federated_...
)。这是构成大部分 FC API 的函数库,如tff.federated_sum
或tff.federated_broadcast
,其中大多数表示与 TFF 一起使用的分布通信算子。 -
\(\lambda\)表达式 (
tff.federated_computation
)。TFF 中的 \(\lambda\)表达式等效于 Python 中的lambda
或def
;它包含参数名称,以及包含对该参数的引用的主体(表达式)。
# the biggest differnce between tf.computation and tff.federated_computation is the placement
@tff.tf_computation(tf.float32)
def add_half(x):
return tf.add(x, 0.5)
@tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))
def add_half_on_clients(x):
return tff.federated_map(add_half, x)
Work flow
一个典型的FL研究code由三种主要逻辑
- 个人层面的TF片段,如
tf.function
可以独立运行在本地,如客户端的训练代码 - TFF编排代码,帮助将个人层面的
tf.function
通过tff_computation
整合在一起,并且通过包含其中的tff.federated_broadcast
和tff.federated_mean
进行orchestating
- 外部的驱动代码,如客户选择。
# data preparation
import nest_asyncio
nest_asyncio.apply()
import tensorflow as tf
import tensorflow_federated as tff
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()
NUM_CLIENTS = 10
BATCH_SIZE = 20
def preprocess(dataset):
def batch_format_fn(element):
"""Flatten a batch of EMNIST data and return a (features, label) tuple."""
return (tf.reshape(element['pixels'], [-1, 784]),
tf.reshape(element['label'], [-1, 1]))
return dataset.batch(BATCH_SIZE).map(batch_format_fn)
client_ids = sorted(emnist_train.client_ids)[:NUM_CLIENTS]
federated_train_data = [preprocess(emnist_train.create_tf_dataset_for_client(x))
for x in client_ids
]
# model preparation
def create_keras_model():
initializer = tf.keras.initializers.GlorotNormal(seed=0)
return tf.keras.models.Sequential([
tf.keras.layers.Input(shape=(784,)),
tf.keras.layers.Dense(10, kernel_initializer=initializer),
tf.keras.layers.Softmax(),
])
def model_fn():
keras_model = create_keras_model()
return tff.learning.from_keras_model(
keras_model,
input_spec=federated_train_data[0].element_spec,
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
To build our own Federated Learning algorithm, there are four main components:
- A server-to-clients broadcast step
- A local client update step
- A client-to-server upload step
- A server update step
Meanwhile, we should rewrite initialize
and next
functions.
Method_1
Local training
本地训练是不需要tff
参与的
# step 2 local training
# return client model weights
@tf.function
def client_update(model, dataset, server_weights, client_optimizer):
client_weights = model.trainable_variables
# clone server_weights, which is exactly state meaning in the previous code.
tf.nest.map_structure(lambda x, y: x.assign(y), client_weights, server_weights)
# optimization
for batch in dataset:
with tf.GradientTape() as tape:
outputs = model.forward_pass(batch)
grads = tape.gradient(outputs.loss, client_weights)
grad_and_vars = zip(grads, client_weights)
client_optimizer.apply_gradients(grad_and_vars) # update
return client_weights
输入的参数有model
,dataset
,server_weights
,client_optimizer
,为什么参数这么多呢?是因为tf.function
不涉及任何数据placement
的信息,而关于placement
的部分全交给tff
去处理。
Server update
跟客户端的更新一样,服务器端的更新也是不需要tff
参与的
# step4
@tf.function
def server_update(model, mean_client_weights):
model_weights = model.trainable_variables
tf.nest.map_structure(lambda x,y: x.assign(y), model_weights, mean_client_weights)
return model_weights
TFF snippet
现在就需要tff
进行不同placement
数据的整合,以及重写tff.templates.IterativeProcess
的两个方法了。
# initialize method
@tff.tf_computation
def server_init():
model = model_fn()
return model.trainable_variables
@tff.federated_computation
def initialize_fn():
return tff.federated_value(server_init(), tff.SERVER) # A federated value with the given placement placement, and the member constituent value equal at all locations.
whimsy_model = model_fn()
tf_dataset_type = tff.SequenceType(whimsy_model.input_spec) # inpute specification
model_weights_type = server_init.type_signature.result # output specification
# there are multiple sources data and should use tff.tf_computation decoration
@tff.tf_computation(tf_dataset_type, model_weights_type)
def client_update_fn(tf_dataset, server_weights):
model = model_fn()
client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
return client_update(model, tf_dataset, server_weights, client_optimizer)
@tff.tf_computation(model_weights_type)
def server_update_fn(mean_client_weights):
model = model_fn()
return server_update(model, mean_client_weights)
federated_server_type = tff.FederatedType(model_weights_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)
# rewrite next function
# state is server_weights.
@tff.federated_computation(federated_server_type, federated_dataset_type)
def next_fn(server_weights, federated_dataset):
# step1. broadcast
server_weights_at_client = tff.federated_broadcast(server_weights)
# step2. local update
client_weights = tff.federated_map(
client_update_fn, (federated_dataset, server_weights_at_client))
# step3. uploading
mean_client_weights = tff.federated_mean(client_weights)
# step4. server update
server_weights = tff.federated_map(server_update_fn, mean_client_weights)
return server_weights
federated_algorithm = tff.templates.IterativeProcess(
initialize_fn=initialize_fn,
next_fn=next_fn
)
Indeed, model
is only a transitory container to keep the server state and the client model weights and this why we will initialize a model instance in the client_update_fn
and server_update_fn
.
central_emnist_test = emnist_test.create_tf_dataset_from_all_clients()
central_emnist_test = preprocess(central_emnist_test)
def evaluate(server_state):
keras_model = create_keras_model()
keras_model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
)
keras_model.set_weights(server_state)
keras_model.evaluate(central_emnist_test)
server_state = federated_algorithm.initialize()
evaluate(server_state)
for round in range(15):
server_state = federated_algorithm.next(server_state, federated_train_data)
evaluate(server_state)
Method_2
In the second method, an optimizer from tff.leraning.optimizers
will supersede the previous one, which has initialize(<Tensorspec>)
and next
functions.
TF snippet
@tf.function
def client_update(model, dataset, server_weights, optimizer):
client_weights = model.trainable_weights
tf.nest.map_structure(lambda x, y: x.assign(y), client_weights, server_weights)
trainable_tensor_specs = tf.nest.map_structure(
lambda v: tf.TensorSpec(v.shape, v.dtype), client_weights)
optimizer_state = client_optimizer.initialize(trainable_tensor_specs)
for batch in iter(dataset):
with tf.GradientTape() as tape:
output = model.forward_pass(batch)
grads = tape.gradient(outputs.loss, client_weights)
optimizer_state, update_weights = client_optimizer.next(
optimizer_state, client_weights, grads)
tf.nest.map_structure(lambda a, b: a.assign(b), client_weights, update_weights)
return tf.nest.map_structure(tf.subtract, client_weights, server_weights) # return the cumulative gradient
# contanier, collecting server weights and server optimizer state.
@attr.s(eq=False, frozen=True, slots=True)
class ServerState(object):
trainable_weights = attr.ib()
optimizer_state = attr.ib()
@tf.function
def server_update(server_state, mean_model_delta, server_optimizer):
negative_weights_delta = tf.nest.map_structure(
lambda w: -1.0 * w, mean_model_delta)
new_optimizer_state, updated_weights = server_optimizer.next(
server_state.optimizer_state, server_state.trainable_weights, negative_weights_delta)
return tff.structure.update_struct(
server_state,
trainable_weights = updated_weights,
optimizer_state = new_optimizer_state)
TFF snippet
server_optimizer = tff.learning.optimizers.build_sgdm(learning_rate=0.05, momentum=0.9)
client_optimizer = tff.learning.optimizers.build_sgdm(learning_rate=0.01)
@tff.tf_computation
def server_init():
model = model_fn()
trainable_tensor_specs = tf.nest.map_structure(
lambda v: tf.TensorSpec(v.shape, v.dtype), model.trainable_variables)
optimizer_state = server_optimizer.initialize(trainable_tensor_specs)
return ServerState(
trainable_weights=model.trainable_variables,
optimizer_state=optimizer_state)
@tff.tff_computation
def server_init_tff():
return tff.federated_value(server_init(), tff.SERVER)
server_state_type = server_init.type_signature.result
trainable_weights_type = server_state_type.trainable_weights
@tff.tf_computation(server_state_type, trainable_weights_type)
def server_update_fn(server_state, model_delta):
return server_update(server_state, model_delta, server_optimizer)
whimsy_model = model_fn()
tf_dataset_type = tff.SequenceType(whimsy_model.input_spec)
@tff.tf_computation(tf_dataset_type, trainable_weights_type)
def client_update_fn(dataset, server_weights):
model = model_fn()
return client_update(model, dataset, server_weights, client_optimizer)
federated_server_type = tff.FederatedType(server_state_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)
@tff.federated_computation(federated_server_type, federated_dataset_type)
def run_one_round(server_state, federated_dataset):
server_weights_at_client = tff.federated_broadcast(
server_state.trainable_weights)
model_deltas = tff.federated_map(
client_update_fn, (federated_dataset, server_weights_at_client))
mean_model_delta = tff.federated_mean(model_deltas)
server_state = tff.federated_map(
server_update_fn, (server_state, mean_model_delta))
return server_state
fedavg_process = tff.templates.IterativeProcess(
initialize_fn=server_init_tff, next_fn=run_one_round)
Summary
The process of customizing our own tff.template.IterativeProcess
class:
- Firstly, regardless of the placement constraint, you should complete the
Tensorflow
code to fulfill the client update and server update function. Usually, the input parameters for the client update function should includemodel
,dataset
,server_weights
andoptimizer
and the output should be the cumulative grads or the new client model trainable variables. The input of the server update is rather simple, a new model and the new aggregated changes and its output is the new server state. According to your definition, the serve state can be the model trainable variables or contains other items. Both of this two functions are decorated bytf.function
- Secondly,
server_update_fn
,client_update_fn
andserver_init_fn
should be completed and all of them are decorated bytff.tf_computation
. The decoration shows that the input parameters should be placed in the same position. In theserver_init_fn
, the output should be a new state. In theclient_update_fn
, the input parameters aredataset
andserver_weights
(Note, server_weights are the duplication and placed in thetff.CLIENTS
by thetff.federated_broadcast
function) and it will call the previous client update function. In theserver_update_fn
, the input parameters areserver_state
and thecumulative changes
(Note,cumulative changes
are aggregated by thetff.federated_mean
function and placed intff.SERVER
) and call the previousserver update
function.- Thirdly,
server_init_tff
andnext_fn
will be created and both of them are decorated bytff.federated_computation
to solve the placement issues. In theserver_init_tff
function, it will place the value, output of theserver_init
function, to thetff.SERVER
by thetff.federated_value
function. In thenext_fn
, four steps in the workflow will be completed.