TensorFlow2训练数据集的两种方式

该博客介绍了两种使用TensorFlow对CIFAR100数据集进行预处理的方法,包括数据归一化和one-hot编码,并展示了如何使用自定义的MyNetwork模型进行训练。此外,还提供了数据集加载、数据增强、模型编译、训练与评估的完整流程。代码参考了李沐的《动手深度学习》和龙良曲的《深度学习与TensorFlow入门实战》项目。
部署运行你感兴趣的模型镜像

方式一:

def pre_process(x, y):
    x = 2. * tf.cast(x, dtype=tf.float32) / 255. - 1.
    y = tf.cast(y, dtype=tf.int32)
    return x, y


(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar100.load_data()
x_train, y_train = pre_process(x_train, y_train)
x_test, y_test = pre_process(x_test, y_test)
print(x_train.shape, y_train.shape)

history = net.fit(x_train, y_train,
                  batch_size=512,
                  epochs=100,
                  validation_split=0.2)

test_scores = net.evaluate(x_test, y_test, verbose=2)

训练方式二:

def pre_process(x, y):
    # [0,255] => [-1,1] ,[-1,1]可能是一个最适合神经网络计算的范围
    x = 2. * tf.cast(x, dtype=tf.float32) / 255. - 1
    y = tf.squeeze(y)  # 从张量形状中移除大小为1的维度.
    y = tf.cast(y, dtype=tf.int32)
    y = tf.one_hot(y, depth=10)
    return x, y


batch_size = 128
(x, y), (x_val, y_val) = datasets.cifar10.load_data()
print('datasets:', x.shape, y.shape, x.min(), y.min())

train_db = tf.data.Dataset.from_tensor_slices((x, y))
train_db = train_db.map(pre_process).shuffle(1000).batch(batch_size)

test_db = tf.data.Dataset.from_tensor_slices((x_val, y_val))
test_db = test_db.map(pre_process).shuffle(1000).batch(batch_size)

sample = next(iter(train_db))
print('batch:', sample[0].shape, sample[1].shape)

network = MyNetwork()  # MYNetwork是Keras.Model的一个子类
network.compile(
    optimizer=optimizers.Adam(learning_rate=1e-3),
    loss=tf.losses.CategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)
network.fit(train_db, epochs=50, validation_data=test_db, validation_freq=1)
network.evaluate(test_db)
network.save_weights('./ckpt/cifar10_weights.ckpt') # b将模型保存到磁盘文件

参考链接:

1.李沐大神《动手深度学习》TensorFlow实现,GitHub链接:https://github.com/TrickyGo/Dive-into-DL-TensorFlow2.0,参考了其中的CNN5.9GoogleNet部分代码

2龙良曲.深度学习与TensorFlow入门实战,项目GitHub链接:https://github.com/dragen1860/Deep-Learning-with-TensorFlow-book,摘自其中的Lesson40--CIFAR与VGG实战

您可能感兴趣的与本文相关的镜像

TensorFlow-v2.15

TensorFlow-v2.15

TensorFlow

TensorFlow 是由Google Brain 团队开发的开源机器学习框架,广泛应用于深度学习研究和生产环境。 它提供了一个灵活的平台,用于构建和训练各种机器学习模型

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值