import numpy as np
import torch
from torch.autograd import grad
def matmul(x, w):
y = torch.pow(x, 3) * w
return y
def net_y_xxx(x, w): # first-order derivative of u
y = matmul(x, w) # 即使不增加require_grad_(True),也是可以求导
y_x = grad(y.sum(), x, create_graph=True)[0]
# 为了求解高阶导数,必须使用create_graph,不然y_x的requires_grad属性为False,并且计算完这个导数以后就会释放上一层的计算图
y_xx = grad(y_x.sum(), x, create_graph=True)[0]
y_xxx = grad(y_xx.sum(), x)[0]
y.sum().backward()
y_x.sum().backward()
# y_xx.sum().backward() # 由于计算y_xxx时候就释放了计算图,所以会报错
return y, y_x, y_xx, y_xxx
if __name__ == '__main__':
x = np.array([[1.0],
[2.0],
[3.0]])
w = np.array([[1.0],
[2.0],
[3.0]])
# requires_grad_(True)
x = torch.from_numpy(x).requires_grad_(True)
# x = torch.from_numpy(x)
# w = torch.from_numpy(w).requires_grad_(True)
w = torch.from_numpy(w)
y, y_x, y_xx, y_xxx = net_y_xxx(x, w)
求y对x的三阶导,一共生成了3张计算图!!!