实现自定义的OP后,难以保证代码正确性,尤其是backward(). mxnet.text_utils中提供两个函数可以对backward()做校验 (在mxnet安装目录下搜索text_utils.py文件,源码比文档详细)
check_symbolic_backward()
这个函数要求输入symbol OP,输入参数以及期望的梯度值, 前面两个容易取得,但期望的梯度值却不是很容易。“鸡生蛋,蛋生鸡”的问题?
check_numeric_gradient()
这个函数只要求输入symbol OP和输入参数,和容许误差,如果误差大于容许范围,会报错,否则零打印。
from mxnet.test_utils import *
class MyOP(nn.Block):
def __init__(self,in_channels, out_channels,
origin_image_size, #size of the network input
batchSize,
kernel_size=3, **kwargs):
return
def verify(self):
#x = mx.nd.array( np.random.random((self.batchSize, 1, 7,7)) )
#w = mx.nd.array( np.random.random((2,1,3,3)) )
x = mx.nd.array( np.ones((self.batchSize, 1, 3,3)) )
w = mx.nd.array( 2*np.ones((2,1,3,3)) )
y = mx.sym.Custom(mx.sym.Variable('input'), mx.sym.Variable('weight'),op_type="myop")
check_numeric_gradient(y,(x,w))
net = Proto2DBlock(1, 2, 12, 1)
net.initialize()
net.verify()
如上MyOP封装了一个custom OP ”myop“,利用mx.sym.Custom()生成op的symbol,输入
x/w是ndarray,不是symbol
- Note:实验最初总是校验不通过,最后发现是myop的backward中对权重做了L2约束,而在forward是无法体现,导致校验总是失败。代码正确,又保证backward中只有forward的逆操作,即可校验通过