图神经网络GNN:给图多个 node features和edge features

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
'''
摘自https://docs.dgl.ai/en/0.6.x/guide_cn/graph-feature.html
'''
 
import dgl
import torch as th
 
# ========================= 无权图 ======================================
 
g = dgl.graph(([0, 0, 1, 5], [1, 2, 2, 0])) # 6个节点,4条边
 
# each graph can have many 'node features'
g.ndata['x'] = th.ones(g.num_nodes(), 3)   # 节点特征x, 特征长度为3
g.ndata['y'] = th.randn(g.num_nodes(), 5# 节点特征y,特征长度为5
 
# similarly, each graph can have many 'edge features'
g.edata['x'] = th.ones(g.num_edges(), dtype=th.int32)  # 标量整型边特征x
g.edata['z'] = th.ones(g.num_edges(), dtype=th.float32)  # 浮点型型边特征z
 
print('g:\n', g)
 
print(g.ndata['x'][1])     # 获取节点特征x的节点1特征
print(g.ndata['y'][1])     # 获取节点特征y的节点1特征
 
print(g.edata['x'][th.tensor([0, 3])])  # 获取边特征x下的0和3节点特征
print()
 
 
# ========================= 有权图 ======================================
 
# edges 0->1, 0->2, 0->3, 1->3
edges = th.tensor([0, 0, 0, 1]), th.tensor([1, 2, 3, 3])
weights = th.tensor([0.1, 0.6, 0.9, 0.7])  # weight of each edge
g = dgl.graph(edges)
g.edata['w'] = weights  # give it a name 'w'
print('weighted graph:\n', g)
posted @   Picassooo  阅读(676)  评论(0编辑  收藏  举报
编辑推荐:
· 基于Microsoft.Extensions.AI核心库实现RAG应用
· Linux系列:如何用heaptrack跟踪.NET程序的非托管内存泄露
· 开发者必知的日志记录最佳实践
· SQL Server 2025 AI相关能力初探
· Linux系列:如何用 C#调用 C方法造成内存泄露
阅读排行:
· Manus爆火,是硬核还是营销?
· 终于写完轮子一部分:tcp代理 了,记录一下
· 震惊!C++程序真的从main开始吗?99%的程序员都答错了
· 别再用vector<bool>了!Google高级工程师:这可能是STL最大的设计失误
· 单元测试从入门到精通
点击右上角即可分享
微信分享提示