import torch as t from torch.utils import data import os from PIL import Image import numpy as np from matplotlib import pyplot as plt from torchvision import transforms as T from torchvision.datasets import ImageFolder from torch.utils.data import DataLoader ##############数据处理################### # class DogCat(data.Dataset): # def __init__(self, root): # imgs = os.listdir(root) # self.imgs = [os.path.join(root, img) for img in imgs] # # def __getitem__(self, index): # img_path = self.imgs[index] # #dog->1, cat->0 # label = 1 if 'dog' in img_path.split('/')[-1] else 0 # pil_img = Image.open(img_path) # array = np.asarray(pil_img) # data = t.from_numpy(array) # return data, label # # def __len__(self): # return len(self.imgs) # dataset = DogCat(r'D:/Python相关/PyTorch入门与实践/pytorch-book-master/chapter5-常用工具/data/dogcat/') # img, label = dataset[0] #相当于调用dataset.__getitem__(0) # for img, label in dataset: # print(img.size(), img.float().mean(), label) #对图像进行多个操作 transform = T.Compose([ T.Resize(224), T.CenterCrop(224), T.ToTensor(), #将图片(Image)转成Tensor,归一化至[0, 1] T.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]) #标准化至[-1, 1] ]) #实现自定义数据集 class DogCat(data.Dataset): def __init__(self, root, transforms=None): imgs = os.listdir(root) self.imgs = [os.path.join(root, img) for img in imgs] self.transforms = transforms def __getitem__(self, index): img_path = self.imgs[index] #dog->1, cat->0 label = 1 if 'dog' in img_path.split('/')[-1] else 0 data = Image.open(img_path) #Image if self.transforms: data = self.transforms(data) return data, label def __len__(self): return len(self.imgs) # dataset = DogCat(r'D:/Python相关/PyTorch入门与实践/pytorch-book-master/chapter5-常用工具/data/dogcat/', transforms=transform) # # img, label = dataset[0] #相当于调用dataset.__getitem__(0) # for img, label in dataset: # print((img.size(), img.float().mean(), label)) #ImageFolder # dataset = ImageFolder(r'D:/Python相关/PyTorch入门与实践/pytorch-book-master/chapter5-常用工具/data/dogcat_2/') # print(dataset.class_to_idx) #了解label和文件名的对应关系 # print(list(dataset.imgs)) # print(dataset[0][0], dataset[0][0]) #第一维是第几张图片,第二维为1返回label,为0返回图片数据 # plt.imshow(dataset[0][0]) # plt.show() #加上transform normalize = T.Normalize(mean=[.4, .4, .4], std=[.2, .2, .2]) transform = T.Compose([ T.RandomResizedCrop(224), T.RandomHorizontalFlip(), T.ToTensor(), #将图片(Image)转成Tensor,归一化至[0, 1] normalize ]) dataset = ImageFolder(r'D:/Python相关/PyTorch入门与实践/pytorch-book-master/chapter5-常用工具/data/dogcat_2/', transform=transform) # print(dataset[0][0].shape) # to_img = T.ToPILImage() # plt.imshow(to_img(dataset[0][0])) # plt.pause(0.5) # plt.imshow(to_img(dataset[0][0]*0.2+0.4)) # plt.show() #DataLoader dataloader = DataLoader(dataset, batch_size=3, shuffle=True, num_workers=0, drop_last=False) dataiter = iter(dataloader) imgs, labels = next(dataiter) print(imgs.size()) #处理样本无法读取问题 class NewDogCat(DogCat): #继承前面实现的DogCat数据集 def __getitem__(self, index): try: return super(NewDogCat, self).__getitem__(index) #调用父类的获取函数 except: return None, None #丢弃异常图片 #另一种处理方式:随机取一张图片代替 # new_index = random.randint(0, len(self)-1) # return self[new_index] from torch.utils.data.dataloader import default_collate #导入默认的拼接方式 def my_collate_fn(batch): #batch中每个元素形如(data, label) batch = list(filter(lambda x: x[0] is not None, batch)) #过滤为None的数据 return default_collate(batch) dataset = NewDogCat(r'D:/Python相关/PyTorch入门与实践/pytorch-book-master/chapter5-常用工具/data/dogcat_wrong/', transforms=transform) # print(dataset[8]) # (None, None) dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=my_collate_fn, num_workers=0, drop_last=False) # for batch_datas, batch_labels in dataloader: # print(batch_datas.size(), batch_labels.size()) #sampler模块 dataset = DogCat(r'D:/Python相关/PyTorch入门与实践/pytorch-book-master/chapter5-常用工具/data/dogcat/', transforms=transform) weights = [2 if label == 1 else 1 for data, label in dataset] # print(weights) from torch.utils.data.sampler import WeightedRandomSampler sampler = WeightedRandomSampler(weights, num_samples=6, replacement=False) dataloader = DataLoader(dataset, batch_size=2, sampler=sampler) # for datas, labels in dataloader: # print(labels.tolist()) ##################计算机视觉工具包:torchvision###################### from torchvision import models from torch import nn # resnet34 = models.resnet34(pretrained=True, num_classes=1000) # #Downloading: "https://download.pytorch.org/models/resnet34-333f7ec4.pth" to C:\Users\45840/.torch\models\resnet34-333f7ec4.pth # resnet34.fc = nn.Linear(512, 10) from torchvision import datasets transform = T.Compose([ T.ToTensor(), T.Lambda(lambda x: x.repeat(3, 1, 1)), #将单通道图片转换为三通道 T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) # 修改的位置 dataset = datasets.MNIST(r'\D\Python相关\数据集\cy\data', download=True, train=False, transform=transform) #Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to \D\Python相关\数据集\cy\data\MNIST\raw\train-images-idx3-ubyte.gz to_pil = T.ToPILImage() # plt.imshow(to_pil(t.randn(3, 64, 64))) # plt.show() print(len(dataset)) dataloader = DataLoader(dataset, shuffle=True, batch_size=16) from torchvision.utils import make_grid, save_image dataiter = iter(dataloader) imgs, labels = next(dataiter) #出问题了,问题在transform时要将单通道图片转为三通道图片 # print(imgs.size()) # img = make_grid(next(dataiter)[0], 4) #拼成4*4网格图片,且得提前转换成三通道 # plt.imshow(to_pil(img)) # plt.show() #读取保存的图片 # save_image(img, 'a.png') # a = Image.open('a.png') # plt.imshow(a) # plt.show() ############################## visdom #################### import visdom #启动visdom服务器 #nohup python -m visdom.server # 新建一个连接客户端 # 指定env = u'test1',默认端口为8097,host是'localhost' # vis = visdom.Visdom(env=u'test1') # # x = t.arange(1, 30, 0.01) # y = t.sin(x) # vis.line(X=x, Y=y, win='sinx', opts={'title': 'y=sin(x)'}) # # # #append 追加数据 # for ii in range(0, 10): # x = t.Tensor([ii]) # y = x # vis.line(X=x, Y=y, win='polynomial', update='append' if ii>0 else None) # updateTrace 新增一条线 # x = t.arange(0, 9, 0.1) # y = (x**2)/9 # vis.updateTrace(X=x, Y=y, win='polynomial', name='this is a Trace') #浏览器输入http://localhost:8097 # # 可视化一张随机的黑白图片 # vis.image(t.randn(64, 64).numpy()) #接受的是二维或三维向量 # # # 可视化一张随机的彩色图片 # vis.image(t.randn(3, 64, 64).numpy(), win='random2') # # # 可视化36张随机的彩色图片,每一行6张 # vis.images(t.randn(36, 3, 64, 64).numpy(), nrow=6, win='random3', opts={'title': 'random_imgs'}) # # # text的可视化输出 # vis.text(u'''<h1>Hello visdom</h1><br>visdom是Faceb专门为<b>PyTorch</b>开发的一个可视化工具''',\ # win='visdom', opts={'title': u'visdom简介'})
Pytorch入门与实践——常用的工具
最新推荐文章于 2025-05-24 11:46:04 发布