论文地址
开源代码
https://github.com/yelusaleng/RRU-Net
下面就是复现的代码
使用的数据集是NIST
项目结构
详细代码
./dataload/NIST_data.py
import torch
import torchvision.transforms as tfs
import os
from PIL import Image
import time
LINUX = 0
if(LINUX==0):
prefix = "G:/data/"
else:
prefix = "/home/yyl/data/"
HEIGHT = 256
WIDTH = 256
class NIST_DATA(object):
def __init__(self, mode="train"):
super(NIST_DATA, self).__init__()
self.nist_data_root = prefix + "NIST/"
self.image_name = []
self.gt_name = []
self.readImage()
# 图像正则化参数
self.mean = [0.40789654, 0.44719302, 0.47026115]
self.std = [0.28863828, 0.27408164, 0.27809835]
self.im_tfs = tfs.Compose([
tfs.ToTensor(),
tfs.Normalize(self.mean, self.std)
])
self.im_tfs_gt = tfs.Compose([tfs.ToTensor()])
#self.filter()
# 训练集配置前80%的图片,测试集配置后20%的图片
if(mode == "train"):
self.image_name = self.image_name[:round(len(self.image_name)*0.8)]
self.gt_name = self.gt_name[:round(len(self.gt_name)*0.8)]
else:
self.image_name = self.image_name[round(len(self.image_name) * 0.8):]
self.gt_name = self.gt_name[round(len(self.gt_name) * 0.8):]
print("%s->一共加载了%d张图片"%(mode,len(self.image_name)))
# 读取图片
def readImage(self):
image_root = self.nist_data_root + "1/"
gt_root = self.nist_data_root + "1_mask/"
filename = os.listdir(image_root)
for i in range(len(filename)):
self.image_name.append(image_root + filename[i])
self.gt_name.append(gt_root + filename[i].split(".")[0]+".png")
image_root = self.nist_data_root + "2/"
gt_root = self.nist_data_root + "2_mask/"
filename = os.listdir(image_root)
for i in range(len(filename)):
self.image_name.append(image_root + filename[i])
self.gt_name.append(gt_root + filename[i].split(".")[0] + ".png")
def filter(self, thresold=10):
tmp_image_name = self.image_name.copy()
tmp_gt_name = self.gt_name.copy()
for idx in range(len(self.image_name)):
img = Image.open(self.image_name[idx])
img_gt = Image.open(self.gt_name[idx])
img, img_gt = self.image_transform(img, img_gt)
if(img_gt.sum()<thresold):
#print("删除掉%d"%idx)
tmp_image_name.remove(self.image_name[idx])
tmp_gt_name.r