PyG
图的构建
同质图
- 假设我们要构建一个 graph \(\mathcal{G}=\langle \mathcal{V}, \mathcal{E} \rangle\), 其中 \(|\mathcal{V}| = V, |\mathcal{E}| = E\).
class Data:
def __init__(self, x: OptTensor = None, edge_index: OptTensor = None,
edge_attr: OptTensor = None, y: OptTensor = None,
pos: OptTensor = None, **kwargs):
- x: torch.Tensor, (V, *).
- edge_index: torch.Tensor, (2, E)
- ...
graph = Data()
graph.x = torch.empty((3,))
graph.edge_index = torch.tensor([
[0, 1],
[1, 2]
])
>>> graph
Data(x=[3], edge_index=[2, 2])
- 此图实际上是 \(v_0 \rightarrow v_1 \rightarrow v_2\), 这个有向图, 我们可以通过[to_undirected]快捷地将它转换为无向图:
def to_undirected(
edge_index: Tensor,
edge_attr: Optional[Union[Tensor, List[Tensor]]] = None,
num_nodes: Optional[int] = None,
reduce: str = "add",
) -> Union[Tensor, Tuple[Tensor, Tensor], Tuple[Tensor, List[Tensor]]]:
>>> graph.edge_index = to_undirected(graph.edge_index)
>>> graph.edge_index
tensor([[0, 1, 1, 2],
[1, 0, 2, 1]])
异质图
-
异质图主要围绕着边的类型构建, 主要通过 [HeteroData] 构建, 每个边类型的构建和普通图是完全一致的.
-
接下来以推荐系统中的二部图为例:
from torch_geometric.data import HeteroData
graph = HeteroData()
# nodes
graph['User'].x = torch.empty((4,))
graph['Item'].x = torch.empty((5,))
# edge type "(User, click, Item)"
graph['User', 'click', 'Item'].edge_index = torch.tensor([
[0, 1, 2, 3, 3],
[0, 1, 2, 3, 4]
])
>>> graph
HeteroData(
User={ x=[4] },
Item={ x=[5] },
(User, click, Item)={ edge_index=[2, 5] }
)
>>> graph.num_nodes
9
>>> graph['User'].num_nodes
4
>>> graph[('User', 'click', 'Item')]
{'edge_index': tensor([[0, 1, 2, 3, 3],
[0, 1, 2, 3, 4]])}
>>> graph['click']
{'edge_index': tensor([[0, 1, 2, 3, 3],
[0, 1, 2, 3, 4]])}
- 二部图转为同质图:
graph = graph.coalesce() # 抹去重复的边
graph = graph.to_homogeneous()
>>> graph
Data(edge_index=[2, 5], x=[9], node_type=[9], edge_type=[5])
>>> graph.node_type
tensor([0, 0, 0, 0, 1, 1, 1, 1, 1])
>>> graph.edge_type
tensor([0, 0, 0, 0, 0])
- 转为无向图:
graph.edge_index, graph.edge_type = to_undirected(graph.edge_index, edge_attr=graph.edge_type)
>>> graph
Data(edge_index=[2, 10], x=[9], node_type=[9], edge_type=[10])
MessagePassing
[torch_geometric.nn.conv.MessagePassing]
-
一般的 GCN 可以归结为如下形式:
\[x_i = \phi(x_i, \oplus_{j \in \mathcal{N}(i)} \: \varphi(x_i, x_j, e_{j \rightarrow i})). \] -
其中我们需要设定的包括:
- \(\varphi\),
message
: 逐边处理的一个函数; - \(\oplus_{j \in \mathcal{N}(i)}\),
aggr
: 聚合操作, 比如常见的, sum, mean, min, max, 也可以是人为定义的. - \(\phi\),
update
: 更新函数, 通常是一些非线性的变换.
- \(\varphi\),
class MessagePassing:
def __init__(
self,
aggr: Optional[Union[str, List[str], Aggregation]] = "add",
*,
aggr_kwargs: Optional[Dict[str, Any]] = None,
flow: str = "source_to_target",
node_dim: int = -2,
decomposed_layers: int = 1,
**kwargs,
): ...
def propagate(self, edge_index: Adj, size: Size = None, **kwargs): ...
def edge_updater(self, edge_index: Adj, **kwargs): ...
def message(self, x_j: Tensor) -> Tensor: ...
def aggregate(self, inputs: Tensor, index: Tensor,
ptr: Optional[Tensor] = None,
dim_size: Optional[int] = None) -> Tensor: ...
def message_and_aggregate(self, adj_t: SparseTensor) -> Tensor: ...
def update(self, inputs: Tensor) -> Tensor: ...
def edge_update(self) -> Tensor: ...
-
aggr
: 上面提到的, 支持多种方式; -
aggr_kwargs
: 自定义的一些聚合方式可能需要传一些参数; -
flow
:source_to_target
(默认) ortarget_to_source
. 我们知道,edge_index
是 (2, E) 大小的 tensor, 每一列表示一条边 \(e_{ji}\), 如果是前者则这条边方向是 \(e_{j \rightarrow i}\), 否则方向为 \(e_{i \rightarrow j}\). 故如果是source_to_target
, 则x_i
相当于是 x[edge_index[1]],x_j
相当于是 x[edge_index[0]]. 如果是target_to_source
, 则x_i
相当于是 x[edge_index[0]],x_j
相当于是 x[edge_index[1]]. 总i
始终表示 target nodes,j
始终表示 source nodes. -
node_dim
: The axis along which to propagate. (default: :obj:-2
) 这个主要用在 aggregation 的时候. 比如 aggregation 的输入为 (V, D) 大小的 tensor, 默认的 node_dim=-2 就能够保证是将不同的结点的特征聚合起来, 如果 node_dim=-1, 还要起到相同的效果就得输入 (D, V) 格式. 至于为什么不是 dim=0, 大概是因为可能有些时候会遇到是 (B, V, D) 之类的情况. 但是感觉文档没有突出它的重要性啊, 应该是很重要的参数. -
MessagePassing 和普通的 nn.Module 类似, 主要脚本在
forward
中, 一般我们会在forward
中调用propagate
方法来管理卷积过程:edge_index
: (2, E);size
: 如果为None
, 则表示默认处理的是普通的图, 此时要求 source, target 的结点数目是一致的; 如果显式给定 (M, N), 则表示 source, target 的结点数目分别为 (M, N).
-
propagate
的执行流程如下:- 检查输入, 整理得到合适的输入格式, 记为
coll_dict
, 基于此得到适合各函数的输入:msg_kwargs, aggr_kwargs, update_kwargs
, coll_dict 中包含如下的特殊的关键字:if isinstance(edge_index, Tensor): out['adj_t'] = None out['edge_index'] = edge_index out['edge_index_i'] = edge_index[i] out['edge_index_j'] = edge_index[j] out['ptr'] = None elif isinstance(edge_index, SparseTensor): out['adj_t'] = edge_index out['edge_index'] = None out['edge_index_i'] = edge_index.storage.row() out['edge_index_j'] = edge_index.storage.col() out['ptr'] = edge_index.storage.rowptr() if out.get('edge_weight', None) is None: out['edge_weight'] = edge_index.storage.value() if out.get('edge_attr', None) is None: out['edge_attr'] = edge_index.storage.value() if out.get('edge_type', None) is None: out['edge_type'] = edge_index.storage.value() out['index'] = out['edge_index_i'] out['size'] = size out['size_i'] = size[i] if size[i] is not None else size[j] out['size_j'] = size[j] if size[j] is not None else size[i] out['dim_size'] = out['size_i']
- 如果
message_and_aggregate
实现了, 则调用它, 否则向下执行; out = self.message(**msg_kwargs)
;out = self.aggregate(out, **aggr_kwargs)
;out = self.update(out, **update_kwargs)
;- 然后输出
- 检查输入, 整理得到合适的输入格式, 记为
-
message
部分默认接受 \(x_j\), 默认情况下,x_j = x[edge_index[0]]
这是个 \((E, D)\) 大小的 tensor, 这里假设在边上的操作只和 source 有关. 如果我们要弄一个复杂一点的, 比如:
\[\varphi(x_i, x_j) = W[x_i\|x_j], \]就可以这么定义:
def message(x_i: torch.Tensor, x_j: torch.Tensor): return self.mlp(torch.cat((x_i, x_j), dim=-1))
-
aggregate
部分接受message
的输出 (E, D) 大小的 tensor 和一些其它的可选参数:def aggregate(self, inputs: Tensor, index: Tensor, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None) -> Tensor: # inputs: (E, D) # index: edge_index_i, namely target index return self.aggr_module(inputs, index, ptr=ptr, dim_size=dim_size, dim=self.node_dim)
-
update
部分默认接受aggregate
的 (V, D) 的输出:def update(self, inputs: Tensor) -> Tensor: # inputs: (V, D) return inputs
稍微复杂点的, 比如
\[\phi(x_i, x_i^{aggr}) = \text{ReLU}(x_i + x_i^{aggr}). \]def update(self, aggregated: Tensor, x) -> Tensor: # aggregated: (V, D) # x: (V, D) return self.relu(inputs + x)