pytorch如何去定义新的自动求导函数

本文介绍了如何在PyTorch中自定义自动求导函数,通过实现`torch.autograd.Function`并覆盖`forward`和`backward`方法。此外,还探讨了自动求导的概念,并给出了保留计算图以进行多次反向传播的例子,结合逻辑回归的应用进行了说明。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

pytorch定义新的自动求导函数

在pytorch中想自定义求导函数,通过实现torch.autograd.Function并重写forward和backward函数,来定义自己的自动求导运算。参考官网上的demo:传送门

直接上代码,定义一个ReLu来实现自动求导
 

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

import torch

class MyRelu(torch.autograd.Function):

    @staticmethod

    def forward(ctx, input):

        # 我们使用ctx上下文对象来缓存,以便在反向传播中使用,ctx存储时候只能存tensor

        # 在正向传播中,我们接收一个上下文对象ctx和一个包含输入的张量input;

        # 我们必须返回一个包含输出的张量,

        # input.clamp(min = 0)表示讲输入中所有值范围规定到0到正无穷,如input=[-1,-2,3]则被转换成input=[0,0,3]

        ctx.save_for_backward(input)

         

        # 返回几个值,backward接受参数则包含ctx和这几个值

        return input.clamp(min = 0)

    @staticmethod

    def backward(ctx, grad_output):

        # 把ctx中存储的input张量读取出来

        input, = ctx.saved_tensors

         

        # grad_output存放反向传播过程中的梯度

        grad_input = grad_output.clone()

         

        # 这儿就是ReLu的规则,表示原始数据小于0,则relu为0,因此对应索引的梯度都置为0

        grad_input[input < 0] = 0

        return grad_input

 

进行输入数据并测试

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

dtype = torch.float

device = torch.device('cuda' if torch.cuda.is_availabl

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值