1.torch.cat():连接,dim=1,增加列,dim=0,增加行
ZS = torch.cat([S, Z], dim=1)
2.scipy.io:可以读取.mat文件
scpy.io.loadmat(str(att_path))
3.np.random.choice()
idx = list(np.random.choice(self.data_length, self.batch_size, replace=replace))
numpy.random.choice(a, size=None, replace=True, p=None)
从a中随机抽取数字,并组成指定大小(size)的数组
replace:True表示可以取相同数字,False表示不可以取相同数字
数组p:与数组a相对应,表示取数组a中每个元素的概率,默认为选取每个元素的概率相同。
4.参数后面有个冒号,是类型注释
def get_negative_samples(Y:list, classes):