Loading

PYG中的邻居采样NeighborLoader

Example

from torch_geometric.datasets import Planetoid
from torch_geometric.loader import NeighborLoader
import torch

data = Planetoid('./dataset', name='Cora')[0]

# Assign each node its global node index:
data.n_id = torch.arange(data.num_nodes)

loader = NeighborLoader(
    data,
    # Sample 30 neighbors for each node for 2 iterations
    num_neighbors=[30] * 2,
    # Use a batch size of 128 for sampling training nodes
    batch_size=128,
    input_nodes=data.train_mask,
)

sampled_data = next(iter(loader))
print(sampled_data.batch_size)
print(sampled_data.n_id) # NeighborLoader返回的子图中的节点index是local的,而非在原始data中的index,因此我们要给data增加一个n_id属保存原始节点id,并进行映射

完整示例

点这里

API 介绍

文档介绍
更详细的版本
参见注释内的args部分

部分用法讲解(代码取自完整示例)

  1. 加载数据
  • data要求是torch_geometric.data.Data or torch_geometric.data.HeteroData类型
  • input_nodes : 中心节点集合,即一个mini-batch内的节点,如果为None,则代表包含data中的所有节点
  • num_neighbors: 每轮迭代要采样邻居节点的个数,即第i轮要为每个节点采样num_neighbors[i]个节点,如果为-1,则代表所有邻居节点都将被包含。
kwargs = {'batch_size': 1024, 'num_workers': 6, 'persistent_workers': True}
train_loader = NeighborLoader(data, input_nodes=data.train_mask,
                              num_neighbors=[25, 10], shuffle=True, **kwargs)

subgraph_loader = NeighborLoader(copy.copy(data), input_nodes=None,
                                 num_neighbors=[-1], shuffle=False, **kwargs)
  1. 子图index映射
    NeighborLoader返回的子图中的节点index是local的,而非在原始data中的index,因此我们要给data增加一个n_id属保存原始节点id,并进行映射
# Add global node index information.
subgraph_loader.data.num_nodes = data.num_nodes
subgraph_loader.data.n_id = torch.arange(data.num_nodes)
...
# 映射回原始index
for batch in subgraph_loader:
    x = x_all[batch.n_id.to(x_all.device)].to(device)
    ...
  1. 每个minibatch内的节点顺序
    NeighborLoader返回的子图的节点顺序是按照采样顺序排的,即mini-batch内的中心节点是最前面的batch size个,因此取模型结果的时候要取前batch size个
y = batch.y[:batch.batch_size]
y_hat = model(batch.x, batch.edge_index.to(device))[:batch.batch_size] # 前batch_size个节点为中心节点,即mini-batch内的节点。
  1. train和test的区别
    在train的时候依次计算每个batch内的节点的表示,要进行n/batch*layer次卷积;但是在inference的时候是将所有的边传入模型,在每层内分batch计算,即每次卷积之后所有节点的表示都得到更新,共进行layer次卷积,能够加快计算速度。
    @torch.no_grad()
    def inference(self, x_all, subgraph_loader):
        pbar = tqdm(total=len(subgraph_loader.dataset) * len(self.convs))
        pbar.set_description('Evaluating')

        # Compute representations of nodes layer by layer, using *all*
        # available edges. This leads to faster computation in contrast to
        # immediately computing the final representations of each batch:
        for i, conv in enumerate(self.convs):
            xs = []
            for batch in subgraph_loader:
                x = x_all[batch.n_id.to(x_all.device)].to(device)
                x = conv(x, batch.edge_index.to(device))
                if i < len(self.convs) - 1:
                    x = x.relu_()
                xs.append(x[:batch.batch_size].cpu())
                pbar.update(batch.batch_size)
            x_all = torch.cat(xs, dim=0)
        pbar.close()
        return x_all

其他资料

pytorch geometric教程三 GraphSAGE源码详解+实战
pytorch geometric教程一: 消息传递源码详解(MESSAGE PASSING)+实例

posted @ 2022-11-08 15:28  摇头晃脑学知识  阅读(2024)  评论(0编辑  收藏  举报