yolo3训练人脸检测模型

本文介绍了使用YOLOv3在Oxford Hand数据集上进行人脸检测,并在此基础上应用剪枝算法,实现了模型参数量和FLOPs的大幅减少,同时保持了mAP的稳定。详细讲述了环境配置、数据集准备以及基于Network Slimming的channel pruning算法。
部署运行你感兴趣的模型镜像

 

YOLOv3-model-pruning

用 YOLOv3 模型在一个开源的人手检测数据集 oxford hand 上做人手检测,并在此基础上做模型剪枝。对于该数据集,对 YOLOv3 进行 channel pruning 之后,模型的参数量、模型大小减少 80% ,FLOPs 降低 70%,前向推断的速度可以达到原来的 200%,同时可以保持 mAP 基本不变(这个效果只是针对该数据集的,不一定能保证在其他数据集上也有同样的效果)。

环境

Python3.6, Pytorch 1.0及以上

YOLOv3 的实现参考了 eriklindernoren 的 PyTorch-YOLOv3 ,因此代码的依赖环境也可以参考其 repo

数据集准备

  1. 下载widerface数据集,得到压缩文件(提取码: ymx2)
  2. 将压缩文件解压到 Dataset
  3. 执行 widerface_label.py,生成 images、labels 文件夹和 train.txt、valid.txt 文件

 

剪枝算法介绍

本代码基于论文 Learning Efficient Convolutional Networks Through Network Slimming (ICCV 2017) 进行改进实现的 channel pruning算法,类似的代码实现还有这个 yolov3-network-slimming。原始论文中的算法是针对分类模型的,基于 BN 层的 gamma 系数进行剪枝的。

**注意**

1.训练自己的数据集时,widerface.data和widerfaces.names需要最后留一空行(换行)
而train.txt valid.txt最后一行必须是非空行(换行),否则出现IndexError: list index out of range
yolov3-face.cfg可以由 creat_custom_model.sh生成

2.正常训练(Baseline)
python3 train.py --model_def config/yolov3-face.cfg  -lr 0.004 --data_config config/widerface.data
3.稀疏化训练
python3 train.py --model_def config/yolov3-face.cfg -sr --s 0.01 --data_config config/widerface.data

 

 

#1. 正常训练(Baseline)
python3 train.py --model_def config/yolov3-hand.cfg
# 2.以下只是剪枝算法的大概步骤,具体实现过程中还要做 s 参数的尝试或者需要进行迭代式剪枝等。

# 2.1 进行稀疏化训练

python3 train.py --model_def config/yolov3-hand.cfg -sr --s 0.01

# 2.2 基于 test_prune.py 文件进行剪枝,得到剪枝后的模型
python3 test_prune.py

# 2.3 对剪枝后的模型进行微调

python3 train.py --model_def config/prune_yolov3-hand.cfg -pre checkpoints/prune_yolov3_ckpt.pth


# 3.测试
#python3 test.py --model_def config/prune_yolov3-hand.cfg --weights_path weights/prune_yolov3_ckpt.pth --data_config config/oxfordhand.data --class_path data/oxfordhand.names --conf_thres 0.01

python3 test.py --model_def config/prune_0.85_yolov3-hand.cfg --weights_path checkpoints/yolov3_ckpt_99_08211153.pth --data_config config/oxfordhand.data --class_path data/oxfordhand.names --conf_thres 0.01

#==================**************************================================
#==================**************************================================
# 基于wider face数据集进行yolov3剪枝训练步骤
1.执行 widerface_label.py,生成 images、labels 文件夹和 train.txt、valid.txt 文件
**注意**
训练自己的数据集时,widerface.data和widerfaces.names需要最后留一空行(换行)
而train.txt valid.txt最后一行必须是非空行(换行),否则出现IndexError: list index out of range
yolov3-face.cfg可以由 creat_custom_model.sh生成

2.正常训练(Baseline)
python3 train.py --model_def config/yolov3-face.cfg  -lr 0.004 --data_config config/widerface.data
3.稀疏化训练
python3 train.py --model_def config/yolov3-face.cfg -sr --s 0.01 --data_config config/widerface.data

step 45,mAP 0.4869 step 95 0.4954

测试:
python3 test.py --model_def config/yolov3-face.cfg --weights_path checkpoints/yolov3_ckpt_45_08241046.pth --data_config config/widerface.data --class_path data/wider/widerfaces.names --conf_thres 0.01
4. 基于 test_prune.py 文件进行剪枝,得到剪枝后的模型
python3 test_prune.py

5. 对剪枝后的模型进行微调

python3 train.py --model_def config/prune_0.85_yolov3-face.cfg --data_config config/widerface.data -pre checkpoints/prune_0.85_yolov3_ckpt_95_08241046.pth

step 35,mAP 0.5417 step 80 0.5660
6.测试
python3 test.py --model_def config/prune_0.85_yolov3-face.cfg --weights_path checkpoints/yolov3_ckpt_80_08261039.pth --data_config config/widerface.data --class_path data/wider/widerfaces.names --conf_thres 0.01

7.在线检测
python3 detect.py --image_folder data/samples/ --weights_path checkpoints/yolov3_ckpt_80_08261039.pth --model_def config/prune_0.85_yolov3-face.cfg --class_path data/wider/widerfaces.names --conf_thres 0.04 --nms_thres 0.4

python3 detect.py --image_folder data/samples/test --weights_path checkpoints/yolov3_ckpt_80_08261039.pth --model_def config/prune_0.85_yolov3-face.cfg --class_path data/wider/widerfaces.names --conf_thres 0.6 --nms_thres 0.2

wider_annotation.py

#!/usr/bin/env python
# -- coding: utf-8 --
"""
Copyright (c) 2018. All rights reserved.
Created by C. L. Wang on 2018/6/14
outputs like this :
eg:wider/WIDER_val/images/0--Parade/0_Parade_marchingband_1_353.jpg 263,381,376,550,0 635,271,769,440,0
此格式可以使用kmeans。py聚类生成anchors

"""
import os

val_bbx_file = '/media/gavin/home/gavin/DataSet/wider/wider_face_split/wider_face_val_bbx_gt.txt'
train_bbx_file = '/media/gavin/home/gavin/DataSet/wider/wider_face_split/wider_face_train_bbx_gt.txt'

val_data_folder = '/media/gavin/home/gavin/DataSet/wider/WIDER_val'
train_data_folder = '/media/gavin/home/gavin/DataSet/wider/WIDER_train'

out_file = 'data/wider/WIDER_train.txt'


def generate_train_file(bbx_file, data_folder, out_file):
    paths_list, names_list = traverse_dir_files(data_folder)
    name_dict = dict()
    for path, name in zip(paths_list, names_list):
        name_dict[name] = path

    data_lines = read_file(bbx_file)

    sub_count = 0
    item_count = 0
    out_list = []

    for data_line in data_lines:
        item_count += 1
        if item_count % 1000 == 0:
            print('item_count: ' + str(item_count))

        data_line = data_line.strip()
        l_names = data_line.split('/')
        if len(l_names) == 2:
            if out_list:
                out_line = ' '.join(out_list)
                write_line(out_file, out_line)
                out_list = []

            name = l_names[-1]
            img_path = name_dict[name]
            sub_count = 1
            out_list.append(img_path)
            continue

        if sub_count == 1:
            sub_count += 1
            continue

        if sub_count >= 2:
            n_list = data_line.split(' ')
            x_min = n_list[0]
            y_min = n_list[1]
            x_max = str(int(n_list[0]) + int(n_list[2]))
            y_max = str(int(n_list[1]) + int(n_list[3]))
            p_list = ','.join([x_min, y_min, x_max, y_max, '0'])  # 标签全部是0,人脸
            out_list.append(p_list)
            continue


def traverse_dir_files(root_dir, ext=None):
    """
    列出文件夹中的文件, 深度遍历目录文件
    :param root_dir: 根目录
    :param ext: 后缀名
    :return: [文件路径列表, 文件名称列表]
    """
    names_list = []
    paths_list = []
    for parent, _, fileNames in os.walk(root_dir):
        for name in fileNames:
            if name.startswith('.'):  # 去除隐藏文件
                continue
            if ext:  # 根据后缀名搜索
                if name.endswith(tuple(ext)):
                    names_list.append(name)
                    paths_list.append(os.path.join(parent, name))
            else:
                names_list.append(name)
                paths_list.append(os.path.join(parent, name))
    paths_list, names_list = sort_two_list(paths_list, names_list)
    return paths_list, names_list


def sort_two_list(list1, list2):
    """
    排序两个列表
    :param list1: 列表1
    :param list2: 列表2
    :return: 排序后的两个列表
    """
    list1, list2 = (list(t) for t in zip(*sorted(zip(list1, list2))))
    return list1, list2


def read_file(data_file, mode='more'):
    """
    读文件, 原文件和数据文件
    :return: 单行或数组
    """
    try:
        with open(data_file, 'r') as f:
            if mode == 'one':
                output = f.read()
                return output
            elif mode == 'more':
                output = f.readlines()
                # return map(str.strip, output)
                return output
            else:
                return list()
    except IOError:
        return list()


def write_line(file_name, line):
    """
    将行数据写入文件
    :param file_name: 文件名
    :param line: 行数据
    :return: None
    """
    if file_name == "":
        return
    with open(file_name, "a+") as fs:
        if type(line) is (tuple or list):
            fs.write("%s\n" % ", ".join(line))
        else:
            fs.write("%s\n" % line)


if __name__ == '__main__':
    generate_train_file(val_bbx_file, val_data_folder, out_file) # 46000+
    generate_train_file(train_bbx_file, train_data_folder, out_file) #  185000+

 

widerface_label.py

#!/usr/bin/env python
# -- coding: utf-8 --
"""
Copyright (c) 2019. All rights reserved.
Created by Gavin on 2019/8/22
ps: Each row in the annotation file should define one bounding box, using the syntax:
 label_idx x_center y_center width height

The coordinates should be scaled [0, 1],
and the label_idx should be zero-indexed and correspond to the row number of the class name in data/custom/classes.names.

eg: 0 0.593042 0.674682 0.067564 0.043648

Now we prepare face detect datasets for yolov3-pruning
we need data/custom/classes.names,eg:data/wider/widerfaces.names
image folder : data/custom/images/ and Annotation Folder :data/custom/labels/
The dataloader expects that the annotation file corresponding to the image data/custom/images/train.jpg has the path
 data/custom/labels/train.txt
"""

from PIL import Image
import os
import datetime
import shutil


created_images_dir = '/home/gavin/Dataset/wider_yolo3/images'
created_labels_dir = '/home/gavin/Dataset/wider_yolo3/labels'


val_bbx_file = '/media/gavin/home/gavin/DataSet/wider/wider_face_split/wider_face_val_bbx_gt.txt'
train_bbx_file = '/media/gavin/home/gavin/DataSet/wider/wider_face_split/wider_face_train_bbx_gt.txt'

val_data_folder = '/media/gavin/home/gavin/DataSet/wider/WIDER_val'
train_data_folder = '/media/gavin/home/gavin/DataSet/wider/WIDER_train'
test_data_folder = '/media/gavin/home/gavin/DataSet/wider/WIDER_test'

out_file_train = 'data/wider/train.txt'
out_file_valid = 'data/wider/valid.txt'

# 最小取20大小的脸,并且补齐
minsize2select = 10

def hms_string(sec_elapsed):    # 格式化显示已消耗时间
    h = int(sec_elapsed / (60 * 60))
    m = int((sec_elapsed % (60 * 60)) / 60)
    s = sec_elapsed % 60.
    return "{}:{:>02}:{:>05.2f}".format(h, m, s)

def generate_train_file(set_name,bbx_file, data_folder, out_file):
    # prepare new folder dataset
    new_images_dir = os.path.join(created_images_dir, set_name)  # 将图片从原来的文件夹复制到该文件夹下
    new_annotation_dir = os.path.join(created_labels_dir, set_name)

    if not os.path.exists(new_images_dir):
        os.makedirs(new_images_dir)

    if not os.path.exists(new_annotation_dir):
        os.makedirs(new_annotation_dir)

    paths_list, names_list = traverse_dir_files(data_folder)
    name_dict = dict()
    for path, name in zip(paths_list, names_list):
        name_dict[name] = path # 这里改变path为新path

    data_lines = read_file(bbx_file)

    sub_count = 0
    item_count = 0
    out_list = []

    # add
    width = 0
    height = 0
    filename = ''
    bboxes = []
    numbbox = 0
    img_path = ''
    img_path_new = ''



    for data_idx,data_line in enumerate(data_lines):
        item_count += 1
        if item_count % 1000 == 0:
            print('item_count: ' + str(item_count))

        data_line = data_line.strip()
        l_names = data_line.split('/')
        if len(l_names) == 2:
            name = l_names[-1]
            filename = name.split(".")[0] # add

            img_path = name_dict[name]
            pil_image = Image.open(img_path)  # add
            width, height = pil_image.size #add
            sub_count = 1

            img_path_new = os.path.join(new_images_dir, name)
            name_dict[name] = img_path_new  # 这里改变path为新path

            bboxes = []

            continue

        if sub_count == 1:
            numbbox = int(data_line.split(' ')[0])
            sub_count += 1
            continue

        if sub_count >= 2:
            sub_count += 1
            n_list = data_line.split(' ')
            x_min = int(n_list[0])
            y_min = int(n_list[1])
            x_max = int(n_list[0]) + int(n_list[2])
            y_max = int(n_list[1]) + int(n_list[3])

            w = int(n_list[2])
            h = int(n_list[3])
            bbox = (x_min, y_min, w, h)


            if int(x_max) - int(x_min) == 0 or int(y_max) - int(y_min) == 0:
                continue
            if (h <= minsize2select or w <= minsize2select):
                continue
            bboxes.append(bbox)

            # clip,防止超出边界
            maxX = min(x_max, width - 1)
            minX = max(x_min, 0)
            maxY = min(y_max, height - 1)
            minY = max(y_min, 0)

            # (<absolute_x> / <image_width>)
            norm_width = (maxX - minX) / width

            # (<absolute_y> / <image_height>)
            norm_height = (maxY - minY) / height

            center_x, center_y = (maxX + minX) / 2, (maxY + minY) / 2

            norm_center_x = center_x / width
            norm_center_y = center_y / height

            with open(os.path.join(new_annotation_dir, filename + ".txt"), "a+") as hs:
                hs.write("0 %f %f %f %f\n" % (norm_center_x, norm_center_y, norm_width, norm_height))  # 0表示类别


            if sub_count == 2 + numbbox: #最后一行再判断
                if len(bboxes) == 0:
                    print("warrning: no face")
                    continue
                shutil.copy(img_path, new_images_dir)
                write_line(out_file, img_path_new)
                continue


def traverse_dir_files(root_dir, ext=None):
    """
    列出文件夹中的文件, 深度遍历目录文件
    :param root_dir: 根目录
    :param ext: 后缀名
    :return: [文件路径列表, 文件名称列表]
    """
    names_list = []
    paths_list = []
    for parent, _, fileNames in os.walk(root_dir):
        for name in fileNames:
            if name.startswith('.'):  # 去除隐藏文件
                continue
            if ext:  # 根据后缀名搜索
                if name.endswith(tuple(ext)):
                    names_list.append(name)
                    paths_list.append(os.path.join(parent, name))
            else:
                names_list.append(name)
                paths_list.append(os.path.join(parent, name))
    paths_list, names_list = sort_two_list(paths_list, names_list)
    return paths_list, names_list


def sort_two_list(list1, list2):
    """
    排序两个列表
    :param list1: 列表1
    :param list2: 列表2
    :return: 排序后的两个列表
    """
    list1, list2 = (list(t) for t in zip(*sorted(zip(list1, list2))))
    return list1, list2


def read_file(data_file, mode='more'):
    """
    读文件, 原文件和数据文件
    :return: 单行或数组
    """
    try:
        with open(data_file, 'r') as f:
            if mode == 'one':
                output = f.read()
                return output
            elif mode == 'more':
                output = f.readlines()
                # return map(str.strip, output)
                return output
            else:
                return list()
    except IOError:
        return list()


def write_line(file_name, line):
    """
    将行数据写入文件
    :param file_name: 文件名
    :param line: 行数据
    :return: None
    """
    if file_name == "":
        return
    with open(file_name, "a+") as fs:
        if type(line) is (tuple or list):
            fs.write("%s\n" % ", ".join(line))
        else:
            fs.write("%s\n" % line)


if __name__ == '__main__':
    start_time = datetime.datetime.now()
    generate_train_file("validation",val_bbx_file, val_data_folder, out_file_valid) # 46000+ 第一个参数表示生成的文件夹的名称
    generate_train_file("train",train_bbx_file, train_data_folder, out_file_train) #  185000+

    end_time = datetime.datetime.now()
    seconds_elapsed = (end_time - start_time).total_seconds()
    print("It took {} to execute this".format(hms_string(seconds_elapsed)))




 

 

您可能感兴趣的与本文相关的镜像

Yolo-v5

Yolo-v5

Yolo

YOLO(You Only Look Once)是一种流行的物体检测和图像分割模型,由华盛顿大学的Joseph Redmon 和Ali Farhadi 开发。 YOLO 于2015 年推出,因其高速和高精度而广受欢迎

评论 1
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值