Part one:读取数据
以图像压缩模型为例,输入数据为图像数据,假设图像数据集是PNG格式的图片。
image_extentions = ['.png', '.PNG', '.jpg', '.JPG']
class My_Dataset(Dataset):
def __init__(self, image_root, transform=None):
super(My_Dataset, self).__init__()
images = []
for filename in os.listdir(image_root):
if any(filename.endswith(extension) for extension in image_extentions):
images.append('{}'.format(filename))
self.root = image_root
self.images = images
self.transform = transform
def __getitem__(self, index):
filename = self.images[index]
try:
img = Image.open(os.path.join(self.root, filename)).convert('RGB')
except:
return torch.zeros((3, 256, 256))
if self.transform is not None:
img = self.transform(img)
return img
def __len__(self):
return len(self.images)
数据读取部分主要继承自torch.utils.data包的Dataset类,并实现Dataset类的__ini__(),__getitem__()与__len__三个方法即可,其中,__init__()初始化图像路径并获取图像路径下的所有符合要求的图像数据路径;__getitem__()方法主要依据图像路径依次读取图像数据并返回图像数据,__init__()方法主要返回图像数据集的长度即图像总的数量;
Part two:构建网络模型
以Unet模型为例,模型分为编码,量化与解码三个部分:
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
self.is_train = is_train
channel = [64, 128, 256, 512]
model = [nn.Conv2d(3, channel[0], 3, 2, 1), nn.BatchNorm2d(channel[0]), nn.ReLU(True)]
for i in range(len(channel) - 1):
model += [nn.Conv2d(channel[i], channel[i + 1], 3, 2, 1), nn.BatchNorm2d(channel[i + 1]), nn.ReLU(True)]
self.model = nn.Sequential(*model)
def forward(self, x):
return self.model(x)
class Quantizer(nn.Module):
def __init__(self, q_lever=13):
super(Quantizer, self).__init__()
self.q_lever = q_lever
def forward(self, x):
h = x / 4
h.data = h.data.clamp(0., 1.)
h.data = h.data * (self.q_lever - 1)
h.data = h.data - 0.5 * torch.sin(2 * np.pi * h.data) / (2 * np.pi)
h.data = torch.round(h.data)
return h
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
channel = [512, 256, 128, 64]
model = []
for i in range(len(channel) - 1):
model += [nn.ConvTranspose2d(channel[i], channel[i + 1], 4, 2, 1), nn.BatchNorm2d(channel[i + 1]),
nn.ReLU(True)]
model += [nn.ConvTranspose2d(channel[-1], 3, 4, 2, 1), nn.BatchNorm2d(3), nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, x):
return self.model(x)
注意,量化处理将编码结果映射至q_level范围内, 此处的q_level默认设置为13。
Part three:训练与测试模块
# 数据处理与加载
train_transform = transforms.Compose([transforms.RandomCrop((256, 256)), transforms.ToTensor()])
train_set = dataset.ImageFolder(root=args.train_image, transform=train_transform
train_loader = data.DataLoader(dataset=train_set, batch_size=args.batch_size, shuffle=True, num_workers=4)
print('total images: {}; total batches: {}'.format(len(train_set), len(train_loader)))
test_transform = transforms.Compose([transforms.ToTensor()])
test_set = dataset.ImageFolder(root=args.test_image, transform=test_transform
test_loader = data.DataLoader(dataset=test_set, batch_size=1, shuffle=False, num_workers=4)
print('test images: {}; test batches: {}'.format(len(test_set), len(test_loader)))
# 模型初始化
encoder = Encoder()
quantizer = Quantizer()
decoder = Decoder()
# 是否使用GPU加速训练
if torch.cuda.is_available():
devices = torch.device('cuda')
else:
devices = torch.device('cpu')
encoder.to(devices)
quantizer.to(devices)
decoder.to(devices)
# 设置模型参数的优化器
solver = optim.Adam(
[
{
'params': encoder.parameters()
},
{
'params': decoder.parameters()
},
],
lr=args.lr)
# 加载模型参数,用于重新训练或者测试阶段
def resume(epoch=None):
if epoch is None:
s = 'iter'
epoch = 0
else:
s = 'epoch'
encoder.load_state_dict(
torch.load('checkpoint/encoder_{}_{:08d}.pth'.format(s, epoch)))
decoder.load_state_dict(
torch.load('checkpoint/decoder_{}_{:08d}.pth'.format(s, epoch)))
# 保存模型参数,训练阶段
def save(index, epoch=True):
if not os.path.exists('checkpoint'):
os.mkdir('checkpoint')
if epoch:
s = 'epoch'
else:
s = 'iter'
torch.save(encoder.state_dict(), 'checkpoint/encoder_{}_{:08d}.pth'.format(
s, index))
torch.save(decoder.state_dict(), 'checkpoint/decoder_{}_{:08d}.pth'.format(
s, index))
# 模型评估(模型训练过程中)
def evaluate_model(number, encoder, quantizer, decoder, test_loader):
encoder.eval()
decoder.eval()
msssim_list = []
psnr_list = []
if number > 250:
s = 'iter'
else:
s = 'epoch'
filepath = os.path.join(Config_Test.test_result, '{}_{}'.format(s, number))
for i, test_img in enumerate(test_loader):
if not os.path.exists(filepath):
os.mkdir(filepath)
with torch.no_grad():
test_w = encoder(Variable(test_img.cuda()))
test_w_hat = quantizers(test_w)
recons_img = decoder(test_w_hat)
recons_img = torch.round(((recons_img + 1) / 2) * 255)
test_image = torch.round(test_image * 255)
msssim = criterionSSIM(test_img, recons_img)
msssim_list.append(msssim)
test_image = np.transpose(np.squeeze(test_img.data.cpu().numpy()), (1, 2, 0)).clip(0, 255).astype(np.uint8)
recons_img = np.transpose(np.squeeze(recons_img.data.cpu().numpy()), (1, 2, 0)).clip(0, 255).astype(np.uint8)
# 测试模型的同时保存测试结果
imsave(os.path.join(filepath, 'reconstructed_{}.png'.format(i)), recons_img)
psnr = compare_psnr(test_image, recons_img)
psnr_list.append(psnr)
mean_msssim = sum(msssim_list) / len(msssim_list)
mean_psnr = sum(psnr_list) / len(psnr_list)
return mean_msssim, mean_psnr
# 设置学习率的学习策略
scheduler = LS.MultiStepLR(solver, milestones=[3, 10, 20, 50, 100], gamma=0.5)
# 设置训练遍历数据集的起点
last_epoch = 0
if args.checkpoint:
resume(args.checkpoint)
last_epoch = args.checkpoint
scheduler.last_epoch = last_epoch - 1
ms = MS_SSIM().cuda()
# mse = nn.MSELoss().cuda()
# 记录保存测试的结果
file_writer = open('./model_test_on_Kodak.txt', 'w')
for epoch in range(last_epoch + 1, args.max_epochs + 1):
# learning rate strategy is set, so learning rate need to be updated before each epoch
scheduler.step()
for batch, img in enumerate(train_loader):
batch_t0 = time.time()
patches = Variable(img.to(devices))
res = patches
# gradient of parameters need to be zero clearing every batch
solver.zero_grad()
h = encoder(patches)
h = quantizer(h)
decoded = decoder(h)
# MSE Loss
# mse_loss = mse(res, decoded)
# msssim loss
patches.data = patches.data.clamp(0, 1) * 255.0
decoded.data = decoded.data.clamp(0, 1) * 255.0
msssim_loss = 1 - (ms.ms_ssim(patches, decoded)).mean()
# L1 loss function
# loss = (res - decoded).abs().mean()
# L1_loss = (res - decoded).abs().mean()
# joint optimization for lossy model and lossless model
# loss = lossless_loss + msssim_loss
# loss = msssim_loss * 0.5 + mse_loss * 5
loss = msssim_loss
# error back propagation
loss.backward()
# updating parameters every batch
solver.step()
batch_t1 = time.time()
print(
'[TRAIN] Epoch[{}]({}/{}); Loss: {:.6f}; Batch: {:.4f} sec'.
format(epoch, batch + 1, len(train_loader), loss.data, batch_t1 - batch_t0)
)
index = (epoch - 1) * len(train_loader) + batch
# save checkpoint every 500 training steps
# if index % 500 == 0:
# save(0, False)
if index % 20000 == 0:
save(index)
iter_msssim, iter_psnr = evaluate_model(index, encoder, quantizer, decoder, test_loader)
# 保存测试结果
file_writer.write('iter {} model test on Kadak: MSSSIM:{:.6f}, PSNR:{:.6f}\n'.format(index, iter_msssim, iter_psnr))
file_writer.flush()
print('iter {} model test on Kadak: MSSSIM:{:.6f}, PSNR:{:.6f}'.format(index, iter_msssim, iter_psnr))
# save the model every epoch
save(epoch)
epoch_msssim, epoch_psnr = evaluate_model(epoch, encoder, quantizer, decoder, test_loader)
file.write('epoch {} model test on Kadak: MSSSIM:{:.6f}, PSNR:{:.6f}\n'.format(epoch, epoch_msssim, epoch_psnr))
file.flush()
file_metrics.write('epoch {} model test on Kadak: MSSSIM:{:.6f}, PSNR:{:.6f}\n'.format(epoch, epoch_msssim, epoch_psnr))
file_metrics.flush()
print('epoch {} model test on Kadak: MSSSIM:{:.6f}, PSNR:{:.6f}'.format(epoch, epoch_msssim, epoch_psnr))
file_writer.close()