安装环境部分:
1、Anaconda的安装(自行参考其他博客)
2、使用pip安装pytorch
- 打开anaconda prompt创建一个pytorch的环境
- conda create -n pytorch_gpu python=3.7 pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
- 安装COUD11(可以查看自己显卡的种类去pytorch官网上指定下载)
- 测试 import
- torch.cuda.is_available()
- end
代码部分:
import os
from PIL import Image
from torch.utils.data import Dataset
# dataset有两个作用:1、加载每一个数据,并获取其label;2、用len()查看数据集的长度
class MyData(Dataset):
def __init__(self, root_dir, label_dir): # 初始化,为这个函数用来设置在类中的全局变量
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir,self.label_dir) # 单纯的连接起来而已,背下来怎么用就好了,因为在win下和linux下的斜线方向不一样,所以用这个函数来连接路径
self.img_path = os.listdir(self.path) # img_path 的返回值,就已经是一个列表了
def __getitem__(self, idx): # 获取数据对应的 label
img_name = self.img_path[idx] # img_name 在上一个函数的最后,返回就是一个列表了
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name) # 这行的返回,是一个图片的路径,加上图片的名称了,能够直接定位到某一张图片了
img = Image.open(img_item_path) # 这个步骤看来是不可缺少的,要想 show 或者 操作图片之前,必须要把图片打开(读取),也就是 Image.open()一下,这可能是 PIL 这个类型图片的特有操作
label = self.label_dir # 这个例子中,比较特殊,因为图片的 label 值,就是图片所在上一级的目录
return img, label # img 是每一张图片的名称,根据这个名称,就可以使用查看(直接img)、print、size等功能
# label 是这个图片的标签,在当前这个类中,标签,就是只文件夹名称,因为我们就是这样定义的
def __len__(self):
return len(self.img_path) # img_path,已经是一个列表了,len()就是在对这个列表进行一些操作
if __name__ == '__main__':
root_dir = "dataset/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = MyData(root_dir, ants_label_dir)
bees_dataset = MyData(root_dir, bees_label_dir)
train_dataset = ants_dataset + bees_dataset
# runfile('E:/pythonProject/learn_pytorch/read_data.py')