代码如下,首先torch的自动推导机制下,如果继承于torch.autograd.Function函数下的实现,则不会为其自动保存相关tensor,通过自定义tensor的维护机制实现打包。
- 保存随机数状态
- 保存输入状态
- backward的时候,首先恢复随机数状态进行前向推导。
- 最后恢复backward的随机数状态进行反向计算。
class CheckpointFunction(torch.autograd.Function):
"""Checkpoint Function
This function is adapted from torch.utils.checkpoint with two main changes:
1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state`
2) the states in the model parallel tracker are also properly tracked/set/reset.
"""
@staticmethod
def forward(ctx, run_function, distribute_saved_activations, *args):
ctx.run_function = run_function
ctx.distribute_saved_activations = distribute_saved_activations
# Copy the rng states.
ctx.fwd_cpu_rng_state = torch.get_rng_state()
ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
with torch.no_grad():
outputs = run_function(*args)
# Divide hidden states across model parallel group and only keep
# the chunk corresponding to the current rank.
if distribute_saved_activations:
ctx.input_0_shape = args[0].data.shape
safely_set_viewless_tensor_data(
args[0], split_tensor_into_1d_equal_chunks(args[0].data, new_buffer=True)
)
# Store everything.
ctx.save_for_backward(*args)
return outputs
@staticmethod
def backward(ctx, *args):
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError(
"Checkpointing is not compatible with .grad(), "
"please use .backward() if possible"
)
inputs = ctx.saved_tensors
if ctx.distribute_saved_activations:
safely_set_viewless_tensor_data(
inputs[0], gather_split_1d_tensor(inputs[0].data).view(ctx.input_0_shape)
)
# Store the current states.
bwd_cpu_rng_state = torch.get_rng_state()
bwd_cuda_rng_state = torch.cuda.get_rng_state()
bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
# Set the states to what it used to be before the forward pass.
torch.set_rng_state(ctx.fwd_cpu_rng_state)
_set_cuda_rng_state(ctx.fwd_cuda_rng_state)
get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)
# Compute the forward pass.
detached_inputs = detach_variable(inputs)
with torch.enable_grad():
outputs = ctx.run_function(*detached_inputs)
# Set the states back to what it was at the start of this function.
torch.set_rng_state(bwd_cpu_rng_state)
_set_cuda_rng_state(bwd_cuda_rng_state)
get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
# filter out non tensor outputs for backward pass
outputs, args = zip(*filter(lambda x: torch.is_tensor(x[0]), zip(outputs, args)))
torch.autograd.backward(outputs, args)
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs)
return (None, None) + grads