一.PyTorch环境配置及安装
1.1 工具安装
1.1.1 Anaconda下载
清华大学镜像站下载,版本为Anaconda3-5.2.0-Windows-x86_64(对应python3.6.5)
Index of /anaconda/archive/ | 清华大学开源软件镜像站 | Tsinghua Open Source Mirror
1.1.2 Pytorch安装
进入官网。选择合适版本下载Start Locally | PyTorch(不推荐),因为我下载了好几次失败了
推荐使用下面的方法:
-
torch下载:
pip install torch -i https://pypi.tuna.tsinghua.edu.cn/simple
-
opencv下载:
pip install opencv-python -i https://pypi.tuna.tsinghua.edu.cn/simple
-
torchvision
pip install torchvision -i https://pypi.tuna.tsinghua.edu.cn/simple
二、DataSet
2.1 DataSet的作用
①获取每个数据及其label
②获取数据总数
2.2 认识DataSet
由上图可知dataset是一个抽象类,可以用来创造数据集,而抽象类不能实例化,需要构造抽象类的子类来创建数据集,所有的datasets继承这个类,并重写两个方法:(1)get_item 方法获取数据和label(2)len:获取数据总数
2.3 重写dataset
PIL中的Image
img=Image.open(image_path) 读取图像路径作为一个变量
img.show() 打开图片
os:
os.path.join(dir1,dir2):将两个路径合并在一起
os.listdir(dir):将目标路径dir中的所有文件路径生成一个列表
from torch.utils.data import Dataset, ConcatDataset
import os
from PIL import Image
class MyDataset(Dataset):
def __init__(self, root_dir, label_dir):
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir, self.label_dir)
self.img_path = os.listdir(self.path)
def __len__(self):
return len(self.img_path)
def __getitem__(self, idx):
&nbs