pytorch可用于图像识别,但我们现在绝大部分用的是MINIST和cifar10图片,想要用自己的训练和测试图像路径,需要制作读取训练集和测试集的代码。本文讲述pytorch实现读取训练集和测试集通用代码。
首先讲一下读取图片路径的框架:
torch.utils.data.Dataset是一个pytorch用来表示数据集的抽象类,我们用这个类来处理自己的数据集时必须继承Dataset,然后重写下面的函数:
len:使得len(dataset)返回数据集的大小
getitem:使得支持dataset[i]能够返回第i个数据样本的下标操作。
接下来给出具体的代码片段,需要调整的是img_id = int(img_path[-12:-9]),这段是读文件名,需要根据自己的文件名来设定img_path里的数值。该文件命名为Path.py
import glob
import torch
from torch.utils.data import Dataset
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
class path(Dataset):
def __init__(self, root_path):
self.mDataX = []
self.mDataY = []
for img_path in glob.glob(root_path + r'\*'):
img = Image.open(img_path)
img = img.convert('L') # 转换成灰度图
# new_size = np.array(img.size) / 4
# new_size = new_size.astype(int)
# img = img.resize(new_size, Image.BILINEAR) # 从(宽,高)(640, 480)缩小为(160, 120)
img_data = np.array(img, dtype=float)
img_data = img_data.reshape(-1)
self.mDataX.append(img_data)
img_id = int(img_path[-12:-9]) ##文件名的后多少位,以后用这里需要根据图像名称更改
self.mDataY.append(img_id)
self.mDataX = torch.tensor(self.mDataX)
self.mDataY = torch.tensor(self.mDataY)
def __getitem__(self, data_index):
input_tensor = torch.tensor(self.mDataX[data_index])
output_tensor = torch.tensor(self.mDataY[data_index])
return input_tensor, output_tensor
def __len__(self):
return len(self.mDataX)
接下来是调用上面写好的.py文件
import torch
from Path import *
train_set = Path(r'C:\datasets\人脸图像\训练样例')
train_loader = torch.utils.data.DataLoader(train_set, batch_size=1, shuffle=False) # , num_workers=8
print('OK! ', len(train_set), len(train_loader))
test_set = Path(r'C:\datasets\人脸图像\测试样例')
test_loader = torch.utils.data.DataLoader(test_set, batch_size=1, shuffle=False) # , num_workers=8
print('OK! ', len(test_set), len(test_loader))
接下来可以使用你自己的训练集和测试集图像了!!!