Tensorflow模型训练过程需要使用训练集和测试集,这里涉及比较细节的方式,其中ndarry类型和Dataset类型是最常用的两种类型,很多开发者没有完成理解清楚这两种在引用和使用过程需要注意的细节,造成运行过程出现错误,本文对这一问题通过实例进行解释和说明。
一、tensorflow模型训练可以使用的数据类型
tensorflow在训练模型的过程中,如果需要训练的数据量不是很大,例如不到1G,那么可以直接全部读入内存中进行训练,这样一般效率最高。但如果需要训练的数据很大,例如超过10G,无法一次载入内存,那么通常需要在训练的过程中分批逐渐读入。使用tf.data API可以构建数据输入管道,处理大量的数据,不同的数据格式以及不同的数据转换。
tensorflow可以直接载入的数据类型主要包括以下几种:
- 从Numpy array构建数据管道
- 从Pandas DataFrame构建数据管道
- 从Python generator构建数据管道
- 从csv文件构建数据管道
- 从文本文件构建数据管道
- 从文件路径构建数据管道
- 从tfrecords文件构建数据管道
在这里介绍训练集采用ndarray和tensorflow中dataset格式作为训练集数据涉及的问题
二、tensorflow模型训练中使用ndarry类型和Dataset类型的说明
ndarry类型和Dataset类型是tensorflow中最常用两种类型,它们两个各自有各自的优势,总结起来,ndarry类型数据直接看见,开发者可以清楚的知道任意一个位置的数据,可以与其它数学算法的接口良好匹配,但在tensorflow中使用的时候需要进行一些手动操作,这对初级者来说容易出现错误;Dataset类型是则进行了比较好的封装,是tensorflow自带的数据类型,可以在tensorflow中形成非常好的兼容关系,自带了多种方法,使用起来比较简单,但数据对用户来说不是直接可见,很多算法功能包的接口不支持Dataset类型,需要开发者自己进行转换。
2.1 使用ndarry类型的实例说明:
代码:
import numpy as np
import tensorflow as tf
from tensorflow.keras.utils import plot_model
import random
mnist = tf.keras.datasets.mnist
data0 = tf.keras.datasets.mnist.load_data()
print('data0 type:',type(data0))
print('data0 len:',len(data0))
data0_0 = tf.keras.datasets.mnist.load_data()[0]
print('data0_0 type:',type(data0_0))
print('data0_0 len:',len(data0_0))
data0_0_0 = tf.keras.datasets.mnist.load_data()[0][0]
data0_0_1 = tf.keras.datasets.mnist.load_data()[0][1]
print('data0_0_0 type:',type(data0_0_0))
print('data0_0_0 shape:',np.shape(tf.keras.datasets.mnist.load_data()[0][0]))
print('data0_0_1 shape:',np.shape(tf.keras.datasets.mnist.load_data()[0][1]))
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
print("训练集样本及标签", train_images.shape, train_labels.shape)
print("测试集样本及标签", test_images