networkx - 可达节点集合

import torch
import networkx as nx
from torch_geometric.data import Data
from torch_geometric.utils import to_networkx

# 构造一个简单的数据对象
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
data = Data(edge_index=edge_index)

# 将数据对象转换为 NetworkX 图形对象
graph = to_networkx(data)

# 计算可达节点数量
reachable_nodes = nx.descendants(graph, 0)
num_reachable_nodes = len(reachable_nodes)
print(num_reachable_nodes)
posted @ 2023-04-02 15:35  X1OO  阅读(64)  评论(0)    收藏  举报