2021年3月13日

torch.gather()

摘要: 作用:收集输入的特定维度指定位置的数值参数:input(tensor): 待操作数。不妨设其维度为(x1, x2, …, xn)dim(int): 待操作的维度。index(LongTensor): 如何对input进行操作。其维度有限定,例如当dim=i时,index的维度为(x1, x2, …y 阅读全文

posted @ 2021-03-13 16:43 cltt 阅读(201) 评论(0) 推荐(0) 编辑

tensor 3维分块乘法

摘要: a = torch.range(1,4) a = a.reshape(2,1,2) b= torch.range(1,12) b = b.reshape(2,2,3) c = torch.bmm(a,b) print('c') print(c) print(c.shape) d = torch.ze 阅读全文

posted @ 2021-03-13 15:31 cltt 阅读(210) 评论(0) 推荐(0) 编辑

导航