学习目标:
学习用pytorch读取显示数据集中的图片
学习内容:
参考视频教程:https://www.bilibili.com/video/BV1xA411e7oX
视频以chest数据集为例,讲解了如何读取数据集中的图片,然后将其进行处理,最终显示出来
# 1 加载库
import torch
import torch.nn as nn
from torchvision import datasets, transforms
import os
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
import numpy as np
import matplotlib.pyplot as plt
# 2 定义一个方法:显示图片(该方法是官方定义好的方法,用于显示图片,所以直接用就行了)
def image_show(inp, title=None):
plt.figure(figsize=(10, 3)) # 设置画布大小
inp = inp.numpy().transpose((1, 2, 0)) # 将torch.FloatTensor 转换为numpy ;transpose转置
mean = np.array([0.485, 0.456, 0.486]) # 均值
std = np.array([0.229, 0.224, 0.225]) # 方差
inp = std * inp + mean
inp = np.clip(inp, 0, 1)
plt.imshow(in