YOLOv3-model-pruning
用 YOLOv3 模型在一个开源的人手检测数据集 oxford hand 上做人手检测,并在此基础上做模型剪枝。对于该数据集,对 YOLOv3 进行 channel pruning 之后,模型的参数量、模型大小减少 80% ,FLOPs 降低 70%,前向推断的速度可以达到原来的 200%,同时可以保持 mAP 基本不变(这个效果只是针对该数据集的,不一定能保证在其他数据集上也有同样的效果)。
环境
Python3.6, Pytorch 1.0及以上
YOLOv3 的实现参考了 eriklindernoren 的 PyTorch-YOLOv3 ,因此代码的依赖环境也可以参考其 repo
数据集准备
- 下载widerface数据集,得到压缩文件(提取码: ymx2)
- 将压缩文件解压到 Dataset
-
执行 widerface_label.py,生成 images、labels 文件夹和 train.txt、valid.txt 文件
剪枝算法介绍
本代码基于论文 Learning Efficient Convolutional Networks Through Network Slimming (ICCV 2017) 进行改进实现的 channel pruning算法,类似的代码实现还有这个 yolov3-network-slimming。原始论文中的算法是针对分类模型的,基于 BN 层的 gamma 系数进行剪枝的。
**注意**
1.训练自己的数据集时,widerface.data和widerfaces.names需要最后留一空行(换行)
而train.txt valid.txt最后一行必须是非空行(换行),否则出现IndexError: list index out of range
yolov3-face.cfg可以由 creat_custom_model.sh生成
2.正常训练(Baseline)
python3 train.py --model_def config/yolov3-face.cfg -lr 0.004 --data_config config/widerface.data
3.稀疏化训练
python3 train.py --model_def config/yolov3-face.cfg -sr --s 0.01 --data_config config/widerface.data
#1. 正常训练(Baseline)
python3 train.py --model_def config/yolov3-hand.cfg
# 2.以下只是剪枝算法的大概步骤,具体实现过程中还要做 s 参数的尝试或者需要进行迭代式剪枝等。
# 2.1 进行稀疏化训练
python3 train.py --model_def config/yolov3-hand.cfg -sr --s 0.01
# 2.2 基于 test_prune.py 文件进行剪枝,得到剪枝后的模型
python3 test_prune.py
# 2.3 对剪枝后的模型进行微调
python3 train.py --model_def config/prune_yolov3-hand.cfg -pre checkpoints/prune_yolov3_ckpt.pth
# 3.测试
#python3 test.py --model_def config/prune_yolov3-hand.cfg --weights_path weights/prune_yolov3_ckpt.pth --data_config config/oxfordhand.data --class_path data/oxfordhand.names --conf_thres 0.01
python3 test.py --model_def config/prune_0.85_yolov3-hand.cfg --weights_path checkpoints/yolov3_ckpt_99_08211153.pth --data_config config/oxfordhand.data --class_path data/oxfordhand.names --conf_thres 0.01
#==================**************************================================
#==================**************************================================
# 基于wider face数据集进行yolov3剪枝训练步骤
1.执行 widerface_label.py,生成 images、labels 文件夹和 train.txt、valid.txt 文件
**注意**
训练自己的数据集时,widerface.data和widerfaces.names需要最后留一空行(换行)
而train.txt valid.txt最后一行必须是非空行(换行),否则出现IndexError: list index out of range
yolov3-face.cfg可以由 creat_custom_model.sh生成
2.正常训练(Baseline)
python3 train.py --model_def config/yolov3-face.cfg -lr 0.004 --data_config config/widerface.data
3.稀疏化训练
python3 train.py --model_def config/yolov3-face.cfg -sr --s 0.01 --data_config config/widerface.data
step 45,mAP 0.4869 step 95 0.4954
测试:
python3 test.py --model_def config/yolov3-face.cfg --weights_path checkpoints/yolov3_ckpt_45_08241046.pth --data_config config/widerface.data --class_path data/wider/widerfaces.names --conf_thres 0.01
4. 基于 test_prune.py 文件进行剪枝,得到剪枝后的模型
python3 test_prune.py
5. 对剪枝后的模型进行微调
python3 train.py --model_def config/prune_0.85_yolov3-face.cfg --data_config config/widerface.data -pre checkpoints/prune_0.85_yolov3_ckpt_95_08241046.pth
step 35,mAP 0.5417 step 80 0.5660
6.测试
python3 test.py --model_def config/prune_0.85_yolov3-face.cfg --weights_path checkpoints/yolov3_ckpt_80_08261039.pth --data_config config/widerface.data --class_path data/wider/widerfaces.names --conf_thres 0.01
7.在线检测
python3 detect.py --image_folder data/samples/ --weights_path checkpoints/yolov3_ckpt_80_08261039.pth --model_def config/prune_0.85_yolov3-face.cfg --class_path data/wider/widerfaces.names --conf_thres 0.04 --nms_thres 0.4
python3 detect.py --image_folder data/samples/test --weights_path checkpoints/yolov3_ckpt_80_08261039.pth --model_def config/prune_0.85_yolov3-face.cfg --class_path data/wider/widerfaces.names --conf_thres 0.6 --nms_thres 0.2
wider_annotation.py
#!/usr/bin/env python
# -- coding: utf-8 --
"""
Copyright (c) 2018. All rights reserved.
Created by C. L. Wang on 2018/6/14
outputs like this :
eg:wider/WIDER_val/images/0--Parade/0_Parade_marchingband_1_353.jpg 263,381,376,550,0 635,271,769,440,0
此格式可以使用kmeans。py聚类生成anchors
"""
import os
val_bbx_file = '/media/gavin/home/gavin/DataSet/wider/wider_face_split/wider_face_val_bbx_gt.txt'
train_bbx_file = '/media/gavin/home/gavin/DataSet/wider/wider_face_split/wider_face_train_bbx_gt.txt'
val_data_folder = '/media/gavin/home/gavin/DataSet/wider/WIDER_val'
train_data_folder = '/media/gavin/home/gavin/DataSet/wider/WIDER_train'
out_file = 'data/wider/WIDER_train.txt'
def generate_train_file(bbx_file, data_folder, out_file):
paths_list, names_list = traverse_dir_files(data_folder)
name_dict = dict()
for path, name in zip(paths_list, names_list):
name_dict[name] = path
data_lines = read_file(bbx_file)
sub_count = 0
item_count = 0
out_list = []
for data_line in data_lines:
item_count += 1
if item_count % 1000 == 0:
print('item_count: ' + str(item_count))
data_line = data_line.strip()
l_names = data_line.split('/')
if len(l_names) == 2:
if out_list:
out_line = ' '.join(out_list)
write_line(out_file, out_line)
out_list = []
name = l_names[-1]
img_path = name_dict[name]
sub_count = 1
out_list.append(img_path)
continue
if sub_count == 1:
sub_count += 1
continue
if sub_count >= 2:
n_list = data_line.split(' ')
x_min = n_list[0]
y_min = n_list[1]
x_max = str(int(n_list[0]) + int(n_list[2]))
y_max = str(int(n_list[1]) + int(n_list[3]))
p_list = ','.join([x_min, y_min, x_max, y_max, '0']) # 标签全部是0,人脸
out_list.append(p_list)
continue
def traverse_dir_files(root_dir, ext=None):
"""
列出文件夹中的文件, 深度遍历目录文件
:param root_dir: 根目录
:param ext: 后缀名
:return: [文件路径列表, 文件名称列表]
"""
names_list = []
paths_list = []
for parent, _, fileNames in os.walk(root_dir):
for name in fileNames:
if name.startswith('.'): # 去除隐藏文件
continue
if ext: # 根据后缀名搜索
if name.endswith(tuple(ext)):
names_list.append(name)
paths_list.append(os.path.join(parent, name))
else:
names_list.append(name)
paths_list.append(os.path.join(parent, name))
paths_list, names_list = sort_two_list(paths_list, names_list)
return paths_list, names_list
def sort_two_list(list1, list2):
"""
排序两个列表
:param list1: 列表1
:param list2: 列表2
:return: 排序后的两个列表
"""
list1, list2 = (list(t) for t in zip(*sorted(zip(list1, list2))))
return list1, list2
def read_file(data_file, mode='more'):
"""
读文件, 原文件和数据文件
:return: 单行或数组
"""
try:
with open(data_file, 'r') as f:
if mode == 'one':
output = f.read()
return output
elif mode == 'more':
output = f.readlines()
# return map(str.strip, output)
return output
else:
return list()
except IOError:
return list()
def write_line(file_name, line):
"""
将行数据写入文件
:param file_name: 文件名
:param line: 行数据
:return: None
"""
if file_name == "":
return
with open(file_name, "a+") as fs:
if type(line) is (tuple or list):
fs.write("%s\n" % ", ".join(line))
else:
fs.write("%s\n" % line)
if __name__ == '__main__':
generate_train_file(val_bbx_file, val_data_folder, out_file) # 46000+
generate_train_file(train_bbx_file, train_data_folder, out_file) # 185000+
widerface_label.py
#!/usr/bin/env python
# -- coding: utf-8 --
"""
Copyright (c) 2019. All rights reserved.
Created by Gavin on 2019/8/22
ps: Each row in the annotation file should define one bounding box, using the syntax:
label_idx x_center y_center width height
The coordinates should be scaled [0, 1],
and the label_idx should be zero-indexed and correspond to the row number of the class name in data/custom/classes.names.
eg: 0 0.593042 0.674682 0.067564 0.043648
Now we prepare face detect datasets for yolov3-pruning
we need data/custom/classes.names,eg:data/wider/widerfaces.names
image folder : data/custom/images/ and Annotation Folder :data/custom/labels/
The dataloader expects that the annotation file corresponding to the image data/custom/images/train.jpg has the path
data/custom/labels/train.txt
"""
from PIL import Image
import os
import datetime
import shutil
created_images_dir = '/home/gavin/Dataset/wider_yolo3/images'
created_labels_dir = '/home/gavin/Dataset/wider_yolo3/labels'
val_bbx_file = '/media/gavin/home/gavin/DataSet/wider/wider_face_split/wider_face_val_bbx_gt.txt'
train_bbx_file = '/media/gavin/home/gavin/DataSet/wider/wider_face_split/wider_face_train_bbx_gt.txt'
val_data_folder = '/media/gavin/home/gavin/DataSet/wider/WIDER_val'
train_data_folder = '/media/gavin/home/gavin/DataSet/wider/WIDER_train'
test_data_folder = '/media/gavin/home/gavin/DataSet/wider/WIDER_test'
out_file_train = 'data/wider/train.txt'
out_file_valid = 'data/wider/valid.txt'
# 最小取20大小的脸,并且补齐
minsize2select = 10
def hms_string(sec_elapsed): # 格式化显示已消耗时间
h = int(sec_elapsed / (60 * 60))
m = int((sec_elapsed % (60 * 60)) / 60)
s = sec_elapsed % 60.
return "{}:{:>02}:{:>05.2f}".format(h, m, s)
def generate_train_file(set_name,bbx_file, data_folder, out_file):
# prepare new folder dataset
new_images_dir = os.path.join(created_images_dir, set_name) # 将图片从原来的文件夹复制到该文件夹下
new_annotation_dir = os.path.join(created_labels_dir, set_name)
if not os.path.exists(new_images_dir):
os.makedirs(new_images_dir)
if not os.path.exists(new_annotation_dir):
os.makedirs(new_annotation_dir)
paths_list, names_list = traverse_dir_files(data_folder)
name_dict = dict()
for path, name in zip(paths_list, names_list):
name_dict[name] = path # 这里改变path为新path
data_lines = read_file(bbx_file)
sub_count = 0
item_count = 0
out_list = []
# add
width = 0
height = 0
filename = ''
bboxes = []
numbbox = 0
img_path = ''
img_path_new = ''
for data_idx,data_line in enumerate(data_lines):
item_count += 1
if item_count % 1000 == 0:
print('item_count: ' + str(item_count))
data_line = data_line.strip()
l_names = data_line.split('/')
if len(l_names) == 2:
name = l_names[-1]
filename = name.split(".")[0] # add
img_path = name_dict[name]
pil_image = Image.open(img_path) # add
width, height = pil_image.size #add
sub_count = 1
img_path_new = os.path.join(new_images_dir, name)
name_dict[name] = img_path_new # 这里改变path为新path
bboxes = []
continue
if sub_count == 1:
numbbox = int(data_line.split(' ')[0])
sub_count += 1
continue
if sub_count >= 2:
sub_count += 1
n_list = data_line.split(' ')
x_min = int(n_list[0])
y_min = int(n_list[1])
x_max = int(n_list[0]) + int(n_list[2])
y_max = int(n_list[1]) + int(n_list[3])
w = int(n_list[2])
h = int(n_list[3])
bbox = (x_min, y_min, w, h)
if int(x_max) - int(x_min) == 0 or int(y_max) - int(y_min) == 0:
continue
if (h <= minsize2select or w <= minsize2select):
continue
bboxes.append(bbox)
# clip,防止超出边界
maxX = min(x_max, width - 1)
minX = max(x_min, 0)
maxY = min(y_max, height - 1)
minY = max(y_min, 0)
# (<absolute_x> / <image_width>)
norm_width = (maxX - minX) / width
# (<absolute_y> / <image_height>)
norm_height = (maxY - minY) / height
center_x, center_y = (maxX + minX) / 2, (maxY + minY) / 2
norm_center_x = center_x / width
norm_center_y = center_y / height
with open(os.path.join(new_annotation_dir, filename + ".txt"), "a+") as hs:
hs.write("0 %f %f %f %f\n" % (norm_center_x, norm_center_y, norm_width, norm_height)) # 0表示类别
if sub_count == 2 + numbbox: #最后一行再判断
if len(bboxes) == 0:
print("warrning: no face")
continue
shutil.copy(img_path, new_images_dir)
write_line(out_file, img_path_new)
continue
def traverse_dir_files(root_dir, ext=None):
"""
列出文件夹中的文件, 深度遍历目录文件
:param root_dir: 根目录
:param ext: 后缀名
:return: [文件路径列表, 文件名称列表]
"""
names_list = []
paths_list = []
for parent, _, fileNames in os.walk(root_dir):
for name in fileNames:
if name.startswith('.'): # 去除隐藏文件
continue
if ext: # 根据后缀名搜索
if name.endswith(tuple(ext)):
names_list.append(name)
paths_list.append(os.path.join(parent, name))
else:
names_list.append(name)
paths_list.append(os.path.join(parent, name))
paths_list, names_list = sort_two_list(paths_list, names_list)
return paths_list, names_list
def sort_two_list(list1, list2):
"""
排序两个列表
:param list1: 列表1
:param list2: 列表2
:return: 排序后的两个列表
"""
list1, list2 = (list(t) for t in zip(*sorted(zip(list1, list2))))
return list1, list2
def read_file(data_file, mode='more'):
"""
读文件, 原文件和数据文件
:return: 单行或数组
"""
try:
with open(data_file, 'r') as f:
if mode == 'one':
output = f.read()
return output
elif mode == 'more':
output = f.readlines()
# return map(str.strip, output)
return output
else:
return list()
except IOError:
return list()
def write_line(file_name, line):
"""
将行数据写入文件
:param file_name: 文件名
:param line: 行数据
:return: None
"""
if file_name == "":
return
with open(file_name, "a+") as fs:
if type(line) is (tuple or list):
fs.write("%s\n" % ", ".join(line))
else:
fs.write("%s\n" % line)
if __name__ == '__main__':
start_time = datetime.datetime.now()
generate_train_file("validation",val_bbx_file, val_data_folder, out_file_valid) # 46000+ 第一个参数表示生成的文件夹的名称
generate_train_file("train",train_bbx_file, train_data_folder, out_file_train) # 185000+
end_time = datetime.datetime.now()
seconds_elapsed = (end_time - start_time).total_seconds()
print("It took {} to execute this".format(hms_string(seconds_elapsed)))
本文介绍了使用YOLOv3在Oxford Hand数据集上进行人脸检测,并在此基础上应用剪枝算法,实现了模型参数量和FLOPs的大幅减少,同时保持了mAP的稳定。详细讲述了环境配置、数据集准备以及基于Network Slimming的channel pruning算法。
3318





