import cv2
import h5py
import numpy as np
from scipy.misc import imsave
from skimage import transform
def load_dataset():
train_dataset = h5py.File('train_happy.h5', "r")
train_set_x_orig = np.array(train_dataset["train_set_x"][:]) # your train set features
train_set_y_orig = np.array(train_dataset["train_set_y"][:]) # your train set labels
test_dataset = h5py.File('test_happy.h5', "r")
test_set_x_orig = np.array(test_dataset["test_set_x"][:]) # your test set features
test_set_y_orig = np.array(test_dataset["test_set_y"][:]) # your test set labels
classes = np.array(test_dataset["list_classes"][:]) # the list of classes
train_set_y_orig = train_set_y_orig.reshape((1, train_set_y_orig.shape[0]))
test_set_y_orig = test_set_y_orig.reshape((1, test_set_y_orig.shape[0]))
return train_set_x_orig, train_set_y_orig, test_set_x_orig, test_set_y_orig, classes
def processing():
X_train_orig, Y_train_orig, X_test_orig, Y_test_orig, classes = load_dataset()
# print("X_train_orig shape: " + str(X_train_orig.shape))
# print("Y_train_orig shape: " + str(Y_train_orig.shape))
# print("X_test_orig shape: " + str(X_test_orig.shape))
# print("Y_test_orig shape: " + str(Y_test_orig.shape))
# print(classes[1])
m = len(X_train_orig)
# print(X_train_orig[1].shape)
Y_train_t = Y_train_orig.T
# for i in range(8):
# plt.subplot(2, 4, i + 1)
# plt.imshow(X_train_orig[i])
# plt.title(Y_train_t[i])
# plt.axis('off')
#
# plt.show()
for i in range(m):
name = 'images/train/' + str(i) + '-[' + str(np.squeeze(Y_train_t[i])) + '].jpg'
# name = 'images/train/' + str(i) + '.jpg'
imsave(name, transform.rescale(X_train_orig[i].reshape(64, 64, 3), 10, mode='constant')) # (640, 640, 3)
def reading():
image = cv2.imread('images/train/16-[1].jpg', cv2.IMREAD_UNCHANGED)
print(image.shape)
cv2.namedWindow('input_image', cv2.WINDOW_AUTOSIZE)
cv2.imshow('input_image', transform.rescale(image, 0.5, mode='constant'))
cv2.waitKey(0)
cv2.destroyAllWindows()
if __name__ == '__main__':
#reading()
processing()
## 数据集说明 ##
本数据集中包含训练集数据X_train_orig, 训练集标签Y_train_orig, 测试集数据X_test_orig,测试集标签Y_test_orig,类别classes。其中包含600张训练图和150张测试图,每张图片的都被存储为64 * 64 的RGB彩色图像。
load_dataset()函数负责将h5格式的数据集以一定要求加载出来;processing()函数负责把数据集中的四维矩阵转化为图片存储起来;reading()函数负责将存储的图片借助OpenCV库显示出来。