Pytorch的backward()相关理解

最近一直在用pytorch做GAN相关的实验,pytorch 框架灵活易用,很适合学术界开展研究工作。
这两天遇到了一些模型参数寻优的问题,才发现自己对pytorch的自动求导和寻优功能没有深刻理解,导致无法灵活的进行实验。于是查阅资料,同时自己做了一点小实验,做了一些总结,虽然好像都是一些显而易见的结论,但是如果不能清晰的理解,对于实验复杂的网络模型程序会造成困扰。以下仅针对pytorch 0.2 版本,如有错误,希望得到指正。(17年写的博文,小作修改)

  • 相关标志位/函数

    • 1
      • requires_grad
      • volatile
      • detach()/detach_()
    • 2
      • retain_graph
      • retain_variables
      • create_graph

    前三个标志位中,最关键的就是 requires_grad,另外两个都可以转化为 requires_grad 来理解。
    后三个标志位,与计算图的保持与建立有关系。其中 retain_variables 与 retain_graph等价,retain_variables 会在pytorch 新版本中被取消掉。

  • requires_grad 的含义及标志位说明

    • 如果对于某Variable 变量 x ,其 x.requires_grad == True , 则表示 它可以参与求导,也可以从它向后求导。
      默认情况下,一个新的Variables 的 requires_grad 和 volatile 都等于 False 。

    • requires_grad == True 具有传递性,如果:
      x.requires_grad == Truey.requires_grad == Falsez=f(x,y)
      则, z.requires_grad == True

    • 凡是参与运算的变量(包括 输入量,中间输出量,输出量,网络权重参数等),都可以设置 requires_grad 。

    • volatile==True 就等价于 requires_grad==Falsevolatile==True 同样具有传递性。一般只用在inference过程中。若是某个过程,从 x 开始 都只需做预测,不需反传梯度的话,那么只需设置x.volatile=True ,那么 x 以后的运算过程的输出均为 volatile==True ,即 requires_grad==False
      虽然inference 过程不必backward(),所以requires_grad 的值为False 或 True,对结果是没有影响的,但是对程序的运算效率有直接影响;所以使用volatile=True ,就不必把运算过程中所有参数都手动设一遍requires_grad=False 了,方便快捷。

    • detach() ,如果 x 为中间输出,x' = x.detach 表示创建一个与 x 相同,但requires_grad==False 的variable, (实际上是把x’ 以前的计算图 grad_fn 都消除了),x’ 也就成了叶节点。原先反向传播时,回传到x时还会继续,而现在回到x’处后,就结束了,不继续回传求到了。另外值得注意, x (variable类型) 和 x’ (variable类型)都指向同一个Tensor ,即 x.data
      detach_() 表示不创建新变量,而是直接修改 x 本身。

    • retain_graph ,每次 backward() 时,默认会把整个计算图free掉。一般情况下是每次迭代,只需一次 forward() 和一次 backward() ,前向运算forward() 和反向传播backward()是成对存在的,一般一次backward()也是够用的。但是不排除,由于自定义loss等的复杂性,需要一次forward(),多个不同loss的backward()来累积同一个网络的grad,来更新参数。于是,若在当前backward()后,不执行forward() 而可以执行另一个backward(),需要在当前backward()时,指定保留计算图,即backward(retain_graph)。

    • create_graph ,这个标志位暂时还未深刻理解,等之后再更新。

  • 反向求导 和 权重更新

    • 求导和优化(权重更新)是两个独立的过程,只不过优化时一定需要对应的已求取的梯度值。所以求得梯度值很关键,而且,经常会累积多种loss对某网络参数造成的梯度,一并更新网络。

    • 反向传播过程中,肯定需要整个过程都链式求导。虽然中间参数参与求导,但是却可以不用于更新该处的网络参数。参数更新可以只更新想要更新的网络的参数。

    • 如果obj是函数运算结果,且是标量,则 obj.backward() (注意,backward()函数中没有填入

评论 17
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值