联邦学习:按病态非独立同分布划分Non-IID样本
1 病态不独立同分布(Non-IID)划分算法
在博客《分布式机器学习、联邦学习、多智能体的区别和联系》中我们提到论文[1]联邦学习每个client具有数据不独立同分布(Non-IID)的性质。
联邦学习的论文多是用FEMNIST、CIFAR10、Shakespare、Synthetic等数据集对模型进行测试,这些数据集包括CV、NLP、普通分类/回归这三种不同的任务。在单次实验时,我们对原始数据集进行非独立同分布(Non-IID) 的随机采样,为\(T\)个不同非任务生成\(T\)个不同分布的数据集。
我们在博客《联邦学习:按Dirichlet分布划分Non-IID样本》中已经介绍了按照Dirichlet分布划分non-IID样本。
然而联邦学习最开始采用的数据划分方法却不是这种。这里我们重新回顾联邦学习开山论文[1],它所采用的的是一种病态非独立同分布(Pathological Non-IID)划分算法。以下我们以CIFAR10数据集的生成为例,来详细地对该论文的数据集划分与采样算法进行分析。
首先,如果选择这种划分方式,需要指定则每个client上数据集所需要的标签类型数做为超参, 该划分算法的函数原型一般如下:
def pathological_non_iid_split(dataset, n_classes, n_clients, n_classes_per_client):
我们解释一下函数的参数,这里dataset
是torch.utils.Dataset
类型的数据集,n_classes
表示数据集里样本分类数,n_client
表示client节点的数量,该函数返回一个由n_client
各自所需样本索引组成的列表client_idcs
。
接下来我们看这个函数的内容。该函数完成的功能可以概括为:先将样本按照标签进行排序;再将样本划分为n_client * n_classes_per_client
个shards(每个shard大小相等),对n_clients
中的每一个client分配n_classes_per_client
个shards(分配到client后,每个client中的shards要合并)。
首先,从数据集索引data_idcs
建立一个key为类别\(\{0,1,...,n\_classes-1\}\),value为对应样本集索引列表的字典,这在实际上这就相当于按照label对样本进行排序了。
data_idcs = list(range(len(dataset)))
label2index = {k: [] for k in range(n_classes)}
for idx in data_idcs:
_, label = dataset[idx]
label2index[label].append(idx)
sorted_idcs = []
for label in label2index:
sorted_idcs += label2index[label]
然后该函数将数据分为n_clients * n_classes_per_client
个独立同分布的shards,每个shards大小相等。然后给n_clients
中的每一个client分配n_classes_per_client
个shards(分配到client后,每个client中的shards要合并),代码如下:
def iid_divide(l, g):
"""
将列表`l`分为`g`个独立同分布的group(其实就是直接划分)
每个group都有 `int(len(l)/g)` 或者 `int(len(l)/g)+1` 个元素
返回由不同的groups组成的列表
"""
num_elems = len(l)
group_size = int(len(l) / g)
num_big_groups = num_elems - g * group_size
num_small_groups = g - num_big_groups
glist = []
for i in range(num_small_groups):
glist.append(l[group_size * i: group_size * (i + 1)])
bi = group_size * num_small_groups
group_size += 1
for i in range(num_big_groups):
glist.append(l[bi + group_size * i:bi + group_size * (i + 1)])
return glist
n_shards = n_clients * n_classes_per_client
# 一共分成n_shards个独立同分布的shards
shards = iid_divide(sorted_idcs, n_shards)
np.random.shuffle(shards)
# 然后再将n_shards拆分为n_client份
tasks_shards = iid_divide(shards, n_clients)
clients_idcs = [[] for _ in range(n_clients)]
for client_id in range(n_clients):
for shard in tasks_shards[client_id]:
# 这里shard是一个shard的数据索引(一个列表)
# += shard 实质上是在列表里并入列表
clients_idcs[client_id] += shard
最后,返回clients_idcs
return clients_idcs
2 算法测试与可视化呈现
接下来我们在EMNIST数据集上调用该函数进行测试,并进行可视化呈现。我们设client数量\(N=10\),每个client规定有两种标签类型样本。
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets
from torch.utils.data import ConcatDataset
n_clients = 10
n_classes_per_client = 2 # 每个client规定有两种标签类型
seed = 42
if __name__ == "__main__":
np.random.seed(seed)
train_data = datasets.CIFAR10(
root=".", download=True, train=True)
test_data = datasets.CIFAR10(
root=".", download=True, train=False)
classes = train_data.classes
n_classes = len(classes)
labels = np.concatenate(
[np.array(train_data.targets), np.array(test_data.targets)], axis=0)
dataset = ConcatDataset([train_data, test_data])
# 按照病态非独立同分布对数据进行Non-IID划分
client_idcs = pathological_non_iid_split(
train_data, n_classes, n_clients, n_classes_per_client)
# 展示不同client的不同label的数据分布
plt.figure(figsize=(12, 8))
label_distribution = [[] for _ in range(n_classes)]
for c_id, idc in enumerate(client_idcs):
for idx in idc:
label_distribution[labels[idx]].append(c_id)
plt.hist(label_distribution, stacked=True,
bins=np.arange(-0.5, n_clients + 1.5, 1),
label=classes, rwidth=0.5)
plt.xticks(np.arange(n_clients), ["Client %d" %
c_id for c_id in range(n_clients)])
plt.xlabel("Client ID")
plt.ylabel("Number of samples")
plt.legend(loc="upper right")
plt.title("Display Label Distribution on Different Clients")
plt.show()
最终的可视化结果如下:
可以看到,62个类别标签在不同client上的分布确实不同,且每个client上的标签类别数量为两个。
注意,这里算法保证的是每个client上标签类别的近似数量为两个,而不是保证每个client上标签类别的绝对数量为两个,因为该算法对两个类别的话是直接将按标签排序的样本切分为n_client * 2
个块,然后每个client分得2个块。比如,如果我们不使用CIFAR10数据集,而是对EMNIST数据集(一共62个类别)进行划分,就会得到下面这样的近似划分结果:
不过,该算法相比下面按照\(\alpha=1.0\)的Dirichlet分布划分的样本(EMNIST数据集)仍然具有大大的不同。这证明我们的样本划分算法是有效的。
参考
- [1] McMahan B, Moore E, Ramage D, et al. Communication-efficient learning of deep networks from decentralized data[C]//Artificial intelligence and statistics. PMLR, 2017: 1273-1282.