```
def _load_optimizer_state(self):
main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
opt_checkpoint = bf.join(
bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt"
)
if bf.exists(opt_checkpoint):
logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}")
state_dict = dist_util.load_state_dict(
opt_checkpoint, map_location=dist_util.dev()
)
self.opt.load_state_dict(state_dict)
def _setup_fp16(self):
self.master_params = make_master_params(self.model_params)
self.model.convert_to_fp16()
def run_loop(self):
lpip_loss = lpips.LPIPS(net="alex").to(dist_util.dev()) ##############
ssim_loss = SSIM(win_size=7, win_sigma=1.5, data_range=1, size_average=False, channel=1)
if parallel:
radon_288_736 = para_prepare_parallel(2.5)
radon_144_736 = para_prepare_parallel(4.5)
radon_72_736 = para_prepare_parallel(8.5)
radon_36_736 = para_prepare_parallel(16.5)
else:
radon_288_736 = para_prepare(2.5)
radon_36_736 = para_prepare(16.5)
helper = {"fbp_para_288_736": radon_288_736, "fbp_para_36_736": radon_36_736, "fbp_para_72_736": radon_72_736, "fbp_para_144_736": radon_144_736} ######################### "fbp_para_36_512": radon_36_51
while (
not self.lr_anneal_steps
or self.step + self.resume_step < self.lr_anneal_steps
):
batch, cond = next(self.data)
timestep = np.random.randint(low=3, high=None, size=None, dtype='l')
t = th.tensor([timestep,timestep]).to("cuda")
if timestep == 2:
cond["x_t"] = F.interpolate(F.interpolate(batch, (36, 736), mode="nearest"), (288, 736), mode="nearest")
elif timestep == 1:
cond["x_t"] = F.interpolate(F.interpolate(batch, (72, 736), mode="nearest"), (288, 736), mode="nearest")
elif timestep == 0:
cond["x_t"] = F.interpolate(F.interpolate(batch, (144, 736), mode="nearest"), (288, 736), mode="nearest")
model_output = self.run_step(batch, cond, t, ssim_loss, lpip_loss, helper)
# for i in range(2,-1,-1):
# t = th.tensor([i,i]).to("cuda")
# if i == 2:
# cond["x_t"] = cond["low_res"]
# model_output = self.run_step(batch, cond, t, ssim_loss, lpip_loss, helper)
# else:
# cond["x_t"] = model_output
# model_output = self.run_step(batch, cond, t, ssim_loss, lpip_loss, helper)
if self.step % self.log_interval == 0:
logger.dumpkvs()
if self.step % self.save_interval == 0:
self.save()
# Run for a finite amount of time in integration tests.
if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0:
return
self.step += 1
# Save the last checkpoint if it wasn't already saved.
if (self.step - 1) % self.save_interval != 0:
self.save()
def run_step(self, batch, cond, t, ssim_loss=None, lpip_loss=None, helper=None):
model_output = self.forward_backward(batch, cond, t, ssim_loss, lpip_loss, helper)
if self.use_fp16:
self.optimize_fp16()
else:
self.optimize_normal()
self.log_step()
return model_output```解释代码内容