Keras深度学习框架实战(2):估计模型训练所需的样本量

1、模型训练样本量评估概述

1.1 样本量评估的意义

预估模型需要的样本量对于机器学习项目的成功至关重要,以下是几个主要原因:

  1. 防止过拟合与欠拟合

    • 过拟合:当模型在训练数据上表现极好,但在未见过的测试数据上表现糟糕时,就发生了过拟合。这通常是因为模型过于复杂,而训练数据不足以支持其学习数据的真实模式。通过预估足够的样本量,我们可以减少过拟合的风险。
    • 欠拟合:与过拟合相反,欠拟合是模型未能捕捉到数据中的关键模式。这可能是因为模型过于简单或训练数据不足。预估样本量有助于确保模型有足够的数据来学习数据的复杂模式。
  2. 资源分配

    • 预估样本量有助于项目团队合理分配资源。如果预计需要大量数据,团队可以提前开始数据收集工作,或考虑使用更高效的数据收集方法。此外,了解所需样本量还可以帮助团队估算项目的时间和成本。
  3. 实验设计

    • 在设计实验或研究时,预估样本量有助于确定实验的规模。这有助于确保实验具有足够的统计功效,以检测感兴趣的效应或差异。
  4. 模型性能评估

    • 有了足够的样本量,我们可以更准确地评估模型的性能。通过将模型应用于独立的测试集,我们可以评估模型在未见过的数据上的表现,并据此调整模型参数或结构。
  5. 可解释性与泛化能力

    • 充足的样本量有助于模型学习数据的普遍规律,而不仅仅是训练数据的特定模式。这使得模型更有可能在类似但不同的数据集上表现良好,即具有更强的泛化能力。此外,充足的样本量还可以提高模型的可解释性,使结果更易于理解和解释给非技术利益相关者。
  6. 合规性与伦理

    • 在某些领域,如医疗、金融和法律等,数据收集和使用受到严格的法规和伦理准则的约束。预估样本量有助于确保项目符合这些要求,避免潜在的合规性问题和伦理争议。
  7. 提高项目成功率

    • 通过预估模型需要的样本量,项目团队可以更好地规划和管理项目资源。这有助于提高项目的成功率和效率,减少因资源不足或分配不当而导致的延误和失败。

预估模型需要的样本量是机器学习项目成功的关键一步。通过仔细考虑和计算所需的样本量,我们可以确保模型具有足够的数据来学习数据的真实模式,并减少过拟合和欠拟合的风险。同时,这还有助于项目团队更好地规划和管理资源,提高项目的成功率和效率。

1.2 样本量评估的一般方法

在许多现实世界的场景中,用于训练深度学习模型的图像数据量是有限的。特别是在医疗成像领域,数据集的创建成本高昂。当面临一个新的问题时,通常首先出现的问题是:“我们需要多少张图像来训练一个足够好的机器学习模型?”

在大多数情况下,只有一小部分样本可用,我们可以利用这些样本来模拟训练数据大小与模型性能之间的关系。这样的模型可以用于估计达到所需模型性能所需的最优图像数量。

样本量确定方法

  1. 平衡子采样方案

    • 在这个例子中,使用平衡子采样方案来确定模型的最佳样本量。该方案通过选择由Y个图像组成的随机子样本,并使用该子样本训练模型来完成。
    • 随后,在一个独立的测试集上对模型进行评估。
    • 该过程对每个子样本重复N次,并进行替换,以构建观测性能的平均值和置信区间。
  2. 样本量与模型性能的关系建模

    • 利用现有的一小部分样本,我们可以构建一个模型来模拟训练数据大小与模型性能之间的关系。
    • 这个模型可以帮助我们预测,随着训练数据量的增加,模型性能将如何变化。
  3. 最优样本量的估计

    • 通过分析模型性能与训练数据大小之间的关系,我们可以估计出达到特定性能水平所需的最优样本量。
    • 这有助于我们确定在资源限制下,应收集多少图像来训练模型。
  4. 重复实验与统计评估

    • 为了获得更准确的估计,我们重复上述过程多次,并计算观测性能的平均值和置信区间。
    • 这有助于我们评估估计的可靠性,并确定所需的样本量是否足够稳健。

通过采用平衡子采样方案和构建模型性能与训练数据大小之间的关系模型,我们可以系统地估计出达到所需模型性能所需的最优图像数量。这种方法不仅可以帮助我们在有限的资源下做出明智的决策,还可以提高机器学习模型在实际应用中的性能和可靠性。在医疗成像等数据稀缺的领域,这种方法尤为重要。

2、设置

import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import keras
from keras import layers
import tensorflow_datasets as tfds

# Define seed and fixed variables
seed = 42
keras.utils.set_random_seed(seed)
AUTO = tf.data.AUTOTUNE

3、数据集加载

我们将使用 TF Flowers 数据集,加载它并将其转换为 NumPy 数组。
数据下载地址如下:
https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
下面是一个示例代码,展示如何使用 TensorFlow 的 tf.keras.preprocessing.image_dataset_from_directory 函数加载数据集,并将其转换为 NumPy 数组:

# Specify dataset parameters
dataset_name = "tf_flowers"
batch_size = 64
image_size = (224, 224)

# Load data from tfds and split 10% off for a test set
(train_data, test_data), ds_info = tfds.load(
    dataset_name,
    split=["train[:90%]", "train[90%:]"],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

# Extract number of classes and list of class names
num_classes = ds_info.features["label"].num_classes
class_names = ds_info.features["label"].names

print(f"Number of classes: {
     num_classes}")
print(f"Class names: {
     class_names}")


# Convert datasets to NumPy arrays
def dataset_to_array(dataset, image_size, num_classes):
    images, labels = [], []
    for img, lab in dataset.as_numpy_iterator():
        images.append(tf.image.resize(img, image_size).numpy())
        labels.append(tf.one_hot(lab, num_classes))
    return np.array(images), np.array(labels)


img_train, label_train = dataset_to_array(train_data, image_size, num_classes)
img_test, label_test = dataset_to_array(test_data, image_size, num_classes)

num_train_samples = len(img_train)
print(f"Number of training samples: {
     num_train_samples}")
Number of classes: 5
Class names: ['dandelion', 'daisy', 'tulips', 'sunflowers', 'roses']
Number of training samples: 3303

从测试集中绘制几个示例的图表

plt.figure(figsize=(16, 12))
for n in range(30):
    ax = plt.subplot(5, 6, n + 1)
    plt.imshow(img_test[n].astype("uint8"))
    plt.title(np.array(class_names)[label_test[n] == True][0])
    plt.axis("off")

<

评论 11
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

MUKAMO

你的鼓励是我们创作最大的动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值