一、加载数据
import tensorflow as tf #导入tensorflow库
mnist=tf.keras.datasets.mnist
(x_train,y_train),(x_test,y_test)=mnist.load_data() #加载数据
如果无法加载数据的话,如类似出现如下图所示问题,则可以根据下载地址,将mnist.npz文件下载到本地。
然后通过以下代码加载本地数据文件
import tensorflow as tf
import numpy as np
def load_data(path):
with np.load(path) as f:
x_train, y_train = f['x_train'], f['y_train']
x_test, y_test = f['x_test'], f['y_test']
return (x_train, y_train), (x_test, y_test)
(x_train, y_train),(x_test, y_test) = load_data(path="../data/mnist.npz") #mnist的本地路径
x_train表示训练数据集,y_train表示训练数据集对应的结果;x_test表示测试数据集,y_test表示测试集对应的结果。加载完成后,可以通过以下代码来检查一下
x_train.shape