HGP-SL程序理解

一、导入自己的数据集

  • PyTorch 所有的数据集对象都是torch.utils.data.Dataset的子类。在继承它的时候必须要重写其__len____getitem__方法;
  • 为了方便数据的存储和读入,可以将数据存为.pt文件(PyTorch 的标准数据文件);
  • 四个基本函数

    1. torch_geometric.data.InMemoryDataset.raw_file_names(): 返回一个文件列表,包含raw_dir中的文件目录。可以根据此列表来决定哪些需要下载或者已下载的直接跳过。
    2. torch_geometric.data.InMemoryDataset.processed_file_names():
      返回一个处理后的文件列表,包含processed_dir中的文件目录。据此来决定需要跳过。也就说,在你处理完后,你再次运行该程序将不会二次处理。
    3. torch_geometric.data.InMemoryDataset.download():
      将原始数据下载到 raw_dir 文件夹.
    4. torch_geometric.data.InMemoryDataset.process():
      处理原始数据将结果存放至 processed_dir 文件夹. 注意,这里需要将结果存储成Data格式。为解决python处理达标存储慢的的问题,通过torch_geometric.data.InMemoryDataset.collate()将许多Data列表整理成一个很大的Data对象,并且返回一个slices索引字典,因此我们需要设置self.dataself.slice这两个属性。
    5. 引自 https://zhuanlan.zhihu.com/p/132335866

二、运行

  • scatter_add_(dim,index,src ) 将张量src中的所有值添加到张量中self指定的索引index
  • numpy.bincount(x, weights=None, minlength=0) 统计非负整数出现的次数。每一个bin都是给出输入数组中每一个数出现的次数。

  • torch.cumsum() 对于二维输入a,dim=0(第1行不动,将第1行累加到其他行);dim=1(进入最内层,转化成列处理。第1列不动,将第1列累加到其他列;从第一列开始后面的每一列都是前面对应行元素的累加和)
  • torch.cat是将两个张量(tensor)拼接在一起
  • isinstance() 函数来判断一个对象是否是一个已知的类型,类似 type()。相比于type()来说,考虑父类继承关系。
  • unsqueeze()函数 增加一个维度  squeeze()减少一个维度
posted @ 2021-06-10 15:19  sushamu  阅读(116)  评论(0编辑  收藏  举报