FedR代码学习文档

main.py

参数设置,进入主函数

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # parser.add_argument('--data_path', default='Fed_data/WN18RR-Fed3.pkl', type=str)
    parser.add_argument('--data_path', default='Fed_data/DDB14-Fed3.pkl', type=str)
    parser.add_argument('--name', default='wn18rr_fed3_fed_TransE', type=str)
    parser.add_argument('--state_dir', '-state_dir', default='./state', type=str)
    parser.add_argument('--log_dir', '-log_dir', default='./log', type=str)
    parser.add_argument('--tb_log_dir', '-tb_log_dir', default='./tb_log', type=str)
    parser.add_argument('--run_mode', default='FedR', choices=['FedE', 'Single', 'test_pretrain'])
    parser.add_argument('--num_multi', default=3, type=int)

    parser.add_argument('--model', default='TransE', choices=['TransE', 'RotatE', 'DistMult', 'ComplEx'])

    # one task hyperparam
    parser.add_argument('--one_client_idx', default=0, type=int)
    parser.add_argument('--max_epoch', default=10000, type=int)
    parser.add_argument('--log_per_epoch', default=1, type=int)
    parser.add_argument('--check_per_epoch', default=10, type=int)


    parser.add_argument('--batch_size', default=512, type=int)
    parser.add_argument('--test_batch_size', default=16, type=int)
    parser.add_argument('--num_neg', default=256, type=int)
    parser.add_argument('--lr', default=0.001, type=int)

    # for FedE
    parser.add_argument('--num_client', default=3, type=int)
    parser.add_argument('--max_round', default=10000, type=int)
    parser.add_argument('--local_epoch', default=3, type=int)
    parser.add_argument('--fraction', default=1, type=float)
    parser.add_argument('--log_per_round', default=1, type=int)
    parser.add_argument('--check_per_round', default=5, type=int)

    parser.add_argument('--early_stop_patience', default=5, type=int)
    parser.add_argument('--gamma', default=10.0, type=float)
    parser.add_argument('--epsilon', default=2.0, type=float)
    parser.add_argument('--hidden_dim', default=128, type=int)
    parser.add_argument('--gpu', default='0', type=str)
    parser.add_argument('--num_cpu', default=10, type=int)
    parser.add_argument('--adversarial_temperature', default=1.0, type=float)

    # parser.add_argument('--negative_adversarial_sampling', default=True, type=bool)
    parser.add_argument('--seed', default=12345, type=int)

    args = parser.parse_args()
    args_str = json.dumps(vars(args))

    args.gpu = torch.device('cuda:' + args.gpu)
    # args.gpu = torch.device(("cuda:" + args.gpu) if torch.cuda.is_available() else "cpu")

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    init_dir(args)
    writer = SummaryWriter(os.path.join(args.tb_log_dir, args.name))
    args.writer = writer
    init_logger(args)
    logging.info(args_str)

    if args.run_mode == 'FedR':
        all_data = pickle.load(open(args.data_path, 'rb'))
        learner = FedR(args, all_data)
        learner.train()
    elif args.run_mode == 'Single':
        all_data = pickle.load(open(args.data_path, 'rb'))
        data = all_data[args.one_client_idx]
        learner = KGERunner(args, data)
        learner.train()

数据导入

.pkl形式的数据 (通过csv的代码可以进行转换)
这里的数据分给三个客户端,每个客户端当中又有train,valid,test

  • edge_index:是一个二维数组,表示第i个三元组的起始节点和终止节点
  • edge_type:表示第i个三元组的relation
  • edge_index_ori:
  • edge_type_ori:
,train,test,valid
0,"{'edge_index': array([[3515, 3614, 3299, ...,  246, 3912, 2853],
       [ 961, 2501,  703, ...,  211, 1904,  442]], dtype=int64), 'edge_type': array([1, 9, 1, ..., 7, 1, 1], dtype=int64), 'edge_index_ori': array([[1796, 4767, 3939, ...,  345, 3215, 4054],
       [1787, 3036,  950, ...,  341, 3204,  537]], dtype=int64), 'edge_type_ori': array([2, 8, 2, ..., 7, 2, 2], dtype=int64)}","{'edge_index': array([[ 392,  822,  331, ..., 1207,  247,  902],
       [ 199,  261,  175, ...,  802,  195,  540]], dtype=int64), 'edge_type': array([1, 1, 1, ..., 1, 2, 1], dtype=int64), 'edge_index_ori': array([[ 424,  655,  373, ..., 1531,  364, 1047],
       [ 416,  530,  366, ..., 1530,  360,  527]], dtype=int64), 'edge_type_ori': array([2, 2, 2, ..., 2, 3, 2], dtype=int64)}","{'edge_index': array([[ 358, 1204, 2395, ...,  210, 1139,  371],
       [2581,  813, 1564, ...,  211,  583,  288]], dtype=int64), 'edge_type': array([1, 1, 6, ..., 8, 1, 1], dtype=int64), 'edge_index_ori': array([[ 393, 1601, 2983, ...,  229, 1223,  572],
       [2692, 1580, 1188, ...,  341,  644,  554]], dtype=int64), 'edge_type_ori': array([2, 2, 5, ..., 6, 2, 2], dtype=int64)}"
1,"{'edge_index': array([[4881, 5080, 2512, ...,  531,  876, 3547],
       [4882,   30,   38, ...,  532,  574,   95]], dtype=int64), 'edge_type': array([10,  0,  1, ...,  1,  4,  0], dtype=int64), 'edge_index_ori': array([[8515, 2880, 4337, ..., 3921,  721, 1996],
       [8391, 2695, 4333, ...,  234, 1556, 2442]], dtype=int64), 'edge_type_ori': array([0, 2, 4, ..., 4, 5, 2], dtype=int64)}","{'edge_index': array([[ 309, 1661, 2880, ..., 1861, 1831,  652],
       [  72, 2083, 2154, ...,  127, 2730, 3940]], dtype=int64), 'edge_type': array([0, 0, 1, ..., 0, 0, 0], dtype=int64), 'edge_index_ori': array([[ 402, 4842,  827, ..., 2229, 2742,  228],
       [ 379, 5955,  826, ..., 2256,  890, 5910]], dtype=int64), 'edge_type_ori': array([2, 2, 4, ..., 2, 2, 2], dtype=int64)}","{'edge_index': array([[1821,   91, 1049, ...,   59,  749,  560],
       [1800,  398, 2353, ...,  511,   34,  381]], dtype=int64), 'edge_type': array([0, 0, 0, ..., 0, 0, 0], dtype=int64), 'edge_index_ori': array([[4224,   62,  620, ..., 1059, 2398, 1696],
       [5502, 3330, 4833, ...,  266,   13,  125]], dtype=int64), 'edge_type_ori': array([2, 2, 2, ..., 2, 2, 2], dtype=int64)}"
2,"{'edge_index': array([[1048, 5151, 2026, ..., 3552, 1835,  897],
       [ 286,   33, 2180, ..., 4712, 1836,   56]], dtype=int64), 'edge_type': array([2, 2, 0, ..., 8, 6, 2], dtype=int64), 'edge_index_ori': array([[5172, 4261, 1779, ..., 6148,  222, 1803],
       [6817, 6663, 1859, ..., 9069, 2987, 6810]], dtype=int64), 'edge_type_ori': array([ 0,  0,  2, ...,  8, 12,  0], dtype=int64)}","{'edge_index': array([[ 508, 5263, 1230, ...,  577, 1646,  439],
       [ 649, 4329,  649, ..., 1496, 2298,  598]], dtype=int64), 'edge_type': array([0, 0, 0, ..., 0, 0, 0], dtype=int64), 'edge_index_ori': array([[ 630, 1847, 1297, ...,  266,  576,  876],
       [1256, 8628, 1256, ..., 4247, 1734, 3295]], dtype=int64), 'edge_type_ori': array([2, 2, 2, ..., 2, 2, 2], dtype=int64)}","{'edge_index': array([[  91, 1798, 1622, ..., 2358, 4665,  427],
       [ 672, 2482,   82, ...,  506,  136, 1011]], dtype=int64), 'edge_type': array([0, 3, 0, ..., 0, 0, 0], dtype=int64), 'edge_index_ori': array([[ 527, 3702,  566, ..., 1209,  251,   39],
       [ 224, 2336, 4523, ...,  666, 1172, 4816]], dtype=int64), 'edge_type_ori': array([2, 5, 2, ..., 2, 2, 2], dtype=int64)}"

数据分发

1.将隐私数据分发到客户机 (客户拥有),初始化服务器
2.统计客户机测试集、验证集的数量,以及权重数量

class FedR(object):
    def __init__(self, args, all_data):
        self.args = args

        train_dataloader_list, valid_dataloader_list, test_dataloader_list, \
            self.rel_freq_mat, ent_embed_list, nrelation = get_all_clients(all_data, args)

        self.args.nrelation = nrelation # question

        # client
        self.num_clients = len(train_dataloader_list)
        # Create client objects for each client
        self.clients = []
        for i in range(self.num_clients):
            client = Client(args, i, all_data[i], train_dataloader_list[i], valid_dataloader_list[i],
                            test_dataloader_list[i], ent_embed_list[i])
            self.clients.append(client)

        # Create the server object
        self.server = Server(args, nrelation)

        #   统计客户机测试集、验证集的数量,以及权重数量
        # Calculate total test data size and test evaluation weights
        self.total_test_data_size = 0
        for client in self.clients:
            self.total_test_data_size += len(client.test_dataloader.dataset)

        self.test_eval_weights = []
        for client in self.clients:
            weight = len(client.test_dataloader.dataset) / self.total_test_data_size
            self.test_eval_weights.append(weight)

        # Calculate total valid data size and valid evaluation weights
        self.total_valid_data_size = 0
        for client in self.clients:
            self.total_valid_data_size += len(client.valid_dataloader.dataset)

        self.valid_eval_weights = []
        for client in self.clients:
            weight = len(client.valid_dataloader.dataset) / self.total_valid_data_size
            self.valid_eval_weights.append(weight)
对初始数据集进行分发

这段代码定义了一个函数 get_all_clients(all_data, args),用于为每个客户端(client)创建数据加载器(dataloader),准备训练、验证和测试数据,以及创建实体嵌入(entity embeddings)和关系频率(relation frequency)信息。函数的具体步骤如下:

  1. all_rel = np.array([], dtype=int)初始化一个空的 NumPy 数组 all_rel 用于存储所有客户端训练数据中的关系类型。
  2. for data in all_data:遍历所有客户端的数据。
  3. all_rel = np.union1d(all_rel, data['train']['edge_type_ori']).reshape(-1)将当前客户端的训练数据中的关系类型与之前收集的关系类型进行合并(去重),并更新 all_rel 数组。
  4. nrelation = len(all_rel)计算所有客户端训练数据中的不同关系类型数量,并将结果存储在变量 nrelation 中。
  5. 初始化用于存储数据加载器、实体嵌入和关系频率的列表:train_dataloader_listvalid_dataloader_listtest_dataloader_listent_embed_listrel_freq_list
  6. for data in tqdm(all_data):再次遍历所有客户端的数据,并为每个客户端构建数据加载器和相关数据。
  7. nentity = len(np.unique(data['train']['edge_index']))计算当前客户端训练数据中的不同实体数量。
  8. 构建训练、验证和测试数据集:使用客户端数据中的边索引、边类型等信息,创建对应的数据集对象(TrainDatasetTestDataset)。
    9.client_mask_rel = np.setdiff1d(np.arange(nrelation), np.unique(data['train']['edge_type_ori'].reshape(-1)), assume_unique=True):这段代码的作用是计算在整个数据集中存在的所有可能关系类型(np.arange(nrelation))中,但在特定客户端的训练数据中缺失的关系类型。计算结果将存储在变量 client_mask_rel 中。(客户端中特有的数据集)
  9. 创建训练、验证和测试数据加载器:使用数据集对象,设置批量大小等参数,并创建对应的数据加载器(train_dataloadervalid_dataloadertest_dataloader),用于在模型训练和评估中使用。该代码段的目的是创建一个用于训练的数据加载器 train_dataloader,它将按照指定的批量大小和洗牌选项加载训练数据,并使用 TrainDataset.collate_fn 函数对数据进行处理。这样,在模型训练时,可以通过遍历 train_dataloader 来获取训练数据的批量,便于对模型进行训练。
  10. 创建实体嵌入向量:根据客户端训练数据中的实体数量和模型的隐藏维度,初始化对应大小的实体嵌入向量,其中实体嵌入向量的值在一定范围内随机初始化。
  11. 统计关系频率:根据客户端训练数据中的关系类型,统计每个关系类型在整个数据集中出现的频率,将结果存储在 rel_freq 中,并添加到 rel_freq_list 列表中。
  12. rel_freq_mat = torch.stack(rel_freq_list).to(args.gpu)将关系频率列表转换为一个 PyTorch 张量 rel_freq_mat,并将其移到 GPU(如果可用)。
  13. 返回 train_dataloader_listvalid_dataloader_listtest_dataloader_listrel_freq_matent_embed_listnrelation
def get_all_clients(all_data, args):
    all_rel = np.array([], dtype=int)
    for data in all_data:
        all_rel = np.union1d(all_rel, data['train']['edge_type_ori']).reshape(-1)
    nrelation = len(all_rel) # all relations of training set in all clients

    train_dataloader_list = []
    test_dataloader_list = []
    valid_dataloader_list = []

    ent_embed_list = []

    rel_freq_list = []

    for data in tqdm(all_data): # in a client
        nentity = len(np.unique(data['train']['edge_index'])) # entities of training in a client

        train_triples = np.stack((data['train']['edge_index'][0],
                                  data['train']['edge_type_ori'],
                                  data['train']['edge_index'][1])).T

        valid_triples = np.stack((data['valid']['edge_index'][0],
                                  data['valid']['edge_type_ori'],
                                  data['valid']['edge_index'][1])).T

        test_triples = np.stack((data['test']['edge_index'][0],
                                 data['test']['edge_type_ori'],
                                 data['test']['edge_index'][1])).T

        client_mask_rel = np.setdiff1d(np.arange(nrelation),
                                       np.unique(data['train']['edge_type_ori'].reshape(-1)), assume_unique=True)

        all_triples = np.concatenate([train_triples, valid_triples, test_triples]) # in a client
        train_dataset = TrainDataset(train_triples, nentity, args.num_neg)
        valid_dataset = TestDataset(valid_triples, all_triples, nentity, client_mask_rel)
        test_dataset = TestDataset(test_triples, all_triples, nentity, client_mask_rel)

        # dataloader,数据划分
        train_dataloader = DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=True,
            collate_fn=TrainDataset.collate_fn
        )
        train_dataloader_list.append(train_dataloader)

        valid_dataloader = DataLoader(
            valid_dataset,
            batch_size=args.test_batch_size,
            collate_fn=TestDataset.collate_fn
        )
        valid_dataloader_list.append(valid_dataloader)

        test_dataloader = DataLoader(
            test_dataset,
            batch_size=args.test_batch_size,
            collate_fn=TestDataset.collate_fn
        )
        test_dataloader_list.append(test_dataloader)

        embedding_range = torch.Tensor([(args.gamma + args.epsilon) / args.hidden_dim])

        '''use n of entity in train or all (train, valid, test)?'''
        if args.model in ['RotatE', 'ComplEx']:
            ent_embed = torch.zeros(nentity, args.hidden_dim*2).to(args.gpu).requires_grad_()
        else:
            ent_embed = torch.zeros(nentity, args.hidden_dim).to(args.gpu).requires_grad_()
        nn.init.uniform_(
            tensor=ent_embed,
            a=-embedding_range.item(),
            b=embedding_range.item()
        )
        ent_embed_list.append(ent_embed)

        rel_freq = torch.zeros(nrelation)
        for r in data['train']['edge_type_ori'].reshape(-1):
            rel_freq[r] += 1
        rel_freq_list.append(rel_freq)

    rel_freq_mat = torch.stack(rel_freq_list).to(args.gpu)

    return train_dataloader_list, valid_dataloader_list, test_dataloader_list, \
           rel_freq_mat, ent_embed_list, nrelation

这里提供了两个类 TrainDatasetTestDataset,它们都继承自 PyTorch 的 Dataset 类,用于处理训练、验证和测试数据。

  1. TrainDataset 类:
    • __init__(self, triples, nentity, negative_sample_size): 类的构造函数,接收 triples(包含训练三元组的数组)、nentity(实体的数量)和 negative_sample_size(负样本数目)作为输入参数。
    • self.hr2t: 该字典用于存储从头实体-关系对 (h, r) 到尾实体 t 的映射。它是一个 defaultdict,对于每个 (h, r) 键,值是一个包含所有尾实体 t 的数组。这样,可以通过 (h, r) 键来快速找到对应的尾实体数组。
  2. TestDataset 类:
    • __init__(self, triples, all_true_triples, nentity, rel_mask=None): 类的构造函数,接收 triples(包含测试或验证三元组的数组)、all_true_triples(包含整个数据集中所有正确三元组的数组)、nentity(实体的数量)和 rel_mask(关系掩码数组)作为输入参数。
    • self.hr2t_all: 该字典类似于 TrainDataset 中的 self.hr2t,用于存储从头实体-关系对 (h, r) 到尾实体 t 的映射。它是一个 defaultdict,对于每个 (h, r) 键,值是一个包含整个数据集中所有尾实体 t 的数组。
    • self.rel_mask: 这是一个可选的参数,用于控制在特定客户端的验证和测试数据中是否考虑某些关系。如果提供了 rel_mask,它会被用于过滤掉在客户端的训练数据中存在的关系类型。这些类用于处理数据集,其中 TrainDataset 主要用于训练数据,而 TestDataset 主要用于验证和测试数据。它们提供了一种组织和访问数据的方式,以便在训练和评估模型时使用。
class TrainDataset(Dataset):
    def __init__(self, triples, nentity, negative_sample_size):
        self.len = len(triples)
        self.triples = triples
        self.nentity = nentity
        self.negative_sample_size = negative_sample_size

        self.hr2t = ddict(set)
        for h, r, t in triples:
            self.hr2t[(h, r)].add(t)
        for h, r in self.hr2t:
            self.hr2t[(h, r)] = np.array(list(self.hr2t[(h, r)]))
class TestDataset(Dataset):
    def __init__(self, triples, all_true_triples, nentity, rel_mask = None):
        self.len = len(triples)
        self.triple_set = all_true_triples
        self.triples = triples
        self.nentity = nentity

        self.rel_mask = rel_mask

        self.hr2t_all = ddict(set)
        for h, r, t in all_true_triples:
            self.hr2t_all[(h, r)].add(t)
客户端的数据分发

每个客户端都有数据,并且拥有自己的模型

class Client(object):
    def __init__(self, args, client_id, data, train_dataloader,
                 valid_dataloader, test_dataloader, ent_embed):
        self.args = args
        self.data = data
        self.train_dataloader = train_dataloader
        self.valid_dataloader = valid_dataloader
        self.test_dataloader = test_dataloader
        self.ent_embed = ent_embed
        self.client_id = client_id

        self.score_local = []
        self.score_global = []

        self.kge_model = KGEModel(args, args.model)
        self.rel_embed = None

这段代码定义了一个 KGEModel 类,并在构造函数 __init__ 中初始化该类的一些属性。

  1. def __init__(self, args, model_name):: 这是 KGEModel 类的构造函数,它接收两个参数 argsmodel_name
  2. super(KGEModel, self).__init__(): 这是调用父类的构造函数,即 nn.Module 类的构造函数,以确保正确初始化继承的属性。
  3. self.model_name = model_name: 这一行将传递给构造函数的 model_name 参数赋值给类的属性 self.model_name。它用于在模型中标识模型的名称。
  4. self.embedding_range = torch.Tensor([(args.gamma + args.epsilon) / args.hidden_dim]): 这一行计算并设置了模型中的 embedding_range 属性。embedding_range 是一个大小为 1 的张量,其值为 (args.gamma + args.epsilon) / args.hidden_dim,它用于限制实体嵌入和关系嵌入的初始化范围。
  5. self.gamma = nn.Parameter(torch.Tensor([args.gamma]), requires_grad=False): 这一行定义了一个名为 gamma 的参数(nn.Parameter),其初始值是 args.gammann.Parameter 是 PyTorch 中的特殊张量类型,它允许将张量标记为模型参数,并可以在模型训练过程中自动优化。在这里,requires_grad=False 表示 gamma 参数在训练过程中不需要计算梯度。
    综上所述,这段代码初始化了 KGEModel 类的属性 model_nameembedding_rangegamma。其中,embedding_range 用于限制实体嵌入和关系嵌入的初始化范围,gamma 是一个固定的模型参数,不需要在训练过程中计算梯度。这些属性在后续模型的构建和训练中可能会被用到。
class KGEModel(nn.Module):
    def __init__(self, args, model_name):
        super(KGEModel, self).__init__()
        self.model_name = model_name
        self.embedding_range = torch.Tensor([(args.gamma + args.epsilon) / args.hidden_dim])
        self.gamma = nn.Parameter(
            torch.Tensor([args.gamma]),
            requires_grad=False
        )
服务器的数据分发

1.embedding_range = torch.Tensor([(args.gamma + args.epsilon) / args.hidden_dim]):这行代码计算了关系嵌入向量初始化的范围 embedding_range。参数 args.gamma 和 args.epsilon 是模型的一些超参数,用于控制关系嵌入向量初始化范围的大小。args.hidden_dim 是模型中嵌入向量的维度。
2.self.rel_embed = torch.zeros(nrelation, args.hidden_dim*2).to(args.gpu).requires_grad_():如果模型类型是 'ComplEx',则创建一个形状为 (nrelation, args.hidden_dim2) 的全零张量 self.rel_embed,用于存储关系嵌入向量。nrelation 是关系的数量,args.hidden_dim2 是每个关系嵌入向量的维度。通过 .to(args.gpu) 将张量放置在指定的 GPU 上(如果使用了 GPU)。最后,通过 requires_grad_() 方法指定张量需要计算梯度,用于后续的模型训练和优化。
3.nn.init.uniform_(tensor=self.rel_embed, a=-embedding_range.item(), b=embedding_range.item()):这行代码使用均匀分布初始化关系嵌入向量 self.rel_embed。关系嵌入向量的值被随机采样自均匀分布,范围是从-embedding_range.item() 到 embedding_range.item()。

    def __init__(self, args, nrelation):
        self.args = args
        embedding_range = torch.Tensor([(args.gamma + args.epsilon) / args.hidden_dim])
        if args.model in ['ComplEx']:
            self.rel_embed = torch.zeros(nrelation, args.hidden_dim*2).to(args.gpu).requires_grad_()
        else:
            self.rel_embed = torch.zeros(nrelation, args.hidden_dim).to(args.gpu).requires_grad_()
        nn.init.uniform_(
            tensor=self.rel_embed,
            a=-embedding_range.item(),
            b=embedding_range.item()
        )
        self.nrelation = nrelation

模型训练

1.best_epoch = 0, best_mrr = 0, bad_count = 0:初始化一些变量,best_epoch 用于存储在训练过程中获得最佳性能的轮次,best_mrr 用于存储最佳的 Mean Reciprocal Rank (MRR) 值,bad_count 用于记录模型性能没有提升的轮次数量。
2.mrr_plot_result = [], loss_plot_result = []:初始化空列表,用于保存每轮评估时的 MRR 值和每轮训练时的损失值。
3.for num_round in range(self.args.max_round):外层循环是训练的主循环,根据 self.args.max_round 指定的最大轮次进行训练。
4.n_sample = max(round(self.args.fraction * self.num_clients), 1):根据 self.args.fraction 和客户端的数量 self.num_clients 计算出本轮次选择的客户端数量 n_sample,保证选择的客户端数量不少于 1。
5.self.send_emb():将服务器relation向量传到客户机中
6.round_loss = 0: 初始化一个变量 round_loss 用于记录当前轮次的总损失值。
7.for k in iter(sample_set):这是内层循环,遍历本轮次选择的客户端。
8.self.server.aggregation(self.clients, self.rel_freq_mat):执行一个函数 aggregation(),该函数可能是用于在服务器端聚合从客户端接收到的更新,以更新全局模型参数。在分布式学习中,通常需要将不同客户端的更新进行聚合,以得到全局的模型。
9.if num_round % self.args.check_per_round == 0 and num_round != 0:检查是否到了评估模型的轮次,self.args.check_per_round 是指定的评估间隔。num_round != 0 确保从第一轮之后才进行评估。
10.eval_res = self.evaluate(): 执行一个函数 evaluate(),该函数用于评估当前轮次模型在验证集上的性能,并返回评估结果。
11.if eval_res['mrr'] > best_mrr:判断当前轮次的 MRR 是否优于最佳 MRR,如果是,则更新最佳 MRR 和最佳轮次 best_mrr 和 best_epoch。
12.bad_count += 1logging.info('best model is at round {0}, mrr {1:.4f}, bad count {2}'.format(best_epoch, best_mrr, bad_count)):如果当前轮次 MRR 不如最佳 MRR,则增加 bad_count 记录性能没有提升的轮次数量,并打印当前最佳轮次的信息。
13.if bad_count >= self.args.early_stop_patience:检查是否达到早停止条件。self.args.early_stop_patience 是设定的早停止容忍度,如果连续 bad_count 轮模型性能没有提升,则触发早停止。
14.self.save_model(best_epoch):保存获得最佳性能的模型参数。
15.self.before_test_load()self.evaluate(istest=True):在最后完成训练后,加载之前保存的最佳模型参数,并在测试集上进行最终的模型评估。

    def train(self):
        best_epoch = 0
        best_mrr = 0
        bad_count = 0

        mrr_plot_result = []
        loss_plot_result = []

        for num_round in range(self.args.max_round):
            n_sample = max(round(self.args.fraction * self.num_clients), 1)
            sample_set = np.random.choice(self.num_clients, n_sample, replace=False)

            self.send_emb()
            round_loss = 0
            for k in iter(sample_set):#不同客户机的损失值相加
                client_loss = self.clients[k].client_update()
                round_loss += client_loss
            round_loss /= n_sample
            self.server.aggregation(self.clients, self.rel_freq_mat)

            logging.info('round: {} | loss: {:.4f}'.format(num_round, np.mean(round_loss)))
            self.write_training_loss(np.mean(round_loss), num_round)

            loss_plot_result.append(np.mean(round_loss))

            if num_round % self.args.check_per_round == 0 and num_round != 0:
                eval_res = self.evaluate()
                self.write_evaluation_result(eval_res, num_round)

                if eval_res['mrr'] > best_mrr:
                    best_mrr = eval_res['mrr']
                    best_epoch = num_round
                    logging.info('best model | mrr {:.4f}'.format(best_mrr))
                    self.save_checkpoint(num_round)
                    bad_count = 0
                else:
                    bad_count += 1
                    logging.info('best model is at round {0}, mrr {1:.4f}, bad count {2}'.format(
                        best_epoch, best_mrr, bad_count))

                mrr_plot_result.append(eval_res['mrr'])

            if bad_count >= self.args.early_stop_patience:
                logging.info('early stop at round {}'.format(num_round))

                loss_file_name = 'loss/' + self.args.name + '_loss.pkl'
                with open(loss_file_name, 'wb') as fp:
                    pickle.dump(loss_plot_result, fp)

                mrr_file_name = 'loss/' + self.args.name + '_mrr.pkl'
                with open(mrr_file_name, 'wb') as fp:
                    pickle.dump(mrr_plot_result, fp)

                break

        logging.info('finish training')
        logging.info('save best model')
        self.save_model(best_epoch)
        self.before_test_load()
        self.evaluate(istest=True)

客户机的训练

在这段代码中,GPU计算是通过PyTorch的自动并行化机制实现的。当使用GPU时,PyTorch会自动将涉及到GPU上的张量的操作转换为GPU计算。在此过程中,你不需要手动编写特定的GPU计算代码,而只需将相关的张量和模型放置在GPU上,PyTorch会自动处理计算的细节。

以下是GPU计算的关键步骤:

  1. 将张量放置在GPU上:在代码中,通过 .to(self.args.gpu) 将张量 self.rel_embedself.ent_embed 放置在GPU上,这样在后续的计算中,它们就会在GPU上执行。

  2. GPU并行计算:在 forward 函数中,当使用 head, relation, tail 这些在GPU上的张量进行计算时,PyTorch会自动在GPU上并行地执行张量操作。这样, score 张量也会在GPU上计算得到。
    s

  3. 自动求导:client_update 函数中,优化器 optim.Adam 用于执行梯度下降。你提供了要优化的参数列表 [{'params': self.rel_embed}, {'params': self.ent_embed}] 给优化器。这样,当 loss.backward() 被调用时,PyTorch会自动计算梯度,并将梯度信息存储在 self.rel_embedself.ent_embed 这些张量的 .grad 属性中。优化器在调用 optimizer.step() 时使用这些梯度来更新模型参数。

loss.backward() 是PyTorch中用于计算梯度的方法。在深度学习中,我们通常通过最小化损失函数来优化模型参数,使得模型的预测结果尽可能接近真实标签。梯度表示损失函数关于模型参数的变化率,它告诉我们在当前参数值下,向哪个方向改变参数能够减少损失函数的值。在训练过程中,我们首先计算出损失函数的值 loss,然后通过调用 loss.backward() 方法来自动计算梯度。具体地,PyTorch会根据 loss 进行反向传播,沿着计算图反向计算每个参数的梯度,将梯度信息存储在参数张量的 .grad 属性中。在之后的优化过程中,我们将使用优化器(例如Adam优化器)根据计算得到的梯度信息来更新模型参数,使得损失函数逐渐减小,从而优化模型。所以,loss.backward() 与Adam优化器的参数更新过程是密切相关的。loss.backward() 计算得到梯度,而优化器根据这些梯度来更新模型参数,进而优化模型。在调用 loss.backward() 后,PyTorch会自动计算 loss 相对于模型参数 self.rel_embed 和 self.ent_embed 的梯度,并将梯度信息存储在对应张量的 .grad 属性中。然后,在调用 optimizer.step() 时,优化器将使用这些梯度信息来更新模型参数。optimizer.step() 方法会根据优化器的算法和学习率来执行参数更新。
总结:PyTorch提供了自动的GPU计算和自动求导机制,通过简单的代码调用,可以在GPU上执行张量操作和自动计算梯度。这样,你可以更专注于模型设计和训练过程的控制,而不必手动编写GPU计算和梯度更新的代码。

在这段代码中,for batch in self.train_dataloader: 是一个迭代器,用于从训练数据加载器 self.train_dataloader 中逐批次获取数据。每次迭代,batch 变量就会得到一批训练样本,其中包含三个张量:positive_samplenegative_samplesample_idx
这是由于在构造数据加载器时,collate_fn 参数指定为 TrainDataset.collate_fn 函数。这个函数用于对数据进行收集和处理,确保每个批次中的样本有相同的形状,方便模型进行批量处理。在 TrainDataset.collate_fn 函数中,将 positive_samplenegative_samplesample_idx 组合成一个元组并返回,因此在迭代器中,每次获取的 batch 就是这个元组。
如果想进一步了解 TrainDataset.collate_fn 函数的具体实现,可以查看数据集类 TrainDataset 中的 collate_fn 方法。这个方法通常会进行填充、对齐或其他数据处理操作,确保生成的 positive_samplenegative_samplesample_idx 在每个批次中都具有相同的维度和形状。这样在模型训练时,每次获取的数据都是可以直接用于模型计算的批量样本。
这三个张量的内容表示了一个批次(batch)的训练数据:

  1. positive_sample: 这个张量的内容是一个大小为 (batch_size, 3) 的整数张量,其中 batch_size 是批次大小,每一行包含一个正样本的三元组。每一行的三个元素分别表示正样本的头实体、关系和尾实体。例如,第一行的三个元素是 [1415, 11, 271],表示第一个正样本的头实体是实体 1415,关系是 11,尾实体是 271。
  2. negative_sample: 这个张量的内容是一个大小为 (batch_size, num_neg) 的整数张量,其中 num_neg 是每个正样本对应的负样本数量。每一行包含一个正样本对应的负样本三元组。负样本是通过将正样本中的头实体或尾实体替换为其他实体生成的。例如,第一行的内容是一个包含多个实体的列表,表示第一个正样本对应的负样本的头实体集合。
  3. sample_idx: 这个张量的内容是一个大小为 (batch_size,) 的整数张量,表示每个样本的索引或标识符。这个张量可能在代码中并未详细使用,但它通常用于跟踪每个样本的来源和标识符。在这个例子中,这个张量是一个从 8759 到 10290 的整数序列,对应于该批次中每个样本的索引。
    请注意,这些张量的具体内容可能因数据不同而变化,但它们的维度和含义保持不变:positive_sample 包含了正样本的三元组,negative_sample 包含了负样本的实体集合,sample_idx 是样本的索引或标识符。这样的数据格式允许模型在一个批次内同时处理正样本和负样本,以及跟踪每个样本的来源和标识符。

在这段代码中,positive_score是一个形状为(batch_size, 1)的二维张量(或者称为矩阵),其中batch_size表示批次大小,而1表示每个样本的得分。F.logsigmoid()是一个PyTorch函数,用于计算输入张量的元素的log sigmoid值。然后,使用.squeeze(dim=1)函数将positive_score张量从形状(batch_size, 1)压缩成形状(batch_size,)

为什么要使用.squeeze(dim=1)呢?由于positive_score的形状是(batch_size, 1),即每个样本的得分值都被包装在一个额外的维度中,但在这个上下文中,我们希望得到一个形状为(batch_size,)的一维张量,即每个样本的单个得分值。使用.squeeze(dim=1)的作用就是去掉那个单一的维度,使得positive_score变成一个形状为(batch_size,)的一维张量,其中每个元素是对应样本的log sigmoid得分值。

    def client_update(self):
        optimizer = optim.Adam([{'params': self.rel_embed},
                                {'params': self.ent_embed}], lr=self.args.lr)

        losses = []
        for i in range(self.args.local_epoch):
            for batch in self.train_dataloader:
                positive_sample, negative_sample, sample_idx = batch

                positive_sample = positive_sample.to(self.args.gpu)
                negative_sample = negative_sample.to(self.args.gpu)
                #这里会调用kge_model的forward函数
                negative_score = self.kge_model((positive_sample, negative_sample),
                                                 self.rel_embed, self.ent_embed)

                negative_score = (F.softmax(negative_score * self.args.adversarial_temperature, dim=1).detach()
                                  * F.logsigmoid(-negative_score)).sum(dim=1)

                positive_score = self.kge_model(positive_sample,
                                                self.rel_embed, self.ent_embed, neg=False)

                positive_score = F.logsigmoid(positive_score).squeeze(dim=1)

                positive_sample_loss = - positive_score.mean()
                negative_sample_loss = - negative_score.mean()

                loss = (positive_sample_loss + negative_sample_loss) / 2

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                losses.append(loss.item())

        return np.mean(losses)
计算分数f

分别为positive_sample和negative_sample计算分数、先执行forward函数,再执行TransE函数
这段代码实现了一个KGE(Knowledge Graph Embedding)模型的前向传播过程。这个KGE模型根据不同的model_name来选择不同的子模型函数进行计算,包括TransEDistMultComplExRotatE
在前向传播的过程中,输入sample是一个张量,包含了一批样本的三元组信息,每一行代表一个三元组,分别是头实体、关系和尾实体的索引。relation_embedding是关系嵌入矩阵,entity_embedding是实体嵌入矩阵。
根据neg参数的不同取值,前向传播过程分为两种情况:

  1. neg=False时,表示计算正样本的得分,即头实体和尾实体之间的关系得分。通过索引操作从entity_embeddingrelation_embedding中选择对应的头实体、关系和尾实体的嵌入向量,并使用模型函数进行计算得分。
  2. neg=True时,表示计算负样本的得分。这种情况下,sample参数被划分为head_parttail_part两部分,分别对应着正样本的头实体和尾实体部分。首先,根据head_part索引选择头实体和关系的嵌入向量。然后,根据tail_part索引选择负样本对应的尾实体的嵌入向量,如果tail_part为None,则直接将整个entity_embedding作为负样本的嵌入向量。最后,使用模型函数进行计算得分。
    对于每个模型函数(TransEDistMultComplExRotatE),前向传播计算的是模型的得分(分数),用于衡量给定三元组的合理性。
    总之,这段代码实现了KGE模型的前向传播过程,通过选择不同的模型函数计算头实体、关系和尾实体之间的得分,用于训练和评估KGE模型的性能。

正样本的计算:
这段代码是前向传播过程中的一部分,用于计算正样本的得分。
在这段代码中,neg参数为False,表示计算正样本的得分。sample是一个张量,包含了一批正样本的三元组信息,每一行代表一个三元组,分别是头实体、关系和尾实体的索引。entity_embedding是实体嵌入矩阵,relation_embedding是关系嵌入矩阵。
代码分别进行以下操作:

  1. entity_embedding中根据sample中的头实体索引选择对应的头实体嵌入向量,然后使用unsqueeze(1)将头实体的维度扩展为(batch_size, 1, embedding_dim),其中batch_size为当前批次中的样本数量,embedding_dim为嵌入向量的维度。这样操作的目的是为了与后续的计算保持统一的维度。
  2. relation_embedding中根据sample中的关系索引选择对应的关系嵌入向量,同样使用unsqueeze(1)将关系的维度扩展为(batch_size, 1, embedding_dim)
  3. entity_embedding中根据sample中的尾实体索引选择对应的尾实体嵌入向量,同样使用unsqueeze(1)将尾实体的维度扩展为(batch_size, 1, embedding_dim)
    最终,得到的headrelationtail分别表示当前批次中所有正样本的头实体、关系和尾实体的嵌入向量,这些嵌入向量将用于后续模型函数的计算,以得到正样本的得分(分数)。

让我们用一个简单的例子来解释dim=0unsqueeze(1)的含义。
假设我们有一个entity_embedding的张量,其形状为(5, 3),其中包含5个实体的3维嵌入向量:

entity_embedding = torch.tensor([[1, 2, 3],
                                 [4, 5, 6],
                                 [7, 8, 9],
                                 [10, 11, 12],
                                 [13, 14, 15]])

现在,我们有一个sample张量,形状为(2, 3),表示两个正样本的索引信息:

sample = torch.tensor([[2, 1, 4],
                       [3, 0, 2]])

现在,我们来解释dim=0unsqueeze(1)的含义:

  1. dim=0:在PyTorch中,dim参数用于指定在哪个维度上进行操作。在我们的例子中,dim=0表示沿着第0维度(行)进行操作。例如,如果我们想从entity_embedding中选择行索引为2的向量,我们可以使用torch.index_select(entity_embedding, dim=0, index=torch.tensor([2]))
    2.unsqueeze(1)操作:unsqueeze(1)是PyTorch中的一个函数,用于在指定位置插入新的维度。它的作用是将现有维度扩展为更高维度。例如,对于一个形状为(3,)的张量,执行unsqueeze(1)操作后,形状将变为(3, 1)。
    让我们看一个具体的例子:
import torch

# 定义一个一维张量
x = torch.tensor([1, 2, 3])

# 使用unsqueeze(1)在列维度上插入新的维度
y = x.unsqueeze(1)

print(x.shape)  # 输出: torch.Size([3])
print(y.shape)  # 输出: torch.Size([3, 1])

print(x)
# 输出: tensor([1, 2, 3])
print(y)
# 输出: tensor([[1],
#               [2],
#               [3]])

在上面的例子中,我们有一个形状为(3,)的张量x,然后我们使用unsqueeze(1)在列维度上插入新的维度,得到了形状为(3, 1)的张量y。可以看到,y成为了一个列向量。

在你提供的代码中,unsqueeze(1)的作用是将一个一维张量扩展为列向量,这在处理矩阵运算时很常见,尤其在深度学习中。

负样本的计算:
当计算负样本时(neg=True),else 代码块负责计算这些负样本的得分。以下是对 else 代码块内部 Python 语法的详细解释:

  1. head_part, tail_part = sample:这行代码解包 sample 张量,获取头实体和尾实体的索引。head_part 包含负样本的头实体索引,tail_part 包含负样本的尾实体索引。这些索引是从提供的负样本张量 sample 中获得的。
  2. batch_size = head_part.shape[0]batch_size 变量被设置为当前批次中负样本的数量。它通过获取 head_part 张量的第一个维度的大小(即负样本的数量)来获得。
  3. head = torch.index_select(...):通过使用索引 head_part[:, 0],从 entity_embedding 张量中选择适当的行(即负样本的头实体)来计算头实体的张量 head。然后使用 unsqueeze(1) 增加维度,使得 head 张量的形状变为 (batch_size, 1, embedding_dim)。这样处理是为了后续计算方便,因为张量 head 现在是三维的,其中第一个维度表示批次大小,第二个维度为 1(因为每个头实体只有一个),第三个维度为嵌入的维度。
  4. relation = torch.index_select(...):类似于步骤 3,通过使用索引 head_part[:, 1],从 relation_embedding 张量中选择适当的行(即负样本的关系)来计算关系的张量 relation。然后同样使用 unsqueeze(1) 增加维度,使得 relation 张量的形状变为 (batch_size, 1, embedding_dim)
  5. if tail_part is None::这个条件块检查 tail_part 是否为 None。如果为 None,意味着在这批负样本中没有提供负尾实体。在这种情况下,tail 张量被设置为 entity_embedding.unsqueeze(0),这意味着尾实体嵌入保持与原始实体嵌入相同。这种情况发生在未应用负采样时。
  6. else::如果 tail_part 不是 None,则表示在这批负样本中提供了负尾实体。通过使用 tail_part.view(-1) 索引 entity_embedding 张量,计算 tail 张量以获取尾实体的嵌入。view(-1) 操作用于将 tail_part 张量展平,以便用于索引 entity_embedding
  7. tail = ...view(batch_size, negative_sample_size, -1):在索引 entity_embedding 后,对 tail 张量进行形状重塑,使其大小为 (batch_size, negative_sample_size, embedding_dim),其中 batch_size 是负样本的数量,negative_sample_size 是每个负样本中负尾实体的数量,embedding_dim 是嵌入的维度。这种重塑是为了正确处理批次中每个头实体的多个负尾实体。
    综上所述,此代码块根据提供的索引(head_parttail_part)和嵌入张量(entity_embeddingrelation_embedding)计算负样本的头实体、关系和尾实体的嵌入。然后,使用计算得到的嵌入来计算 KGEModel 中的负样本得分。
    def forward(self, sample, relation_embedding, entity_embedding, neg=True):
        if not neg:
            head = torch.index_select(
                entity_embedding,
                dim=0,
                index=sample[:, 0]
            ).unsqueeze(1)

            relation = torch.index_select(
                relation_embedding,
                dim=0,
                index=sample[:, 1]
            ).unsqueeze(1)

            tail = torch.index_select(
                entity_embedding,
                dim=0,
                index=sample[:, 2]
            ).unsqueeze(1)
        else:
            head_part, tail_part = sample
            batch_size = head_part.shape[0]

            head = torch.index_select(
                entity_embedding,
                dim=0,
                index=head_part[:, 0]
            ).unsqueeze(1)

            relation = torch.index_select(
                relation_embedding,
                dim=0,
                index=head_part[:, 1]
            ).unsqueeze(1)

            if tail_part == None:
                tail = entity_embedding.unsqueeze(0)
            else:
                negative_sample_size = tail_part.size(1)
                tail = torch.index_select(
                    entity_embedding,
                    dim=0,
                    index=tail_part.view(-1)
                ).view(batch_size, negative_sample_size, -1)
            
        model_func = {
            'TransE': self.TransE,
            'DistMult': self.DistMult,
            'ComplEx': self.ComplEx,
            'RotatE': self.RotatE,
        }

        score = model_func[self.model_name](head, relation, tail)
        
        return score

在这段代码中,首先进行了一系列数学运算,然后得到了一个分数(score)作为输出。

  1. (head + relation) - tail: 这一步是将三个输入张量 headrelationtail 进行逐元素的加法和减法操作。这些张量的形状应该是一致的,因为它们是表示实体(entities)和关系(relations)的嵌入向量(embeddings)。这一步的目的是计算每个三元组 (head, relation, tail) 对应的得分。
  2. torch.norm(score, p=1, dim=2): 这一步是计算 score 张量的范数,其中 p=1 表示计算 L1 范数(也称为曼哈顿距离或绝对值范数)。参数 dim=2 表示在第二个维度上计算范数。由于 score 张量的形状是 (batch_size, 1),其中 batch_size 是每次处理的样本数,因此在第二个维度上进行范数计算就是对每个样本计算范数。
  3. self.gamma.item(): 这里的 self.gamma 是一个张量,.item() 方法用于提取这个张量中的标量值。在这个例子中,self.gamma 应该是一个包含一个元素的张量,因此.item() 用于获取该张量的值作为标量。
  4. score = self.gamma.item() - torch.norm(score, p=1, dim=2): 这一步将之前计算得到的 L1 范数从 self.gamma 得到的标量值中减去,从而得到最终的得分 score。这个得分将用于表示特定三元组 (head, relation, tail) 的模型评分。
    最终,这个 score 张量包含了每个样本的评分,用于表示模型对每个三元组的打分,评估其符合事实的程度。
    在第二个维度上进行计算意味着对张量中的第二个维度(也可以称为轴或列)上的元素进行操作。在PyTorch中,dim参数用于指定在哪个维度上进行计算,这允许我们在张量的特定维度上进行各种操作,例如求和、平均、范数等。

假设有一个形状为 (3, 4) 的张量 score,其中包含三个样本,每个样本由四个元素组成。这个张量的第二个维度表示每个样本的维度,也可以理解为每个样本的特征或属性。

示例张量 score

score = [[1.5, 2.2, 0.8, 3.1],
         [2.7, 0.9, 2.4, 1.2],
         [0.9, 1.3, 1.8, 2.6]]

现在,我们要计算 score 的 L1 范数,并且在第二个维度上进行计算,即对每个样本进行 L1 范数的计算。
L1 范数是将向量中每个元素的绝对值相加。对于张量 score,我们分别对每个样本的四个元素进行 L1 范数的计算。
计算结果为:

L1_norm = [1.5 + 2.2 + 0.8 + 3.1,
           2.7 + 0.9 + 2.4 + 1.2,
           0.9 + 1.3 + 1.8 + 2.6]
L1_norm = [7.6, 7.2, 6.6]

因此,通过在第二个维度上进行 L1 范数的计算,我们得到了一个包含三个元素的张量 L1_norm,其中每个元素表示对应样本的 L1 范数。在这个例子中,得到的 L1 范数分别为 7.6、7.2 和 6.6。

在给定的例子中,矩阵score是一个形状为(3, 4)的二维矩阵,其中第一维度有3个元素,第二维度有4个元素。我们来看一下在第一维度求和和在第二维度求和的区别:

  1. 第一维度求和:
    在第一维度求和意味着我们将每一列中的元素相加,得到一个新的一维张量,其长度等于矩阵的第一维度长度。在这个例子中,我们对每一列进行求和,得到了一个长度为4的一维张量:

[1.5 + 2.7 + 0.9, 2.2 + 0.9 + 1.3, 0.8 + 2.4 + 1.8, 3.1 + 1.2 + 2.6]

  1. 第二维度求和:
    在第二维度求和意味着我们将每一行中的元素相加,得到一个新的一维张量,其长度等于矩阵的第二维度长度。在这个例子中,我们对每一行进行求和,得到了一个长度为3的一维张量:

[1.5 + 2.2 + 0.8 + 3.1, 2.7 + 0.9 + 2.4 + 1.2, 0.9 + 1.3 + 1.8 + 2.6]

总结:在第一维度求和会将矩阵的每一列元素相加,得到长度为列数的一维张量;而在第二维度求和会将矩阵的每一行元素相加,得到长度为行数的一维张量。

    def TransE(self, head, relation, tail):
        score = (head + relation) - tail
        score = self.gamma.item() - torch.norm(score, p=1, dim=2)
        return score

服务器的聚合

1.agg_rel_mask = rel_update_weights:agg_rel_mask 是一个与 rel_update_weights 维度相同的张量,它用于保存每个关系在三个客户端中的权重。
2.agg_rel_mask[rel_update_weights != 0] = 1:将 rel_update_weights 中非零元素对应的位置标记为1,即将关系在三个客户端中存在的标记为1,不存在的标记为0。
3.rel_w_sum = torch.sum(agg_rel_mask, dim=0):对 agg_rel_mask 沿着维度0(关系维度)求和,得到每个关系在三个客户端中出现的次数。
4.rel_w = agg_rel_mask / rel_w_sum:将每个关系在三个客户端中出现的次数进行归一化,得到每个关系的权重。
5.rel_w[torch.isnan(rel_w)] = 0:由于某些关系在三个客户端中都不存在,可能导致归一化时得到 NaN(Not a Number)。这里将这些 NaN 设置为0,表示这些关系的权重为0。
6.根据模型类型,初始化 update_rel_embed 张量,形状为 (self.nrelation, self.args.hidden_dim) 或 (self.nrelation, self.args.hidden_dim * 2)。
7.遍历每个客户端,计算每个客户端的局部关系嵌入向量 local_rel_embed,并将其乘以对应的关系权重 rel_w[i],然后加到 update_rel_embed 中。这样做的目的是对所有客户端的关系嵌入向量进行加权聚合,以得到一个全局的关系嵌入向量。
这行代码将当前客户端的局部关系嵌入向量 local_rel_embed 与其对应的关系权重 rel_w[i] 相乘,并加到 update_rel_embed 中。 由于 local_rel_embed 的形状为 (self.nrelation, self.args.hidden_dim),而 rel_w[i] 的形状为 (self.nrelation,),所以通过 rel_w[i].reshape(-1, 1) 将其转换为形状为 (self.nrelation, 1),这样两个张量可以进行逐元素乘法。
8.最后,将聚合后的全局关系嵌入向量 update_rel_embed 赋值给服务器端的 self.rel_embed,并标记 self.rel_embed 为可求导的(requires_grad_()),以便在训练过程中更新关系嵌入向量。这样服务器端的关系嵌入向量就更新为聚合后的全局嵌入向量。
最终,全局关系嵌入向量 update_rel_embed 将作为服务器端的关系嵌入向量,并用于后续的训练过程。

假设有3个客户端,并且每个客户端的关系嵌入向量的维度为2,关系数量为4。

客户端1的关系嵌入向量为:[ [0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8] ]
客户端2的关系嵌入向量为:[ [1.1, 1.2], [1.3, 1.4], [1.5, 1.6] ]
客户端3的关系嵌入向量为:[ [2.1, 2.2], [2.3, 2.4], [2.5, 2.6], [2.7, 2.8] ]

  1. 计算关系在三个客户端中的权重:

    agg_rel_mask = [ [1, 1, 1, 1], [1, 1, 1, 0], [1, 1, 1, 1] ]
    rel_w_sum = [ 3, 2, 3, 2 ]
    rel_w = [ [1/3, 1/2, 1/3, 1/2], [1/3, 1/2, 1/3, 0], [1/3, 1/2, 1/3, 1/2] ]
    
  2. 初始化全局关系嵌入向量:

    update_rel_embed = [ [0, 0], [0, 0], [0, 0], [0, 0] ]
    
  3. 加权聚合关系嵌入向量:

    i = 1:
    local_rel_embed = [ [0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.7, 0.8] ]
    update_rel_embed += local_rel_embed * [1/3, 1/2, 1/3, 1/2] = [ [0.033, 0.100], [0.100, 0.200], [0.167, 0.200], [0.350, 0.400] ]
    
    i = 2:
    local_rel_embed = [ [1.1, 1.2], [1.3, 1.4], [1.5, 1.6] ]
    update_rel_embed += local_rel_embed * [1/3, 1/2, 1/3, 0] = [ [0.500, 0.600], [0.800, 0.800], [1.167, 1.067], [0.350, 0.400] ]
    
    i = 3:
    local_rel_embed = [ [2.1, 2.2], [2.3, 2.4], [2.5, 2.6], [2.7, 2.8] ]
    update_rel_embed += local_rel_embed * [1/3, 1/2, 1/3, 1/2] = [ [1.000, 1.100], [1.900, 2.000], [2.000, 2.200], [3.100, 3.400] ]
    
  4. 更新全局关系嵌入向量:

    self.rel_embed = [ [1.000, 1.100], [1.900, 2.000], [2.000, 2.200], [3.100, 3.400] ]
    

最终,服务器端的全局关系嵌入向量为 [ [1.000, 1.100], [1.900, 2.000], [2.000, 2.200], [3.100, 3.400] ]

    def aggregation(self, clients, rel_update_weights):
        agg_rel_mask = rel_update_weights  #relation在三个客户机的权重
        agg_rel_mask[rel_update_weights != 0] = 1 #非0元素标注为1

        rel_w_sum = torch.sum(agg_rel_mask, dim=0) #对relation求和
        rel_w = agg_rel_mask / rel_w_sum
        rel_w[torch.isnan(rel_w)] = 0 #归一化
        if self.args.model in ['ComplEx']:
            update_rel_embed = torch.zeros(self.nrelation, self.args.hidden_dim * 2).to(self.args.gpu)
        else:
            update_rel_embed = torch.zeros(self.nrelation, self.args.hidden_dim).to(self.args.gpu) #初始化
        for i, client in enumerate(clients):
            local_rel_embed = client.rel_embed.clone().detach()
            # 这行代码将当前客户端的局部关系嵌入向量 local_rel_embed 与其对应的关系权重 rel_w[i] 相乘,
            # 并加到 update_rel_embed 中。
            # 由于 local_rel_embed 的形状为 (self.nrelation, self.args.hidden_dim),
            # 而 rel_w[i] 的形状为 (self.nrelation,),所以通过 rel_w[i].reshape(-1, 1) 将其转换为形状为 (self.nrelation, 1),这样两个张量可以进行逐元素乘法。
            update_rel_embed += local_rel_embed * rel_w[i].reshape(-1, 1)
        self.rel_embed = update_rel_embed.requires_grad_()

记录结果

这段代码是用于保存模型训练过程中的检查点的函数save_checkpoint(self, e),下面逐行解释每一行的内容:

  1. state = {'rel_embed': self.server.rel_embed, 'ent_embed': [client.ent_embed for client in self.clients]}:首先创建一个字典state,其中包含两个键值对:rel_embedent_embedself.server.rel_embed是服务器对象中的关系嵌入,而[client.ent_embed for client in self.clients]是一个列表,包含所有客户机的实体嵌入。
  2. for filename in os.listdir(self.args.state_dir)::遍历指定目录self.args.state_dir下的所有文件名。
  3. if self.args.name in filename.split('.') and os.path.isfile(os.path.join(self.args.state_dir, filename))::判断文件名是否包含self.args.name(模型名称)且文件是否存在。
  4. os.remove(os.path.join(self.args.state_dir, filename)):如果满足条件,删除之前保存的同名检查点文件。
  5. torch.save(state, os.path.join(self.args.state_dir, self.args.name + '.' + str(e) + '.ckpt')):将当前的模型参数保存为检查点文件。函数使用torch.save()来将state字典保存为文件,文件名的格式为args.name + '.' + str(e) + '.ckpt',其中e是当前的轮数,用于在文件名中标记当前的训练轮次。检查点文件保存在self.args.state_dir指定的目录下。
  def save_checkpoint(self, e):
        state = {'rel_embed': self.server.rel_embed,
                 'ent_embed': [client.ent_embed for client in self.clients]}
        # delete previous checkpoint
        for filename in os.listdir(self.args.state_dir):
            if self.args.name in filename.split('.') and os.path.isfile(os.path.join(self.args.state_dir, filename)):
                os.remove(os.path.join(self.args.state_dir, filename))
        # save current checkpoint
        torch.save(state, os.path.join(self.args.state_dir,
                                       self.args.name + '.' + str(e) + '.ckpt'))

预测

训练完用测试集预测
1.def save_model(self, best_epoch):这个函数用于保存在训练过程中获得的性能最佳的模型参数。best_epoch 参数是指定的性能最佳的轮次。函数首先将指定轮次的模型参数文件重命名为 'args.name.best',以保存为最佳模型。
2.def before_test_load(self):这个函数在进行测试之前加载之前保存的最佳模型参数。它首先从磁盘中加载之前保存的最佳模型参数文件 'args.name.best',然后将其中的关系嵌入向量 rel_embed 分配给服务器端的 self.server.rel_embed,并将各个客户端的实体嵌入向量 ent_embed 分配给相应的客户端。
3.def evaluate(self, istest=False):这个函数用于在测试集或验证集上进行模型评估。istest 参数用于指示是否在测试集上评估,如果为 True,则在测试集上进行评估,否则在验证集上评估。首先,该函数调用 self.send_emb(),将服务器端的实体嵌入向量发送给客户端。然后,对每个客户端,调用 client.client_eval(istest) 函数,进行模型在该客户端上的性能评估,得到评估结果 client_res,包括 MRR 和 hits@1、hits@3、hits@10 等指标。
对于每个客户端的评估结果,将其加权求和,其中权重由 weights 决定。weights 是一个列表,保存了每个客户端的权重,通常是根据客户端的数据量或其他指标确定的。最后,输出整体评估结果,包括 MRR 和 hits@1、hits@3、hits@10 等指标,并返回结果 result。

    def save_model(self, best_epoch):
        os.rename(os.path.join(self.args.state_dir, self.args.name + '.' + str(best_epoch) + '.ckpt'),
                  os.path.join(self.args.state_dir, self.args.name + '.best'))
    def before_test_load(self):
        state = torch.load(os.path.join(self.args.state_dir, self.args.name + '.best'), map_location=self.args.gpu)
        self.server.rel_embed = state['rel_embed']
        for idx, client in enumerate(self.clients):
            client.ent_embed = state['ent_embed'][idx]
    def evaluate(self, istest=False):
        self.send_emb()
        result = ddict(int)
        if istest:
            weights = self.test_eval_weights
        else:
            weights = self.valid_eval_weights
        for idx, client in enumerate(self.clients):
            client_res = client.client_eval(istest)

            logging.info('mrr: {:.4f}, hits@1: {:.4f}, hits@3: {:.4f}, hits@10: {:.4f}'.format(
                client_res['mrr'], client_res['hits@1'],
                client_res['hits@3'], client_res['hits@10']))

            for k, v in client_res.items():
                result[k] += v * weights[idx]

        logging.info('mrr: {:.4f}, hits@1: {:.4f}, hits@3: {:.4f}, hits@10: {:.4f}'.format(
                     result['mrr'], result['hits@1'],
                     result['hits@3'], result['hits@10']))

        return result
预测的函数

计算了每个样本的预测得分,并根据这些得分计算了样本的预测排名,以便在后续的评估中使用。同时,对于目标实体为正例的样本,将它们的预测得分设置为较大的负数,以确保在计算排名时,这些得分对正确排名没有影响。
1.在每个批次中 (预测集也是按客户机进行区分),将数据转移到GPU上(如果使用GPU的话)。
2.调用模型 self.kge_model 对批次数据进行前向传播得到预测结果 pred。pred = self.kge_model((triplets, None), self.rel_embed, self.ent_embed)。将三元组数据 triplets 和两个嵌入向量 self.rel_embed 和 self.ent_embed 作为输入,并返回对应样本的预测得分 pred。模型的前向传播是在这里完成的。
3.b_range = torch.arange(pred.size()[0], device=self.args.gpu):这行代码创建一个张量 b_range,其元素从 0 开始递增,范围是从 0 到 pred.size()[0] - 1,其中 pred.size()[0] 是 pred 张量的行数(样本数)。device=self.args.gpu 表示将该张量放置在 GPU 上,如果设置了 GPU,否则将放置在 CPU 上。
4.target_pred = pred[b_range, tail_idx]:这行代码将 pred 张量中的目标实体(tail)对应的预测得分提取出来,并存储在 target_pred 中。b_range 是预测得分张量的行索引,tail_idx 是目标实体的索引,这样通过索引的方式可以获取目标实体的预测得分。
5.pred = torch.where(labels.byte(), -torch.ones_like(pred) * 10000000, pred):这行代码根据 labels 张量中的标签信息(True/False)来处理 pred 张量中的预测得分。如果 labels 中某个位置是 True(1),则说明对应样本的目标实体是正例,将该位置对应的预测得分设置为一个较大的负数,即 -10000000。这样,在后续计算排名时,这些预测得分将排在最后,不会影响对正确排名的计算。
6.pred[b_range, tail_idx] = target_pred:这行代码将之前提取的目标实体的预测得分 target_pred 赋值回 pred 张量中的相应位置,以恢复原始的预测得分。
7.ranks = 1 + torch.argsort(torch.argsort(pred, dim=1, descending=True), dim=1, descending=False)[b_range, tail_idx]:这行代码计算每个样本的预测排名 ranks。首先,torch.argsort(pred, dim=1, descending=True) 对 pred 张量进行降序排序,并返回排序后的索引,即得到每个样本的预测得分从高到低的排列顺序。然后,再使用 torch.argsort() 对排序后的索引进行升序排序,得到原始样本的预测排名。注意,由于排名是从 0 开始的,所以最后再加上 1,使排名从 1 开始。
8.计算每个指标(MRR、hits@1、hits@3、hits@10)所需的统计量:总样本数、平均排名、平均倒数排名、每个指标下的命中次数。
9.最后,将统计量除以总样本数来计算平均值。

    def client_eval(self, istest=False):
        if istest:
            dataloader = self.test_dataloader
        else:
            dataloader = self.valid_dataloader

        results = ddict(float)
        for batch in dataloader:
            triplets, labels = batch
            triplets, labels = triplets.to(self.args.gpu), labels.to(self.args.gpu)
            head_idx, rel_idx, tail_idx = triplets[:, 0], triplets[:, 1], triplets[:, 2]
            pred = self.kge_model((triplets, None),
                                   self.rel_embed, self.ent_embed)
            b_range = torch.arange(pred.size()[0], device=self.args.gpu)
            target_pred = pred[b_range, tail_idx]
            pred = torch.where(labels.byte(), -torch.ones_like(pred) * 10000000, pred)
            pred[b_range, tail_idx] = target_pred

            ranks = 1 + torch.argsort(torch.argsort(pred, dim=1, descending=True),
                                      dim=1, descending=False)[b_range, tail_idx]

            ranks = ranks.float()
            count = torch.numel(ranks)

            results['count'] += count
            results['mr'] += torch.sum(ranks).item()
            results['mrr'] += torch.sum(1.0 / ranks).item()

            for k in [1, 3, 10]:
                results['hits@{}'.format(k)] += torch.numel(ranks[ranks <= k])

        for k, v in results.items():
            if k != 'count':
                results[k] /= results['count']

        return results
posted @   GraphL  阅读(151)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· TypeScript + Deepseek 打造卜卦网站:技术与玄学的结合
· 阿里巴巴 QwQ-32B真的超越了 DeepSeek R-1吗?
· 【译】Visual Studio 中新的强大生产力特性
· 【设计模式】告别冗长if-else语句:使用策略模式优化代码结构
· AI与.NET技术实操系列(六):基于图像分类模型对图像进行分类
点击右上角即可分享
微信分享提示