消息传递范式
消息传递图神经网络
一、经网络流程
- 为节点生成节点表征(Node Representation)是图计算任务成功的关键
- 利用神经网络来学习节点表征。
- 实现图和神经网络的连接的其中一种方式:消息传递范式
二、 消息传递范式获取节点表征的过程
- 定义:将邻居节点信息传递到中心节点的过程
- 过程:邻居节点通过变换后聚合目标节点,目标节点的原始信息和邻居节点们变换聚合后传递过来的信息一起经过变换聚合后得到新的目标节点信息(所有节点都操作此过程)
- 多次重复该操作,不断更新信息,最终获得节点表征
\(\mathbf{x}^{(k-1)}_i\in\mathbb{R}^F\)表示\((k-1)\)层中节点\(i\)的节点表征,
\(\mathbf{e}_{j,i} \in \mathbb{R}^D\) 表示从节点\(j\)到节点\(i\)的边的属性,
消息传递图神经网络可以描述为
三、MessagePassing
基类初步分析
MessagePassing(aggr="add", flow="source_to_target", node_dim=-2)
(对象初始化方法):aggr
:定义要使用的聚合方案("add"、"mean "或 "max");flow
:定义消息传递的流向("source_to_target "或 "target_to_source");node_dim
:定义沿着哪个维度传播,默认值为-2
,也就是节点表征张量(Tensor)的哪一个维度是节点维度。节点表征张量x
形状为[num_nodes, num_features]
,其第0维度(也是第-2维度)是节点维度,其第1维度(也是第-1维度)是节点表征维度,所以我们可以设置node_dim=-2
。- 注:
MessagePassing(……)
等同于MessagePassing.__init__(……)
MessagePassing.propagate(edge_index, size=None, **kwargs)
:- 开始传递消息的起始调用,在此方法中
message
、update
等方法被调用。 - 它以
edge_index
(边的端点的索引)和flow
(消息的流向)以及一些额外的数据为参数。 - 请注意,
propagate()
不仅限于基于形状为[N, N]
的对称邻接矩阵进行“消息传递过程”。基于非对称的邻接矩阵进行消息传递(当图为二部图时),需要传递参数size=(N, M)
。 - 如果设置
size=None
,则认为邻接矩阵是对称的。
- 开始传递消息的起始调用,在此方法中
MessagePassing.message(...)
:- 首先确定要给节点\(i\)传递消息的边的集合:
- 如果
flow="source_to_target"
,则是\((j,i) \in \mathcal{E}\)的边的集合; - 如果
flow="target_to_source"
,则是\((i,j) \in \mathcal{E}\)的边的集合。
- 如果
- 接着为各条边创建要传递给节点\(i\)的消息,即实现\(\phi\)函数。
MessagePassing.message(...)
方法可以接收传递给MessagePassing.propagate(edge_index, size=None, **kwargs)
方法的所有参数,我们在message()
方法的参数列表里定义要接收的参数,例如我们要接收x,y,z
参数,则我们应定义message(x,y,z)
方法。- 传递给
propagate()
方法的参数,如果是节点的属性的话,可以被拆分成属于中心节点的部分和属于邻接节点的部分,只需在变量名后面加上_i
或_j
。例如,我们自己定义的meassage
方法包含参数x_i
,那么首先propagate()
方法将节点表征拆分成中心节点表征和邻接节点表征,接着propagate()
方法调用message
方法并传递中心节点表征给参数x_i
。而如果我们自己定义的meassage
方法包含参数x_j
,那么propagate()
方法会传递邻接节点表征给参数x_j
。 - 我们用\(i\)表示“消息传递”中的中心节点,用\(j\)表示“消息传递”中的邻接节点。
- 首先确定要给节点\(i\)传递消息的边的集合:
MessagePassing.aggregate(...)
:- 将从源节点传递过来的消息聚合在目标节点上,一般可选的聚合方式有
sum
,mean
和max
。
- 将从源节点传递过来的消息聚合在目标节点上,一般可选的聚合方式有
MessagePassing.message_and_aggregate(...)
:- 在一些场景里,邻接节点信息变换和邻接节点信息聚合这两项操作可以融合在一起,那么我们可以在此方法里定义这两项操作,从而让程序运行更加高效。
MessagePassing.update(aggr_out, ...)
:- 为每个节点\(i \in \mathcal{V}\)更新节点表征,即实现\(\gamma\)函数。此方法以
aggregate
方法的输出为第一个参数,并接收所有传递给propagate()
方法的参数。
- 为每个节点\(i \in \mathcal{V}\)更新节点表征,即实现\(\gamma\)函数。此方法以
四、MessagePassing
子类实例
我们以继承MessagePassing
基类的GCNConv
类为例,学习如何通过继承MessagePassing
基类来实现一个简单的图神经网络。
GCNConv
的数学定义为
其中,邻接节点的表征\(\mathbf{x}_j^{(k-1)}\)首先通过与权重矩阵\(\mathbf{\Theta}\)相乘进行变换,然后按端点的度\(\deg(i), \deg(j)\)进行归一化处理,最后进行求和。这个公式可以分为以下几个步骤:
- 向邻接矩阵添加自环边。
- 对节点表征做线性转换。
- 计算归一化系数。
- 归一化邻接节点的节点表征。
- 将相邻节点表征相加("求和 "聚合)。
步骤1-3通常是在消息传递发生之前计算的。步骤4-5可以使用MessagePassing
基类轻松处理。该层的全部实现如下所示。
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
# "Add" aggregation (Step 5).
# flow='source_to_target' 表示消息从源节点传播到目标节点
self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
# x has shape [N, in_channels]
# edge_index has shape [2, E]
# Step 1: Add self-loops to the adjacency matrix.
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# Step 2: Linearly transform node feature matrix.
x = self.lin(x)
# Step 3: Compute normalization.
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
# Step 4-5: Start propagating messages.
return self.propagate(edge_index, x=x, norm=norm)
def message(self, x_j, norm):
# x_j has shape [E, out_channels]
# Step 4: Normalize node features.
return norm.view(-1, 1) * x_j
GCNConv
继承了MessagePassing
并以"求和"作为领域节点信息聚合方式。该层的所有逻辑都发生在其forward()
方法中。在这里,我们首先使用torch_geometric.utils.add_self_loops()
函数向我们的边索引添加自循环边(步骤1),以及通过调用torch.nn.Linear
实例对节点表征进行线性变换(步骤2)。propagate()
方法也在forward
方法中被调用,propagate()
方法被调用后节点间的信息传递开始执行。
归一化系数是由每个节点的节点度得出的,它被转换为每条边的节点度。结果被保存在形状为[num_edges,]
的变量norm
中(步骤3)。
在message()
方法中,我们需要通过norm
对邻接节点表征x_j
进行归一化处理。
通过以上内容的学习,我们便掌握了创建一个仅包含一次“消息传递过程”的图神经网络的方法。如下方代码所示,我们可以很方便地初始化和调用它:
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='dataset/Cora', name='Cora')
data = dataset[0]
net = GCNConv(data.num_features, 64)
h_nodes = net(data.x, data.edge_index)
print(h_nodes.shape)