用VGG16对102种鲜花分类
dataset:
import os
import torch
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
class Datasets(Dataset):
def __init__(self,path):
self.path=path
self.dataset=[]
self.dataset.extend(open(os.path.join(path,"label.txt")).readlines())
# self.dataset.extend(open(os.path.join(path, "negative.txt")).readlines())
# self.dataset.extend(open(os.path.join(path,"part.txt")).readlines())
def __getitem__(self, index):
strs=self.dataset[index].strip().split( )
# print(strs)
image_path=os.path.join(self.path,strs[0])
label=torch.Tensor([int(strs[1])])
# offset=torch.Tensor([float(strs[2]),float(strs[3]),float(strs[4]),float(strs[5])])
image_data=Image.open(image_path)
image_data = image_data.convert('RGB')
#把图片制作成正方形,否则采样的时候会报错
w, h = image_data.size
background = Image.new('RGB', size=(max(w, h), max(w, h)), color=(127, 127, 127)) # 创建背景图,颜色值为127
length = int(abs(w - h) // 2) # 一侧需要填充的长度
box = (length, 0) if w < h else (0, length) # 粘贴的位置
background.paste(image_data, box)
image_data=background.resize((224,224))
image_data=torch.Tensor(np.array(image_data)