具体怎么学习pytorch,看b站刘二大人的视频。
完整代码:
import numpy as np
import os
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
# 设置随机种子
torch.manual_seed(42)
np.random.seed(42)
'''https://zhuanlan.zhihu.com/p/156926543'''
# 定义图片目录
image_dir = 'images'
# 初始化图片路径列表
img_list = []
# 遍历指定目录及其子目录中的所有文件
for parent, _, filenames in os.walk(image_dir):
for filename in filenames:
# 拼接文件的完整路径
filename_path = os.path.join(parent, filename)
img_list.append(filename_path)
# 初始化图像张量列表和标签列表
image_tensors = []
y_list = []
for image_path in img_list:
# 提取标签 (假设标签是文件名的第一个字符)
label = int(os.path.basename(image_path)[0])
y_list.append(label)
# 打开图像
img = Image.open(image_path)
# 获取图像尺寸
width, height = img.siz