-
加载数据集,导入相应的包
-
import numpy as np import torch from torch import nn import h5py from torch.nn import init def load_data(): train_dataset = h5py.File('E:\文档\【吴恩达课后编程作业】第二周作业 - Logistic回归-识别猫的图片资源/train_catvnoncat.h5', "r") train_set_x_orig = np.array(train_dataset["train_set_x"][:]) # your train set features train_set_y_orig = np.array(train_dataset["train_set_y"][:]) # your train set labels test_dataset = h5py.File('E:\文档\【吴恩达课后编程作业】第二周作业 - Logistic回归-识别猫的图片资源/test_catvnoncat.h5', "r") test_set_x_orig = np.array(test_dataset["test_set_x"][:]) # your test set features test_set_y_orig = np.array(test_dataset["test_set_y"][:]) # your test set labels classes = np.array(test_dataset["list_classes"][:]) # the list of classes train_set_y_orig = train_set_y_orig.reshape((1, train_set_y_orig.shape[0])) test_set_y_orig = test_set_y_orig.reshape((1, test_set_y_orig.shape[0]))