7、生成mask,过滤无效张量
1、准备环境
import torch random_tensor = torch.randn(10, 2) print(random_tensor)
2、准备batch索引
from torch_geometric.utils import to_dense_batch, to_dense_adj, degree batch= torch.tensor([0,0,0,1,1,2,2,2,2,2]) abstract_features_1, mask_1 = to_dense_batch(random_tensor, batch) abstract_features_1