Torch常用的函数总结
矩阵运算相关:
torch.mul(a,b) 是矩阵a和b对应位相乘,a和b的维度必须相等,比如a的维度是(1, 2),b的维度是(1, 2),返回的仍是(1, 2)的矩阵
torch.mm(a,b) 是矩阵a和b的 矩阵相乘。比如a的维度是(1, 2),b的维度是(2, 3),返回的就是(1, 3)的矩阵
torch.transpose(Phi, 0, 1) 是交换一个tensor的两个维度,返回的类型也是tensor。即“torch.transpose(input, dim0, dim1) → Tensor”,需要注意的是transpose中的两个维度参数的顺序是可以交换位置的,即transpose(x, 0, 1) 和transpose(x, 1, 0)效果是相同的。
view(-1,...,) 在torch里面,view函数相当于numpy的reshape
Note:经常会看见参数为-1,这里-1表示一个不确定的数,让电脑帮我们计算。例如,一个长度的16向量xx.view(-1, 2)等价于x.view(8,2)
模型相关
1. 统计模型参数
model = Net(layer_num) print('Total number of parameters net:', sum(p.numel() for p in model.parameters()))