先放例程, dataloader
import torch
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import os
import numpy as np
import os.path as osp
import json
from tqdm import tqdm
from PIL import Image
from utils import Config
class polyvore_dataset:
def __init__(self):
self.root_dir = Config['root_path']
self.image_dir = osp.join(self.root_dir, 'images')
self.transforms = self.get_data_transforms()
# self.X_train, self.X_test, self.y_train, self.y_test, self.classes = self.create_dataset()
def get_data_transforms(self):
data_transforms = {
'train': transforms.Compose([
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
]),
'test': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
]),
}
# 这里用Compose方法把预处理函数列集合起来,可以看作预处理的pipeline。
# CenterCrop是保持原图的中心,然后向外延伸至指定的分辨率
# Resize使用bilinear插值法来放缩原图片,这里是放缩为256x256,如果输入是长方形,则会把较
# 长的一边进行截断,变为正方形后放缩。
# Normalize是把图片每个像素减去均值再除以方差,这里其实并没有真实地算出方差,而是采用了
# 经验值
# ToTensor把PIL image(0-255)变为(0-1)
return data_transforms
def create_dataset(self):
# map id to category
meta_file = open(osp.join(self.root_dir, Config['meta_file']), 'r')
meta_json = json.load(meta_file)
id_to_category = {}
for k, v in tqdm(meta_json.items()):
id_to_category[k] = v['category_id']
# create X, y pairs
files = os.listdir(self.image_dir)
X = []; y = []
for x in files:
if x[:-4] in id_to_category:
X.append(x)
y.append(int(id_to_category[x[:-4]]))
y = LabelEncoder().fit_transform(y)
print('len of X: {}, # of categories: {}'.format(len(X), max(y) + 1))
# split dataset
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
return X_train, X_test, y_train, y_test, max(y) + 1
# For category classification
# pytorch dataset的构造有三个基本函数需要自定义:
# __init__, __len__, __getitem__
# 这三个基本函数确保了dataloader运行时不会出错。__getitem__是作为generator在for循环时push元素
# __len__是为了for语句可以获得计数。self.transform是作为保留字不能被customize的,因为这个函数
# 会在torch.Dataloader中被调用。
class polyvore_train(Dataset):
def __init__(self, X_train, y_train, transform):
self.X_train = X_train
self.y_train = y_train
self.transform = transform
self.image_dir = osp.join(Config['root_path'], 'images')
def __len__(self):
return len(self.X_train)
def __getitem__(self, item):
file_path = osp.join(self.image_dir, self.X_train[item])
return self.transform(Image.open(file_path)),self.y_train[item]
class polyvore_test(Dataset):
def __init__(self, X_test, y_test, transform):
self.X_test = X_test
self.y_test = y_test
self.transform = transform
self.image_dir = osp.join(Config['root_path'], 'images')
def __len__(self):
return len(self.X_test)
def __getitem__(self, item):
file_path = osp.join(self.image_dir, self.X_test[item])
return self.transform(Image.open(file_path)), self.y_test[item]
def get_dataloader(debug, batch_size, num_workers):
dataset = polyvore_dataset()
transforms = dataset.get_data_transforms()
X_train, X_test, y_train, y_test, classes = dataset.create_dataset()
if debug==True:
train_set = polyvore_train(X_train[:100], y_train[:100], transform=transforms['train'])
test_set = polyvore_test(X_test[:100], y_test[:100], transform=transforms['test'])
dataset_size = {'train': len(y_train), 'test': len(y_test)}
else:
train_set = polyvore_train(X_train, y_train, transforms['train'])
test_set = polyvore_test(X_test, y_test, transforms['test'])
dataset_size = {'train': len(y_train), 'test': len(y_test)}
datasets = {'train': train_set, 'test': test_set}
dataloaders = {x: DataLoader(datasets[x],
shuffle=True if x=='train' else False,
batch_size=batch_size,
num_workers=num_workers)
for x in ['train', 'test']}
return dataloaders, classes, dataset_size
########################################################################
# For Pairwise Compatibility Classification
后续待更新