联邦学习手写体识别

主函数(train.py)

from sampling import noniid
from Nets import MLP
from params import args_parser
from update import LocalUpdate
from aggregation import FedAvg
from test import img_test

import pdb
import torch
from torchvision import transforms,datasets
import numpy as np
import copy
import matplotlib.pyplot as plt

#引入超参数
args = args_parser()
args.device = torch.device("cuda:{}".format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')

#处理数据集
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST('../data/mnist',train = True,download=True,transform = trans)
test_dataset = datasets.MNIST('../data/mnist',train = False,download = True,transform = trans)
#数据集分割
user_dict = noniid(train_dataset,100)

img_size = train_dataset[0][0].shape

#设置网络结构
len_in = 1
for x in img_size:
    len_in*=x
glob_net = MLP(dim_in = len_in,dim_hidden = 20,dim_out = args.num_classes).to(args.device)
print(glob_net)
glob_net.train()

glob_w = glob_net.state_dict()  #全局模型参数
# pdb.set_trace()

#训练
train_loss = []#训练损失 

if args.all_clients:
    print('所有客户均参与训练')
    local_w = [glob_w for i in range(args.num_users)]

for iter in range(args.epochs):
    local_loss = []
    if not args.all_clients:
        local_w = []
    m = max( int(args.num_users * args.frac) , 1 )
    #随机选择m个客户参与训练
    idx_users = np.random.choice(range(args.num_users),m,replace = False)
    #每轮中在每个客户端上进行单独训练
    for idx in idx_users:
        local = LocalUpdate(args = args,dataset = train_dataset,idxs = user_dict[idx])
        #传入超参数、训练集、该client的样本集合
        w, loss = local.train(net=copy.deepcopy(glob_net).to(args.device))
        if args.all_clients:
            local_w[idx] = copy.deepcopy(w)
        else:
            local_w.append(copy.deepcopy(w))
        local_loss.append(copy.deepcopy(loss))
    # if args.method == 'fedavg':
    #联邦平均,更新全局模型参数
    glob_w = FedAvg(local_w,local_loss)
    glob_net.load_state_dict(glob_w)
    
    #计算平均损失并打印
    avg_loss = sum(local_loss)/len(local_loss)
    print('',iter,'轮客户端平均损失:',avg_loss)
    train_loss.append(avg_loss)
    
plt.figure()
plt.plot(range(len(train_loss)),train_loss)
plt.ylabel('avg_loss')
plt.xlabel('iterations')
plt.show()

glob_net.eval()
train_acc,train_los = img_test(glob_net,train_dataset,args)
test_acc,test_los = img_test(glob_net,test_dataset,args)
print('训练集准确度:',train_acc)
print('训练集损失:',train_los)

print('测试集准确度:',test_acc)
print('测试集准确度:',test_los)

print("Training accuracy: {:.2f}".format(train_acc))
print("Testing accuracy: {:.2f}".format(test_acc))

 
pdb.set_trace()

非独立同分布模拟(sampling.py)

import numpy as np

def noniid(dataset,user_num):    
    a,b = 200,300
    list_a = [i for i in range(a)]
    dict_user = {i:np.array([],dtype = 'int64') for i in range(user_num)}
    idx = [i for i in range(a*b)]
    label = dataset.train_labels.numpy()
    
    tmp = np.vstack((idx,label))
    tmp = tmp[:,tmp[1,:].argsort()]
    idx = tmp[0,:]
    
    for i in range(user_num):
        randset = set(np.random.choice(list_a,2,replace = False))
        list_a = list(set(list_a)-randset)
        for rand in randset:
            dict_user[i] = np.concatenate((dict_user[i],idx[rand*b:(rand+1)*b]),axis = 0)
    
    return dict_user

网络定义(Nets.py)

import torch.nn as nn
import torch.nn.functional as F

class MLP(nn.Module):
    def __init__(self,dim_in,dim_hidden,dim_out):
        super(MLP,self).__init__()
        self.input = nn.Linear(dim_in,dim_hidden)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout()
        self.output = nn.Linear(dim_hidden,dim_out)
    
    def forward(self,x):
        x = x.view(-1, x.shape[1]*x.shape[-2]*x.shape[-1])
        return self.output(self.relu(self.dropout(self.input(x))))

参数定义(params.py)

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import argparse

def args_parser():
    parser = argparse.ArgumentParser()
    # federated arguments
    parser.add_argument('--epochs', type=int, default=3, help="rounds of training")
    parser.add_argument('--num_users', type=int, default=100, help="number of users: K")
    parser.add_argument('--frac', type=float, default=0.1, help="the fraction of clients: C")
    parser.add_argument('--local_ep', type=int, default=5, help="the number of local epochs: E")
    parser.add_argument('--local_bs', type=int, default=10, help="local batch size: B")
    parser.add_argument('--bs', type=int, default=128, help="test batch size")
    parser.add_argument('--lr', type=float, default=0.01, help="learning rate")
    parser.add_argument('--momentum', type=float, default=0.5, help="SGD momentum (default: 0.5)")
    parser.add_argument('--split', type=str, default='user', help="train-test split type, user or sample")
    parser.add_argument('--method',type=str,default='fedavg',help='aggregation methods')
    
    # model arguments
    parser.add_argument('--model', type=str, default='mlp', help='model name')
    parser.add_argument('--kernel_num', type=int, default=9, help='number of each kind of kernel')
    parser.add_argument('--kernel_sizes', type=str, default='3,4,5',
                        help='comma-separated kernel size to use for convolution')
    parser.add_argument('--norm', type=str, default='batch_norm', help="batch_norm, layer_norm, or None")
    parser.add_argument('--num_filters', type=int, default=32, help="number of filters for conv nets")
    parser.add_argument('--max_pool', type=str, default='True',
                        help="Whether use max pooling rather than strided convolutions")

    # other arguments
    parser.add_argument('--dataset', type=str, default='mnist', help="name of dataset")
    parser.add_argument('--iid', action='store_true', help='whether i.i.d or not')
    parser.add_argument('--num_classes', type=int, default=10, help="number of classes")
    parser.add_argument('--num_channels', type=int, default=3, help="number of channels of imges")
    parser.add_argument('--gpu', type=int, default=0, help="GPU ID, -1 for CPU")
    parser.add_argument('--stopping_rounds', type=int, default=10, help='rounds of early stopping')
    parser.add_argument('--verbose', action='store_true', help='verbose print')
    parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)')
    parser.add_argument('--all_clients', action='store_true', help='aggregation over all clients')
    args = parser.parse_args()
    return args

客户端训练(update.py)

import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn

class DataSplit(Dataset):
    def __init__(self, dataset,idxs):
        self.dataset = dataset
        self.idxs = list(idxs)
        
    def __len__(self):
        return len(self.idxs)
    
    def __getitem__(self,item):
        image, label = self.dataset[self.idxs[item]]
        return image, label

class LocalUpdate(object):
    def __init__(self,args,dataset = None,idxs = None):
        self.args = args
        self.loss_fun = nn.CrossEntropyLoss()
        self.train_batch = DataLoader(DataSplit(dataset,idxs),batch_size=self.args.local_bs,shuffle = True)
    
    def train(self,net):
        net.train()
        optimizer = torch.optim.SGD(net.parameters(),lr = self.args.lr,momentum=self.args.momentum)
    
        epoch_loss = []
        for iter in range(self.args.local_ep):
            batch_loss = []
            
            for tmp, (images,labels) in enumerate(self.train_batch):
                images, labels = images.to(self.args.device), labels.to(self.args.device)
                net.zero_grad()
                label_pre = net(images)
                loss = self.loss_fun(label_pre,labels)
                loss.backward()
                optimizer.step()
                batch_loss.append(loss.item())
            epoch_loss.append(sum(batch_loss)/len(batch_loss))
        return net.state_dict(), sum(epoch_loss) / len(epoch_loss)

本地客户端聚合—fedavg(aggregation.py)

import copy
import torch

def FedAvg(w,loss):
    avg_w = copy.deepcopy(w[0])
    for k in avg_w.keys():
        for i in range(1,len(w)):
            avg_w[k] += w[i][k]
        avg_w[k] = torch.div(avg_w[k],len(w))
    return avg_w

准确性测试(test.py)

from torch.utils.data import DataLoader
import torch.nn.functional as F

def img_test(glob_net,dataset,args):
    glob_net.eval()
    test_loss = 0
    correct = 0
    dataloader = DataLoader(dataset,batch_size=args.bs)
    l = len(dataloader)
    for idx,(images,labels) in enumerate(dataloader):
        if args.gpu != -1:
            images,labels = images.cuda(),labels.cuda()
        pre_labels = glob_net(images)
        test_loss += F.cross_entropy(pre_labels,labels,reduction='sum').item()
        y_pred = pre_labels.data.max(1, keepdim=True)[1]
        correct += y_pred.eq(labels.data.view_as(y_pred)).long().cpu().sum()
    
    test_loss /= len(dataloader.dataset)
    accuracy = 100 * correct/len(dataloader.dataset)
    
    return accuracy,test_loss

实验结果:50轮noniid

 

posted @ 2022-09-14 00:35  never_surrender  阅读(123)  评论(0)    收藏  举报