Pytorch入门与实践——常用的工具

本文介绍了PyTorch中数据处理的基本步骤,包括使用`Dataset`和`DataLoader`,`transforms`的组合应用,处理异常样本的方法,以及`WeightedRandomSampler`的使用。还涉及到计算机视觉工具包`torchvision`中的模型和数据集预处理,并展示了如何使用`visdom`进行可视化。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

import torch as t
from torch.utils import data

import os
from PIL import Image
import numpy as np
from matplotlib import pyplot as plt

from torchvision import transforms as T
from torchvision.datasets import ImageFolder

from torch.utils.data import DataLoader

##############数据处理###################
# class DogCat(data.Dataset):
#     def __init__(self, root):
#         imgs = os.listdir(root)
#         self.imgs = [os.path.join(root, img) for img in imgs]
#
#     def __getitem__(self, index):
#         img_path = self.imgs[index]
#         #dog->1, cat->0
#         label = 1 if 'dog' in img_path.split('/')[-1] else 0
#         pil_img = Image.open(img_path)
#         array = np.asarray(pil_img)
#         data = t.from_numpy(array)
#         return data, label
#
#     def __len__(self):
#         return len(self.imgs)

# dataset = DogCat(r'D:/Python相关/PyTorch入门与实践/pytorch-book-master/chapter5-常用工具/data/dogcat/')
# img, label = dataset[0]     #相当于调用dataset.__getitem__(0)
# for img, label in dataset:
#     print(img.size(), img.float().mean(), label)


#对图像进行多个操作
transform = T.Compose([
    T.Resize(224),
    T.CenterCrop(224),
    T.ToTensor(),   #将图片(Image)转成Tensor,归一化至[0, 1]
    T.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])    #标准化至[-1, 1]
])

#实现自定义数据集
class DogCat(data.Dataset):
    def __init__(self, root, transforms=None):
        imgs = os.listdir(root)
        self.imgs = [os.path.join(root, img) for img in imgs]
        self.transforms = transforms

    def __getitem__(self, index):
        img_path = self.imgs[index]
        #dog->1, cat->0
        label = 1 if 'dog' in img_path.split('/')[-1] else 0
        data = Image.open(img_path)         #Image
        if self.transforms:
            data = self.transforms(data)
        return data, label

    def __len__(self):
        return len(self.imgs)

# dataset = DogCat(r'D:/Python相关/PyTorch入门与实践/pytorch-book-master/chapter5-常用工具/data/dogcat/', transforms=transform)
# # img, label = dataset[0]     #相当于调用dataset.__getitem__(0)
# for img, label in dataset:
#     print((img.size(), img.float().mean(), label))


#ImageFolder
# dataset = ImageFolder(r'D:/Python相关/PyTorch入门与实践/pytorch-book-master/chapter5-常用工具/data/dogcat_2/')
# print(dataset.class_to_idx)     #了解label和文件名的对应关系
# print(list(dataset.imgs))
# print(dataset[0][0], dataset[0][0])     #第一维是第几张图片,第二维为1返回label,为0返回图片数据
# plt.imshow(dataset[0][0])
# plt.show()

#加上transform
normalize = T.Normalize(mean=[.4, .4, .4], std=[.2, .2, .2])
transform = T.Compose([
    T.RandomResizedCrop(224),
    T.RandomHorizontalFlip(),
    T.ToTensor(),   #将图片(Image)转成Tensor,归一化至[0, 1]
    normalize
])
dataset = ImageFolder(r'D:/Python相关/PyTorch入门与实践/pytorch-book-master/chapter5-常用工具/data/dogcat_2/', transform=transform)
# print(dataset[0][0].shape)

# to_img = T.ToPILImage()
# plt.imshow(to_img(dataset[0][0]))
# plt.pause(0.5)
# plt.imshow(to_img(dataset[0][0]*0.2+0.4))
# plt.show()

#DataLoader
dataloader = DataLoader(dataset, batch_size=3, shuffle=True, num_workers=0, drop_last=False)
dataiter = iter(dataloader)
imgs, labels = next(dataiter)
print(imgs.size())

#处理样本无法读取问题
class NewDogCat(DogCat):    #继承前面实现的DogCat数据集
    def __getitem__(self, index):
        try:
            return super(NewDogCat, self).__getitem__(index)    #调用父类的获取函数
        except:
            return None, None   #丢弃异常图片
            #另一种处理方式:随机取一张图片代替
            # new_index = random.randint(0, len(self)-1)
            # return self[new_index]

from torch.utils.data.dataloader import default_collate     #导入默认的拼接方式
def my_collate_fn(batch):
    #batch中每个元素形如(data, label)
    batch = list(filter(lambda x: x[0] is not None, batch))     #过滤为None的数据
    return default_collate(batch)

dataset = NewDogCat(r'D:/Python相关/PyTorch入门与实践/pytorch-book-master/chapter5-常用工具/data/dogcat_wrong/', transforms=transform)
# print(dataset[8])   # (None, None)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=my_collate_fn, num_workers=0, drop_last=False)
# for batch_datas, batch_labels in dataloader:
#     print(batch_datas.size(), batch_labels.size())

#sampler模块
dataset = DogCat(r'D:/Python相关/PyTorch入门与实践/pytorch-book-master/chapter5-常用工具/data/dogcat/', transforms=transform)
weights = [2 if label == 1 else 1 for data, label in dataset]
# print(weights)

from torch.utils.data.sampler import WeightedRandomSampler
sampler = WeightedRandomSampler(weights, num_samples=6, replacement=False)
dataloader = DataLoader(dataset, batch_size=2, sampler=sampler)
# for datas, labels in dataloader:
#     print(labels.tolist())


##################计算机视觉工具包:torchvision######################
from torchvision import models
from torch import nn

# resnet34 = models.resnet34(pretrained=True, num_classes=1000)
# #Downloading: "https://download.pytorch.org/models/resnet34-333f7ec4.pth" to C:\Users\45840/.torch\models\resnet34-333f7ec4.pth
# resnet34.fc = nn.Linear(512, 10)

from torchvision import datasets

transform = T.Compose([
     T.ToTensor(),
     T.Lambda(lambda x: x.repeat(3, 1, 1)),     #将单通道图片转换为三通道
     T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
 ])   # 修改的位置

dataset = datasets.MNIST(r'\D\Python相关\数据集\cy\data', download=True, train=False, transform=transform)
#Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to \D\Python相关\数据集\cy\data\MNIST\raw\train-images-idx3-ubyte.gz

to_pil = T.ToPILImage()
# plt.imshow(to_pil(t.randn(3, 64, 64)))
# plt.show()


print(len(dataset))
dataloader = DataLoader(dataset, shuffle=True, batch_size=16)

from torchvision.utils import make_grid, save_image
dataiter = iter(dataloader)
imgs, labels = next(dataiter)   #出问题了,问题在transform时要将单通道图片转为三通道图片
# print(imgs.size())
# img = make_grid(next(dataiter)[0], 4)   #拼成4*4网格图片,且得提前转换成三通道
# plt.imshow(to_pil(img))
# plt.show()
#读取保存的图片
# save_image(img, 'a.png')
# a = Image.open('a.png')
# plt.imshow(a)
# plt.show()


##############################   visdom    ####################
import visdom

#启动visdom服务器
#nohup python -m visdom.server

# 新建一个连接客户端
# 指定env = u'test1',默认端口为8097,host是'localhost'
# vis = visdom.Visdom(env=u'test1')
#
# x = t.arange(1, 30, 0.01)
# y = t.sin(x)
# vis.line(X=x, Y=y, win='sinx', opts={'title': 'y=sin(x)'})
#
#
# #append 追加数据
# for ii in range(0, 10):
#     x = t.Tensor([ii])
#     y = x
#     vis.line(X=x, Y=y, win='polynomial', update='append' if ii>0 else None)

# updateTrace 新增一条线
# x = t.arange(0, 9, 0.1)
# y = (x**2)/9
# vis.updateTrace(X=x, Y=y, win='polynomial', name='this is a Trace')     #浏览器输入http://localhost:8097


# # 可视化一张随机的黑白图片
# vis.image(t.randn(64, 64).numpy())  #接受的是二维或三维向量
#
# # 可视化一张随机的彩色图片
# vis.image(t.randn(3, 64, 64).numpy(), win='random2')
#
# # 可视化36张随机的彩色图片,每一行6张
# vis.images(t.randn(36, 3, 64, 64).numpy(), nrow=6, win='random3', opts={'title': 'random_imgs'})
#
# # text的可视化输出
# vis.text(u'''<h1>Hello visdom</h1><br>visdom是Faceb专门为<b>PyTorch</b>开发的一个可视化工具''',\
#          win='visdom', opts={'title': u'visdom简介'})

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值