1. 读取数据集
import pickle
with open('data_batch_2', 'rb') as f:
#x = pic.load(f, encoding='bytes')
x = pickle.load(f, encoding='latin1')
print(x['data'].shape)
#shape(10000, 3072)
- cifar数据集是用pickle序列化保存,读取方式python2和python3不同,此处采用的python3。encoding可以是bytes,也可以是latin1,目前还没搞懂这是为什么。
def cifarLoad():
file = 'data_batch_'
train_data = []
train_label = []
val_data = []
val_label = []
for i in range(1, 6):
filename = file + str(i)
data_batch = unpickle(filename)
train_data.extend(list(data_batch['data'])[0:9000])
list(data_batch['data'])
train_label.extend(data_batch['labels'][0:9000])
val_data.extend(data_batch['data'][9000:, :])
val_label.extend(data_batch['labels'][9000:])
return np.array(train_data), np.array(train_label), np.array(val_data), np.array(val_label)