论文讲解见上篇博客
这篇论文的标题是《Unpaired Unsupervised CT Metal Artifact Reduction》,作者是Bo-Yuan Chen和Chu-Song Chen。这篇论文主要研究了如何使用深度学习技术来减少医学成像中由于金属植入物引起的CT图像伪影。
项目给出了几个不同的unet网络的实验,以pytorch_Net.py举例
train
1、参数如下
batch_size = 8
num_epoch = 25
lr = 2e-5
channels = 3
img_size = 320
lmda_g = 0.05
lmda_dnn = 0.1
input_shape = (channels, img_size, img_size)
居然是3通道的,大家要用记者修改
2、获得患者信息
train_patient_info_noise, train_patient_info_clear, train_noise_num, train_clear_num = get_patient_info(CT_dir, OMA_dir, patients_id_list_train, semi=True)
test_patient_info_noise, test_patient_info_clear, test_noise_num, test_clear_num = get_patient_info(CT_dir, OMA_dir, patients_id_list_test, semi=True)
def get_patient_info(root, patients_id_list):
patient_info_clear = list()
patient_info_clear = pd.DataFrame(patient_info_clear, columns = ['name', 'path', 'class']) # clear : 0
patient_info_noise = list()
patient_info_noise = pd.DataFrame(patient_info_noise, columns = ['name', 'path', 'class']) # noise : 1
noise_num = 0
clear_num = 0
for i, patient_id in enumerate(patients_id_list):
patient_id_path = os.path.join(root, patient_id)
f = open(os.path.join(patient_id_path, 'MA_slice_num.txt'))
noisy_patients_No = list()
for line in f.read().splitlines():
noisy_patients_No.append(line)
for item in os.listdir(patient_id_path):
if ('.jpg' in item and item.split('_')[0] in noisy_patients_No):
patient_info_noise = patient_info_noise.append({'name':item,'path': patient_id_path, 'class': 1}, ignore_index = True)
noise_num += 1
elif ('.jpg' in item and item.split('_')[0] not in noisy_patients_No):
patient_info_clear = patient_info_clear.append({'name':item,'path': patient_id_path, 'class': 0}, ignore_index = True)
clear_num += 1
return patient_info_noise, patient_info_clear, noise_num, clear_num
包括CT是否是干净的,CT名,CT路径等
3、根据id划分训练、测试集
test_transform = transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
train_set_noise1 = CTImg(transform = train_transform, patient_info = train_patient_info_noise,CT_dir=CT_dir,OMA_dir=OMA_dir,Mask_dir=Mask_dir)
train_set_noise = ConcatDataset([train_set_noise1, train_set_noise1, train_set_noise1, train_set_noise1])
train_set_noise = ConcatDataset([train_set_noise,train_set_noise])
train_set_clear = CTImg(transform = train_transform, patient_info = train_patient_info_clear,CT_dir=CT_dir,OMA_dir=OMA_dir,Mask_dir=Mask_dir)
test_set_noise = CTImg(transform = test_transform, patient_info = test_patient_info_noise,CT_dir=CT_dir,OMA_dir=OMA_dir,Mask_dir=Mask_dir)
test_set_clear = CTImg(transform = test_transform, patient_info = test_patient_info_clear,CT_dir=CT_dir,OMA_dir=OMA_dir,Mask_dir=Mask_dir)
train_noise_loader = DataLoader(train_set_noise, batch_size = batch_size, shuffle=True)
train_clear_loader = DataLoader(train_set_clear, batch_size = batch_size, shuffle=True)
test_noise_loader = DataLoader(test_set_noise, batch_size = batch_size, shuffle=False)
test_clear_loader = DataLoader(test_set_clear, batch_size = batch_size, shuffle=False)
有CT也有noise 的数据
4、加载损失函数
g_loss = torch.nn.BCEWithLogitsLoss()
g_r_loss = torch.nn.MSELoss()
d_loss = torch.nn.BCEWithLogitsLoss()
dnn_loss = torch.nn.MSELoss()
dnn_r_loss = torch.nn.MSELoss()
5、两个生成器一个鉴别器
Gen = Generator(input_shape)
Dis = Discriminator(input_shape)
Dnn = Denoiser_UNet(input_shape)
6、放入cuda,初始化权重、优化函数
if cuda:
Gen = Gen.cuda()
Dis = Dis.cuda()
Dnn = Dnn.cuda()
g_loss.cuda()
d_loss.cuda()
dnn_loss.cuda()
# Initialize weights
Gen.apply(weights_init_normal)
Dis.apply(weights_init_normal)
Dnn.apply(weights_init_normal)
# Optimizers
optimizer_Gen = torch.optim.Adam(Gen.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_Dis = torch.optim.Adam(Dis.parameters(), lr=lr/2, betas=(0.5, 0.999))
optimizer_Dnn = torch.optim.Adam(Dnn.parameters(), lr=lr, betas=(0.5, 0.999))
# Input tensor type
Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor
fix_batch_sample_z = Tensor(get_random_sample(([batch_size] + list(input_shape)), method = 'uniform'))
7、开始训练,训练鉴别器,先生成个噪音g_noise,然后再与干净数据结合,提取特征DIS,计算损失real_loss、fake_loss,返回梯度。
""" Train D """
optimizer_Dis.zero_grad()
batch_sample_z = Tensor(get_random_sample(([len(clear_img)] + list(input_shape)), method = 'uniform'))
g_noise = Gen(torch.cat((Variable(batch_sample_z).cuda(),Variable(clear_img).cuda()), 1))
g_img = g_noise + Variable(clear_img).cuda()
noisy_real = diff(Variable(noise_img).cuda())
noisy_fake = diff(g_img)
#if i ==0:
# print(f"shape of noisy_real: {noisy_real.shape}, shape of noisy_fake: {noisy_fake.shape}")
real_logit = Dis(noisy_real.detach())
fake_logit = Dis(noisy_fake.detach())
real_label = Variable(noise_cls.float().cuda()) #1
fake_label = Variable(clear_cls.float().cuda()) #0
real_loss = d_loss(real_logit, real_label)
fake_loss = d_loss(fake_logit, fake_label)
loss_D = (real_loss + fake_loss) / 2
loss_D.backward()
optimizer_Dis.step()
训练生成器,
optimizer_Gen.zero_grad()
optimizer_Dnn.zero_grad()
batch_sample_z = Tensor(get_random_sample(([len(clear_img)] + list(input_shape)), method = 'uniform'))
g_noise = Gen(torch.cat((Variable(batch_sample_z).cuda(),Variable(clear_img).cuda()), 1))
# semi-part
loss_g_r, loss_dnn_r = 0, 0
spl = 0
for li, (ni,s,nl) in enumerate(zip(noise_img, supervised, noise_label)):
b_s_z = Tensor(get_random_sample(([1] + list(input_shape)), method = 'uniform'))
if s:
spl += 1
g_n_GT = Gen(torch.cat((Variable(b_s_z).cuda(),Variable(nl[None]).cuda()), 1))
loss_g_r += g_r_loss(g_n_GT, Variable(ni[None]).cuda() -Variable(nl)[None].cuda())
dnn_p_GT = Dnn(g_n_GT.detach())
loss_dnn_r = dnn_r_loss(dnn_p_GT, Variable(ni[None]).cuda() -Variable(nl[None]).cuda())
if spl != 0:
loss_g_r /= spl
loss_dnn_r /= spl
g_img = g_noise + Variable(clear_img).cuda()
noisy_fake = diff(g_img)
fake_logit = Dis(noisy_fake)
loss_G = g_loss(fake_logit, torch.ones((len(clear_img))).cuda()) + lmda_g * loss_g_r
loss_G.backward()
optimizer_Gen.step()
dnn_pred = Dnn(g_noise.detach())
out = g_img.detach() - dnn_pred
loss_Dnn = dnn_loss(out,Variable(clear_img).cuda()) + lmda_dnn * loss_dnn_r
loss_Dnn.backward()
optimizer_Dnn.step()
8、验证+保存
with torch.no_grad():
psnr = PSNR()
mae = MAE()
N_GT_psnr, DN_GT_psnr, N_GT_mae, DN_GT_mae, N_GT_ssim, DN_GT_ssim = 0, 0, 0, 0, 0, 0
for i, ((noise_img, _,_,noise_label,_), (clear_img,_,_,clear_label,_)) in enumerate(zip(test_noise_loader, test_clear_loader)):
'''Gen'''
g_noise = Gen(torch.cat((Variable(fix_batch_sample_z).cuda(),Variable(clear_img).cuda()), 1))
g_img = g_noise + Variable(clear_img).cuda()
'''Dnn'''
dnn_pred = Dnn(Variable(noise_img).cuda())
out = Variable(noise_img).cuda() - dnn_pred
batch_len = len(out)
for (noise,label) in zip(Variable(noise_img).cuda(),Variable(noise_label).cuda()):
N_GT_psnr += psnr(noise, label)/batch_len
#N_GT_ssim += compare_ssim(noise,label)/batch_len
N_GT_mae += mae(noise,label)/batch_len
for (denoise,label) in zip(out,Variable(noise_label).cuda()):
DN_GT_psnr += psnr(clp(denoise), label)/batch_len
#DN_GT_ssim += compare_ssim(denoise,label)/batch_len
DN_GT_mae += mae(clp(denoise), label)/batch_len
if i == 0:
fig = plt.figure(figsize=[8*6,8*4])
axes = [fig.add_subplot(6, 1, r+1 ) for r in range(0, 6)]
for ax in axes:
ax.axis('off')
plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator())
plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)
plt.margins(0,0)
axes[0].imshow(torchvision.utils.make_grid(clear_img.cpu(), nrow=8).permute(1, 2, 0))
#torchvision.utils.save_image(clear_img.cpu(), './samples/origin_clear_ep{:02d}-{:04d}.png'.format(epoch, i))
axes[1].imshow(torchvision.utils.make_grid(g_noise.cpu(), nrow=8).permute(1, 2, 0))
#torchvision.utils.save_image(g_noise.cpu(), './samples/gen_noise_ep{:02d}-{:04d}.png'.format(epoch, i))
axes[2].imshow(torchvision.utils.make_grid(g_img.cpu(), nrow=8).permute(1, 2, 0))
#torchvision.utils.save_image(g_img.cpu(), './samples/gen_img_ep{:02d}-{:04d}.png'.format(epoch, i))
axes[3].imshow(torchvision.utils.make_grid(noise_img.cpu(), nrow=8).permute(1, 2, 0))
#torchvision.utils.save_image(noise_img.cpu(), './samples/origin_noise_ep{:02d}-{:04d}.png'.format(epoch, i))
axes[4].imshow(torchvision.utils.make_grid(dnn_pred.cpu(), nrow=8).permute(1, 2, 0))
#torchvision.utils.save_image(dnn_pred.cpu(), './samples/dnn_noise_ep{:02d}-{:04d}.png'.format(epoch, i))
axes[5].imshow(torchvision.utils.make_grid(out.cpu(), nrow=8).permute(1, 2, 0))
#torchvision.utils.save_image(out.cpu(), './samples/denoised_img_ep{:02d}-{:04d}.png'.format(epoch, i))
fig.savefig("results/SS_DNN2UNet/cv{:02d}ep{:02d}.png".format(idx+1,epoch),bbox_inches = 'tight',pad_inches = 0)
plt.close(fig)
print("saving...")
model

class Generator(nn.Module):
def __init__(self, input_shape, cat=True):
super(Generator, self).__init__()
channels, _, _ = input_shape
if cat:
channels*=2
self.down1 = G_Down(channels, 32, normalize=False)
self.down2 = G_Down(32, 32)
self.down3 = G_Down(32, 64, pooling=True, dropout=0.5)
self.down4 = G_Down(64, 64)
self.down5 = G_Down(64, 128, pooling=True, dropout=0.5)
self.down6 = G_Down(128, 128, normalize=False)
self.up1 = G_Up(256, 64, uppooling=True, dropout=0.5)
self.up2 = G_Up(64, 64)
self.up3 = G_Up(128, 32, uppooling=True, dropout=0.5)
self.up4 = G_Up(32, 32)
self.up5 = G_Up(32, 3)
self.final = nn.Sequential(
nn.Conv2d(3, 3, kernel_size = 3,stride=1, padding=1),
nn.Tanh()
)
def forward(self, x): #[batchsize, 6, 64, 64]
# U-Net generator with skip connections from encoder to decoder
d1 = self.down1(x) #[batchsize, 32, 64, 64]
d2 = self.down2(d1) #[batchsize, 32, 64, 64]
d3 = self.down3(d2) #[batchsize, 64, 32, 32]
d4 = self.down4(d3) #[batchsize, 64, 32, 32]
d5 = self.down5(d4) #[batchsize, 128, 16, 16]
d6 = self.down6(d5) #[batchsize, 128, 16, 16]
cat1 = torch.cat((d6, d5), 1) #[batchsize, 256, 16, 16]
u1 = self.up1(cat1) #[batchsize, 64, 32, 32]
u2 = self.up2(u1) #[batchsize, 64, 32, 32]
cat2 = torch.cat((u2, d4), 1) #[batchsize, 128, 32, 32]
u3 = self.up3(cat2) #[batchsize, 32, 64, 64]
u4 = self.up4(u3) #[batchsize, 32, 64, 64]
u5 = self.up5(u4) #[batchsize, 3, 64, 64]
return self.final(u5) #[batchsize, 3, 64, 64]

class Discriminator(nn.Module):
def __init__(self, input_shape):
super(Discriminator, self).__init__()
channels, height, width = input_shape
self.input_shape = (channels*2, height, width) #[batchsize, 3, 64, 64]
# Calculate output of image discriminator (PatchGAN)
self.output_shape = (1, height // 2 ** 3, width // 2 ** 3)
def discriminator_block(in_filters, out_filters, normalization=True):
"""Returns downsampling layers of each discriminator block"""
layers = [nn.Conv2d(in_filters, out_filters, 4, 2, 1)]
if normalization:
layers.append(nn.BatchNorm2d(out_filters))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
*discriminator_block(channels*2, 16, normalization=False), #[batchsize, 64, 32, 32]
*discriminator_block(16, 32), #[batchsize, 128, 16, 16]
*discriminator_block(32, 128), #[batchsize, 256, 8, 8]
*discriminator_block(128, 128), #[batchsize, 512, 4, 4]
)
self.final = nn.Sequential(
nn.Linear(128 * 20 * 20, 1),
nn.Sigmoid(),
)
def forward(self, img):
# Concatenate image and condition image by channels to produce input
conv = self.model(img)
conv = conv.view(conv.shape[0], -1)
return self.final(conv).view(-1)
综上,与论文框架描述一致,没有弯弯绕绕


被折叠的 条评论
为什么被折叠?



