MXNet之SSD目标检测下篇


本篇不在于复现SSD论文算法结果,主要介绍 算法细节

1.数据准备

下载VOC2007和2012合并的数据集(需要科学上网
数据集地址
VGG16预训练model地址
下载后的文件如下:
在这里插入图片描述
Annotations:合并后的xml文件,共21503项;
JPEGImages:合并后的图像文件,共21503项,与xml文件一一对应。
ImageSets:在该文件夹下的main/中,制作MXNet的.lst文件和.rec文件。

(base) yuyang@oceanshadow:~/下载/MXNet-Deep-Learning-in-Action-master/demo9/9.2-jectDetection/tools$ python create_list.py  --set test --save-path ../data/VOC
devkit/VOC/ImageSets/Main --dataset-path ../data/VOCdevkit/VOC
(base) yuyang@oceanshadow:~/下载/MXNet-Deep-Learning-in-Action-master/demo9/9.2-jectDetection/tools$ python create_list.py  --set trainval --save-path ../data
/VOCdevkit/VOC/ImageSets/Main --dataset-path ../data/VOCdevkit/VOC --shuffle True
(base) yuyang@oceanshadow:~/下载/MXNet-Deep-Learning-in-Action-master/demo9/9.2-jectDetection/tools$ conda activate mxnet
(mxnet) yuyang@oceanshadow:~/下载/MXNet-Deep-Learning-in-Action-master/demo9/9.2bjectDetection/tools$ python im2rec.py ../data/VOCdevkit/VOC/ImageSets/Main/te
st.lst ../data/VOCdevkit --no-shuffle --pack-label
Creating .rec file from /home/yuyang/下载/MXNet-Deep-Learning-in-Action-master/demo9/9.2-objectDetection/data/VOCdevkit/VOC/ImageSets/Main/test.lst in /home/yuyang/下载/MXNet-Deep-Learning-in-Action-master/demo9/9.2-objectDetection/data/VOCdevkit/VOC/ImageSets/Main
multiprocessing not available, fall back to single threaded encoding
time: 0.02251458168029785  count: 0
time: 3.2134971618652344  count: 1000
time: 3.0257110595703125  count: 2000
time: 3.4036951065063477  count: 3000
time: 3.0267889499664307  count: 4000
(mxnet) yuyang@oceanshadow:~/下载/MXNet-Deep-Learning-in-Action-master/demo9/9.2bjectDetection/tools$ python im2rec.py ../data/VOCdevkit/VOC/ImageSets/Main/tr
ainval.lst ../data/VOCdevkit --pack-label
Creating .rec file from /home/yuyang/下载/MXNet-Deep-Learning-in-Action-master/demo9/9.2-objectDetection/data/VOCdevkit/VOC/ImageSets/Main/trainval.lst in /home/yuyang/下载/MXNet-Deep-Learning-in-Action-master/demo9/9.2-objectDetection/data/VOCdevkit/VOC/ImageSets/Main
multiprocessing not available, fall back to single threaded encoding
time: 0.0036492347717285156  count: 0
time: 3.3942601680755615  count: 1000
time: 3.395172119140625  count: 2000
time: 3.3868696689605713  count: 3000
time: 3.387535810470581  count: 4000
time: 3.3465471267700195  count: 5000
time: 3.422719955444336  count: 6000
time: 3.7703046798706055  count: 7000
time: 4.644562482833862  count: 8000
time: 3.329374313354492  count: 9000
time: 3.3966259956359863  count: 10000
time: 3.372823476791382  count: 11000
time: 3.3226966857910156  count: 12000
time: 3.4331839084625244  count: 13000
time: 3.3643810749053955  count: 14000
time: 3.428262233734131  count: 15000
time: 3.4237232208251953  count: 16000

"–set"参数用于指定生成的列表文件名,本例中为:test/trainval
“no-shuffle”代表不对数据随机打乱;
“–pack-label”代表打包标签信息到RecordIO文件。
在Main/下生成的文件如下图所示:

在这里插入图片描述

  • creat_list.py
import os
import argparse
from PIL import Image
import xml.etree.ElementTree as ET
import random

def parse_args():
   parser = argparse.ArgumentParser()
   parser.add_argument('--set', type=str, default='train')
   parser.add_argument('--save-path', type=str, default='')
   parser.add_argument('--dataset-path', type=str, default='')
   parser.add_argument('--shuffle', type=bool, default=False)
   args = parser.parse_args()
   return args

def main():
   label_dic = {"aeroplane": 0, "bicycle": 1, "bird": 2, "boat": 3, "bottle": 4, "bus": 5,
                "car": 6, "cat": 7, "chair": 8, "cow": 9, "diningtable": 10, "dog": 11,
                "horse": 12, "motorbike": 13, "person": 14, "pottedplant": 15, "sheep": 16,
                "sofa": 17, "train": 18, "tvmonitor": 19}
   args = parse_args()
   if not os.path.exists(os.path.join(args.save_path, "{}.lst".format(args.set))):
       os.mknod(os.path.join(args.save_path, "{}.lst".format(args.set)))
   with open(os.path.join(args.save_path, "{}.txt".format(args.set)), "r") as input_file:
       lines = input_file.readlines()
       if args.shuffle:
           random.shuffle(lines)
       with open(os.path.join(args.save_path, "{}.lst".format(args.set)), "w") as output_file:
           index = 0
           for line in lines:
               line = line.strip()
               out_str = "\t".join([str(index), "2", "6"])
               img = Image.open(os.path.join(args.dataset_path, "JPEGImages", line+".jpg"))
               width, height = img.size
               xml_path = os.path.join(args.dataset_path, "Annotations", line+".xml")
               tree = ET.parse(xml_path)
               root = tree.getroot()
               objects = root.findall('object')
               for object in objects:
                   name = object.find('name').text
                   difficult = ("%.4f" % int(object.find('difficult').text))
                   label_idx = ("%.4f" % label_dic[name])
                   bndbox = object.find('bndbox')
                   xmin = ("%.4f" % (int(bndbox.find('xmin').text)/width))
                   ymin = ("%.4f" % (int(bndbox.find('ymin').text)/height))
                   xmax = ("%.4f" % (int(bndbox.find('xmax').text)/width))
                   ymax = ("%.4f" % (int(bndbox.find('ymax').text)/height))
                   object_str = "\t".join([label_idx, xmin, ymin, xmax, ymax, difficult])
                   out_str = "\t".join([out_str, object_str])
               out_str = "\t".join([out_str, "{}/JPEGImages/".format(args.dataset_path.split("/")[-1])+line+".jpg"+"\n"])
               output_file.writelines(out_str)
               index += 1

if __name__ == '__main__':
   main()
  • im2rec.py
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import print_function
import os
import sys

curr_path = os.path.abspath(os.path.dirname(__file__))
sys.path.append(os.path.join(curr_path, "../python"))
import mxnet as mx
import random
import argparse
import cv2
import time
import traceback

try:
   import multiprocessing
except ImportError:
   multiprocessing = None

def list_image(root, recursive, exts):
   """Traverses the root of directory that contains images and
   generates image list iterator.
   Parameters
   ----------
   root: string
   recursive: bool
   exts: string
   Returns
   -------
   image iterator that contains all the image under the specified path
   """

   i = 0
   if recursive:
       cat = {}
       for path, dirs, files in os.walk(root, followlinks=True):
           dirs.sort()
           files.sort()
           for fname in files:
               fpath = os.path.join(path, fname)
               suffix = os.path.splitext(fname)[1].lower()
               if os.path.isfile(fpath) and (suffix in exts):
                   if path not in cat:
                       cat[path] = len(cat)
                   yield (i, os.path.relpath(fpath, root), cat[path])
                   i += 1
       for k, v in sorted(cat.items(), key=lambda x: x[1]):
           print(os.path.relpath(k, root), v)
   else:
       for fname in sorted(os.listdir(root)):
           fpath = os.path.join(root, fname)
           suffix = os.path.splitext(fname)[1].lower()
           if os.path.isfile(fpath) and (suffix in exts):
               yield (i, os.path.relpath(fpath, root), 0)
               i += 1

def write_list(path_out, image_list):
   """Hepler function to write image list into the file.
   The format is as below,
   integer_image_index \t float_label_index \t path_to_image
   Note that the blank between number and tab is only used for readability.
   Parameters
   ----------
   path_out: string
   image_list: list
   """
   with open(path_out, 'w') as fout:
       for i, item in enumerate(image_list):
           line = '%d\t' % item[0]
           for j in item[2:]:
               line += '%f\t' % j
           line += '%s\n' % item[1]
           fout.write(line)

def make_list(args):
   """Generates .lst file.
   Parameters
   ----------
   args: object that contains all the arguments
   """
   image_list = list_image(args.root, args.recursive, args.exts)
   image_list = list(image_list)
   if args.shuffle is True:
       random.seed(100)
       random.shuffle(image_list)
   N = len(image_list)
   chunk_size = (N + args.chunks - 1) // args.chunks
   for i in range(args.chunks):
       chunk = image_list[i * chunk_size:(i + 1) * chunk_size]
       if args.chunks > 1:
           str_chunk = '_%d' % i
       else:
           str_chunk = ''
       sep = int(chunk_size * args.train_ratio)
       sep_test = int(chunk_size * args.test_ratio)
       if args.train_ratio == 1.0:
           write_list(args.prefix + str_chunk + '.lst', chunk)
       else:
           if args.test_ratio:
               write_list(args.prefix + str_chunk + '_test.lst', chunk[:sep_test])
           if args.train_ratio + args.test_ratio < 1.0:
               write_list(args.prefix + str_chunk + '_val.lst', chunk[sep_test + sep:])
           write_list(args.prefix + str_chunk + '_train.lst', chunk[sep_test:sep_test + sep])

def read_list(path_in):
   """Reads the .lst file and generates corresponding iterator.
   Parameters
   ----------
   path_in: string
   Returns
   -------
   item iterator that contains information in .lst file
   """
   with open(path_in) as fin:
       while True:
           line = fin.readline()
           if not line:
               break
           line = [i.strip() for i in line.strip().split('\t')]
           line_len = len(line)
           # check the data format of .lst file
           if line_len < 3:
               print('lst should have at least has three parts, but only has %s parts for %s' % (line_len, line))
               continue
           try:
               item = [int(line[0])] + [line[-1]] + [float(i) for i in line[1:-1]]
           except Exception as e:
               print('Parsing lst met error for %s, detail: %s' % (line, e))
               continue
           yield item

def image_encode(args, i, item, q_out):
   """Reads, preprocesses, packs the image and put it back in output queue.
   Parameters
   ----------
   args: object
   i: int
   item: list
   q_out: queue
   """
   fullpath = os.path.join(args.root, item[1])

   if len(item) > 3 and args.pack_label:
       header = mx.recordio.IRHeader(0, item[2:], item[0], 0)
   else:
       header = mx.recordio.IRHeader(0, item[2], item[0], 0)

   if args.pass_through:
       try:
           with open(fullpath, 'rb') as fin:
               img = fin.read()
           s = mx.recordio.pack(header, img)
           q_out.put((i, s, item))
       except Exception as e:
           traceback.print_exc()
           print('pack_img error:', item[1], e)
           q_out.put((i, None, item))
       return

   try:
       img = cv2.imread(fullpath, args.color)
   except:
       traceback.print_exc()
       print('imread error trying to load file: %s ' % fullpath)
       q_out.put((i, None, item))
       return
   if img is None:
       print('imread read blank (None) image for file: %s' % fullpath)
       q_out.put((i, None, item))
       return
   if args.center_crop:
       if img.shape[0] > img.shape[1]:
           margin = (img.shape[0] - img.shape[1]) // 2
           img = img[margin:margin + img.shape[1], :]
       else:
           margin = (img.shape[1] - img.shape[0]) // 2
           img = img[:, margin:margin + img.shape[0]]
   if args.resize:
       if img.shape[0] > img.shape[1]:
           newsize = (args.resize, img.shape[0] * args.resize // img.shape[1])
       else:
           newsize = (img.shape[1] * args.resize // img.shape[0], args.resize)
       img = cv2.resize(img, newsize)

   try:
       s = mx.recordio.pack_img(header, img, quality=args.quality, img_fmt=args.encoding)
       q_out.put((i, s, item))
   except Exception as e:
       traceback.print_exc()
       print('pack_img error on file: %s' % fullpath, e)
       q_out.put((i, None, item))
       return

def read_worker(args, q_in, q_out):
   """Function that will be spawned to fetch the image
   from the input queue and put it back to output queue.
   Parameters
   ----------
   args: object
   q_in: queue
   q_out: queue
   """
   while True:
       deq = q_in.get()
       if deq is None:
           break
       i, item = deq
       image_encode(args, i, item, q_out)

def write_worker(q_out, fname, working_dir):
   """Function that will be spawned to fetch processed image
   from the output queue and write to the .rec file.
   Parameters
   ----------
   q_out: queue
   fname: string
   working_dir: string
   """
   pre_time = time.time()
   count = 0
   fname = os.path.basename(fname)
   fname_rec = os.path.splitext(fname)[0] + '.rec'
   fname_idx = os.path.splitext(fname)[0] + '.idx'
   record = mx.recordio.MXIndexedRecordIO(os.path.join(working_dir, fname_idx),
                                          os.path.join(working_dir, fname_rec), 'w')
   buf = {}
   more = True
   while more:
       deq = q_out.get()
       if deq is not None:
           i, s, item = deq
           buf[i] = (s, item)
       else:
           more = False
       while count in buf:
           s, item = buf[count]
           del buf[count]
           if s is not None:
               record.write_idx(item[0], s)

           if count % 1000 == 0:
               cur_time = time.time()
               print('time:', cur_time - pre_time, ' count:', count)
               pre_time = cur_time
           count += 1

def parse_args():
   """Defines all arguments.
   Returns
   -------
   args object that contains all the params
   """
   parser = argparse.ArgumentParser(
       formatter_class=argparse.ArgumentDefaultsHelpFormatter,
       description='Create an image list or \
       make a record database by reading from an image list')
   parser.add_argument('prefix', help='prefix of input/output lst and rec files.')
   parser.add_argument('root', help='path to folder containing images.')

   cgroup = parser.add_argument_group('Options for creating image lists')
   cgroup.add_argument('--list', action='store_true',
                       help='If this is set im2rec will create image list(s) by traversing root folder\
       and output to <prefix>.lst.\
       Otherwise im2rec will read <prefix>.lst and create a database at <prefix>.rec')
   cgroup.add_argument('--exts', nargs='+', default=['.jpeg', '.jpg', '.png'],
                       help='list of acceptable image extensions.')
   cgroup.add_argument('--chunks', type=int, default=1, help='number of chunks.')
   cgroup.add_argument('--train-ratio', type=float, default=1.0,
                       help='Ratio of images to use for training.')
   cgroup.add_argument('--test-ratio', type=float, default=0,
                       help='Ratio of images to use for testing.')
   cgroup.add_argument('--recursive', action='store_true',
                       help='If true recursively walk through subdirs and assign an unique label\
       to images in each folder. Otherwise only include images in the root folder\
       and give them label 0.')
   cgroup.add_argument('--no-shuffle', dest='shuffle', action='store_false',
                       help='If this is passed, \
       im2rec will not randomize the image order in <prefix>.lst')
   rgroup = parser.add_argument_group('Options for creating database')
   rgroup.add_argument('--pass-through', action='store_true',
                       help='whether to skip transformation and save image as is')
   rgroup.add_argument('--resize', type=int, default=0,
                       help='resize the shorter edge of image to the newsize, original images will\
       be packed by default.')
   rgroup.add_argument('--center-crop', action='store_true',
                       help='specify whether to crop the center image to make it rectangular.')
   rgroup.add_argument('--quality', type=int, default=95,
                       help='JPEG quality for encoding, 1-100; or PNG compression for encoding, 1-9')
   rgroup.add_argument('--num-thread', type=int, default=1,
                       help='number of thread to use for encoding. order of images will be different\
       from the input list if >1. the input list will be modified to match the\
       resulting order.')
   rgroup.add_argument('--color', type=int, default=1, choices=[-1, 0, 1],
                       help='specify the color mode of the loaded image.\
       1: Loads a color image. Any transparency of image will be neglected. It is the default flag.\
       0: Loads image in grayscale mode.\
       -1:Loads image as such including alpha channel.')
   rgroup.add_argument('--encoding', type=str, default='.jpg', choices=['.jpg', '.png'],
                       help='specify the encoding of the images.')
   rgroup.add_argument('--pack-label', action='store_true',
       help='Whether to also pack multi dimensional label in the record file')
   args = parser.parse_args()
   args.prefix = os.path.abspath(args.prefix)
   args.root = os.path.abspath(args.root)
   return args

if __name__ == '__main__':
   args = parse_args()
   # if the '--list' is used, it generates .lst file
   if args.list:
       make_list(args)
   # otherwise read .lst file to generates .rec file
   else:
       if os.path.isdir(args.prefix):
           working_dir = args.prefix
       else:
           working_dir = os.path.dirname(args.prefix)
       files = [os.path.join(working_dir, fname) for fname in os.listdir(working_dir)
                   if os.path.isfile(os.path.join(working_dir, fname))]
       count = 0
       for fname in files:
           if fname.startswith(args.prefix) and fname.endswith('.lst'):
               print('Creating .rec file from', fname, 'in', working_dir)
               count += 1
               image_list = read_list(fname)
               # -- write_record -- #
               if args.num_thread > 1 and multiprocessing is not None:
                   q_in = [multiprocessing.Queue(1024) for i in range(args.num_thread)]
                   q_out = multiprocessing.Queue(1024)
                   # define the process
                   read_process = [multiprocessing.Process(target=read_worker, args=(args, q_in[i], q_out)) \
                                   for i in range(args.num_thread)]
                   # process images with num_thread process
                   for p in read_process:
                       p.start()
                   # only use one process to write .rec to avoid race-condtion
                   write_process = multiprocessing.Process(target=write_worker, args=(q_out, fname, working_dir))
                   write_process.start()
                   # put the image list into input queue
                   for i, item in enumerate(image_list):
                       q_in[i % len(q_in)].put((i, item))
                   for q in q_in:
                       q.put(None)
                   for p in read_process:
                       p.join()

                   q_out.put(None)
                   write_process.join()
               else:
                   print('multiprocessing not available, fall back to single threaded encoding')
                   try:
                       import Queue as queue
                   except ImportError:
                       import queue
                   q_out = queue.Queue()
                   fname = os.path.basename(fname)
                   fname_rec = os.path.splitext(fname)[0] + '.rec'
                   fname_idx = os.path.splitext(fname)[0] + '.idx'
                   record = mx.recordio.MXIndexedRecordIO(os.path.join(working_dir, fname_idx),
                                                          os.path.join(working_dir, fname_rec), 'w')
                   cnt = 0
                   pre_time = time.time()
                   for i, item in enumerate(image_list):
                       image_encode(args, i, item, q_out)
                       if q_out.empty():
                           continue
                       _, s, _ = q_out.get()
                       record.write_idx(item[0], s)
                       if cnt % 1000 == 0:
                           cur_time = time.time()
                           print('time:', cur_time - pre_time, ' count:', cnt)
                           pre_time = cur_time
                       cnt += 1
       if not count:
           print('Did not find and list file with prefix %s'%args.prefix)

2.训练参数及配置

参见本人github,下载即可运行。
SSD

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值