联邦学习手写体识别
主函数(train.py)
from sampling import noniid from Nets import MLP from params import args_parser from update import LocalUpdate from aggregation import FedAvg from test import img_test import pdb import torch from torchvision import transforms,datasets import numpy as np import copy import matplotlib.pyplot as plt #引入超参数 args = args_parser() args.device = torch.device("cuda:{}".format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu') #处理数据集 trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) train_dataset = datasets.MNIST('../data/mnist',train = True,download=True,transform = trans) test_dataset = datasets.MNIST('../data/mnist',train = False,download = True,transform = trans) #数据集分割 user_dict = noniid(train_dataset,100) img_size = train_dataset[0][0].shape #设置网络结构 len_in = 1 for x in img_size: len_in*=x glob_net = MLP(dim_in = len_in,dim_hidden = 20,dim_out = args.num_classes).to(args.device) print(glob_net) glob_net.train() glob_w = glob_net.state_dict() #全局模型参数 # pdb.set_trace() #训练 train_loss = []#训练损失 if args.all_clients: print('所有客户均参与训练') local_w = [glob_w for i in range(args.num_users)] for iter in range(args.epochs): local_loss = [] if not args.all_clients: local_w = [] m = max( int(args.num_users * args.frac) , 1 ) #随机选择m个客户参与训练 idx_users = np.random.choice(range(args.num_users),m,replace = False) #每轮中在每个客户端上进行单独训练 for idx in idx_users: local = LocalUpdate(args = args,dataset = train_dataset,idxs = user_dict[idx]) #传入超参数、训练集、该client的样本集合 w, loss = local.train(net=copy.deepcopy(glob_net).to(args.device)) if args.all_clients: local_w[idx] = copy.deepcopy(w) else: local_w.append(copy.deepcopy(w)) local_loss.append(copy.deepcopy(loss)) # if args.method == 'fedavg': #联邦平均,更新全局模型参数 glob_w = FedAvg(local_w,local_loss) glob_net.load_state_dict(glob_w) #计算平均损失并打印 avg_loss = sum(local_loss)/len(local_loss) print('第',iter,'轮客户端平均损失:',avg_loss) train_loss.append(avg_loss) plt.figure() plt.plot(range(len(train_loss)),train_loss) plt.ylabel('avg_loss') plt.xlabel('iterations') plt.show() glob_net.eval() train_acc,train_los = img_test(glob_net,train_dataset,args) test_acc,test_los = img_test(glob_net,test_dataset,args) print('训练集准确度:',train_acc) print('训练集损失:',train_los) print('测试集准确度:',test_acc) print('测试集准确度:',test_los) print("Training accuracy: {:.2f}".format(train_acc)) print("Testing accuracy: {:.2f}".format(test_acc)) pdb.set_trace()
非独立同分布模拟(sampling.py)
import numpy as np def noniid(dataset,user_num): a,b = 200,300 list_a = [i for i in range(a)] dict_user = {i:np.array([],dtype = 'int64') for i in range(user_num)} idx = [i for i in range(a*b)] label = dataset.train_labels.numpy() tmp = np.vstack((idx,label)) tmp = tmp[:,tmp[1,:].argsort()] idx = tmp[0,:] for i in range(user_num): randset = set(np.random.choice(list_a,2,replace = False)) list_a = list(set(list_a)-randset) for rand in randset: dict_user[i] = np.concatenate((dict_user[i],idx[rand*b:(rand+1)*b]),axis = 0) return dict_user
网络定义(Nets.py)
import torch.nn as nn import torch.nn.functional as F class MLP(nn.Module): def __init__(self,dim_in,dim_hidden,dim_out): super(MLP,self).__init__() self.input = nn.Linear(dim_in,dim_hidden) self.relu = nn.ReLU() self.dropout = nn.Dropout() self.output = nn.Linear(dim_hidden,dim_out) def forward(self,x): x = x.view(-1, x.shape[1]*x.shape[-2]*x.shape[-1]) return self.output(self.relu(self.dropout(self.input(x))))
参数定义(params.py)
#!/usr/bin/env python # -*- coding: utf-8 -*- # Python version: 3.6 import argparse def args_parser(): parser = argparse.ArgumentParser() # federated arguments parser.add_argument('--epochs', type=int, default=3, help="rounds of training") parser.add_argument('--num_users', type=int, default=100, help="number of users: K") parser.add_argument('--frac', type=float, default=0.1, help="the fraction of clients: C") parser.add_argument('--local_ep', type=int, default=5, help="the number of local epochs: E") parser.add_argument('--local_bs', type=int, default=10, help="local batch size: B") parser.add_argument('--bs', type=int, default=128, help="test batch size") parser.add_argument('--lr', type=float, default=0.01, help="learning rate") parser.add_argument('--momentum', type=float, default=0.5, help="SGD momentum (default: 0.5)") parser.add_argument('--split', type=str, default='user', help="train-test split type, user or sample") parser.add_argument('--method',type=str,default='fedavg',help='aggregation methods') # model arguments parser.add_argument('--model', type=str, default='mlp', help='model name') parser.add_argument('--kernel_num', type=int, default=9, help='number of each kind of kernel') parser.add_argument('--kernel_sizes', type=str, default='3,4,5', help='comma-separated kernel size to use for convolution') parser.add_argument('--norm', type=str, default='batch_norm', help="batch_norm, layer_norm, or None") parser.add_argument('--num_filters', type=int, default=32, help="number of filters for conv nets") parser.add_argument('--max_pool', type=str, default='True', help="Whether use max pooling rather than strided convolutions") # other arguments parser.add_argument('--dataset', type=str, default='mnist', help="name of dataset") parser.add_argument('--iid', action='store_true', help='whether i.i.d or not') parser.add_argument('--num_classes', type=int, default=10, help="number of classes") parser.add_argument('--num_channels', type=int, default=3, help="number of channels of imges") parser.add_argument('--gpu', type=int, default=0, help="GPU ID, -1 for CPU") parser.add_argument('--stopping_rounds', type=int, default=10, help='rounds of early stopping') parser.add_argument('--verbose', action='store_true', help='verbose print') parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)') parser.add_argument('--all_clients', action='store_true', help='aggregation over all clients') args = parser.parse_args() return args
客户端训练(update.py)
import torch from torch.utils.data import Dataset, DataLoader from torch import nn class DataSplit(Dataset): def __init__(self, dataset,idxs): self.dataset = dataset self.idxs = list(idxs) def __len__(self): return len(self.idxs) def __getitem__(self,item): image, label = self.dataset[self.idxs[item]] return image, label class LocalUpdate(object): def __init__(self,args,dataset = None,idxs = None): self.args = args self.loss_fun = nn.CrossEntropyLoss() self.train_batch = DataLoader(DataSplit(dataset,idxs),batch_size=self.args.local_bs,shuffle = True) def train(self,net): net.train() optimizer = torch.optim.SGD(net.parameters(),lr = self.args.lr,momentum=self.args.momentum) epoch_loss = [] for iter in range(self.args.local_ep): batch_loss = [] for tmp, (images,labels) in enumerate(self.train_batch): images, labels = images.to(self.args.device), labels.to(self.args.device) net.zero_grad() label_pre = net(images) loss = self.loss_fun(label_pre,labels) loss.backward() optimizer.step() batch_loss.append(loss.item()) epoch_loss.append(sum(batch_loss)/len(batch_loss)) return net.state_dict(), sum(epoch_loss) / len(epoch_loss)
本地客户端聚合—fedavg(aggregation.py)
import copy import torch def FedAvg(w,loss): avg_w = copy.deepcopy(w[0]) for k in avg_w.keys(): for i in range(1,len(w)): avg_w[k] += w[i][k] avg_w[k] = torch.div(avg_w[k],len(w)) return avg_w
准确性测试(test.py)
from torch.utils.data import DataLoader import torch.nn.functional as F def img_test(glob_net,dataset,args): glob_net.eval() test_loss = 0 correct = 0 dataloader = DataLoader(dataset,batch_size=args.bs) l = len(dataloader) for idx,(images,labels) in enumerate(dataloader): if args.gpu != -1: images,labels = images.cuda(),labels.cuda() pre_labels = glob_net(images) test_loss += F.cross_entropy(pre_labels,labels,reduction='sum').item() y_pred = pre_labels.data.max(1, keepdim=True)[1] correct += y_pred.eq(labels.data.view_as(y_pred)).long().cpu().sum() test_loss /= len(dataloader.dataset) accuracy = 100 * correct/len(dataloader.dataset) return accuracy,test_loss
实验结果:50轮noniid