pytorch中遇见的那些函数
学习的资料主要包括《Pytorch深度学习实战》和深度学习爱好者 公众号后台回复:Pytorch常用函数手册下载的资料。
torch.no_grad()
参考链接:
- https://blog.csdn.net/jiang_huixin/article/details/112669244
- https://blog.csdn.net/weixin_41658139/article/details/126199040
该方法一般用于神经网路的推理阶段,表示张量的计算过程中无需计算梯度。 - https://zhuanlan.zhihu.com/p/509649964
参考链接3说的太清楚了,等我对pytorch有更深入的了解后再进行解读。
目前的理解只能达到:为了方式更新系数太多,该方法可保证每次更新完就归零,不会积累。
call()和forward
参考链接:https://blog.csdn.net/qq_43745026/article/details/125537774
很显然,开头是双下划线,结尾也是双下划线,必然是魔法方法(上一次学习魔法方法还是在看__init__)。一般在特殊情况下自动调用。
这篇参考链接里的东西很是熟悉,抽个一半天再复习一遍! https://www.jb51.net/article/214097.htm
简言之,一般使用的魔法方法有__init__()来得到方法参数并进行赋值,在根据类生成对象时自动调用。call()方法则是让python中的类的实例能够像方法一样被调用,具备加括号的功能。
class A():
def __init__(self):
print('init函数')
def __call__(self, param):
print('call函数', param)
a = A()
a(1)
运行结果输出结果是init函数和call 函数 1。
call()中可以调用其它函数,如forward函数
class A():
def __init__(self):
print('init函数')
def __call__(self, param):
print('call函数', param)
res = self.forward(param)
return res+2
def forward(self, input_):
print('forward 函数', input_)
return input_
a = A()
b = a(1)
print('结果b=', b)
运行结果是输出init函数、call函数1、forward函数1 和 结果b=3。
call方法成功调用了forward(),且返回值给了b。
nn.linear()
该方法用于构建网络全连接层。
三个参数包括 in_features,out_features,bias=True。具体的示例见参考链接。
参考链接:https://blog.csdn.net/zhaohongfei_358/article/details/122797190
getarr()函数
get attribute 获得对应属性值
参考链接:https://www.runoob.com/python/python-func-getattr.html
log_softmax()函数
看名字就知道是在进行softmax归一化
公式:$ - \log \frac{{\exp ({p_j})}}{{\sum\limits_i {\exp ({p_i})} }}$
参考链接:https://blog.csdn.net/liu16659/article/details/122018549
上述链接探究了自己计算这个公式和直接调用Pytorch内置函数的区别,在正常情况下是没有区别的。
但是,如果遇到特殊情况如外溢等情况,常规算法会出现问题。内置函数考虑了多种情况,对公式进行了变换。
所以,能使用已经封装好的函数就使用封装好的函数,可靠性更高。
gather()函数
调用该方法时,可以找到原对象中对应索引的对应位置的值
参考链接:https://blog.csdn.net/weixin_42899627/article/details/122816250
super()函数
先调用父类的对应方法,再调用子类的方法
参考链接:https://www.runoob.com/python/python-func-super.html
神经网络的保存
参考链接:1. https://blog.csdn.net/u013075024/article/details/106250264
2. https://blog.csdn.net/baishuiniyaonulia/article/details/100039845
一种是全部保存(结构和参数) 一种是只保存参数(结构需要自己构建)
cat()函数
参考链接:https://blog.csdn.net/xinjieyuan/article/details/105208352#SnippetTab
和Python内置的cat函数很像,起到连接作用。dim参数代表连接方式。
logsumexp()函数
参考链接:https://blog.csdn.net/liu16659/article/details/115217082
顾名思义 将传入的tensor作为exp指数,然后求和,最后取对数。选择的维度需要根据函数的dim参数来判断。
dim表示的是张量的维度,dim可取的值从0开始,从外边往里边剥来看。
当指定参数为dim=0的时候,可看做要去观察行数变化的情况所以其他维度不变,这一维度降低为1。
参考链接:https://blog.csdn.net/qq_41375609/article/details/106078474
.view(-1,1)函数
参考链接:https://blog.csdn.net/cf_jack/article/details/103092006
view主要是来改变张量形状,负数代表不确定。所以 view(-1,1)表示改为行数不确定列数为1的张量。