import numpy as np import os import sys import keras.backend as K from six.moves import cPickle import cv2 def load_batch(fpath, label_key='labels'): f = open(fpath, 'rb') if sys.version_info < (3,): d = cPickle.load(f) else: d = cPickle.load(f, encoding='bytes') # decode utf8 d_decoded = {} for k, v in d.items(): d_decoded[k.decode('utf8')] = v d = d_decoded f.close() data = d['data'] labels = d[label_key] data = data.reshape(data.shape[0], 3, 32, 32) return data, labels def load_data(): dirname = './cifar-10-batches-py/' num_train_samples = 50000 x_train = np.empty((num_train_samples, 3, 32, 32), dtype='uint8') y_train = np.empty((num_train_samples,), dtype='uint8') for i in range(1, 6): fpath = os.path.join(dirname, 'data_batch_' + str(i)) (x_train[(i - 1) * 10000: i * 10000, :, :, :], y_train[(i - 1) * 10000: i * 10000]) = load_batch(fpath) fpath = os.path.join(dirname, 'test_batch') x_test, y_test = load_batch(fpath) y_train = np.reshape(y_train, (len(y_train), 1)) y_test = np.reshape(y_test, (len(y_test), 1)) if K.image_data_format() == 'channels_last': x_train = x_train.transpose(0, 2, 3, 1) x_test = x_test.transpose(0, 2, 3, 1) return (x_train, y_train), (x_test, y_test) cifar10数据地址:链接: https://pan.baidu.com/s/1feueGEEjJ1sYZCC08Q3HFA 密码: ausw 解压后即可修改load_data的dirname即可读取,里面还有mnist数据。
Cifar10数据以及加载
最新推荐文章于 2024-11-12 15:07:01 发布