python实现识别猫神经网络
用pycharm进行编程实现的,具体细节可参考:
https://blog.youkuaiyun.com/qq_34290470/article/details/99849514
百度云pycharm项目源码:https://pan.baidu.com/s/12q_Er1vJpeo-O8h_KQYgCQ
完整python代码:
import numpy as np
import matplotlib.pyplot as plt
import h5py
def load_dataset():
train_dataset = h5py.File('train_catvnoncat.h5', "r")
train_set_x_orig = np.array(train_dataset["train_set_x"][:]) # your train set features
train_set_y_orig = np.array(train_dataset["train_set_y"][:]) # your train set labels
test_dataset = h5py.File('test_catvnoncat.h5', "r")
test_set_x_orig = np.array(test_dataset["test_set_x"][:]) # your test set features
test_set_y_orig = np.array(test_dataset["test_set_y"][:]) # your test set labels
classes = np.array(test_dataset["list_classes"][:]) # the list of classes
train_set_y_orig = train_set_y_orig.reshape((1, train_set_y_orig.shape[0]))
test_set_y_orig = test_set_y_orig.reshape((1, test_set_y_orig.shape[0]))
return train_set_x_orig, train_set_y_orig, test_set_x_orig, test_set_y_orig, classes
index=25
train_set_x_orig , train_set_y , test_set_x_orig , test_set_y , classes = load_dataset() # 加载数据集
plt.imshow(train_set_x_orig[index]) # 查看训练集中的图片
plt.show()
# 打印出当前的训练标签值
# train_set_y是二维数组,使用np.squeeze的目的是压缩维度,即去掉shape中的1
# classe[0]='non-cat',clas