TFRecords 拆分

背景:由于之前已经生成了训练数据的tfrecords,需要对已经生成的tfrecords根据类别去按比例拆分。或者将一个大的tfrecors拆分成几个小的tfrecords。

参考:https://stackoverflow.com/questions/54519309/split-tfrecords-file-into-many-tfrecords-files?rq=1

1. 将大的tfrecords拆分成小的tfrecords可以参考上述链接

2. 根据类别去按比例拆分

由于已经生成了tfrecords,目前考虑的两种解决办法。要么直接在生成的时候去按比例重新生成tfrecords数据。要么去读取tfrecortds去解析并判断类别按比例重新写。此处也只是为了练习tfrecords相关的方法,该方法也不是最高效的处理手段。只是用作记录。

def split_tfrecord_by_class(tfrecord_path):
    with tf.Graph().as_default(), tf.Session() as sess:
        ds = tf.data.TFRecordDataset(tfrecord_path)
        ds_iterator = ds.make_one_shot_iterator()
        next_element = ds_iterator.get_next()
        total_cnt =0
        train_cnt_A = 0
        train_cnt_B = 0
        val_cnt_A = 0
        val_cnt_B = 0
        train_path = '/train.tfrecords'
        val_path = '/val.tfrecords'
        train_writer = tf.python_io.TFRecordWriter(train_path)
        val_writer = tf.python_io.TFRecordWriter(val_path)
        while True:
            try:
                parsed_example =tf.parse_single_example(next_element,
                                        features={
                                               'label': tf.FixedLenFeature([], tf.int64),
                                               'img_raw': tf.FixedLenFeature([], tf.string),
                                           })
                
                label = tf.cast(parsed_example['label'], tf.int32)
                label =sess.run([label])
                record = sess.run(next_element)
                if len(label)==1:
                    if label[0] == 1 and train_cnt_white < 6:  
                        
                        train_writer.write(record)
                        train_cnt_A += 1
                        total_cnt+=1
                    elif label[0] == 1 and train_cnt_white >=6:
                        
                        val_writer.write(record)
                        val_cnt_A += 1
                        total_cnt += 1
                        if val_cnt_A == 2:
                            train_cnt_A = 0
                            val_cnt_A = 0
                    if label[0] == 2 and train_cnt_yellow < 6:  
                       
                        train_writer.write(record)
                        train_cnt_B += 1
                        total_cnt += 1
                    elif label[0] == 2 and train_cnt_yellow >=6:
                        
                        val_writer.write(record)
                        val_cnt_yellow += 1
                        total_cnt += 1
                        if val_cnt_B == 2:
                            train_cnt_B = 0
                            val_cnt_B = 0
                print("already preprocessing total {} lanes.".format(total_cnt))
            except tf.errors.OutOfRangeError:
                train_writer.close()
                val_writer.close()
                break

上述代码:

第一步:ds =tf.data.TFRecordDataset(path)方法可以用来读取已经生成好的tfrecords文件。官方链接:https://www.tensorflow.org/tutorials/load_data/tf_records

第二步:ds_iterator =ds.make_one_shot_iterator()构建数据集的一个迭代器,每次可以获取数据集中的一个数据。

第三步:element = ds_iterator.get_next()获取下一个数据集中的数据

第四步:生成writer = tf.python_io.TFRecordWriter(path)实例

第五步:parse获取的example实例通过tf.parse_single_example(example,features)方法

第六步:从解析的example实例中获取相应的值,这时候取出的值是tensor。如果想获取tensor的value 需要在Session中run一下。

 

 

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值