使用caffe训练一个多标签分类/回归模型

前言

这篇博客和上一篇性质差不多,都是旨在说明使用caffe训练图像分类模型的大概流程。不同的是,上篇博客讲的单标签图像分类问题,顾名思义,其输入和输出都是单标签或者可以说是单类别的,而此篇则把重点放在如何处理多标签分类/回归问题的输入和输出上。多标签分类/回归问题和单标签的工作流程比较类似,大致分为以下几个步骤,然后本博客再对各个环节做进一步解释。

  1. 准备数据(图像整理好放到合适的文件夹中,对应的ground-truth整理到一个txt中)。
  2. 编写用于模型输入的代码(使用python data layer来读取图片和label)。
  3. 网络模型的定义/编辑(从头编写或修改prototxt文件)。
  4. Solver的定义/编辑(solver是caffe用来训练网络的一个类似配置文件的东西)。
  5. 进行网络模型的训练与测试。

可以看出,此博客和上一篇在流程上的区别主要集中在第二步。博客写作时,caffe还不支持将含有多标签ground-truth的txt文件直接转换成LMDB文件。虽然仍然可以使用LMDB的格式,但是需要分别将图片和label放到两个LMDB文件中,操作略复杂,而caffe官网上比较推荐的一种方式是使用python data layer来直接读取图片和label,故本文也采用了这种形式。

准备数据

这一步先按一定的比例、一定的划分策略将数据集划分为训练集、验证集(根据具体情况可有可无)和测试集,如果要在使用caffe前手动地进行数据预处理或数据增强,尽量在数据集划分之前完成。数据划分好之后,对应的还要有标有ground-truth的txt文件,大致如下图所示:

这里写图片描述
文件的每一行代表一个样本,一行中第一项是图片的名称(名称前还可以有路径),后面是类别标签,具体的任务包含几类图像就对应着有多少个标签,每个标签都是非0即1的,0代表当前图片不属于这个标签,1则相反。图片名称以及各个标签之间以空格间隔。

编写用于模型输入的代码

这一步是这篇博客的重点。caffe自带的例子中也包含了使用python data layer来处理多标签输入的情况。我们更改一下例子给出的源文件大致即可实现自己的需求。首先,找到”caffe_root”/examples/pycaffe/layers/pascal_multilabel_datalayers.py,这个文件就是用于处理多标签输入的代码。此例子是关于PASCAL VOC 2012数据集上的多标签分类任务的,样本情况大致如下所示:


这里写图片描述

对应的ground-truth里除了有物体类别标签,还有各物体在图片中对应的bounding box,如果读者的任务和该例子相似,可以仔细研究下这份代码。本人的工作虽然也是处理多标签问题,但是是研究天气识别的,和这个例子略有不同,ground-truth中只有类别标签而已,所以对代码做了一定的修改,其实是变简单了。下面来看一下代码,先贴出全部代码,然后再主要讲一些需要修改的部分。

import scipy.misc
import caffe

import numpy as np
import os.path as osp

from random import shuffle
from PIL import Image

from tools import SimpleTransformer


class WeatherMultilabelDataLayerSync(caffe.Layer):

    """
    This is a simple syncronous datalayer for training a multilabel model on
    weather dataset.
    """

    def setup(self, bottom, top):

        self.top_names = ['data', 'label']

        # === Read input parameters ===

        # params is a python dictionary with layer parameters.
        params = eval(self.param_str)

        # Check the paramameters for validity.
        check_params(params)

        # store input as class variables
        self.batch_size = params['batch_size']

        # Create a batch loader to load the images.
        self.batch_loader = BatchLoader(params, None)

        # === reshape tops ===
        # since we use a fixed input image size, we can shape the data layer
        # once. Else, we'd have to do it in the reshape call.
        top[0].reshape(
            self.batch_size, 3, params['im_shape'][0], params['im_shape'][1])
        top[1].reshape(self.batch_size, 7)

        print_info("WeatherMultilabelDataLayerSync", params)

    def forward(self, bottom, top):
        """
        Load data.
        """
        for itt in range(self.batch_size):
            # Use the batch loader to load the next image.
            im, multilabel = self.batch_loader.load_next_image()
            # Add directly to the caffe data layer
            top[0].data[itt, ...] = im
            top[1].data[itt, ...] = multilabel

    def reshape(self, bottom, top):
        """
        There is no need to reshape the data, since the input is of fixed size
        (rows and columns)
        """
        pass

    def backward(self, top, propagate_down, bottom):
        """
        These layers does not back propagate
        """
        pass


class BatchLoader(object):

    """
    This class abstracts away the loading of images.
    Images can either be loaded singly, or in a batch. The latter is used for
    the asyncronous data layer to preload batches while other processing is
    performed.
    """

    def __init__(self, params, result):
        self.result = result
        self.batch_size = params['batch_size']
        self.weather_root = params['weather_root']
        self.im_shape = params['im_shape']
        if params['split'] == 'train':
            self.isshuffle = True
            self.iscenter = False
            self.isflip = True
        else:
            self.isshuffle = False
            self.iscenter = True
            self.isflip = False
        # get list of image indexes.
        self.flode
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值