pytorch实现Unet模型

自定义数据集

# -*- I Love Python!!! And You? -*-
# @Time    : 2022/3/27 12:25
# @Author  : sunao
# @Email   : 939419697@qq.com
# @File    : img_segData.py
# @Software: PyCharm
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
import torch.nn.functional as F
import os

class img_segData(Dataset):
	def __init__(self,img_h=256,img_w=256,path="./data/img_seg",data_file="images",label_file="profiles",
	             preprocess=True):
		'''
		数据集初始化
		:param img_h: resize图像高度
		:param img_w: resize图像宽度
		:param path: 数据集路径
		:param data_file: 数据特征值文件夹名称
		:param label_file: 数据标签文件夹名称
		:param preprocess: 是否进行数据预处理
		'''
		super(img_segData, self).__init__()
		self.file_list = os.listdir(path+"/"+data_file)
		self.data_file = data_file
		self.label_files = label_file
		self.path = path
		self.img_h = img_h
		self.img_w = img_w
		self.preprocess = preprocess
		pass
		
		
	def __len__(self):
		# 返回数据集大小
		return len(self.file_list)
		
	
	def __getitem__(self, item):
		# 返回指定索引的数据集
		img_name = self.file_list[item]
		label_name = img_name.split(".")[0]+"-profile.jpg"
		label_path = self.path+"/"+self.label_files+"/"+label_name
		img_path = self.path+"/"+self.data_file+"/"+img_name
		
		# 读取数据
		img = Image.open(img_path)
		label = Image.open(label_path)
		
		# 数据预处理
		if self.preprocess:
			trans_img = transforms.Compose([
				transforms.Resize(size=(self.img_w,self.img_h)),
				transforms.ToTensor(),
				transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))
			])
			img = trans_img(img)
			trans_label = transforms.Compose([
				transforms.Resize(size=(self.img_w,self.img_h)),
				transforms.ToTensor(),
			])
			label = trans_label(label)
		return img,label



if __name__ == '__main__':
	trans_data = img_segData()
	img,label = trans_data.__getitem__(5)
	print(img.size(),label.size())
	
	# plt.imshow(img.data.numpy().transpose([1,2,0]))
	# plt.show()
	# plt.imshow(label.data.numpy().reshape(256,256))
	# plt.show()
	label = torch.where(label==
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

语音不识别

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值