一、项目背景
HiDDen的jpeg()的训练,要求的torch版本是1.0,我的是1.11,考虑过降低版本号,但如果要降版本的话还要改python版本、cuda版本,而且刚开始复现代码,希望能提高自己读代码和纠正错误的能力,所以决定就在这个基础上直接改。
def train_on_batch(self, batch: list):
"""
Trains the network on a single batch consisting of images and messages
:param batch: batch of training data, in the form [images, messages]
:return: dictionary of error metrics from Encoder, Decoder, and Discriminator on the current batch
"""
images, messages = batch
batch_size = images.shape[0]
self.encoder_decoder.train()
self.discriminator.train()
with (torch.enable_grad()):
# ---------------- Train the discriminator -----------------------------
self.optimizer_discrim.zero_grad()
# train on cover
d_target_label_cover = torch.full((batch_size, 1), self.cover_label, device=self.device).float()
d_target_label_encoded = torch.full((batch_size, 1), self.encoded_label, device=self.device).float()
g_target_label_encoded = torch.full((batch_size, 1), self.cover_label, device=self.device).float()
d_on_cover