PyG

PyG

图的构建

同质图

  • 假设我们要构建一个 graph \(\mathcal{G}=\langle \mathcal{V}, \mathcal{E} \rangle\), 其中 \(|\mathcal{V}| = V, |\mathcal{E}| = E\).

[torch_geometric.data.Data]

class Data:

    def __init__(self, x: OptTensor = None, edge_index: OptTensor = None,
                 edge_attr: OptTensor = None, y: OptTensor = None,
                 pos: OptTensor = None, **kwargs):
  • x: torch.Tensor, (V, *).
  • edge_index: torch.Tensor, (2, E)
  • ...
graph = Data()
graph.x = torch.empty((3,))
graph.edge_index = torch.tensor([
    [0, 1],
    [1, 2]
])
>>> graph
Data(x=[3], edge_index=[2, 2])
  • 此图实际上是 \(v_0 \rightarrow v_1 \rightarrow v_2\), 这个有向图, 我们可以通过[to_undirected]快捷地将它转换为无向图:
def to_undirected(
    edge_index: Tensor,
    edge_attr: Optional[Union[Tensor, List[Tensor]]] = None,
    num_nodes: Optional[int] = None,
    reduce: str = "add",
) -> Union[Tensor, Tuple[Tensor, Tensor], Tuple[Tensor, List[Tensor]]]:
>>> graph.edge_index = to_undirected(graph.edge_index)
>>> graph.edge_index
tensor([[0, 1, 1, 2],
        [1, 0, 2, 1]])

异质图

  • 异质图主要围绕着边的类型构建, 主要通过 [HeteroData] 构建, 每个边类型的构建和普通图是完全一致的.

  • 接下来以推荐系统中的二部图为例:

from torch_geometric.data import HeteroData

graph = HeteroData()

# nodes
graph['User'].x = torch.empty((4,))
graph['Item'].x = torch.empty((5,))

# edge type "(User, click, Item)"
graph['User', 'click', 'Item'].edge_index = torch.tensor([
    [0, 1, 2, 3, 3],
    [0, 1, 2, 3, 4]
])

>>> graph
HeteroData(
  User={ x=[4] },
  Item={ x=[5] },
  (User, click, Item)={ edge_index=[2, 5] }
)

>>> graph.num_nodes
9
>>> graph['User'].num_nodes
4
>>> graph[('User', 'click', 'Item')]
{'edge_index': tensor([[0, 1, 2, 3, 3],
        [0, 1, 2, 3, 4]])}
>>> graph['click']
{'edge_index': tensor([[0, 1, 2, 3, 3],
        [0, 1, 2, 3, 4]])}
  • 二部图转为同质图:
graph = graph.coalesce() # 抹去重复的边
graph = graph.to_homogeneous()

>>> graph
Data(edge_index=[2, 5], x=[9], node_type=[9], edge_type=[5])
>>> graph.node_type
tensor([0, 0, 0, 0, 1, 1, 1, 1, 1])
>>> graph.edge_type
tensor([0, 0, 0, 0, 0])
  • 转为无向图:
graph.edge_index, graph.edge_type = to_undirected(graph.edge_index, edge_attr=graph.edge_type)

>>> graph
Data(edge_index=[2, 10], x=[9], node_type=[9], edge_type=[10])

MessagePassing

[torch_geometric.nn.conv.MessagePassing]

  • 一般的 GCN 可以归结为如下形式:

    \[x_i = \phi(x_i, \oplus_{j \in \mathcal{N}(i)} \: \varphi(x_i, x_j, e_{j \rightarrow i})). \]

  • 其中我们需要设定的包括:

    • \(\varphi\), message: 逐边处理的一个函数;
    • \(\oplus_{j \in \mathcal{N}(i)}\), aggr: 聚合操作, 比如常见的, sum, mean, min, max, 也可以是人为定义的.
    • \(\phi\), update: 更新函数, 通常是一些非线性的变换.
class MessagePassing:

    def __init__(
        self,
        aggr: Optional[Union[str, List[str], Aggregation]] = "add",
        *,
        aggr_kwargs: Optional[Dict[str, Any]] = None,
        flow: str = "source_to_target",
        node_dim: int = -2,
        decomposed_layers: int = 1,
        **kwargs,
    ): ...

    def propagate(self, edge_index: Adj, size: Size = None, **kwargs): ...
    def edge_updater(self, edge_index: Adj, **kwargs): ...
    def message(self, x_j: Tensor) -> Tensor: ...
    def aggregate(self, inputs: Tensor, index: Tensor,
                  ptr: Optional[Tensor] = None,
                  dim_size: Optional[int] = None) -> Tensor: ...
    def message_and_aggregate(self, adj_t: SparseTensor) -> Tensor: ...
    def update(self, inputs: Tensor) -> Tensor: ...
    def edge_update(self) -> Tensor: ...
  • aggr: 上面提到的, 支持多种方式;

  • aggr_kwargs: 自定义的一些聚合方式可能需要传一些参数;

  • flow: source_to_target (默认) or target_to_source. 我们知道, edge_index 是 (2, E) 大小的 tensor, 每一列表示一条边 \(e_{ji}\), 如果是前者则这条边方向是 \(e_{j \rightarrow i}\), 否则方向为 \(e_{i \rightarrow j}\). 故如果是 source_to_target, 则 x_i 相当于是 x[edge_index[1]], x_j 相当于是 x[edge_index[0]]. 如果是 target_to_source, 则 x_i 相当于是 x[edge_index[0]], x_j 相当于是 x[edge_index[1]]. 总 i 始终表示 target nodes, j 始终表示 source nodes.

  • node_dim: The axis along which to propagate. (default: :obj:-2) 这个主要用在 aggregation 的时候. 比如 aggregation 的输入为 (V, D) 大小的 tensor, 默认的 node_dim=-2 就能够保证是将不同的结点的特征聚合起来, 如果 node_dim=-1, 还要起到相同的效果就得输入 (D, V) 格式. 至于为什么不是 dim=0, 大概是因为可能有些时候会遇到是 (B, V, D) 之类的情况. 但是感觉文档没有突出它的重要性啊, 应该是很重要的参数.

  • MessagePassing 和普通的 nn.Module 类似, 主要脚本在 forward 中, 一般我们会在 forward 中调用 propagate 方法来管理卷积过程:

    • edge_index: (2, E);
    • size: 如果为 None, 则表示默认处理的是普通的图, 此时要求 source, target 的结点数目是一致的; 如果显式给定 (M, N), 则表示 source, target 的结点数目分别为 (M, N).
  • propagate 的执行流程如下:

    1. 检查输入, 整理得到合适的输入格式, 记为 coll_dict, 基于此得到适合各函数的输入: msg_kwargs, aggr_kwargs, update_kwargs, coll_dict 中包含如下的特殊的关键字:
          if isinstance(edge_index, Tensor):
              out['adj_t'] = None
              out['edge_index'] = edge_index
              out['edge_index_i'] = edge_index[i]
              out['edge_index_j'] = edge_index[j]
              out['ptr'] = None
          elif isinstance(edge_index, SparseTensor):
              out['adj_t'] = edge_index
              out['edge_index'] = None
              out['edge_index_i'] = edge_index.storage.row()
              out['edge_index_j'] = edge_index.storage.col()
              out['ptr'] = edge_index.storage.rowptr()
              if out.get('edge_weight', None) is None:
                  out['edge_weight'] = edge_index.storage.value()
              if out.get('edge_attr', None) is None:
                  out['edge_attr'] = edge_index.storage.value()
              if out.get('edge_type', None) is None:
                  out['edge_type'] = edge_index.storage.value()
      
          out['index'] = out['edge_index_i']
          out['size'] = size
          out['size_i'] = size[i] if size[i] is not None else size[j]
          out['size_j'] = size[j] if size[j] is not None else size[i]
          out['dim_size'] = out['size_i']
      
    2. 如果 message_and_aggregate 实现了, 则调用它, 否则向下执行;
    3. out = self.message(**msg_kwargs);
    4. out = self.aggregate(out, **aggr_kwargs);
    5. out = self.update(out, **update_kwargs);
    6. 然后输出
  • message 部分默认接受 \(x_j\), 默认情况下,

    x_j = x[edge_index[0]]
    

    这是个 \((E, D)\) 大小的 tensor, 这里假设在边上的操作只和 source 有关. 如果我们要弄一个复杂一点的, 比如:

    \[\varphi(x_i, x_j) = W[x_i\|x_j], \]

    就可以这么定义:

    def message(x_i: torch.Tensor, x_j: torch.Tensor):
        return self.mlp(torch.cat((x_i, x_j), dim=-1))
    
  • aggregate 部分接受 message 的输出 (E, D) 大小的 tensor 和一些其它的可选参数:

    def aggregate(self, inputs: Tensor, index: Tensor,
                  ptr: Optional[Tensor] = None,
                  dim_size: Optional[int] = None) -> Tensor:
        # inputs: (E, D)
        # index: edge_index_i, namely target index
        return self.aggr_module(inputs, index, ptr=ptr, dim_size=dim_size,
                                dim=self.node_dim)
    
  • update 部分默认接受 aggregate 的 (V, D) 的输出:

    def update(self, inputs: Tensor) -> Tensor:
        # inputs: (V, D)
        return inputs
    

    稍微复杂点的, 比如

    \[\phi(x_i, x_i^{aggr}) = \text{ReLU}(x_i + x_i^{aggr}). \]

    def update(self, aggregated: Tensor, x) -> Tensor:
        # aggregated: (V, D)
        # x: (V, D)
        return self.relu(inputs + x)
    
posted @ 2023-10-16 21:54  馒头and花卷  阅读(129)  评论(0编辑  收藏  举报