论文笔记(8)-"Personalized Federated Learning using Hypernetworks"
这篇是ICML 2021的一篇论文,论文和代码都看了一下,配合着代码简单说一下文章思路。
Motivation
文章说PFL的难点在于用尽量少的通讯成本为每个用户提供个性化模型。然后文章列出的主要贡献也是传输成本和模型复杂度以及可以为不同算力资源的设备提供适应大小的模型,并且在结果上取得了不错的效果。
作者通过在Server
端训练一个hyper net
来为各个用户生成所需要的模型参数来实现解耦传输成本和模型复杂度。
Model Construction
文中的Hyper net
是一个多头网络,每个头输出的都是某一层的权重Tensor
。具体而言,例如对于Cifar10
,它的Hyper Net实际是这个样子
class CNNHyper(nn.Module):
def __init__(
self, n_nodes, embedding_dim, in_channels=3, out_dim=10, n_kernels=16, hidden_dim=100,
spec_norm=False, n_hidden=1):
'''
The hyper network stored in the server to generate the weight of the target network.
Args:
n_nodes: int, the total number of all nodes(users or clients)
embedding_dim: int, dimension of the embedding
in_channels: int, the channels of the input image or data.
out_dim: int, the amount of categories
n_kernels: int, the number of kernels used in CNN
hidden_dim: int, the dimension of the finnal latent layer in hypernetwork
spec_norm: Bool, whether apply the sepc norm
n_hidden: int, the number of the latent layers
'''
super().__init__()
self.in_channels = in_channels
self.out_dim = out_dim
self.n_kernels = n_kernels
self.embeddings = nn.Embedding(num_embeddings=n_nodes, embedding_dim=embedding_dim)
# Multilayer perceptron
layers = [
spectral_norm(nn.Linear(embedding_dim, hidden_dim)) if spec_norm else nn.Linear(embedding_dim, hidden_dim),
]
for _ in range(n_hidden):
layers.append(nn.ReLU(inplace=True))
layers.append(
spectral_norm(nn.Linear(hidden_dim, hidden_dim)) if spec_norm else nn.Linear(hidden_dim, hidden_dim),
)
self.mlp = nn.Sequential(*layers)
# the weights of the targe network
self.c1_weights = nn.Linear(hidden_dim, self.n_kernels * self.in_channels * 5 * 5)
self.c1_bias = nn.Linear(hidden_dim, self.n_kernels)
self.c2_weights = nn.Linear(hidden_dim, 2 * self.n_kernels * self.n_kernels * 5 * 5)
self.c2_bias = nn.Linear(hidden_dim, 2 * self.n_kernels)
self.l1_weights = nn.Linear(hidden_dim, 120 * 2 * self.n_kernels * 5 * 5)
self.l1_bias = nn.Linear(hidden_dim, 120)
self.l2_weights = nn.Linear(hidden_dim, 84 * 120)
self.l2_bias = nn.Linear(hidden_dim, 84)
self.l3_weights = nn.Linear(hidden_dim, self.out_dim * 84)
self.l3_bias = nn.Linear(hidden_dim, self.out_dim)
if spec_norm:
self.c1_weights = spectral_norm(self.c1_weights)
self.c1_bias = spectral_norm(self.c1_bias)
self.c2_weights = spectral_norm(self.c2_weights)
self.c2_bias = spectral_norm(self.c2_bias)
self.l1_weights = spectral_norm(self.l1_weights)
self.l1_bias = spectral_norm(self.l1_bias)
self.l2_weights = spectral_norm(self.l2_weights)
self.l2_bias = spectral_norm(self.l2_bias)
self.l3_weights = spectral_norm(self.l3_weights)
self.l3_bias = spectral_norm(self.l3_bias)
def forward(self, idx):
emd = self.embeddings(idx)
features = self.mlp(emd)
weights = OrderedDict({
"conv1.weight": self.c1_weights(features).view(self.n_kernels, self.in_channels, 5, 5),
"conv1.bias": self.c1_bias(features).view(-1),
"conv2.weight": self.c2_weights(features).view(2 * self.n_kernels, self.n_kernels, 5, 5),
"conv2.bias": self.c2_bias(features).view(-1),
"fc1.weight": self.l1_weights(features).view(120, 2 * self.n_kernels * 5 * 5),
"fc1.bias": self.l1_bias(features).view(-1),
"fc2.weight": self.l2_weights(features).view(84, 120),
"fc2.bias": self.l2_bias(features).view(-1),
"fc3.weight": self.l3_weights(features).view(self.out_dim, 84),
"fc3.bias": self.l3_bias(features).view(-1),
})
return weights
用户端的Target Network
结构
class CNNTarget(nn.Module):
def __init__(self, in_channels=3, n_kernels=16, out_dim=10):
super(CNNTarget, self).__init__()
self.conv1 = nn.Conv2d(in_channels, n_kernels, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(n_kernels, 2 * n_kernels, 5)
self.fc1 = nn.Linear(2 * n_kernels * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, out_dim)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(x.shape[0], -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
Optimization
然后直接看优化流程吧,对于我比较挂心的用户的特征向量\(\mathcal{v}_i\),他是直接拿用户的node_id
也就是用户的标号,embedding出来的。整个代码只有两个model的实例,分别就是Hyper Network
和Target Network
的,然后每一轮只选择一个用户,Target Network
加载根据node_id
embedding出来的特征向量计算得来的权重,并进行优化。
def train(data_name: str, data_path: str, classes_per_node: int, num_nodes: int,
steps: int, inner_steps: int, optim: str, lr: float, inner_lr: float,
embed_lr: float, wd: float, inner_wd: float, embed_dim: int, hyper_hid: int,
n_hidden: int, n_kernels: int, bs: int, device, eval_every: int, save_path: Path,
seed: int) -> None:
'''
The optimization process
Arg:
data_name: str, [Cifar10 or Cifar100]
data_path: the path of the data
classes_per_node: int, the number of classes chosen by each node
num_nodes: int, the total number of nodes or users
steps: int, the number of conmmunication rounds
inner_steps: int, the number of local graidnet steps
optim: str, sgd or adam
lr: float, learning rate of the server
inner_lr: float, learning rate of the node
embed_lr: float, learning rate of the embedding layer
wd: float, weight decay of the server
inner_wd: float, weight decay of the node
embed_dim: int, the dimension of the embedding layer output
hyper_hid: int, the dimension of the finnal hidden layer output
n_hidden: int, the number of latent layers
n_kernels: int, the number of kernnels in CNN
bs: int, batch_size
'''
###############################
# init nodes, hnet, local net #
###############################
nodes = BaseNodes(data_name, data_path, num_nodes, classes_per_node=classes_per_node,
batch_size=bs)
# setting the embedding dim according to the n_nodes
embed_dim = embed_dim
if embed_dim == -1:
logging.info("auto embedding size")
embed_dim = int(1 + num_nodes / 4)
# Build the model
if data_name == "cifar10":
hnet = CNNHyper(num_nodes, embed_dim, hidden_dim=hyper_hid, n_hidden=n_hidden, n_kernels=n_kernels)
net = CNNTarget(n_kernels=n_kernels)
elif data_name == "cifar100":
hnet = CNNHyper(num_nodes, embed_dim, hidden_dim=hyper_hid,
n_hidden=n_hidden, n_kernels=n_kernels, out_dim=100)
net = CNNTarget(n_kernels=n_kernels, out_dim=100)
else:
raise ValueError("choose data_name from ['cifar10', 'cifar100']")
hnet = hnet.to(device)
net = net.to(device)
##################
# init optimizer #
##################
embed_lr = embed_lr if embed_lr is not None else lr
optimizers = {
'sgd': torch.optim.SGD(
[
{'params': [p for n, p in hnet.named_parameters() if 'embed' not in n]},
{'params': [p for n, p in hnet.named_parameters() if 'embed' in n], 'lr': embed_lr}
], lr=lr, momentum=0.9, weight_decay=wd
),
'adam': torch.optim.Adam(params=hnet.parameters(), lr=lr)
}
optimizer = optimizers[optim]
criteria = torch.nn.CrossEntropyLoss()
################
# init metrics #
################
last_eval = -1
best_step = -1
best_acc = -1
test_best_based_on_step, test_best_min_based_on_step = -1, -1
test_best_max_based_on_step, test_best_std_based_on_step = -1, -1
step_iter = trange(steps)
results = defaultdict(list)
for step in step_iter:
hnet.train()
# select a client at random
node_id = random.choice(range(num_nodes))
# produce & load local network weights
weights = hnet(torch.tensor([node_id], dtype=torch.long).to(device))
net.load_state_dict(weights)
# init inner optimizer
inner_optim = torch.optim.SGD(
net.parameters(), lr=inner_lr, momentum=.9, weight_decay=inner_wd
)
# storing theta_i for later calculating delta theta
inner_state = OrderedDict({k: tensor.data for k, tensor in weights.items()})
# NOTE: evaluation on sent model
with torch.no_grad():
net.eval()
batch = next(iter(nodes.test_loaders[node_id]))
img, label = tuple(t.to(device) for t in batch)
pred = net(img)
prvs_loss = criteria(pred, label)
prvs_acc = pred.argmax(1).eq(label).sum().item() / len(label)
net.train()
# inner updates -> obtaining theta_tilda
for i in range(inner_steps):
net.train()
inner_optim.zero_grad()
optimizer.zero_grad()
batch = next(iter(nodes.train_loaders[node_id]))
img, label = tuple(t.to(device) for t in batch)
pred = net(img)
loss = criteria(pred, label)
loss.backward()
torch.nn.utils.clip_grad_norm_(net.parameters(), 50)
inner_optim.step()
optimizer.zero_grad()
final_state = net.state_dict()
# calculating delta theta
delta_theta = OrderedDict({k: inner_state[k] - final_state[k] for k in weights.keys()})
# calculating phi gradient
hnet_grads = torch.autograd.grad(
list(weights.values()), hnet.parameters(), grad_outputs=list(delta_theta.values())
)
# update hnet weights
for p, g in zip(hnet.parameters(), hnet_grads):
p.grad = g
torch.nn.utils.clip_grad_norm_(hnet.parameters(), 50)
optimizer.step()
step_iter.set_description(
f"Step: {step+1}, Node ID: {node_id}, Loss: {prvs_loss:.4f}, Acc: {prvs_acc:.4f}"
)
# evaluation
if step % eval_every == 0:
last_eval = step
step_results, avg_loss, avg_acc, all_acc = eval_model(nodes, num_nodes, hnet, net, criteria, device, split="test")
logging.info(f"\nStep: {step+1}, AVG Loss: {avg_loss:.4f}, AVG Acc: {avg_acc:.4f}")
results['test_avg_loss'].append(avg_loss)
results['test_avg_acc'].append(avg_acc)
_, val_avg_loss, val_avg_acc, _ = eval_model(nodes, num_nodes, hnet, net, criteria, device, split="val")
if best_acc < val_avg_acc:
best_acc = val_avg_acc
best_step = step
test_best_based_on_step = avg_acc
test_best_min_based_on_step = np.min(all_acc)
test_best_max_based_on_step = np.max(all_acc)
test_best_std_based_on_step = np.std(all_acc)
results['val_avg_loss'].append(val_avg_loss)
results['val_avg_acc'].append(val_avg_acc)
results['best_step'].append(best_step)
results['best_val_acc'].append(best_acc)
results['best_test_acc_based_on_val_beststep'].append(test_best_based_on_step)
results['test_best_min_based_on_step'].append(test_best_min_based_on_step)
results['test_best_max_based_on_step'].append(test_best_max_based_on_step)
results['test_best_std_based_on_step'].append(test_best_std_based_on_step)
if step != last_eval:
_, val_avg_loss, val_avg_acc, _ = eval_model(nodes, num_nodes, hnet, net, criteria, device, split="val")
step_results, avg_loss, avg_acc, all_acc = eval_model(nodes, num_nodes, hnet, net, criteria, device, split="test")
logging.info(f"\nStep: {step + 1}, AVG Loss: {avg_loss:.4f}, AVG Acc: {avg_acc:.4f}")
results['test_avg_loss'].append(avg_loss)
results['test_avg_acc'].append(avg_acc)
if best_acc < val_avg_acc:
best_acc = val_avg_acc
best_step = step
test_best_based_on_step = avg_acc
test_best_min_based_on_step = np.min(all_acc)
test_best_max_based_on_step = np.max(all_acc)
test_best_std_based_on_step = np.std(all_acc)
results['val_avg_loss'].append(val_avg_loss)
results['val_avg_acc'].append(val_avg_acc)
results['best_step'].append(best_step)
results['best_val_acc'].append(best_acc)
results['best_test_acc_based_on_val_beststep'].append(test_best_based_on_step)
results['test_best_min_based_on_step'].append(test_best_min_based_on_step)
results['test_best_max_based_on_step'].append(test_best_max_based_on_step)
results['test_best_std_based_on_step'].append(test_best_std_based_on_step)
save_path = Path(save_path)
save_path.mkdir(parents=True, exist_ok=True)
with open(str(save_path / f"results_{inner_steps}_inner_steps_seed_{seed}.json"), "w") as file:
json.dump(results, file, indent=4)
Summary
- 用户特征\(\mathcal{v}_i\)的获取是最让我感到奇怪的,可能用
embedding
来生成很直接,但是放在server
端去根据node_id
生成就有一种server
和node
是一种对抗的感觉,明明用户有自己的例如人口特征等用户特征数据。感觉让所有用户用这些数据去产生一个\(v_i\)更符合逻辑; - 关于他说的传输成本和模型复杂度的解耦,感觉说的模棱两可,他传输的数据和普通的
FedAvg
是一样的,他确实可以在server端训练一个很深的网络,但是用户本地的模型变复杂那他的传输成本也会提高; - 提供的代码里没有展示对不同算力资源的设备生成不同的模型。