caffe2学习笔记二:利用numpy数组格式图像数据集生成lmdb格式图像数据集

本文介绍如何使用Python脚本将CIFAR-10数据集转换为Caffe2所需的LMDB格式,并通过示例代码展示了创建过程中的注意事项。

想用caffe2训练网络,首先要做的就是生成caffe/caffe2使用的数据集格式,常用的是lmdb格式。lmdb 是Lightning Memory-Mapped Database的缩写。 从名字可以看出来这种格式比较轻量化,采用的是一种key-value对的存储方式,LMDB示例文件包含一个数据文件data.mdb和一个锁文件lock.mdb,关于它的详细介绍可以搜wiki,用python和caffe对lmdb文件进行读写可以参考文件格式之lmdb

之前本来想着用caffe里面生成cifar10数据集的脚本create_cifar10.sh来生成可供caffe2训练的lmdb格式文件,不过生成的lmdb文件caffe2读取时会报错,网上有人说caffe用的lmdb文件和caffe2用的并不完全兼容,只得另寻它法。后来发现caffe2的github上有人已经提供了用于生成lmdb格式文件的python代码lmdb_create_example.py,用它作参考改一改就能用了。
下面贴出的是我参考上面链接里的代码修改的用cifar10-python文件生成lmdb格式文件的代码。

## @package lmdb_create_example
# Module caffe2.python.examples.lmdb_create_example
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import argparse
import numpy as np

import lmdb
from caffe2.proto import caffe2_pb2
from caffe2.python import workspace, model_helper

'''
Simple example to create an lmdb database of random image data and labels.
This can be used a skeleton to write your own data import.

It also runs a dummy-model with Caffe2 that reads the data and
validates the checksum is same.
'''
def unpickle_cifar10():
    data_path = '/media/ygj/00030DB30006F338/ygj/dataset/cifar-10/cifar-10-python/cifar-10-batches-py/'
    import cPickle
#     # unpickle cifar10 training dataset
#     for i in range(1,6):
#         file = data_path+'data_batch_'+str(i)
#         with open(file, 'rb') as fo:
#             dict = cPickle.load(fo)
#         if i == 1:
#             img_data = dict['data']
#             labels = dict['labels']
#         else:
#             img_data = np.concatenate((img_data,dict['data']),axis=0)
#             labels = labels+dict['labels']
#         print(img_data.shape,len(labels))
    # unpickle cifar10 test dataset
    file = data_path+'test_batch'
    with open(file,'rb') as fo:
        dict = cPickle.load(fo)
        img_data = dict['data']
        labels = dict['labels']
    return img_data, labels

def create_db(output_file):
    print(">>> Write database...")
    LMDB_MAP_SIZE = 1 << 40   # MODIFY
    env = lmdb.open(output_file, map_size=LMDB_MAP_SIZE)

    checksum = 0
    with env.begin(write=True) as txn:
        img_num = 10000
        imgs_data,labels = unpickle_cifar10()
        for j in range(0, img_num): 
#             # MODIFY: add your own data reader / creator
#             label = j % 10
#             width = 64
#             height = 32
# 
#             img_data = np.random.rand(3, width, height)
#             # ...

            img_data = imgs_data[j].reshape((3,32,32))
#             if j==1: # visualize to test
#                 import matplotlib.pyplot as plt
#                 show_data=img_data.swapaxes(0,1).swapaxes(1,2)
#                 plt.imshow(show_data)
#                 plt.show()
#                 pass
            label = labels[j]
            # Create TensorProtos
            tensor_protos = caffe2_pb2.TensorProtos()
            img_tensor = tensor_protos.protos.add()
            img_tensor.dims.extend(img_data.shape)
            img_tensor.data_type = 2 #1:float, 2:int32 

            flatten_img = img_data.reshape(np.prod(img_data.shape))
            img_tensor.int32_data.extend(flatten_img) # need to be correspond with data_type

            label_tensor = tensor_protos.protos.add()
            label_tensor.data_type = 2
            label_tensor.int32_data.append(label)
            txn.put(
                '{}'.format(j).encode('ascii'),
                tensor_protos.SerializeToString()
            )

            checksum += np.sum(img_data) * label
            if (j % 16 == 0):
                print("Inserted {} rows".format(j))

    print("Checksum/write: {}".format(int(checksum)))
    return checksum


def read_db_with_caffe2(db_file, expected_checksum):
    print(">>> Read database...")
    model = model_helper.ModelHelper(name="lmdbtest")
    batch_size = 125
    data, label = model.TensorProtosDBInput(
        [], ["data", "label"], batch_size=batch_size,
        db=db_file, db_type="lmdb")

    checksum = 0

    workspace.RunNetOnce(model.param_init_net)
    workspace.CreateNet(model.net)

    for _ in range(0, 80):
        workspace.RunNet(model.net.Proto().name)

        img_datas = workspace.FetchBlob("data")
        labels = workspace.FetchBlob("label")
        for j in range(batch_size):
            checksum += np.sum(img_datas[j, :]) * labels[j]

    print("Checksum/read: {}".format(int(checksum)))
    assert np.abs(expected_checksum - checksum < 0.1), \
        "Read/write checksums dont match"


def main():
    parser = argparse.ArgumentParser(
        description="Example LMDB creation"
    )
    parser.add_argument("--output_file", type=str, 
                        default=None,
                        help="Path to write the database to",
                        required=True)

    args = parser.parse_args()
    checksum = create_db(args.output_file)

    # For testing reading:
    read_db_with_caffe2(args.output_file, checksum)

if __name__ == '__main__':
    main()

遇到的问题:
Prefetching error [enforce fail at blob_serialization.h:490] chunkSize == proto.byte_data().size(). Incorrect proto field size
一开始也不知道问题出在哪里,不过隐约感觉是生成的lmdb文件有问题,折腾了半天才发现是在做lmdb格式数据时img_tensor.data_type和后面的img_tensor.xxx_data格式没有对应上,希望大家在自己写的时候也注意不要弄错了。

# Create TensorProtos
            tensor_protos = caffe2_pb2.TensorProtos()
            img_tensor = tensor_protos.protos.add()
            img_tensor.dims.extend(img_data.shape)
            img_tensor.data_type = 1

            flatten_img = img_data.reshape(np.prod(img_data.shape))
            img_tensor.float_data.extend(flatten_img) # float_data要与上面的img_tensor.data_type的数字对应
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值