Torch常用的函数总结

矩阵运算相关:

torch.mul(a,b)  是矩阵a和b对应位相乘,a和b的维度必须相等,比如a的维度是(1, 2),b的维度是(1, 2),返回的仍是(1, 2)的矩阵

 

torch.mm(a,b)  是矩阵a和b的 矩阵相乘。比如a的维度是(1, 2),b的维度是(2, 3),返回的就是(1, 3)的矩阵

 

torch.transpose(Phi, 0, 1)  是交换一个tensor的两个维度,返回的类型也是tensor。即“torch.transpose(input, dim0, dim1) → Tensor”,需要注意的是transpose中的两个维度参数的顺序是可以交换位置的,即transpose(x, 0, 1) 和transpose(x, 1, 0)效果是相同的。

 

 

view(-1,...,) 在torch里面,view函数相当于numpy的reshape

Note:经常会看见参数为-1,这里-1表示一个不确定的数,让电脑帮我们计算。例如,一个长度的16向量xx.view(-1, 2)等价于x.view(8,2)

 

模型相关

1. 统计模型参数

model = Net(layer_num)
print('Total number of parameters net:',
      sum(p.numel() for p in model.parameters()))

 

posted @ 2020-10-20 21:51  不学无墅_NKer  阅读(421)  评论(0编辑  收藏  举报