5、SimGNN实战

一、概述

文献标题:SimGNN: A Neural Network Approach to Fast Graph Similarity Computation
来源:WSDM2018( 网络搜索和数据挖掘国际会议)
论文链接:
https://arxiv.org/abs/1808.05689
代码链接:
https://paperswithcode.com/paper/graph-edit-distance-computation-via-graph#code

目的:使用图神经网络的方法计算图相似度,并减轻计算的负担。

创新处:SimGNN的方法结合了两种策略。一、首先设计了一个可学习的嵌入函数,将每个图映射为一个嵌入向量,该向量提供了图的全局摘要,该策略提出了一种新的注意机制来强调特定相似度度量下的重要节点。二、设计了一种成对节点比较方法,用细粒度节点信息补充图级嵌入。

引言

设计了一个基于神经网络的函数,将一对图映射成一个相似度评分。在训练阶段,该函数所涉及的参数将通过最小化预测的相似度分数与事实(真是标签)的差来学习,其中每个训练数据点是一对图及其真实相似度分数。在测试阶段,通过向学习的函数输入任意一对图,我们可以得到一个预测的相似度分数。我们将这种方法命名为SimGNN,即通过图神经网络进行相似性计算。

模型的优势

(1)表示不变。通过改变节点的顺序,可以用不同的邻接矩阵来表示同一个图。所计算的相似性得分对于这种变化应该是不变的。

(2)归纳。相似性计算应该推广到看不见的图,即计算训练图对之外的图的相似性得分。

(3)可学。通过训练调整其参数,该模型应该适应任何相似性度量。


二、背景

Background and Motivation

图相似度搜索具有重要的意义,比如找到与query化合物最相似的化合物等。通常用图编辑距离或者最大共同子图来衡量图的相似度,然而这两个指标的计算复杂度都是很高的(NP-complete)。这篇文章提出了一种基于图神经网络的方法来解决这一问题。

Main idea

神经网络学习的对象是从输入 一对图(a pair of graphs)到输出 两个图的相似度分数 的映射。因此是一种有监督的学习,需要知道输入图对相似度的ground truth。

Network structure

一个简单直接的思想就是:给定一对图,我们需要将图进行向量表示,再根据图对应的向量来计算相似度,也就是 graph embedding。在此基础上,考虑到只利用 graph embedding 可能忽略了局部节点的差异性,因此作者进一步考虑了两个图中节点之间的相关性或者是差异性 (pairwise node comparison)。

 三、代码

1、SimGNN和计算直方图

import torch
import random
import numpy as np
import torch.nn.functional as F
from tqdm import tqdm, trange
from scipy.stats import spearmanr, kendalltau

from layers import AttentionModule, TensorNetworkModule, DiffPool
from utils import calculate_ranking_correlation, calculate_prec_at_k, gen_pairs

from torch_geometric.nn import GCNConv, GINConv
from torch_geometric.data import DataLoader, Batch
from torch_geometric.utils import to_dense_batch, to_dense_adj, degree
from torch_geometric.datasets import GEDDataset
from torch_geometric.transforms import OneHotDegree

import matplotlib.pyplot as plt


class SimGNN(torch.nn.Module):
    """
    SimGNN: A Neural Network Approach to Fast Graph Similarity Computation
    https://arxiv.org/abs/1808.05689
    """

    def __init__(self, args, number_of_labels):
        """
        :param args: Arguments object.
        :param number_of_labels: Number of node labels.
        """
        super(SimGNN, self).__init__()
        self.args = args
        self.number_labels = number_of_labels
        self.setup_layers()

    def calculate_bottleneck_features(self):
        """
        Deciding the shape of the bottleneck layer.
        """
        if self.args.histogram:
            self.feature_count = self.args.tensor_neurons + self.args.bins
        else:
            self.feature_count = self.args.tensor_neurons

    def setup_layers(self):
        """
        Creating the layers.
        """
        self.calculate_bottleneck_features()
        if self.args.gnn_operator == "gcn":
            self.convolution_1 = GCNConv(self.number_labels, self.args.filters_1)
            self.convolution_2 = GCNConv(self.args.filters_1, self.args.filters_2)
            self.convolution_3 = GCNConv(self.args.filters_2, self.args.filters_3)
        elif self.args.gnn_operator == "gin":
            nn1 = torch.nn.Sequential(
                torch.nn.Linear(self.number_labels, self.args.filters_1),
                torch.nn.ReLU(),
                torch.nn.Linear(self.args.filters_1, self.args.filters_1),
                torch.nn.BatchNorm1d(self.args.filters_1),
            )

            nn2 = torch.nn.Sequential(
                torch.nn.Linear(self.args.filters_1, self.args.filters_2),
                torch.nn.ReLU(),
                torch.nn.Linear(self.args.filters_2, self.args.filters_2),
                torch.nn.BatchNorm1d(self.args.filters_2),
            )

            nn3 = torch.nn.Sequential(
                torch.nn.Linear(self.args.filters_2, self.args.filters_3),
                torch.nn.ReLU(),
                torch.nn.Linear(self.args.filters_3, self.args.filters_3),
                torch.nn.BatchNorm1d(self.args.filters_3),
            )

            self.convolution_1 = GINConv(nn1, train_eps=True)
            self.convolution_2 = GINConv(nn2, train_eps=True)
            self.convolution_3 = GINConv(nn3, train_eps=True)
        else:
            raise NotImplementedError("Unknown GNN-Operator.")

        if self.args.diffpool:
            self.attention = DiffPool(self.args)
        else:
            self.attention = AttentionModule(self.args)

        self.tensor_network = TensorNetworkModule(self.args)
        self.fully_connected_first = torch.nn.Linear(
            self.feature_count, self.args.bottle_neck_neurons
        )
        self.scoring_layer = torch.nn.Linear(self.args.bottle_neck_neurons, 1)

    def calculate_histogram(
        self, abstract_features_1, abstract_features_2, batch_1, batch_2
    ):
        """
        Calculate histogram from similarity matrix.
        :param abstract_features_1: Feature matrix for target graphs.
        :param abstract_features_2: Feature matrix for source graphs.
        :param batch_1: Batch vector for source graphs, which assigns each node to a specific example
        :param batch_1: Batch vector for target graphs, which assigns each node to a specific example
        :return hist: Histsogram of similarity scores.
        """
        print(abstract_features_1.shape)#torch.Size([1156, 16])
        print(abstract_features_2.shape)#torch.Size([1156, 16])
        #to_dense_batch意思是稀疏向量转稠密
        #另外,每个图最大10个节点,不足10个的补上。因此,mask返回了一个128行,10列的【true,false】矩阵,一行一个图,10列表示10个节点。补上的节点为false
        abstract_features_1, mask_1 = to_dense_batch(abstract_features_1, batch_1)
        print(abstract_features_1.shape)#torch.Size([128, 10, 16])
        print(mask_1.shape)#torch.Size([128, 10])
        abstract_features_2, mask_2 = to_dense_batch(abstract_features_2, batch_2)
        print(abstract_features_2.shape)#torch.Size([128, 10, 16])
        print(mask_2.shape)#torch.Size([128, 10])
        B1, N1, _ = abstract_features_1.size()
        B2, N2, _ = abstract_features_2.size()
        #b=128 n=10
        mask_1 = mask_1.view(B1, N1)
        mask_2 = mask_2.view(B2, N2)
        num_nodes = torch.max(mask_1.sum(dim=1), mask_2.sum(dim=1))
        #128的数组,数组中每一个数是一对图的最大节点数。
        """
        .detach()方法用于将张量从计算图中分离出来,得到一个新的张量。
        在PyTorch中,计算图是用于自动求导的一种机制。当我们对张量进行操作时,
        PyTorch会自动构建一个计算图,用于跟踪张量之间的依赖关系,并计算梯度。计算图的构建过程会消耗一定的内存和计算资源。
        使用.detach()方法可以将张量从计算图中分离出来,得到一个新的张量,新的张量不再与原始计算图相关联。
        这意味着新的张量不会再参与梯度计算,也不会影响原始张量的梯度计算。
        具体来说,.detach()方法会返回一个新的张量,该张量与原始张量的值相同,
        但是不再具有梯度信息。这对于需要保留中间结果但不需要进行梯度计算的情况非常有用。
        例如,在训练神经网络时,有时我们需要计算某个中间结果,并将其用于后续的计算,但是不希望中间结果对网络参数进行梯度传播。
        这时,可以使用.detach()方法将中间结果从计算图中分离出来,保留其值,并将其用于后续的计算,而不会对网络参数进行梯度计算。
        总之,.detach()方法用于将张量从计算图中分离出来,得到一个新的张量,新的张量不再参与梯度计算。
        """
        scores = torch.matmul(
            abstract_features_1, abstract_features_2.permute([0, 2, 1])
        ).detach()#[128,10,16],[128,16,10]
        print(scores.shape)#[128,10,10]
        hist_list = []#128个1行16列的矩阵
        for i, mat in enumerate(scores):
            #mat[10,10],对于具体一对网络的score矩阵,由于矩阵乘得到的得分矩阵是相似度,得出了10个节点与另10个节点的相似度。
            mat = torch.sigmoid(mat[: num_nodes[i], : num_nodes[i]]).view(-1)#展平100
            print(mat.shape)
            hist = torch.histc(mat, bins=self.args.bins)#bin是16,画出数组的直方图。16堆
            print(hist.shape)#【16】
            hist = hist / torch.sum(hist)
            hist = hist.view(1, -1)#【1,16】
            print(hist.shape)
            hist_list.append(hist)
        print(torch.stack(hist_list).view(-1, self.args.bins).shape)
        """
        import torch
        x = torch.tensor([1, 2, 3])
        y = torch.tensor([4, 5, 6])
        z = torch.stack([x, y], dim=0)
        print(z)
        输出结果为:
        tensor([[1, 2, 3],
                [4, 5, 6]])
        .view(-1, self.args.bins)将堆叠后的张量形状重塑为(-1, self.args.bins),
        其中-1表示根据其他维度的大小自动计算该维度的大小,而self.args.bins表示指定的维度大小。
        """
        return torch.stack(hist_list).view(-1, self.args.bins) #【128,16】

    def convolutional_pass(self, edge_index, features):
        """
        Making convolutional pass.
        :param edge_index: Edge indices.
        :param features: Feature matrix.
        :return features: Abstract feature matrix.
        """
        features = self.convolution_1(features, edge_index)
        features = F.relu(features)
        features = F.dropout(features, p=self.args.dropout, training=self.training)
        features = self.convolution_2(features, edge_index)
        features = F.relu(features)
        features = F.dropout(features, p=self.args.dropout, training=self.training)
        features = self.convolution_3(features, edge_index)
        return features

    def diffpool(self, abstract_features, edge_index, batch):
        """
        Making differentiable pooling.
        :param abstract_features: Node feature matrix.
        :param edge_index: Edge indices
        :param batch: Batch vector, which assigns each node to a specific example
        :return pooled_features: Graph feature matrix.
        """
        x, mask = to_dense_batch(abstract_features, batch)
        adj = to_dense_adj(edge_index, batch)
        return self.attention(x, adj, mask)

    def forward(self, data):
        """
        Forward pass with graphs.
        :param data: Data dictionary.
        :return score: Similarity score.
        """
        edge_index_1 = data["g1"].edge_index
        edge_index_2 = data["g2"].edge_index
        features_1 = data["g1"].x
        print(features_1.shape) #torch.Size([1152, 29])
        features_2 = data["g2"].x
        batch_1 = (
            data["g1"].batch
            if hasattr(data["g1"], "batch")
            else torch.tensor((), dtype=torch.long).new_zeros(data["g1"].num_nodes)
        )
        batch_2 = (
            data["g2"].batch
            if hasattr(data["g2"], "batch")
            else torch.tensor((), dtype=torch.long).new_zeros(data["g2"].num_nodes)
        )
        #两个图过同一个GIN
        abstract_features_1 = self.convolutional_pass(edge_index_1, features_1)
        print(abstract_features_1.shape)#torch.Size([1156, 16])
        abstract_features_2 = self.convolutional_pass(edge_index_2, features_2)

        # 得到直方图向量
        if self.args.histogram:
            hist = self.calculate_histogram(
                abstract_features_1, abstract_features_2, batch_1, batch_2
            )

        # 得到图级别的向量
        if self.args.diffpool:
            pooled_features_1 = self.diffpool(
                abstract_features_1, edge_index_1, batch_1
            )
            pooled_features_2 = self.diffpool(
                abstract_features_2, edge_index_2, batch_2
            )
        else:
            pooled_features_1 = self.attention(abstract_features_1, batch_1)
            print(pooled_features_1.shape)
            pooled_features_2 = self.attention(abstract_features_2, batch_2)

        #TNT模块,意思类似与SVD学习隐向量。例如【老虎和尾巴两个实体之间的关系,用户和商品的某个关系】
        scores = self.tensor_network(pooled_features_1, pooled_features_2)
        print(scores.shape)
        if self.args.histogram:
            scores = torch.cat((scores, hist), dim=1)

        scores = F.relu(self.fully_connected_first(scores))
        print(scores.shape)
        score = torch.sigmoid(self.scoring_layer(scores)).view(-1)
        print(score.shape)
        return score


class SimGNNTrainer(object):
    """
    SimGNN model trainer.
    """

    def __init__(self, args):
        """
        :param args: Arguments object.
        """
        self.args = args
        self.process_dataset()
        self.setup_model()

    def setup_model(self):
        """
        Creating a SimGNN.
        """
        self.model = SimGNN(self.args, self.number_of_labels)

    def save(self):
        """
        Saving model.
        """
        torch.save(self.model.state_dict(), self.args.save)
        print(f"Model is saved under {self.args.save}.")

    def load(self):
        """
        Loading model.
        """
        self.model.load_state_dict(torch.load(self.args.load))
        print(f"Model is loaded from {self.args.save}.")

    def process_dataset(self):
        """
        Downloading and processing dataset.
        """
        print("\nPreparing dataset.\n")

        self.training_graphs = GEDDataset(
            "datasets/{}".format(self.args.dataset), self.args.dataset, train=True
        )
        self.testing_graphs = GEDDataset(
            "datasets/{}".format(self.args.dataset), self.args.dataset, train=False
        )
        self.nged_matrix = self.training_graphs.norm_ged
        self.real_data_size = self.nged_matrix.size(0)

        if self.args.synth:
            # self.synth_data_1, self.synth_data_2, _, synth_nged_matrix = gen_synth_data(500, 10, 12, 0.5, 0, 3)
            self.synth_data_1, self.synth_data_2, _, synth_nged_matrix = gen_pairs(
                self.training_graphs.shuffle()[:500], 0, 3
            )

            real_data_size = self.nged_matrix.size(0)
            synth_data_size = synth_nged_matrix.size(0)
            self.nged_matrix = torch.cat(
                (
                    self.nged_matrix,
                    torch.full((real_data_size, synth_data_size), float("inf")),
                ),
                dim=1,
            )
            synth_nged_matrix = torch.cat(
                (
                    torch.full((synth_data_size, real_data_size), float("inf")),
                    synth_nged_matrix,
                ),
                dim=1,
            )
            self.nged_matrix = torch.cat((self.nged_matrix, synth_nged_matrix))

        if self.training_graphs[0].x is None:
            max_degree = 0
            for g in (
                self.training_graphs
                + self.testing_graphs
                + (self.synth_data_1 + self.synth_data_2 if self.args.synth else [])
            ):
                if g.edge_index.size(1) > 0:
                    max_degree = max(
                        max_degree, int(degree(g.edge_index[0]).max().item())
                    )
            one_hot_degree = OneHotDegree(max_degree, cat=False)
            self.training_graphs.transform = one_hot_degree
            self.testing_graphs.transform = one_hot_degree

            # labeling of synth data according to real data format
            if self.args.synth:
                for g in self.synth_data_1 + self.synth_data_2:
                    g = one_hot_degree(g)
                    g.i = g.i + real_data_size
        elif self.args.synth:
            for g in self.synth_data_1 + self.synth_data_2:
                g.i = g.i + real_data_size
                # g.x = torch.cat((g.x, torch.zeros((g.x.size(0), self.training_graphs.num_features-1))), dim=1)

        self.number_of_labels = self.training_graphs.num_features

    def create_batches(self):
        """
        Creating batches from the training graph list.
        :return batches: Zipped loaders as list.
        """
        if self.args.synth:
            synth_data_ind = random.sample(range(len(self.synth_data_1)), 100)

        source_loader = DataLoader(
            self.training_graphs.shuffle()
            + (
                [self.synth_data_1[i] for i in synth_data_ind]
                if self.args.synth
                else []
            ),
            batch_size=self.args.batch_size,
        )
        target_loader = DataLoader(
            self.training_graphs.shuffle()
            + (
                [self.synth_data_2[i] for i in synth_data_ind]
                if self.args.synth
                else []
            ),
            batch_size=self.args.batch_size,
        )

        return list(zip(source_loader, target_loader))

    def transform(self, data):
        """
        Getting ged for graph pair and grouping with data into dictionary.
        :param data: Graph pair.
        :return new_data: Dictionary with data.
        """
        new_data = dict()

        new_data["g1"] = data[0]
        new_data["g2"] = data[1]

        normalized_ged = self.nged_matrix[
            data[0]["i"].reshape(-1).tolist(), data[1]["i"].reshape(-1).tolist()
        ].tolist()
        new_data["target"] = (
            torch.from_numpy(np.exp([(-el) for el in normalized_ged])).view(-1).float()
        )
        return new_data

    def process_batch(self, data):
        """
        Forward pass with a data.
        :param data: Data that is essentially pair of batches, for source and target graphs.
        :return loss: Loss on the data.
        """
        self.optimizer.zero_grad()
        data = self.transform(data)
        target = data["target"]
        prediction = self.model(data)
        loss = F.mse_loss(prediction, target, reduction="sum")
        loss.backward()
        self.optimizer.step()
        return loss.item()

    def fit(self):
        """
        Training a model.
        """
        print("\nModel training.\n")
        self.optimizer = torch.optim.Adam(
            self.model.parameters(),
            lr=self.args.learning_rate,
            weight_decay=self.args.weight_decay,
        )
        self.model.train()

        epochs = trange(self.args.epochs, leave=True, desc="Epoch")
        loss_list = []
        loss_list_test = []
        for epoch in epochs:

            if self.args.plot:
                if epoch % 10 == 0:
                    self.model.train(False)
                    cnt_test = 20
                    cnt_train = 100
                    t = tqdm(
                        total=cnt_test * cnt_train,
                        position=2,
                        leave=False,
                        desc="Validation",
                    )
                    scores = torch.empty((cnt_test, cnt_train))

                    for i, g in enumerate(self.testing_graphs[:cnt_test].shuffle()):
                        source_batch = Batch.from_data_list([g] * cnt_train)
                        target_batch = Batch.from_data_list(
                            self.training_graphs[:cnt_train].shuffle()
                        )
                        data = self.transform((source_batch, target_batch))
                        target = data["target"]
                        prediction = self.model(data)

                        scores[i] = F.mse_loss(
                            prediction, target, reduction="none"
                        ).detach()
                        t.update(cnt_train)

                    t.close()
                    loss_list_test.append(scores.mean().item())
                    self.model.train(True)

            batches = self.create_batches()
            main_index = 0
            loss_sum = 0
            for index, batch_pair in tqdm(
                enumerate(batches), total=len(batches), desc="Batches", leave=False
            ):
                loss_score = self.process_batch(batch_pair)
                main_index = main_index + batch_pair[0].num_graphs
                loss_sum = loss_sum + loss_score
            loss = loss_sum / main_index
            epochs.set_description("Epoch (Loss=%g)" % round(loss, 5))
            loss_list.append(loss)

        if self.args.plot:
            plt.plot(loss_list, label="Train")
            plt.plot(
                [*range(0, self.args.epochs, 10)], loss_list_test, label="Validation"
            )
            plt.ylim([0, 0.01])
            plt.legend()
            filename = self.args.dataset
            filename += "_" + self.args.gnn_operator
            if self.args.diffpool:
                filename += "_diffpool"
            if self.args.histogram:
                filename += "_hist"
            filename = filename + str(self.args.epochs) + ".pdf"
            plt.savefig(filename)

    def measure_time(self):
        import time

        self.model.eval()
        count = len(self.testing_graphs) * len(self.training_graphs)

        t = np.empty(count)
        i = 0
        tq = tqdm(total=count, desc="Graph pairs")
        for g1 in self.testing_graphs:
            for g2 in self.training_graphs:
                source_batch = Batch.from_data_list([g1])
                target_batch = Batch.from_data_list([g2])
                data = self.transform((source_batch, target_batch))

                start = time.process_time()
                self.model(data)
                t[i] = time.process_time() - start
                i += 1
                tq.update()
        tq.close()

        print(
            "Average time (ms): {}; Standard deviation: {}".format(
                round(t.mean() * 1000, 5), round(t.std() * 1000, 5)
            )
        )

    def score(self):
        """
        Scoring.
        """
        print("\n\nModel evaluation.\n")
        self.model.eval()

        scores = np.empty((len(self.testing_graphs), len(self.training_graphs)))
        ground_truth = np.empty((len(self.testing_graphs), len(self.training_graphs)))
        prediction_mat = np.empty((len(self.testing_graphs), len(self.training_graphs)))

        rho_list = []
        tau_list = []
        prec_at_10_list = []
        prec_at_20_list = []

        t = tqdm(total=len(self.testing_graphs) * len(self.training_graphs))

        for i, g in enumerate(self.testing_graphs):
            source_batch = Batch.from_data_list([g] * len(self.training_graphs))
            target_batch = Batch.from_data_list(self.training_graphs)

            data = self.transform((source_batch, target_batch))
            target = data["target"]
            ground_truth[i] = target
            prediction = self.model(data)
            prediction_mat[i] = prediction.detach().numpy()

            scores[i] = (
                F.mse_loss(prediction, target, reduction="none").detach().numpy()
            )

            rho_list.append(
                calculate_ranking_correlation(
                    spearmanr, prediction_mat[i], ground_truth[i]
                )
            )
            tau_list.append(
                calculate_ranking_correlation(
                    kendalltau, prediction_mat[i], ground_truth[i]
                )
            )
            prec_at_10_list.append(
                calculate_prec_at_k(10, prediction_mat[i], ground_truth[i])
            )
            prec_at_20_list.append(
                calculate_prec_at_k(20, prediction_mat[i], ground_truth[i])
            )

            t.update(len(self.training_graphs))

        self.rho = np.mean(rho_list).item()
        self.tau = np.mean(tau_list).item()
        self.prec_at_10 = np.mean(prec_at_10_list).item()
        self.prec_at_20 = np.mean(prec_at_20_list).item()
        self.model_error = np.mean(scores).item()
        self.print_evaluation()

    def print_evaluation(self):
        """
        Printing the error rates.
        """
        print("\nmse(10^-3): " + str(round(self.model_error * 1000, 5)) + ".")
        print("Spearman's rho: " + str(round(self.rho, 5)) + ".")
        print("Kendall's tau: " + str(round(self.tau, 5)) + ".")
        print("p@10: " + str(round(self.prec_at_10, 5)) + ".")
        print("p@20: " + str(round(self.prec_at_20, 5)) + ".")

2、注意力得到全局

class AttentionModule(torch.nn.Module):
    """
    SimGNN Attention Module to make a pass on graph.
    """

    def __init__(self, args):
        """
        :param args: Arguments object.
        """
        super(AttentionModule, self).__init__()
        self.args = args
        self.setup_weights()
        self.init_parameters()

    def setup_weights(self):
        """
        Defining weights.
        """
        self.weight_matrix = torch.nn.Parameter(
            torch.Tensor(self.args.filters_3, self.args.filters_3)
        )

    def init_parameters(self):
        """
        Initializing weights.
        """
        torch.nn.init.xavier_uniform_(self.weight_matrix)

    def forward(self, x, batch, size=None):
        """
        Making a forward propagation pass to create a graph level representation.
        :param x: Result of the GNN.
        :param size: Dimension size for scatter_
        :param batch: Batch vector, which assigns each node to a specific example
        :return representation: A graph level representation matrix.
        输入张量input_tensor的形状为(3, 3),内容如下:
        tensor([[1, 2, 3],
                [4, 5, 6],
                [7, 8, 9]])
        索引张量index的形状为(3,),内容如下:
        tensor([0, 1, 0])
        聚合操作的过程如下:
        根据索引张量index的值,将输入张量input_tensor中的值聚合到对应的位置上。在这个例子中,索引张量index的第一个元素为0,
        表示将输入张量的第一行([1, 2, 3])聚合到输出张量的第一行上;
        索引张量的第二个元素为1,表示将输入张量的第二行([4, 5, 6])聚合到输出张量的第二行上;索引张量的第三个元素为0,
        表示将输入张量的第三行([7, 8, 9])聚合到输出张量的第一行上。
        对于每个聚合位置,使用指定的聚合操作进行聚合。在这个例子中,我们使用的是scatter_add()函数,它将输入张量中的值累加到聚合位置上。
        聚合结果保存在输出张量output_tensor中。在这个例子中,输出张量的形状为(2, 3),内容如下:
        tensor([[8, 10, 12],
                [4,  5,  6]])
        输出张量的第一行是将输入张量的第一行和第三行聚合得到的,第二行是将输入张量的第二行聚合得到的。
        通过聚合操作,我们可以将输入张量中的值按照指定的索引聚合到输出张量的指定位置上,从而实现灵活的聚合操作。
        """
        size = batch[-1].item() + 1 if size is None else size#128
        mean = scatter_mean(x, batch, dim=0, dim_size=size)#【128,16】,每个图中所有节点求均值
        print(mean.shape)#X [1151,16],batch[1151], self.weight_matrix[16,16]
        transformed_global = torch.tanh(torch.mm(mean, self.weight_matrix))#[128,16],全局上下文(全局向量)乘以可学习参数
        print(self.weight_matrix.shape)
        print(transformed_global.shape)
        coefs = torch.sigmoid((x * transformed_global[batch]).sum(dim=1))#【1151】。数据乘以全局向量
        print(coefs.shape)
        weighted = coefs.unsqueeze(-1) * x
        
        return scatter_add(weighted, batch, dim=0, dim_size=size)

    def get_coefs(self, x):
        mean = x.mean(dim=0)
        transformed_global = torch.tanh(torch.matmul(mean, self.weight_matrix))

        return torch.sigmoid(torch.matmul(x, transformed_global))

3、TN得到隐向量

class TensorNetworkModule(torch.nn.Module):
    """
    SimGNN Tensor Network module to calculate similarity vector.
    """

    def __init__(self, args):
        """
        :param args: Arguments object.
        """
        super(TensorNetworkModule, self).__init__()
        self.args = args
        self.setup_weights()
        self.init_parameters()

    def setup_weights(self):
        """
        Defining weights.
        """
        self.weight_matrix = torch.nn.Parameter(
            torch.Tensor(
                self.args.filters_3, self.args.filters_3, self.args.tensor_neurons
            )
        )
        self.weight_matrix_block = torch.nn.Parameter(
            torch.Tensor(self.args.tensor_neurons, 2 * self.args.filters_3)
        )
        self.bias = torch.nn.Parameter(torch.Tensor(self.args.tensor_neurons, 1))

    def init_parameters(self):
        """
        Initializing weights.
        """
        torch.nn.init.xavier_uniform_(self.weight_matrix)
        torch.nn.init.xavier_uniform_(self.weight_matrix_block)
        torch.nn.init.xavier_uniform_(self.bias)

    def forward(self, embedding_1, embedding_2):
        """
        Making a forward propagation pass to create a similarity vector.
        :param embedding_1: Result of the 1st embedding after attention.
        :param embedding_2: Result of the 2nd embedding after attention.
        :return scores: A similarity score vector.
        """
        batch_size = len(embedding_1)#【128】,embedding_1, embedding_2都是【128,16】
        #print(self.weight_matrix.view(self.args.filters_3, -1).shape) # 原始输入的两个实体都是16维向量,k中关系16个关系,现在用256维表示他们的某种关系
        scoring = torch.matmul(
            embedding_1, self.weight_matrix.view(self.args.filters_3, -1)
        )#scoring为【128,256】
        #print(self.weight_matrix.view(self.args.filters_3, -1).shape) 【k=16种关系,256为关系矩阵16*16】
        #print(scoring.shape)
        scoring = scoring.view(batch_size, self.args.filters_3, -1).permute([0, 2, 1]) #filters_3可以理解成找多少种关系【128,16,16】,比如两个实体找出16中关系。最后的16是固定的
        #print(scoring.shape)
        scoring = torch.matmul(
            scoring, embedding_2.view(batch_size, self.args.filters_3, 1)
        ).view(batch_size, -1)
        print(scoring.shape)#【128,16】
        combined_representation = torch.cat((embedding_1, embedding_2), 1)#【128,32】
        print(combined_representation.shape)
        block_scoring = torch.t(
            torch.mm(self.weight_matrix_block, torch.t(combined_representation))
        )#【128,16】=(【16,32】*【32,128】)T:拼接块乘以一个可学习权重,下一步加偏执
        print(block_scoring.shape)
        scores = F.relu(scoring + block_scoring + self.bias.view(-1))#【128,16】
        print(scores.shape)
        return scores

 

posted @ 2023-09-26 22:33  jasonzhangxianrong  阅读(428)  评论(2编辑  收藏  举报