问题描述
使用 Keras 预处理层,对图片进行例如亮度,对比度,灰度,水平旋转,垂直旋转等的操作。降低神经网络模型对数据的过拟合。
数据集
采用tensorflow自带数据集beans
Beans 是使用智能手机相机在田间拍摄的豆类图像数据集。它由3个类别组成:2个疾病类别和健康类别。描述的疾病包括角叶斑病和豆锈病。数据由乌干达国家作物资源研究所 (NaCRRI) 的专家进行注释,并由 Makerere AI 研究实验室收集。
代码分析
1.在使用tfds.load数据集会出现无法获取google的token信息,从而无法下载数据集的情况
# 添加代码,取消认证
# 获得 Google 身份
tfds.core.utils.gcs_utils._is_gcs_disabled = True
2.下载数据集并划分训练,验证,测试集
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras import layers
# # 查看数据集
# tfds.list_builders()
# 下载数据集合
(train_ds,val_ds,test_ds),metadata = tfds.load(
'beans',
split=['train[:80%]','train[80%:90%]','train[90%:]'],
with_info=True,
as_supervised=True
)
# 数据集合的类别
num_classes = metadata.features['label'].num_classes
print(nu