深入浅出图神经网络 GCN代码实战
GCN代码实战
书中5.6节的GCN代码实战做的是最经典Cora数据集上的分类,恰当又不恰当的类比Cora之于GNN就相当于MNIST之于机器学习。
有关Cora的介绍网上一搜一大把我就不赘述了,这里说一下Cora这个数据集对应的图是怎么样的。
Cora有2708篇论文,之间有引用关系共5429个,每篇论文作为一个节点,引用关系就是节点之间的边。每篇论文有一个1433维的特征来表示某个词是否在文中出现过,也就是每个节点有1433维的特征。最后这些论文被分为7类。
所以在Cora上训练的目的就是学习节点的特征及其与邻居的关系,根据已知的节点分类对未知分类的节点的类别进行预测。
知道这些应该就OK了,下面来看代码。
数据处理
注释里自己都写了代码引用自PyG我觉得就扫几眼就行了,因为现在常用的数据集两个GNN轮子(DGL和PyG)里都有,现在基本都是直接用,很少自己下原始数据再处理了,所以略过。
GCN层定义
回顾第5章中GCN层的定义:
所以对于一层GCN,就是对输入X,乘一个参数矩阵W,再乘一个算好归一化后的“拉普拉斯矩阵”即可。
来看代码:
定义了一层GCN的输入输出维度和偏置,对于GCN层来说,每一层有自己的W,X是输入给的,˜Lsym是数据集算的,所以只需要定义一个weight
矩阵,注意一下维度就行。
传播的时候只要按照公式X′=σ(˜LsymXW)进行一下矩阵乘法就好,注意一个trick:˜Lsym是稀疏矩阵,所以先矩阵乘法得到XW,再用稀疏矩阵乘法计算˜LsymXW运算效率上更好。
GCN模型定义
知道了GCN层的定义之后堆叠GCN层就可以得到GCN模型了,两层的GCN就可以取得很好的效果(过深的GCN因为过度平滑的问题会导致准确率下降):
这里设置隐藏层维度为16,调到32,64,...都是可以的,我自己试的结果来说没有太大的区别。从隐藏层到输出层直接将输出维度设置为分类的维度就可以得到预测分类。
传播的时候相比于每一层的传播只需要加上激活函数,这里选用ReLU
。
训练
定义模型、损失函数(交叉熵)、优化器:
具体的训练函数注释已经解释的很清楚:
对应的测试函数:
注意模型得到的分类不是one-hot的,而是对应不同种类的预测概率,所以要test_mask_logits.max(1)[1]
取概率最高的一个作为模型预测的类别。
这些都写好之后直接运行训练函数即可。有需要还可以对train_loss
和validation_accuracy
进行画图,书上也给出了相应的代码,比较简单不再赘述。
__EOF__

本文链接:https://www.cnblogs.com/daztricky/p/15010350.html
关于博主:评论和私信会在第一时间回复。或者直接私信我。
版权声明:本博客所有文章除特别声明外,均采用 BY-NC-SA 许可协议。转载请注明出处!
声援博主:如果您觉得文章对您有帮助,可以点击文章右下角【推荐】一下。您的鼓励是博主的最大动力!
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】凌霞软件回馈社区,博客园 & 1Panel & Halo 联合会员上线
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】博客园社区专享云产品让利特惠,阿里云新客6.5折上折
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· PostgreSQL 和 SQL Server 在统计信息维护中的关键差异
· C++代码改造为UTF-8编码问题的总结
· DeepSeek 解答了困扰我五年的技术问题
· 为什么说在企业级应用开发中,后端往往是效率杀手?
· 用 C# 插值字符串处理器写一个 sscanf
· [翻译] 为什么 Tracebit 用 C# 开发
· Deepseek官网太卡,教你白嫖阿里云的Deepseek-R1满血版
· DeepSeek崛起:程序员“饭碗”被抢,还是职业进化新起点?
· 2分钟学会 DeepSeek API,竟然比官方更好用!
· .NET 使用 DeepSeek R1 开发智能 AI 客户端