模型加速--IO加速,tfRecord和keras Sequence

本文介绍如何使用Keras的Sequence和TensorFlow的TFRecord优化大规模数据集的模型训练过程,减少数据读取时间,提高GPU利用率。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

每次在训练模型时,尤其是训练数据较大时,都会大部分时间都会花在数据IO读写上,而不是真正的GPU计算,这也就意味着,GPU实际上很多时候是空闲等待状态!
在keras中可以通过sequence实现,在tensorflow中可以通过tfRecord实现。或者将图片以.npy的格式保存在本地,在训练的时候读取也会快很多。

如果将大规模数据一次性读进内存会很耗内存,可以使用tensorflow的queue和keras的sequence来存储数据。

1、keras的sequence
2、tensorflow的tfRecord

1、keras sequence

keras官方给出了参考实例,这里给了个常用的

import ast
import os
import numpy as np
import random
import math
from tensorflow.python.keras.preprocessing.image import img_to_array as img_to_array
from tensorflow.python.keras.preprocessing.image import load_img as load_img
def load_image(image_path, size):
    return img_to_array(load_img(image_path, target_size=(size, size))) / 255.

# shuffle好像有点问题,如有问题可参考https://www.kaggle.com/wrosinski/pretrained-cnn-albumentations
class KagglePlanetSequence(tf.keras.utils.Sequence):
    """
    在不把数据一次性读进内存的情况下,我们使用Sequence完成数据相对高效的IO
    """
    
    def __init__(self, df, data_path, im_size, batch_size, mode='train'):
        """
        df: pandas dataframe that contains columns with image names and labels
        data_path: path that contains the training images
        im_size: image size
        mode: when in training mode, data will be shuffled between epochs
        """
        self.df = df
        self.batch_size = batch_size
        self.im_size = im_size
        self.mode = mode
        
        # Take labels and a list of image locations in memory
        # ast.literal_eval类似eval,将字符转化为原有形式,更安全
        self.wlabels = self.df['weather_labels'].apply(lambda x: ast.literal_eval(x)).tolist()
        self.glabels = self.df['ground_labels'].apply(lambda x: ast.literal_eval(x)).tolist()
        self.image_list = self.df['image_name'].apply(lambda x: os.path.join(data_path, x + '.jpg')).tolist()

    def __len__(self):
        return int(math.ceil(len(self.df) / float(self.batch_size)))

    def on_epoch_end(self):
        # 每一轮之后对数据乱序
        self.indexes = range(len(self.image_list))
        if self.mode == 'train':
            self.indexes = random.sample(self.indexes, k=len(self.indexes))

    def get_batch_labels(self, idx): 
        # 拿到一个batch的标签
        return [self.wlabels[idx * self.batch_size: (idx + 1) * self.batch_size],
                self.glabels[idx * self.batch_size: (idx + 1) * self.batch_size]]

    def get_batch_features(self, idx):
        # 拿到一个batch的图像
        batch_images = self.image_list[idx * self.batch_size: (1 + idx) * self.batch_size]
        return np.array([load_image(im, self.im_size) for im in batch_images])

    def __getitem__(self, idx):
        batch_x = self.get_batch_features(idx)
        batch_y = self.get_batch_labels(idx)
        return batch_x, batch_y
    

可以使用此Sequence对象代替自定义生成器fit_generator(来训练模型。大家注意一下,不需要提供每个epoch的步数(steps),__len__方法已经为生成器内部实现了这个逻辑。此外,tf.keras提供对可用于增强训练循环的所有可用Keras回调函数。可以作为辅助功能加入,比如可以提供early stopping,学习速率调度,为TensorBoard可视化写日志等等…比如使用ModelCheckPoint回调在每个时期之后保存模型,以便我们可以随时从预训练模型开始训练。

# 使用
seq = KagglePlanetSequence(df_train,
                       './train-jpg/',
                       im_size=IM_SIZE,
                       batch_size=32)
another_model = tf.keras.models.load_model('./model.h5')
another_model.fit_generator(generator=seq, verbose=1, epochs=1)

# 测试
test_seq = KagglePlanetSequence(df_train,
                       './train-jpg/',
                       im_size=IM_SIZE,
                       batch_size=32,
                       mode='test') # test mode disables shuffling

predictions = model.predict_generator(generator=test_seq, verbose=1)

2、tensorflow tfrecord

tensorflow中提供的tf.data是一个非常强大的数据源接口,它可以接受很不同形态的数据输入到模型中进行学习训练。
TFRecords 其实是一种二进制文件,虽然它不如其他格式好理解,但是它能更好的利用内存,更方便赋值和移动,并且不需要单独的标签文件,理论上,它能保存所有的信息。
在这里插入图片描述
tfrecord原理讲解可参考https://www.jianshu.com/p/b251e85ac582
tfrecord代码讲解可参考https://www.cnblogs.com/wj-1314/p/11211333.html

2.1 tf.Example

TFRecord 的核心内容在于内部有一系列的Example,Example 是protocolbuf 协议(protocolbuf 是通用的协议格式,对主流的编程语言都适用。所以这些 List对应到Python语言当中是列表。而对于Java 或者 C/C++来说他们就是数组)下的消息体。

一个Example消息体包含了一系列的feature属性。每一个feature是一个map,也就是 key-value 的键值对。key 取值是String类型。而value是Feature类型的消息体。将数据表示为{‘string’: value}形式的 message类型,TensorFlow经常使用 tf.Example 来写入,读取 TFRecord数据。
  
通常情况下,tf.Example中可以使用以下几种格式:

tf.train.BytesList: 可以使用的类型包括 string和byte
tf.train.FloatList: 可以使用的类型包括 float和double
tf.train.Int64List: 可以使用的类型包括 enum,bool, int32, uint32, int64
  TFRecord支持写入三种格式的数据:string,int64,float32,以列表的形式分别通过tf.train.BytesList,tf.train.Int64List,tf.train.FloatList 写入 tf.train.Feature

# 建立tfEample
# 以dict的形式把要写入的数据汇总,并构建 tf.train.Features,然后构建 tf.train.Example
def get_tfrecords_example(feature, label):
    tfrecords_features = {}
    feat_shape = feature.shape
    tfrecords_features['feature'] = tf.train.Feature(bytes_list=
                                              tf.train.BytesList(value=[feature.tostring()]))
    tfrecords_features['shape'] = tf.train.Feature(int64_list=
                                              tf.train.Int64List(value=list(feat_shape)))
    tfrecords_features['label'] = tf.train.Feature(float_list=
                                              tf.train.FloatList(value=label))
 
    return tf.train.Example(features=tf.train.Features(feature=tfrecords_features))

# 涉及的具体函数如下
def _bytes_feature(value):
    """Returns a bytes_list from a string/byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
 
def _float_feature(value):
    """Return a float_list form a float/double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
 
def _int64_feature(value):
    """Return a int64_list from a bool/enum/int/uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

把创建的tf.train.Example序列化下,便可以通过 tf.python_io.TFRecordWriter 写入 tfrecord文件中,如下:

#创建tfrecord的writer,文件名为xxx
tfrecord_wrt = tf.python_io.TFRecordWriter('xxx.tfrecord') 
#把数据写入Example
exmp = get_tfrecords_example(feats[inx], labels[inx]) 
#Example序列化
exmp_serial = exmp.SerializeToString()   
#写入tfrecord文件 
tfrecord_wrt.write(exmp_serial)   
#写完后关闭tfrecord的writer
tfrecord_wrt.close()  

2.2 如何将一张图片转换为tfRecord格式

针对上面写入tfRecord,多了如何将image转化为feature

# 处理image
# 读取图片并进行解码
image = tf.read_file(input)
image_data = tf.image.decode_jpeg(image_data)
# 将图片转换成string
image_data = image_data.tostring()

# 或者keras处理方式(将其resize)
im = np.array(img_to_array(load_img(im_list[i], target_size=(IM_SIZE, IM_SIZE))) / 255.).tostring()

# 处理label名字
name = bytes('cat', encoding='utf-8')

总代码如下:

# _*_coding:utf-8_*_
import tensorflow as tf
 
def write_test(input, output):
    # 借助于TFRecordWriter 才能将信息写入TFRecord 文件
    writer = tf.python_io.TFRecordWriter(output)
 
    # 读取图片并进行解码
    image = tf.read_file(input)
    image = tf.image.decode_jpeg(image)
 
    with tf.Session() as sess:
        image = sess.run(image)
        shape = image.shape
        # 将图片转换成string
        image_data = image.tostring()
        print(type(image))
        print(len(image_data))
        name = bytes('cat', encoding='utf-8')
        print(type(name))
        # 创建Example对象,并将Feature一一对应填充进去
        example = tf.train.Example(features=tf.train.Features(feature={
             'name': tf.train.Feature(bytes_list=tf.train.BytesList(value=[name])),
             # 如果图片大小固定,可以不实用shape这一栏
             'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[0], shape[1], shape[2]])),
             'data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_data]))
        }
        ))
        # 将example序列化成string 类型,然后写入。
        writer.write(example.SerializeToString())
    writer.close()
 
 
if __name__ == '__main__':
    input_photo = 'cat.jpg'
    output_file = 'cat.tfrecord'
    write_test(input_photo, output_file)

2.3 处理tfRecord

读取之前存取的tfRecord并形成batch进行训练

# 用dataset读取TFRecords文件
dataset = tf.contrib.data.TFRecordDataset(input_file)

def _parse_record(example_photo):
    features = {
        'name': tf.FixedLenFeature((), tf.string),
        'shape': tf.FixedLenFeature([3], tf.int64),
        'data': tf.FixedLenFeature((), tf.string)
    }
    # 解析tfrecord 文件的每条记录,即序列化后的 tf.train.Example;使用 tf.parse_single_example 来解析:
    parsed_features = tf.parse_single_example(example_photo,features=features)
    return parsed_features
    
# 解析tfrecord文件中的所有记录,我们需要使用dataset的map方法
dataset = dataset.map(_parse_record)


# map方法可以接受任意函数对dataset中的数据进行处理;另外可以使用repeat,shuffle,batch方法对dataset进行重复,混洗,分批;用repeat赋值dataset以进行多个epoch;如下:
ds_train = dataset.repeat(epochs).shuffle(buffer_size).batch(batch_size)

对ds_train进行迭代

history = model.fit(ds_train, 
                    steps_per_epoch=100, # let's just take some steps
                    epochs=1)

# 或者通过创建iterator来进行
iterator = ds_train.make_one_shot_iterator()
features = sess.run(iterator.get_next())
<think>好的,我现在要解决用户提出的关于TensorFlow中的UnknownError问题,特别是错误信息中的“KeyError: '[0] not in index'”。首先,我需要理解这个错误的上下文可能的原因。用户之前已经询问过关于`tf.config.run_functions_eagerly(True)`的问题,说明他们在使用TensorFlow时遇到了调试相关的问题,现在可能是在调试过程中遇到了这个新的错误。 首先,我需要分析错误信息本身。错误类型是UnknownError,但内部提到了KeyError,指出'[0] not in index'。这通常与数据结构的索引有关,可能在处理数据时试图访问不存在的索引。结合TensorFlow的上下文,用户可能在使用数据集(Dataset API)或数据管道时出现了问题,特别是在使用迭代器(如IteratorGetNext)获取数据时。 接下来,我需要考虑可能的场景。用户可能在加载数据时,例如从TFRecord文件或其他格式的数据集中读取数据,但数据的结构与代码中的预期不匹配。例如,特征字典中的键可能缺失或名称不匹配,或者在解析数据时指定的键不存在。另一个可能是数据的形状或类型不符合预期,导致在访问特定索引时出错。 另外,考虑到用户之前的问题涉及Eager Execution模式,可能他们在尝试调试时启用了`run_functions_eagerly(True)`,但错误依然存在。这说明问题可能出现在数据预处理阶段,而不是计算图转换的问题。需要检查数据管道的每个步骤,确保数据解析转换的正确性。 可能的解决方案包括:检查数据文件是否正确生成,是否有损坏;验证数据解析逻辑中的键名是否正确;确保数据集中的每个元素都包含预期的特征;使用TensorFlow的调试工具,如tf.data.experimental.enable_debug_mode()来获取更详细的错误信息。 同时,需要提醒用户注意数据预处理读取步骤的一致性,特别是在使用TFRecord时,写入读取的feature_description必须完全匹配。此外,如果使用Keras的输入层,检查输入层配置是否与数据形状相符。 总结来说,这个错误通常源于数据加载或解析时的键不匹配或索引错误,需要逐步检查数据管道的每个环节,确保数据结构预期一致,并使用调试工具来定位具体出错的位置。</think>遇到 `UnknownError` 并伴随 `KeyError: '[0] not in index'` 通常与 **TensorFlow 数据管道中的数据格式不匹配**有关,尤其是在使用 `tf.data.Dataset` 或加载结构化数据(如 TFRecord、CSV、Pandas DataFrame)时。以下是逐步排查解决思路: --- ### **错误原因分析** 1. **数据键名不匹配** - 当代码尝试通过键(如 `features[0]` 或 `features['key_name']`)访问数据时,**实际数据中不存在对应的键或索引**。 - 常见场景: - 解析 TFRecord 时,`feature_description` 中的键名与写入时的键名不一致。 - 使用 `tf.data.Dataset.from_generator()` 生成数据时,返回的数据结构与模型输入层不兼容。 2. **数据形状不匹配** - 模型输入层期望的维度与实际数据维度冲突(例如模型需要 `(batch, 256, 256, 3)`,但数据实际形状为 `(256, 256)`)。 3. **索引越界** - 在使用 `tf.data.Dataset` 或 `tf.keras.utils.Sequence` 时,尝试访问超出数据长度的索引(如数据集共有 100 条数据,但试图取第 101 条)。 --- ### **解决步骤** #### **1. 检查数据加载代码** - **若使用 TFRecord**: 确保 `feature_description` 中的键名与写入数据时使用的键名 **完全一致**: ```python # 写入时的键名(例如:'image' 'label') def _bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) example = tf.train.Example(features=tf.train.Features(feature={ 'image': _bytes_feature(image_data), 'label': _bytes_feature(label_data) })) # 读取时的解析函数必须匹配键名 def parse_tfrecord(example_proto): feature_description = { 'image': tf.io.FixedLenFeature([], tf.string), 'label': tf.io.FixedLenFeature([], tf.string) # 键名必须一致! } return tf.io.parse_single_example(example_proto, feature_description) ``` - **若使用 Pandas DataFrame**: 检查列名是否被意外修改,或尝试通过整数索引访问列(如 `df[0]` 可能失败,而应用 `df.iloc[:, 0]`)。 #### **2. 验证数据形状** - **打印数据样本**: 在数据管道中添加 `dataset.take(1).as_numpy_iterator()` 查看实际数据形状: ```python dataset = ... # 你的数据集 sample = next(iter(dataset.take(1))) print("Sample shape:", sample[0].shape) # 检查输入特征形状 print("Sample keys:", sample[0].keys()) # 若为字典,检查键名 ``` - **匹配模型输入层**: 确保模型的第一层(`Input` 层)的 `shape` 参数与数据形状兼容: ```python # 假设数据形状为 (256, 256, 3) model = tf.keras.Sequential([ tf.keras.layers.Input(shape=(256, 256, 3)), # 必须与数据形状一致 # 后续层... ]) ``` #### **3. 启用调试模式** 使用 TensorFlow 的调试工具定位具体出错位置: ```python # 启用数据集调试模式(TF 2.3+) tf.data.experimental.enable_debug_mode() # 或手动捕获错误 try: for batch in dataset: # 处理数据 except tf.errors.UnknownError as e: print("Error details:", e.message) # 查看具体错误信息 ``` #### **4. 检查索引越界** 若使用自定义生成器或 `tf.keras.utils.Sequence`,验证 `__getitem__` 方法是否在有效范围内: ```python class CustomGenerator(tf.keras.utils.Sequence): def __getitem__(self, index): if index >= len(self): # 避免越界 raise IndexError # 返回数据... ``` --- ### **代码示例修正** 假设错误源于 TFRecord 键名不匹配: ```python # 错误代码:解析时使用了错误的键名 'img' def parse_tfrecord(example_proto): feature_description = {'img': tf.io.FixedLenFeature([], tf.string)} # 实际键名为 'image' return tf.io.parse_single_example(example_proto, feature_description) # KeyError! # 修正后:键名改为 'image' feature_description = {'image': tf.io.FixedLenFeature([], tf.string)} ``` --- ### **总结** 此错误的核心是 **数据访问的键名或索引与实际数据不匹配**。解决方法包括: 1. 检查数据写入/读取的键名一致性。 2. 验证数据形状与模型输入兼容性。 3. 使用调试工具定位错误源头。 4. 确保索引操作不越界。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值