模型测试中的代码理解
模型训练之后,测试是重中之重,目前,我所碰到的有以下两种方式:
🅰利用
torchvision.utils.save_image
来对模型结果直接进行保存🅱先将结果转换为
numpy array
之后进行提取与合成
利用torchvision.util.save_image
进行测试
了解源码请看官方文档📒
- 若模型的输出是三通道的
RGB
图像,那么直接利用
torchvision.utils.save_image(model_out, './out.png', padding=0)
- 若模型是单通道的输出时,则需要做一些想要的变化工作:
HR_2_r = model_r(r)
HR_2_g = model_g(g)
black_2 = torch.zeros(1, 1024, 1024).unsqueeze(0).cuda()
HR_2 = torch.cat((HR_2_r.squeeze(0), HR_2_g.squeeze(0), black_2.squeeze(0))).unsqueeze(0)
torchvision.utils.save_image(HR_2, './out.png', padding=0)
print("saved !")
-
假设经过模型计算后的
HR_2_r.shape
:[1, 1, 1024, 1024]
-
假设经过模型计算后的
HR_2_g.shape
:[1, 1, 1024, 1024]
-
创建
black_2
是因为想获得全黑色无信息的B通道
而后结合模型输出的R and G
通道来生成RGB
的图像,利用unsqueeze(0)
来使得torch.zeros(1, 1024, 1024)
的形状变成[1, 1, 1024, 1024]
,使其与其他两个通道相同以便于合并 -
HR_2_r.squeeze(0)
使得其形状由[1, 1, 1024, 1024]
变为[1, 1024, 1024]
,其他同理。之后利用torch.cat
来进行合并通道,最后使用torchvision.utils.save_image
来进行保存。
利用numpy array
来保存图像
model = torch.load(opt.model).cuda()
transform = Compose([ToTensor(), Lambda(lambda x: x.repeat(3,1,1)),])
img = Image.open(opt.input).convert('RGB')
LR_r_spilt, _, _ = img.split()
LR_r = transform(LR_r_spilt)
LR_r = LR_r.unsqueeze(0)
LR_r = LR_r.cuda()
HR_2_r = model(LR_r)
HR_2_r = HR_2_r.cpu()
out_HR_2_r_detach = HR_2_r.data[0].numpy()
out_HR_2_r = out_HR_2_r_detach * 255.0
out_HR_2_r = out_HR_2_r.clip(0, 255)
out_HR_2_r_result = Image.fromarray(np.uint8(out_HR_2_r[0]), mode='L')
out_HR_2_r_result.save(opt.outputHR2)
print('output image saved to ', opt.outputHR2)
- 假设
img.size
为[3, 512, 512]
,经过split
后,LR_r_spilt.size
为[512, 512]
,再经过transform
后,则模型的输入LR_r.shape
:torch.Size([3, 512, 512])
,但是网络的输入应该是[batchSize, C, H, W]
的4维向量,所以要进行unsqueeze(0)
操作,使得LR_r
变为torch.Size([1, 3, 512, 512])
- 经过模型计算后的
HR_2_r.shape
:torch.Size([1, 3, 1024, 1024])
,因为模型是利用torch
计算的,在进行numpy
转换前,应该先用.cpu()
做一个数据的转换 - 利用
HR_2_r.data[0]
是对数据进行分离,使得四维的数据变成三维torch.Size([3, 1024, 1024])
。当然也有使用detach()
来进行这样操作👉MORE - 由于网络的结果是
[0, 1]
之间的值,而之后要利用np.unit8()
做转换,所以要将结果转为[0, 255]
之间,所以要乘以255.0
。而乘以255后可能使得数据大于255,故用clip(0, 255)
来保证数据域是正确的 - 利用
out_HR_2[0]
来提取第一通道,即R通道
,最后利用Image.fromarray
来实现从array
到Image
的转变,此时,out_HR_r_result
即为单通道的灰度图
以下是测试输出,可以不看:
Namespace(cuda=True, input='/content/drive/My Drive/TestImg/val/95-512pix-speed7-ave1.tif', model='/content/drive/My Drive/TestImg/H_Adagrad_epoch_r_400.pth', outputHR2='/content/drive/My Drive/TestImg/val2/95_X2.tif', outputHR4='/content/drive/My Drive/TestImg/val4/95_X4.tif')
After Split:(512, 512)
After transform:torch.Size([3, 512, 512])
After unsqueeze:torch.Size([1, 3, 512, 512])
After Model:torch.Size([1, 3, 1024, 1024])
After data[0]:torch.Size([3, 1024, 1024])
After data[0].numpy():(3, 1024, 1024)
# 网络输出
[[[0.0726895 0.05225104 0.05731031 ... 0.08038348 0.09176005 0.11681005]
[0.05010414 0.03608295 0.05038059 ... 0.07422554 0.08273476 0.09623137]
[0.05970517 0.04536656 0.04714845 ... 0.07344079 0.07673076 0.08257505]
...
[0.10552698 0.10863981 0.1160759 ... 0.07840011 0.05360675 0.05854338]
[0.12117186 0.11651129 0.12565961 ... 0.06953818 0.03986105 0.05407226]
[0.13890806 0.13018039 0.12679848 ... 0.07166685 0.05829817 0.08199665]]
[[0.06548429 0.04780073 0.05077077 ... 0.07777905 0.08438891 0.08978723]
[0.04724503 0.03454134 0.04663415 ... 0.07442069 0.07981256 0.08830814]
[0.05518261 0.04601964 0.04341775 ... 0.07343534 0.07746741 0.07461968]
...
[0.10186043 0.10778397 0.11563006 ... 0.07755548 0.05533218 0.04998586]
[0.11432448 0.11585081 0.12556481 ... 0.06927705 0.04027918 0.04754832]
[0.12643194 0.12204635 0.12268594 ... 0.06363812 0.05656445 0.07072814]]
[[0.06303221 0.04812427 0.05175997 ... 0.0773088 0.08727033 0.09965152]
[0.04497686 0.03826225 0.04548033 ... 0.07697698 0.08157301 0.09077145]
[0.05506983 0.04517868 0.04388288 ... 0.07452801 0.07780927 0.07592566]
...
[0.10016826 0.10798413 0.11497539 ... 0.078601 0.05558252 0.05430539]
[0.11343104 0.11644471 0.12352681 ... 0.06749362 0.03726557 0.0477092 ]
[0.125687 0.12345901 0.12515822 ... 0.06406513 0.05620003 0.0672472 ]]]
# 乘以 255
[[[18.535824 13.324016 14.61413 ... 20.497787 23.398813 29.786564 ]
[12.776556 9.201153 12.84705 ... 18.927513 21.097364 24.539 ]
[15.224818 11.568472 12.022855 ... 18.727402 19.566343 21.056639 ]
...
[26.90938 27.70315 29.599356 ... 19.992027 13.669721 14.928563 ]
[30.898825 29.710379 32.0432 ... 17.732235 10.164569 13.788426 ]
[35.421555 33.196 32.333614 ... 18.275047 14.866034 20.909145 ]]
[[16.698492 12.189187 12.9465475 ... 19.833658 21.519173 22.895744 ]
[12.047482 8.808042 11.891709 ... 18.977276 20.352201 22.518576 ]
[14.071565 11.735009 11.071527 ... 18.726011 19.75419 19.028019 ]
...
[25.974411 27.484913 29.485666 ... 19.776648 14.109707 12.746393 ]
[29.152742 29.541956 32.019028 ... 17.665648 10.271191 12.124823 ]
[32.240147 31.12182 31.284914 ... 16.22772 14.423935 18.035675 ]]
[[16.073214 12.271688 13.198793 ... 19.713745 22.253935 25.411137 ]
[11.469099 9.756873 11.597483 ... 19.629131 20.801117 23.14672 ]
[14.042808 11.520564 11.190133 ... 19.004642 19.841366 19.361044 ]
...
[25.542906 27.535952 29.318726 ... 20.043255 14.173544 13.847875 ]
[28.924913 29.6934 31.499336 ... 17.210873 9.50272 12.165845 ]
[32.050186 31.482048 31.915346 ... 16.336607 14.331007 17.148035 ]]]
# clip之后
[[[18.535824 13.324016 14.61413 ... 20.497787 23.398813 29.786564 ]
[12.776556 9.201153 12.84705 ... 18.927513 21.097364 24.539 ]
[15.224818 11.568472 12.022855 ... 18.727402 19.566343 21.056639 ]
...
[26.90938 27.70315 29.599356 ... 19.992027 13.669721 14.928563 ]
[30.898825 29.710379 32.0432 ... 17.732235 10.164569 13.788426 ]
[35.421555 33.196 32.333614 ... 18.275047 14.866034 20.909145 ]]
[[16.698492 12.189187 12.9465475 ... 19.833658 21.519173 22.895744 ]
[12.047482 8.808042 11.891709 ... 18.977276 20.352201 22.518576 ]
[14.071565 11.735009 11.071527 ... 18.726011 19.75419 19.028019 ]
...
[25.974411 27.484913 29.485666 ... 19.776648 14.109707 12.746393 ]
[29.152742 29.541956 32.019028 ... 17.665648 10.271191 12.124823 ]
[32.240147 31.12182 31.284914 ... 16.22772 14.423935 18.035675 ]]
[[16.073214 12.271688 13.198793 ... 19.713745 22.253935 25.411137 ]
[11.469099 9.756873 11.597483 ... 19.629131 20.801117 23.14672 ]
[14.042808 11.520564 11.190133 ... 19.004642 19.841366 19.361044 ]
...
[25.542906 27.535952 29.318726 ... 20.043255 14.173544 13.847875 ]
[28.924913 29.6934 31.499336 ... 17.210873 9.50272 12.165845 ]
[32.050186 31.482048 31.915346 ... 16.336607 14.331007 17.148035 ]]]
# 提取第一通道
[[18.535824 13.324016 14.61413 ... 20.497787 23.398813 29.786564]
[12.776556 9.201153 12.84705 ... 18.927513 21.097364 24.539 ]
[15.224818 11.568472 12.022855 ... 18.727402 19.566343 21.056639]
...
[26.90938 27.70315 29.599356 ... 19.992027 13.669721 14.928563]
[30.898825 29.710379 32.0432 ... 17.732235 10.164569 13.788426]
[35.421555 33.196 32.333614 ... 18.275047 14.866034 20.909145]]
After PostDeal:(1024, 1024)
output image saved to /content/drive/My Drive/TestImg/val2/95_X2.tif
output image saved to /content/drive/My Drive/TestImg/val4/95_X4.tif