mxnet中校验backward

在实现MXNet中的自定义OP后,校验backward的正确性至关重要。本文介绍了如何利用`check_symbolic_backward()`和`check_numeric_gradient()`两个函数进行校验。`check_symbolic_backward()`需要输入符号OP、参数和预期梯度,而`check_numeric_gradient()`仅需输入符号OP、参数和允许误差。当在自定义OP的backward中出现额外操作,如L2约束,可能导致校验失败。确保backward仅包含forward的逆运算可使校验成功。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

实现自定义的OP后,难以保证代码正确性,尤其是backward(). mxnet.text_utils中提供两个函数可以对backward()做校验 (在mxnet安装目录下搜索text_utils.py文件,源码比文档详细)

check_symbolic_backward()

这个函数要求输入symbol OP,输入参数以及期望的梯度值, 前面两个容易取得,但期望的梯度值却不是很容易。“鸡生蛋,蛋生鸡”的问题?

check_numeric_gradient()

这个函数只要求输入symbol OP和输入参数,和容许误差,如果误差大于容许范围,会报错,否则零打印。

from mxnet.test_utils import *
class MyOP(nn.Block):
    def __init__(self,in_channels, out_channels,
                 origin_image_size, #size of the network input
                 batchSize,
                 kernel_size=3, **kwargs):
        return
    def verify(self):        

        #x = mx.nd.array( np.random.random((self.batchSize, 1, 7,7)) )
        #w = mx.nd.array( np.random.random((2,1,3,3)) )
        x = mx.nd.array( np.ones((self.batchSize, 1, 3,3)) )
        w = mx.nd.array( 2*np.ones((2,1,3,3)) )        
        y = mx.sym.Custom(mx.sym.Variable('input'), mx.sym.Variable('weight'),op_type="myop")
        check_numeric_gradient(y,(x,w))

net = Proto2DBlock(1, 2, 12, 1)
net.initialize()
net.verify()   

如上MyOP封装了一个custom OP ”myop“,利用mx.sym.Custom()生成op的symbol,输入
x/w是ndarray,不是symbol

  • Note:实验最初总是校验不通过,最后发现是myop的backward中对权重做了L2约束,而在forward是无法体现,导致校验总是失败。代码正确,又保证backward中只有forward的逆操作,即可校验通过
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值