GNN图神经网络原理解析
以前搞机器学习、数据挖掘,主要是针对文本、图像和结构化的数据。但在现实的物理世界中,还有一类非常重要的数据结构:图(不是图片Image,而是graph)!最常见的graph结构:
- 社交网络了:比如微信、qq这种好友关系的无向图;又比如weibo、x这种关注关系的有向图;
- google搜索引擎早期排序的算法:pageRank中网页被其他网页引用
- 导航的物理地图
- 化学/材料分子结构中原子的连接方式
- 电子元器件的连接方式
如果能较好地处理这类grapg数据,肯定是能产生巨大应用价值的!比如:
- 社交网络的communite detection,看看哪些人是一伙的
- 黑灰产的洗钱、薅羊毛识别
- 新材料、药品的研发合成
- 交通流量的预测
1、(1)传统的机器学习中,结构化数据存放在数据库;文本转成unicode表示和存储,image则以像素的形式表示和存储,那么graph是怎么表示和存储的了?从物理上讲,graph核心的构成无非就是vertex和edge,如下所示:
所以graph最核心的信息就是vertex之间的连接关系了!最简单的表示和存储方式当属 Adjacency matrix了,以social network为例,存储方式如下:A和B两个user如果有interact,那么在矩阵中对应的位置Wab就有值,否则该位置为0;
以这个思路类推,image中像素以graph的形式表示为:
文本text的存储方式:
因为很多vertex并未连接,如果全都用矩阵存储会比较浪费空间,所以也可以用hashmap的形式存储,比如A和B之间有连接,那么就是<A,B>形式,比如:
(2)graph的存储和表示问题解决了,就该考虑怎么做好上层应用了! 对于整张图,常见的应用是有没有成环?
对于单个node,常见的应用是comminute detection,比如社交网络的社团挖掘;金融体系的资金监控、反洗钱等;
对于edge,这条边的属性是啥?比如图片中不同人之间的关系:watching吃瓜,还是fighting打斗?
2、上述的应用,都应该怎么实现了?
传统的机器学习和现在流行的LLM,最核心的就是特征构造了!传统的LR、SVM、贝叶斯、决策树等模型,都需要人工手动生成和构造特征,特征决定了整个效果的下限,使用的模型只是最后的临门一脚,特征构造才是最核心的工作!DNN和大模型流行后,特征构造自动化了(主要是通过多层神经网络自动做特征交叉和筛选),构造好的特征存储embedding向量,所以这时最核心的数据就是保留了原始输入信息的密集embedding向量啦!总结一句话就是:万物皆可embedding!embedding就是世间万物的projection!那么问题来了:graph能借鉴embedding的思路,把不规则的node、edge、整个graph都转成embedding么?graph能compatible神经网络么?主要是借此完成合理的信息表达!原文章作者给这种方法取的名字:"message passing neural network"
(1)DNN和transformer架构最核心的就是把任何输入都转成embedding的形式,这里的graph是不是也能借鉴了?先看一个最简单的GNN:
不论是整个graph,还是vertex,还是edge,都认为随机给一个初始化的embedding,然后各自分别接上MLP,这不就把graph和NN联系起来了么?GNN= graph + NN不就顺利成章了么?至于MLP的hidden层数、每层的神经元个数、激活函数、loss函数等都可以根据实际情况灵活设置啦!并且从layer N到layer N+1,整个graph的形状和结构完全不变,变化的仅仅是embedding和NN的参数!
那么问题来了:这个所谓的graph independent layer一般需要多少层了?有句话怎么说来着?你个任何人之间只隔了6个人,所以这里的层级也不用太多,2~3层就够了!既然embedding都有了,接下来的predict岂不是简单很多了?
vertext、edge、整个graph经过MLP做转换,得到新的embedding;然后根据业务需求,加上classifier,对embedding做分类!整个流程的原理so easy!具体在计算时,可能涉及到大量的聚合操作,比如用node的embedding表示edge或整个图,可以用mean或sum的方式。anyway,不管怎么做聚合操作,真个end2end的流程是不变的!
(2)GNN就这样完结了?哪有这么简单!上述的GNN有严重的问题:node之间不管有没有连接,对最终的结果没有任何影响,和普通的DNN没本质区别!论文作者给这种算法取的名字:message passing neural network; 上述的GNN只体现了neural network,message passing在哪?
假设A结点连接了B和C,那么A的embedding就是包括自身在内的3个node的聚合(sum、mean都行),这才能体现相连结点之间互相的影响嘛!经过这样一番更新,再进入MLP做后续的流程不就行了?
上述的思路是没问题了,但细节还不够完善,因为没有考虑vertex结点本身的权重!试想:一个graph中(比如人类社交网络),肯定有影响大的核心node,也有边缘化的凑数node,相邻node的embedding如果同权相加,很明显不合适啊!怎么做相邻结点的加权求和了?
(3)一个node是否核心,关键指标之一是计算该node的degree:degree越大, 越核心!所以如下如,E的degree是4,A的degree是1,很明显E比A更核心!
- 邻接矩阵与特征矩阵进行乘法操作,表示聚合邻居信息:比如A和E连接,A需要聚合E的embedding,所以 邻接矩阵* feather矩阵,A就把E的embedding信息整合到自己这里了!
上述的做法完全把相邻点的信息拿过来了,但自己的信息又丢了,所以还是要把自己的信息先加上,具体就是在邻接矩阵的对角线都设置为1,如下:
- 上述方式得到的feather矩阵只是把相邻node的embedding求和,还没求平均了!怎么通过矩阵乘法让所有相邻结点的embedding和自己原来的embedding加起来求mean了?既然求平均的除数和degree相关,那就先对degree矩阵做处理呗!
最终embedding加权求和的计算如下:先对周边邻居结点和自己的embedding求和,得到sum of neighborb矩阵,再乘以邻接矩阵的逆矩阵,就得到每个结点周边相邻结点embedding和的均值!这个就是对sum of neighborb矩阵的每行做归一化!
列也要做归一化:(原因后续解释)
但是如果A矩阵左右都乘以邻接矩阵,相当于做了两次scale,数值明显有问题,所以左右相乘各自开平方:
- 对于矩阵A,为啥要用D^(-1/2)来左乘和右乘?业务意义是啥?举个例子:我和王二狗是小学同学,互相加了微信。突然某天,王二狗因为家里拆迁一夜暴富,有了大量的资本,进而结识了很多X二代。而我了,还是默默无闻地坐在格子间继续搬砖。从社交网络图看,王二狗的degree高达10000,而我因为只认识王二狗,所以degree只有1,如下图的绿点和红点;
很明显,因为王二狗的人脉宽广,我对于王二狗而言只是1/10000,所以利用王二狗的embedding更新我自己的embedding的时候,王二狗的贡献会很小,肯定要打折扣,而不是全部!那么问题来了:这个折扣打多少?原论文作者的公式如下:
Vi和Vj分别是我和王二狗的degree;从公式可以看出,王二狗degree很大,我只是他很不起眼的一个关系,所以王二狗对我的贡献很小,刚好可以通过上述公式量化计算贡献值!
(4)整个流程和细节理顺了,先来整体看一下message passing的流程,加深印象:以4层为例图示如下:每个node的下一层会连接新node,经过的layer越多,上层node累计的信息就越多。从原作者实践情况看,2~3层足够了,层数太多反而指标下降!(把八竿子都打不着的远方亲戚拉进来,显然不合适啊!)
以2层的layer为例,计算公式如下:就是里面一层,外面一层;如果层数增加,按照这个规律继续嵌套就行了!
损失函数根据业务目的不同而改变,这里用cross entropy,然后做back proprgation更新参数W和A!2~3层就合适了,不用太多!感受野过大,特征发散,结果反而不准了!
3、总结:GNN和传统的DNN相比,最大的区别在于使用了邻接点的embedding!每个node都会通过邻接点的degree做加权,然后将其embedding的信息整合到自己的embedding中来更新自己的embedding!这个思路和attention这种利用context更新自己embedding的机制类似!
参考:
1、https://distill.pub/2021/gnn-intro/ A Gentle Introduction to Graph Neural Networks
2、https://www.bilibili.com/video/BV1zDppe2Edx?p=12&vd_source=241a5bcb1c13e6828e519dd1f78f35b2
3、https://www.bilibili.com/video/BV1iT4y1d7zP/?spm_id_from=333.337.search-card.all.click&vd_source=241a5bcb1c13e6828e519dd1f78f35b2 图解图神经网络GNN
4、https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch%20geometric.nn.conv.GCNConv
5、https://arxiv.org/pdf/1609.02907 SEMI-SUPERVISED CLASSIFICATION WITH GRAPH CONVOLUTIONAL NETWORKS