pytorch geometric 创建自己的数据集的两种方法
- 方法一:简单使用data 创建
我真的需要使用这些数据集接口吗?
不!就像在普通的PyTorch中,你不需要使用数据集,例如,当你想要动态创建合成数据而不显式地将它们保存到磁盘时。在本例中,只需传递一个包含torch_geometry .data的常规python列表。数据对象并将它们传递给torch_geometry .loader. dataloader:
https://pytorch-geometric.readthedocs.io/en/latest/notes/create_dataset.html
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
data_list = [Data(...), ..., Data(...)]
loader = DataLoader(data_list, batch_size=32)
其中data对象的创建:https://pytorch-geometric.readthedocs.io/en/latest/notes/introduction.html
A graph is used to model pairwise relations (edges) between objects (nodes). A single graph in PyG is described by an instance of torch_geometric.data.Data, which holds the following attributes by default:
data.x: Node feature matrix with shape [num_nodes, num_node_features]
data.edge_index: Graph connectivity in COO format with shape [2, num_edges] and type torch.long
data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
data.y: Target to train against (may have arbitrary shape), e.g., node-level targets of shape [num_nodes, *] or graph-level targets of shape [1, *]
data.pos: Node position matrix with shape [num_nodes, num_dimensions]
简单示例:
import torch
from torch_geometric.data import Data
edge_index = torch.tensor([[0, 1, 1, 2],
[1, 0, 2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
data = Data(x=x, edge_index=edge_index)
Data(edge_index=[2, 4], x=[3, 1])
方法2: 根据https://zhuanlan.zhihu.com/p/480173796