Tensorflow Federated(TFF)框架整理(下)

之前提到的方法,完全没有提供任何的反向传播/优化过程,都是tff.templates.IterativeProcess帮我们处理好的,我们每次传入当前state和训练集就可以得到新的statemetrics。为了更好的定制我们自己的优化方法,我们需要自己编写tff.template.IterativeProcess方法,重写initializenext方法,并且自己设定优化过程。

数据类型

Federated Core 提供了以下几种类型:

  • 张量类型(tff.TensorType)。对象不仅限于在 TensorFlow 计算图中表示 TensorFlow 运算输出的 Python 的 tf.Tensor 实例,而是也可能包括可产生的数据单位,例如,作为分布聚合协议的输出。张量类型的紧凑表示法为 dtypedtype[shape]。例如,int32int32[10] 分别是整数和整数向量的类型。
  • 序列类型 (tff.SequenceType)。这些是 TFF 中等效于 TensorFlow 中 tf.data.Dataset 的具体概念的抽象。用户可以按顺序使用序列的元素,并且可以包含复杂的类型。序列类型的紧凑表示法为 T*,其中 T 是元素的类型。例如,int32* 表示整数序列。
  • 命名元组类型 (tff.StructType)。这些是 TFF 使用指定类型构造具有预定义数量元素的元组或字典式结构(无论命名与否)的方式。重要的一点是,TFF 的命名元组概念包含等效于 Python 参数元组的抽象,即元组的元素集合中有一部分(并非全部)是命名元素,还有一部分是位置元素。命名元组的紧凑表示法为 <n_1=T_1, ..., n_k=T_k>,其中 n_k 是可选元素名称,T_k 是元素类型。例如,<int32,int32> 是一对未命名整数的紧凑表示法,<X=float32,Y=float32> 是命名为 XY(可能代表平面上的一个点)的一对浮点数的紧凑表示法。元组可以嵌套,也可以与其他类型混用,例如,<X=float32,Y=float32>* 可能是一系列点的紧凑表示法。
  • 函数类型 (tff.FunctionType)。TFF 是一个函数式编程框架,其中函数被视为这些函数的紧凑表示法为 (T -> U),其中 T 为参数类型,U 为结果类型;或者,如果没有参数(虽然无参数函数是一个大部分情况下仅在 Python 级别存在的过时概念),则可以表示为 ( -> U)。例如,(int32* -> int32) 表示一种将整数序列缩减为单个整数值的函数类型。第一类值。函数最多有一个参数,并且只有一个结果。

以下类型解决 TFF 计算的分布系统方面的问题:

  • 布局类型。除了 2 个文字形式的 tff.SERVERtff.CLIENTS(可将其视为这种类型的常量)外,这种类型还没有在公共 API 中公开。它仅供内部使用,但是,将在以后的公共 API 版本中引入。该类型的紧凑表示法为 placement。布局表示扮演特定角色的系统参与者的集合。最初的版本是为了解决客户端-服务器计算的问题,其中有 2 组参与者:客户端和服务器(可将后者视为单一实例组)。但是,在更复杂的架构中,可能还有其他角色,如多层系统中的中间聚合器。这种聚合器可能执行不同类型的聚合,或者使用不同类型的数据压缩/解压缩,而不是服务器或客户端使用的类型。定义布局概念的主要目的是作为定义联合类型的基础。
  • 联合类型 (tff.FederatedType)。联合类型的值是由特定布局(如 tff.SERVERtff.CLIENTS)定义的一组系统参与者托管的值。联合类型通过布局值(因此,它是一种依赖类型), 成员组成要素(每个参与者在本地托管的内容类型),以及指定所有参与者是否在本地托管同一项目的附加部分 all_equal 进行定义。对于包含 T 类型项目(成员组成)的值的联合类型,如果每个项目由组(布局)G 托管,则其紧凑表示法为 T@G{T}@G,分别设置或不设置 all_equal 位。{int32}@CLIENTS 表示包含一组可能不同的整数;{<X=float32,Y=float32>*}@CLIENTS 表示一个联合数据集;<weights=float32[10,5],bias=float32[5]>@SERVER 表示服务器上的权重和偏差张量的命名元组。我们省略了花括号,这表示已设置 all_equal 位。
federated_float_on_clients = tff.FederatedType(tf.float32, tff.CLIENTS)  # '{float32}@CLIENTS'

函数

Federated Core 的语言是一种 λ 演算,它提供了当前在公共 API 中公开的以下编程抽象:

  • TensorFlow 计算 (tff.tf_computation)。TFF 中有一些使用 tff.tf_computation 装饰器包装为可重用组件的 TensorFlow 代码部分。这些代码一般都是函数式类型,但是与 TensorFlow 中的函数不同,它们可以接受结构化参数或返回序列类型的结构化结果。
# tensor computation is constricted in tff.federated_computation
# should be completion by the following way:
@tff.tf_computation(tff.SequenceType(tf.int32))
def add_up_integers(x):
  return x.reduce(np.int32(0), lambda x, y: x + y)

@tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))
def get_average_temperature(client_temperatures):
  return tff.federated_mean(client_temperatures)
# '({float32}@CLIENTS -> float32@SERVER)'
  • 内联函数(tff.federated_...)。这是构成大部分 FC API 的函数库,如 tff.federated_sumtff.federated_broadcast,其中大多数表示与 TFF 一起使用的分布通信算子。

  • \(\lambda\)表达式 (tff.federated_computation)。TFF 中的 \(\lambda\)表达式等效于 Python 中的 lambdadef;它包含参数名称,以及包含对该参数的引用的主体(表达式)。

# the biggest differnce between tf.computation and tff.federated_computation is the placement
@tff.tf_computation(tf.float32)
def add_half(x):
  return tf.add(x, 0.5)

@tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))
def add_half_on_clients(x):
  return tff.federated_map(add_half, x)

Work flow

一个典型的FL研究code由三种主要逻辑

  • 个人层面的TF片段,如tf.function可以独立运行在本地,如客户端的训练代码
  • TFF编排代码,帮助将个人层面的tf.function通过tff_computation整合在一起,并且通过包含其中的tff.federated_broadcasttff.federated_mean进行orchestating
  • 外部的驱动代码,如客户选择。
# data preparation
import nest_asyncio
nest_asyncio.apply()

import tensorflow as tf
import tensorflow_federated as tff

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

NUM_CLIENTS = 10
BATCH_SIZE = 20

def preprocess(dataset):

  def batch_format_fn(element):
    """Flatten a batch of EMNIST data and return a (features, label) tuple."""
    return (tf.reshape(element['pixels'], [-1, 784]), 
            tf.reshape(element['label'], [-1, 1]))

  return dataset.batch(BATCH_SIZE).map(batch_format_fn)

client_ids = sorted(emnist_train.client_ids)[:NUM_CLIENTS]
federated_train_data = [preprocess(emnist_train.create_tf_dataset_for_client(x))
  for x in client_ids
]
# model preparation
def create_keras_model():
  initializer = tf.keras.initializers.GlorotNormal(seed=0)
  return tf.keras.models.Sequential([
      tf.keras.layers.Input(shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer=initializer),
      tf.keras.layers.Softmax(),
  ])

def model_fn():
    keras_model = create_keras_model()
    return tff.learning.from_keras_model(
        keras_model,
        input_spec=federated_train_data[0].element_spec,
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

To build our own Federated Learning algorithm, there are four main components:

  1. A server-to-clients broadcast step
  2. A local client update step
  3. A client-to-server upload step
  4. A server update step

Meanwhile, we should rewrite initialize and next functions.

Method_1

Local training

本地训练是不需要tff参与的

# step 2 local training
# return client model weights
@tf.function
def client_update(model, dataset, server_weights, client_optimizer):
    client_weights = model.trainable_variables
    
    # clone server_weights, which is exactly state meaning in the previous code.
    tf.nest.map_structure(lambda x, y: x.assign(y), client_weights, server_weights)
	
    # optimization
    for batch in dataset:
        with tf.GradientTape() as tape:
            outputs = model.forward_pass(batch)
    
        grads = tape.gradient(outputs.loss, client_weights)
        grad_and_vars = zip(grads, client_weights)

        client_optimizer.apply_gradients(grad_and_vars)  # update
    
    return client_weights

输入的参数有modeldatasetserver_weightsclient_optimizer,为什么参数这么多呢?是因为tf.function不涉及任何数据placement的信息,而关于placement的部分全交给tff去处理。

Server update

跟客户端的更新一样,服务器端的更新也是不需要tff参与的

# step4
@tf.function
def server_update(model, mean_client_weights):
    model_weights = model.trainable_variables
    tf.nest.map_structure(lambda x,y: x.assign(y), model_weights, mean_client_weights)
    return model_weights

TFF snippet

现在就需要tff进行不同placement数据的整合,以及重写tff.templates.IterativeProcess的两个方法了。

# initialize method
@tff.tf_computation
def server_init():
    model = model_fn()
    return model.trainable_variables

@tff.federated_computation
def initialize_fn():
    return tff.federated_value(server_init(), tff.SERVER)  # A federated value with the given placement placement, and the member constituent value equal at all locations.
whimsy_model = model_fn()
tf_dataset_type = tff.SequenceType(whimsy_model.input_spec)  # inpute specification
model_weights_type = server_init.type_signature.result  # output specification
# there are multiple sources data and should use tff.tf_computation decoration
@tff.tf_computation(tf_dataset_type, model_weights_type)
def client_update_fn(tf_dataset, server_weights):
    model = model_fn()
    client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
    return client_update(model, tf_dataset, server_weights, client_optimizer)

@tff.tf_computation(model_weights_type)
def server_update_fn(mean_client_weights):
  model = model_fn()
  return server_update(model, mean_client_weights)

federated_server_type = tff.FederatedType(model_weights_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)


# rewrite next function
# state is server_weights.
@tff.federated_computation(federated_server_type, federated_dataset_type)
def next_fn(server_weights, federated_dataset):
    # step1. broadcast
    server_weights_at_client = tff.federated_broadcast(server_weights)
	
    # step2. local update
    client_weights = tff.federated_map(
        client_update_fn, (federated_dataset, server_weights_at_client))
	
    # step3. uploading
    mean_client_weights = tff.federated_mean(client_weights)
	
    # step4. server update
    server_weights = tff.federated_map(server_update_fn, mean_client_weights)

    return server_weights

federated_algorithm = tff.templates.IterativeProcess(
    initialize_fn=initialize_fn,
    next_fn=next_fn
)

Indeed, model is only a transitory container to keep the server state and the client model weights and this why we will initialize a model instance in the client_update_fn and server_update_fn.

central_emnist_test = emnist_test.create_tf_dataset_from_all_clients()
central_emnist_test = preprocess(central_emnist_test)

def evaluate(server_state):
  keras_model = create_keras_model()
  keras_model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]  
  )
  keras_model.set_weights(server_state)
  keras_model.evaluate(central_emnist_test)

server_state = federated_algorithm.initialize()
evaluate(server_state)
for round in range(15):
  server_state = federated_algorithm.next(server_state, federated_train_data)
evaluate(server_state)

Method_2

In the second method, an optimizer from tff.leraning.optimizers will supersede the previous one, which has initialize(<Tensorspec>) and next functions.

TF snippet

@tf.function
def client_update(model, dataset, server_weights, optimizer):
    client_weights = model.trainable_weights
    tf.nest.map_structure(lambda x, y: x.assign(y), client_weights, server_weights)
    
    trainable_tensor_specs = tf.nest.map_structure(
        lambda v: tf.TensorSpec(v.shape, v.dtype), client_weights)
    optimizer_state = client_optimizer.initialize(trainable_tensor_specs)
    
    for batch in iter(dataset):
        with tf.GradientTape() as tape:
            output = model.forward_pass(batch)
        grads = tape.gradient(outputs.loss, client_weights)
        optimizer_state, update_weights = client_optimizer.next(
            optimizer_state, client_weights, grads)
        tf.nest.map_structure(lambda a, b: a.assign(b), client_weights, update_weights)
    return tf.nest.map_structure(tf.subtract, client_weights, server_weights)  # return the cumulative gradient

# contanier, collecting server weights and server optimizer state.
@attr.s(eq=False, frozen=True, slots=True)
class ServerState(object):
    trainable_weights = attr.ib()
    optimizer_state = attr.ib()

@tf.function
def server_update(server_state, mean_model_delta, server_optimizer):
    negative_weights_delta = tf.nest.map_structure(
        lambda w: -1.0 * w, mean_model_delta)
    new_optimizer_state, updated_weights = server_optimizer.next(
        server_state.optimizer_state, server_state.trainable_weights, negative_weights_delta)
    return tff.structure.update_struct(
        server_state, 
        trainable_weights = updated_weights, 
        optimizer_state = new_optimizer_state)

TFF snippet

server_optimizer = tff.learning.optimizers.build_sgdm(learning_rate=0.05, momentum=0.9)
client_optimizer = tff.learning.optimizers.build_sgdm(learning_rate=0.01)

@tff.tf_computation
def server_init():
    model = model_fn()
    trainable_tensor_specs = tf.nest.map_structure(
        lambda v: tf.TensorSpec(v.shape, v.dtype), model.trainable_variables)
    optimizer_state = server_optimizer.initialize(trainable_tensor_specs)
    return ServerState(
        trainable_weights=model.trainable_variables,
        optimizer_state=optimizer_state)

@tff.tff_computation
def server_init_tff():
    return tff.federated_value(server_init(), tff.SERVER)

server_state_type = server_init.type_signature.result
trainable_weights_type = server_state_type.trainable_weights

@tff.tf_computation(server_state_type, trainable_weights_type)
def server_update_fn(server_state, model_delta):
    return server_update(server_state, model_delta, server_optimizer)

whimsy_model = model_fn()
tf_dataset_type = tff.SequenceType(whimsy_model.input_spec)

@tff.tf_computation(tf_dataset_type, trainable_weights_type)
def client_update_fn(dataset, server_weights):
    model = model_fn()
    return client_update(model, dataset, server_weights, client_optimizer)

federated_server_type = tff.FederatedType(server_state_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)

@tff.federated_computation(federated_server_type, federated_dataset_type)
def run_one_round(server_state, federated_dataset):
    server_weights_at_client = tff.federated_broadcast(
      server_state.trainable_weights)
    
    model_deltas = tff.federated_map(
      client_update_fn, (federated_dataset, server_weights_at_client))
    
    mean_model_delta = tff.federated_mean(model_deltas)
    
    server_state = tff.federated_map(
      server_update_fn, (server_state, mean_model_delta))
    return server_state

fedavg_process = tff.templates.IterativeProcess(
    initialize_fn=server_init_tff, next_fn=run_one_round)

Summary

The process of customizing our own tff.template.IterativeProcess class:

  1. Firstly, regardless of the placement constraint, you should complete the Tensorflow code to fulfill the client update and server update function. Usually, the input parameters for the client update function should include model, dataset, server_weights and optimizer and the output should be the cumulative grads or the new client model trainable variables. The input of the server update is rather simple, a new model and the new aggregated changes and its output is the new server state. According to your definition, the serve state can be the model trainable variables or contains other items. Both of this two functions are decorated by tf.function
  2. Secondly, server_update_fn, client_update_fn and server_init_fn should be completed and all of them are decorated by tff.tf_computation. The decoration shows that the input parameters should be placed in the same position. In the server_init_fn, the output should be a new state. In the client_update_fn, the input parameters are dataset and server_weights(Note, server_weights are the duplication and placed in the tff.CLIENTS by the tff.federated_broadcast function) and it will call the previous client update function. In the server_update_fn, the input parameters are server_state and the cumulative changes(Note, cumulative changes are aggregated by the tff.federated_mean function and placed in tff.SERVER) and call the previous server update function.
  3. Thirdly, server_init_tff and next_fn will be created and both of them are decorated by tff.federated_computation to solve the placement issues. In the server_init_tff function, it will place the value, output of the server_init function, to the tff.SERVER by the tff.federated_value function. In the next_fn, four steps in the workflow will be completed.
posted @ 2021-12-04 18:05  Neo_DH  阅读(581)  评论(0编辑  收藏  举报