pytorch中遇见的那些函数

学习的资料主要包括《Pytorch深度学习实战》和深度学习爱好者 公众号后台回复:Pytorch常用函数手册下载的资料。

torch.no_grad()

参考链接:

  1. https://blog.csdn.net/jiang_huixin/article/details/112669244
  2. https://blog.csdn.net/weixin_41658139/article/details/126199040
    该方法一般用于神经网路的推理阶段,表示张量的计算过程中无需计算梯度。
  3. https://zhuanlan.zhihu.com/p/509649964
    参考链接3说的太清楚了,等我对pytorch有更深入的了解后再进行解读。
    目前的理解只能达到:为了方式更新系数太多,该方法可保证每次更新完就归零,不会积累。

call()和forward

参考链接:https://blog.csdn.net/qq_43745026/article/details/125537774

很显然,开头是双下划线,结尾也是双下划线,必然是魔法方法(上一次学习魔法方法还是在看__init__)。一般在特殊情况下自动调用。
这篇参考链接里的东西很是熟悉,抽个一半天再复习一遍! https://www.jb51.net/article/214097.htm

简言之,一般使用的魔法方法有__init__()来得到方法参数并进行赋值,在根据类生成对象时自动调用。call()方法则是让python中的类的实例能够像方法一样被调用,具备加括号的功能。

class A():
  def __init__(self):
    print('init函数')

  def __call__(self, param):
    print('call函数', param)

a = A()
a(1)

运行结果输出结果是init函数和call 函数 1。
call()中可以调用其它函数,如forward函数

class A():
  def __init__(self):
    print('init函数')

  def __call__(self, param):
    print('call函数', param)
    res = self.forward(param)
    return res+2

  def forward(self, input_):
    print('forward 函数', input_)
    return input_
a = A()
b = a(1)
print('结果b=', b)

运行结果是输出init函数、call函数1、forward函数1 和 结果b=3。
call方法成功调用了forward(),且返回值给了b。

nn.linear()

该方法用于构建网络全连接层。
三个参数包括 in_features,out_features,bias=True。具体的示例见参考链接。
参考链接:https://blog.csdn.net/zhaohongfei_358/article/details/122797190

getarr()函数

get attribute 获得对应属性值
参考链接:https://www.runoob.com/python/python-func-getattr.html

log_softmax()函数

看名字就知道是在进行softmax归一化
公式:$ - \log \frac{{\exp ({p_j})}}{{\sum\limits_i {\exp ({p_i})} }}$
参考链接:https://blog.csdn.net/liu16659/article/details/122018549
上述链接探究了自己计算这个公式和直接调用Pytorch内置函数的区别,在正常情况下是没有区别的。
但是,如果遇到特殊情况如外溢等情况,常规算法会出现问题。内置函数考虑了多种情况,对公式进行了变换。
所以,能使用已经封装好的函数就使用封装好的函数,可靠性更高。

gather()函数

调用该方法时,可以找到原对象中对应索引的对应位置的值
参考链接:https://blog.csdn.net/weixin_42899627/article/details/122816250

super()函数

先调用父类的对应方法,再调用子类的方法
参考链接:https://www.runoob.com/python/python-func-super.html

神经网络的保存

参考链接:1. https://blog.csdn.net/u013075024/article/details/106250264
2. https://blog.csdn.net/baishuiniyaonulia/article/details/100039845
一种是全部保存(结构和参数) 一种是只保存参数(结构需要自己构建)

cat()函数

参考链接:https://blog.csdn.net/xinjieyuan/article/details/105208352#SnippetTab
和Python内置的cat函数很像,起到连接作用。dim参数代表连接方式。

logsumexp()函数

参考链接:https://blog.csdn.net/liu16659/article/details/115217082
顾名思义 将传入的tensor作为exp指数,然后求和,最后取对数。选择的维度需要根据函数的dim参数来判断。
dim表示的是张量的维度,dim可取的值从0开始,从外边往里边剥来看。
当指定参数为dim=0的时候,可看做要去观察行数变化的情况所以其他维度不变,这一维度降低为1。
参考链接:https://blog.csdn.net/qq_41375609/article/details/106078474

.view(-1,1)函数

参考链接:https://blog.csdn.net/cf_jack/article/details/103092006
view主要是来改变张量形状,负数代表不确定。所以 view(-1,1)表示改为行数不确定列数为1的张量。

posted @ 2022-10-31 10:15  芋圆院长  阅读(13)  评论(0编辑  收藏  举报