DGL使用手记

DGL简介

DGL(Deep Graph Library)是一个用于搭建图神经网络的框架,支持pytorch, TensorFlow, MXNet等机器学习框架,集成了图神经网络的许多功能.编写这篇文章的契机是由于torch上的spmm速度有点慢,所以想借用一下DGL的.

图构建

一般来说,图由节点,边以及相应的节点特征和边特征组成.我们可以从coo, csr等边列表的形式直接生成一个图:

import dgl
import dgl.function as fn
src_nodes, dst_nodes = torch.tensor([0, 0, 1, 2]), torch.tensor([0, 1, 2, 0])
g = dgl.graph((src_nodes, dst_nodes))

这样我们就得到了一个三节点四条边的有向图.注意:dgl将所有的图都视为有向图.

设定特征

dgl的特征存储方式与其他框架并无差别,用n*m的矩阵存储n节点m维特征.比如上面的三节点图,若要给每个节点一个特征f,则可以这样写:

feature = torch.arange(0, 12, dtype=torch.float).reshape(3, 4)
g.ndata['f'] = feature

其中g.ndata是一个访问节点特征的接口,使用起来类似一个字典.同样的也有g.edata来访问边特征.但注意:

  • 仅允许使用数值类型(如单精度浮点型、双精度浮点型和整型)的特征。这些特征可以是标量、向量或多维张量

  • 每个节点特征具有唯一名称,每个边特征也具有唯一名称。节点和边的特征可以具有相同的名称

  • 通过张量分配创建特征时,DGL会将特征赋给图中的每个节点和每条边。该张量的第一维必须与图中节点或边的数量一致。 不能将特征赋给图中节点或边的子集。(如果特征只有一维,可以是(n, )的形状)

  • 相同名称的特征必须具有相同的维度和数据类型。

  • 特征张量使用”行优先”的原则,即每个行切片储存1个节点或1条边的特征

消息传递

这是dgl的核心.我们知道,在GCN当中,一般有三种操作:

  • 矩阵乘.与普通的神经网络类似,是对节点自身特征的线性变换.
  • 激活函数.与普通的神经网络类似,是对节点自身特征的非线性变换.
  • 聚合操作.这是核心,即节点通过图的邻接矩阵获取邻居信息,并进行一定操作(如取最值,求和,平均等)

聚合操作实际上可用稀疏矩阵乘法实现.在dgl当中,用消息传递机制作为spmm的实现:

DGLGraph.update_all(message_func, reduce_func, apply_node_func=None, etype=None)

可以看出,它接受四个参数,我们主要解释前三个:

  • message_func是消息函数,是边接受源节点的信息.进行节点间消息传递自然是通过边进行,这个函数就是完成源节点特征与边特征的处理.
  • reduce_func是聚合函数,按字面意思应理解为"归约函数",有分布式编程经验的人应该不难理解这个词的含义.它将上面处理过的特征以某种方式归约到目标节点(如mean, max, sum)等.
  • apply_node_func是更新函数,它实际完成的就是激活函数的工作.由于它可以用纯张量操作完成,因此dgl不推荐在update_all内指定更新函数.

我们来看一个实际例子.

>>> g = dgl.graph(([0, 1, 2, 3], [1, 2, 3, 4]))
>>> g.ndata['x'] = torch.ones(5, 2)
>>> g.update_all(fn.copy_u('x', 'm'), fn.sum('m', 'h'))
>>> g.ndata['h']

输出为

tensor([[0., 0.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.]])

在这里面,fn.copy_u('x', 'm')的含义就是将边的源节点的x属性复制到边上,命名为m,随后在归约时,每个节点将指向自己的边上的m属性求和,保存到节点属性h上.

posted @   LinXiaoshu  阅读(1036)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 【自荐】一款简洁、开源的在线白板工具 Drawnix
· 没有Manus邀请码?试试免邀请码的MGX或者开源的OpenManus吧
· 无需6万激活码!GitHub神秘组织3小时极速复刻Manus,手把手教你使用OpenManus搭建本
· C#/.NET/.NET Core优秀项目和框架2025年2月简报
· DeepSeek在M芯片Mac上本地化部署
点击右上角即可分享
微信分享提示