TensorFlow Dataset API 实战:从内存到磁盘数据加载

TensorFlow Dataset API 实战:从内存到磁盘数据加载

【免费下载链接】asl-ml-immersion This repos contains notebooks for the Advanced Solutions Lab: ML Immersion 【免费下载链接】asl-ml-immersion 项目地址: https://gitcode.com/gh_mirrors/as/asl-ml-immersion

概述

本文基于Google Cloud Platform机器学习实践项目中的TensorFlow Dataset API教程,深入讲解如何使用tf.data API高效地加载和处理数据。我们将从内存数据开始,逐步过渡到磁盘数据加载,并探讨生产级数据管道的构建方法。

从内存加载数据

创建内存数据集

首先我们创建一个简单的合成数据集,模拟线性关系 y = 2x + 10:

N_POINTS = 10
X = tf.constant(range(N_POINTS), dtype=tf.float32)
Y = 2 * X + 10

构建Dataset对象

我们定义一个函数将数据转换为tf.data.Dataset对象:

def create_dataset(X, Y, epochs, batch_size):
    dataset = tf.data.Dataset.from_tensor_slices((X, Y))
    dataset = dataset.repeat(epochs).batch(batch_size, drop_remainder=True)
    return dataset

关键点说明:

  • from_tensor_slices:从内存数据创建Dataset
  • repeat:指定数据集重复次数
  • batch:将数据划分为批次,drop_remainder确保每批大小一致

训练循环实现

使用Dataset API的训练循环更加简洁:

EPOCHS = 250
BATCH_SIZE = 2
LEARNING_RATE = 0.02

w0 = tf.Variable(0.0)
w1 = tf.Variable(0.0)

dataset = create_dataset(X, Y, epochs=EPOCHS, batch_size=BATCH_SIZE)

for step, (X_batch, Y_batch) in enumerate(dataset):
    dw0, dw1 = compute_gradients(X_batch, Y_batch, w0, w1)
    w0.assign_sub(dw0 * LEARNING_RATE)
    w1.assign_sub(dw1 * LEARNING_RATE)

这种方法实现了真正的随机梯度下降,每次迭代处理一个小批次数据。

从磁盘加载数据

CSV文件处理

对于存储在磁盘上的数据(如CSV文件),tf.data提供了更高级的功能:

CSV_COLUMNS = [
    "fare_amount", "pickup_datetime", 
    "pickup_longitude", "pickup_latitude",
    "dropoff_longitude", "dropoff_latitude",
    "passenger_count", "key"
]
LABEL_COLUMN = "fare_amount"
DEFAULTS = [[0.0], ["na"], [0.0], [0.0], [0.0], [0.0], [0.0], ["na"]]

创建数据集函数

def create_dataset(pattern, batch_size=1, mode="eval"):
    dataset = tf.data.experimental.make_csv_dataset(
        pattern, batch_size, CSV_COLUMNS, DEFAULTS, shuffle=False
    )
    dataset = dataset.map(features_and_labels).cache()
    
    if mode == "train":
        dataset = dataset.shuffle(1000).repeat()
    
    dataset = dataset.prefetch(1)
    return dataset

关键特性:

  • make_csv_dataset:自动解析CSV文件
  • shuffle:训练时打乱数据顺序
  • cache:缓存数据提高性能
  • prefetch:预取数据减少等待时间

数据转换

我们需要从原始数据中提取特征和标签:

UNWANTED_COLS = ["pickup_datetime", "key"]

def features_and_labels(row_data):
    label = row_data.pop(LABEL_COLUMN)
    features = row_data
    
    for unwanted_col in UNWANTED_COLS:
        features.pop(unwanted_col)
        
    return features, label

低阶API使用

对于更精细的控制,可以使用底层TextLineDataset:

def parse_csv(row):
    ds = tf.strings.split(row, ",")
    label = tf.strings.to_number(ds[0])
    features = tf.strings.to_number(ds[2:6])  # 仅使用部分特征
    return features, label

def create_dataset(pattern, batch_size):
    ds = tf.data.TextLineDataset(pattern)
    ds = ds.map(parse_csv).batch(batch_size)
    return ds

这种方法提供了最大的灵活性,但需要手动处理更多细节。

最佳实践总结

  1. 内存数据:使用from_tensor_slices快速创建Dataset
  2. CSV文件:优先使用make_csv_dataset简化处理
  3. 性能优化
    • 使用cache()缓存预处理结果
    • 使用prefetch()重叠数据预处理和模型执行
    • 训练时添加shuffle()确保数据随机性
  4. 灵活控制:需要特殊处理时使用TextLineDataset等底层API

通过合理使用tf.data API,可以构建高效的数据管道,充分发挥TensorFlow模型的性能潜力。

【免费下载链接】asl-ml-immersion This repos contains notebooks for the Advanced Solutions Lab: ML Immersion 【免费下载链接】asl-ml-immersion 项目地址: https://gitcode.com/gh_mirrors/as/asl-ml-immersion

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值