1.网络结构
结构可以分为两部分 左边部分是编码结构,进行特征提取 右边是解码结果,进行特征还原
2.数据集准备
import os.path
from torchvision import transforms
from torch.utils.data import Dataset
from utils import *
#数据归一化
transform = transforms.Compose([
transforms.ToTensor()
])
class MyDataset(Dataset):
def __init__(self,path):
self.path = path
#获取索引的名字 E:\Pcproject\LB-UNet-main\isic2018\train
self.name = os.listdir(os.path.join(path,'masks'))
def __len__(self):
return len(self.name)
def __getitem__(self,index):
segment_name = self.name[index]
segment_path = os.path.join(self.path,'masks',segment_name)
#原图地址
image_path = os.path.join(self.path,'images',segment_name)
#规范图片的大小尺寸
segment_image = keep_image_size_open(segment_path)
image = keep_image_size_open(image_path)
return transform(image),transform(segment_image)
if __name__=='__main__':
data = MyDataset('E:/Pcproject/pythonProjectlw/UNet')
print(data[0][0].shape)
print(data[0][0].shape)
数据图片规范函数:
from PIL import Image
def keep_image_size_open(path,size = (256,256)):
#打开图像文件
img = Image.open(path)
#取最长边 获取图像尺寸 最长边
temp = max(img.size)
#创建空白图像
mask = Image.new('RGB',(temp,temp),(0,0,0))
#粘贴原始图像
mask.paste(img,(0,0))
#调整图像大小
mask = mask.resize(size)
#返回调整后的图像
return mask
下面

最低0.47元/天 解锁文章
3632

被折叠的 条评论
为什么被折叠?



