DGL用户指南--第3章:构建图神经网络(GNN)模块
DGL NN模块是用户构建GNN模型的基本模块。根据DGL所使用的后端深度神经网络框架, DGL NN模块的父类取决于后端所使用的深度神经网络框架。对于PyTorch后端, 它应该继承 PyTorch的NN模块;对于MXNet后端,它应该继承 MXNet Gluon的NN块; 对于TensorFlow后端,它应该继承 Tensorflow的Keras层。
3.1 DGL NN模块的构造函数
构造函数完成以下几个任务:
-
设置选项。
-
注册可学习的参数或者子模块。
-
初始化参数。
1 import torch.nn as nn 2 3 from dgl.utils import expand_as_pair 4 5 class SAGEConv(nn.Module): 6 def __init__(self, 7 in_feats, 8 out_feats, 9 aggregator_type, 10 bias=True, 11 norm=None, 12 activation=None): 13 super(SAGEConv, self).__init__() 14 15 self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats) 16 self._out_feats = out_feats 17 self._aggre_type = aggregator_type 18 self.norm = norm 19 self.activation = activation
在构造函数中,用户首先需要设置数据的维度:
- 维度通常包括输入的维度、输出的维度和隐层的维度
- 对于图神经网络,输入维度可被分为源节点特征维度和目标节点特征维度
除了数据维度,图神经网络的一个典型选项是聚合类型(self._aggre_type
)。对于特定目标节点,聚合类型决定了如何聚合不同边上的信息。 常用的聚合类型包括 mean
、 sum
、 max
和 min
。一些模块可能会使用更加复杂的聚合函数,比如 lstm
上面代码里的 norm
是用于特征归一化的可调用函数。在SAGEConv论文里,归一化可以是L2归一化:
1 # 聚合类型:mean、pool、lstm、gcn 2 if aggregator_type not in ['mean', 'pool', 'lstm', 'gcn']: 3 raise KeyError('Aggregator type {} not supported.'.format(aggregator_type)) 4 if aggregator_type == 'pool': 5 self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats) 6 if aggregator_type == 'lstm': 7 self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True) 8 if aggregator_type in ['mean', 'pool', 'lstm']: 9 self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias) 10 self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias) 11 self.reset_parameters()
注册参数和子模块。在SAGEConv中,子模块根据聚合类型而有所不同。这些模块是纯PyTorch NN模块,例如 nn.Linear
、 nn.LSTM
等。 构造函数的最后调用了 reset_parameters()
进行权重初始化。
  
1 def reset_parameters(self): 2 """重新初始化可学习的参数""" 3 gain = nn.init.calculate_gain('relu') 4 if self._aggre_type == 'pool': 5 nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain) 6 if self._aggre_type == 'lstm': 7 self.lstm.reset_parameters() 8 if self._aggre_type != 'gcn': 9 nn.init.xavier_uniform_(self.fc_self.weight, gain=gain) 10 nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)
3.2 编写DGL NN模块的forward函数
forward()
函数实际执行了消息传递和计算的步骤
forward()
函数的内容一般可以分为3项操作:
-
检测输入图对象是否符合规范。
-
消息传递和聚合。
-
聚合后,更新特征作为输出
输入图对象的规范检测
1 def forward(self, graph, feat): 2 with graph.local_scope(): 3 # 指定图类型,然后根据图类型扩展输入特征 4 feat_src, feat_dst = expand_as_pair(feat, graph)
forward()
函数需要处理输入的许多极端情况,这些情况可能导致计算和消息传递中的值无效。
DGL NN模块可在不同类型的图输入中重复使用,包括:同构图、异构图(1.5 异构图)和子图块(第6章:在大图上的随机(批次)训练)。
源节点特征 feat_src
和目标节点特征 feat_dst
需要根据图类型被指定。 用于指定图类型并将 feat
扩展为 feat_src
和 feat_dst
的函数是 expand_as_pair()
。 该函数的细节如下所示。
1 def expand_as_pair(input_, g=None): 2 if isinstance(input_, tuple): 3 # 二分图的情况 4 return input_ 5 elif g is not None and g.is_block: 6 # 子图块的情况 7 if isinstance(input_, Mapping): 8 input_dst = { 9 k: F.narrow_row(v, 0, g.number_of_dst_nodes(k)) 10 for k, v in input_.items()} 11 else: 12 input_dst = F.narrow_row(input_, 0, g.number_of_dst_nodes()) 13 return input_, input_dst 14 else: 15 # 同构图的情况 16 return input_, input_
对于同构图上的全图训练,源节点和目标节点相同,它们都是图中的所有节点。
在异构图的情况下,图可以分为几个二分图,每种关系对应一个。关系表示为 (src_type, edge_type, dst_dtype)
。 当输入特征 feat
是1个元组时,图将会被视为二分图。元组中的第1个元素为源节点特征,第2个元素为目标节点特征。
在小批次训练中,计算应用于给定的一堆目标节点所采样的子图。子图在DGL中称为区块(block
)。 在区块创建的阶段,dst nodes
位于节点列表的最前面。通过索引 [0:g.number_of_dst_nodes()]
可以找到 feat_dst
。
确定 feat_src
和 feat_dst
之后,以上3种图类型的计算方法是相同的。
消息传递和聚合
1 import dgl.function as fn 2 import torch.nn.functional as F 3 from dgl.utils import check_eq_shape 4 5 if self._aggre_type == 'mean': 6 graph.srcdata['h'] = feat_src 7 graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh')) 8 h_neigh = graph.dstdata['neigh'] 9 elif self._aggre_type == 'gcn': 10 check_eq_shape(feat) 11 graph.srcdata['h'] = feat_src 12 graph.dstdata['h'] = feat_dst 13 graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh')) 14 # 除以入度 15 degs = graph.in_degrees().to(feat_dst) 16 h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1) 17 elif self._aggre_type == 'pool': 18 graph.srcdata['h'] = F.relu(self.fc_pool(feat_src)) 19 graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh')) 20 h_neigh = graph.dstdata['neigh'] 21 else: 22 raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type)) 23 24 # GraphSAGE中gcn聚合不需要fc_self 25 if self._aggre_type == 'gcn': 26 rst = self.fc_neigh(h_neigh) 27 else: 28 rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)
上面的代码执行了消息传递和聚合的计算。这部分代码会因模块而异。请注意,代码中的所有消息传递均使用 update_all()
API和 DGL内置的消息/聚合函数来实现,以充分利用 2.2 编写高效的消息传递代码 里所介绍的性能优化。
聚合后,更新特征作为输出
1 # 激活函数 2 if self.activation is not None: 3 rst = self.activation(rst) 4 # 归一化 5 if self.norm is not None: 6 rst = self.norm(rst) 7 return rst
forward()
函数的最后一部分是在完成消息聚合后更新节点的特征。 常见的更新操作是根据构造函数中设置的选项来应用激活函数和进行归一化。
3.3 异构图上的GraphConv模块
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 【自荐】一款简洁、开源的在线白板工具 Drawnix
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
· 无需6万激活码!GitHub神秘组织3小时极速复刻Manus,手把手教你使用OpenManus搭建本
· C#/.NET/.NET Core优秀项目和框架2025年2月简报
· DeepSeek在M芯片Mac上本地化部署