from __future__ import print_function
from PIL import Image #从文件加载图像(python Image Library)
import os #文件操作
import sys #文件操作
import numpy as np #与torch混合使用搭建数据传输
import argparse #处理命令行参数的库
import torch.utils.data as data ##创建数据集
#水果数据预处理
class Fruit(data.Dataset):
#初始化,定义数据内容和标签
def __init__(self, root_dir, train=True, transform=None):
self.root_dir = os.path.abspath(root_dir)
self.transform = transform
self.train=train
if (self.train):
self.data = np.load(os.path.join(self.root_dir, "train_data.npy"))
self.labels = np.load(os.path.join(self.root_dir, "train_labels.npy"))
else:
self.data = np.load(os.path.join(self.root_dir, "validation_data.npy"))
self.labels = np.load(os.path.join(self.root_dir, "validation_labels.npy"))
self.data = self.data.transpose((0, 2, 3, 1))#转换底层编号
#查找数据和标签
def __getitem__(self, index):
# img, target = self.data[index], self.labels[index]
#img = Image.fromarray(img.astype('uint8'))
img = self.data[index]
target = self.labels[index]
if self.transform is not None:
img = self.transform(img)
return img, target
#给出数据集的大小
def __len__(self):
return (len(self.data))
## 引入函数库
import argparse
import os
import sys
import numpy as np
import cv2
import glob
print ("INFO: all the modules are imported.")
##功能是把你的输入参数打印到屏幕
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, required=True, help='Path to the dataset folder')
args = parser.parse_args()
##To load 64 of 94 kinds fruit from fruit-360
fruit_names = [
'AppleBraeburn',
'AppleGolden1',
'AppleGolden2',
'AppleGolden3',
'AppleGrannySmith',
'AppleRed1',
'AppleRed2',
'AppleRed3',
'AppleRedDelicious',
'AppleRedYellow1',
'AppleRedYellow2',
'Apricot',
'Avocado',
'Avocadoripe',
'Banana',
'BananaLadyFinger',
'BananaRed',
'Cactusfruit',
'Cantaloupe1',
'Cantaloupe2',
'Carambula',
'Cherry1',
'Cherry2',
'CherryRainier',
'CherryWaxBlack',
'CherryWaxRed',
'