上一篇:深度学习之用CelebA_Spoof数据集搭建一个活体检测-数据处理,我们使用了CelebA_Spoof数据集进行了处理,目的是为了一个2D的活体检测。本文将详细介绍如何使用CelebA_Spoof数据集训练一个2D活体检测模型,包括模型架构设计、分布式训练实现和关键代码解析。
1. 项目概述
对于活体检测,个人认为一直是个伪命题,在人脸识别中,其实真正意义是不管这个人脸是真还是假,随着人脸识别技术的应用越来越广泛,人脸识别技术已经不再是一种技术,而是技术意外的一个比较重要的工具了,那么安全性就必须得到提高,活体检测自热而然就有了它的用武之地(本人是很不情愿做这个活体检测的)。如何设计一个活体检测的模型呢,当然现在这个领域已经有了很多的研究和很好的方法,在现有的实际应用中已经完全可以安全使用。但是作为入门者,还是得从简单的步骤做起来。
在这里,我暂且把2D的活体检测当作一个二分类的问题去解决(实际中,活体检测很复杂),考虑到要移植到嵌入式设备上,所以用一个小模型去作为二分类的主体,当然有人说为了模型的小,随便搭几个卷积层就行了,但是,凭什么你自己搭建的能有别人团队开源的好呢?所以拿来主义打败了自己。本项目采用知识蒸馏技术,使用ResNet18作为教师模型,SqueezeNet1.0作为学生模型,在CelebA_Spoof数据集上进行训练。
2. 环境配置
2.1 硬件要求
- 多GPU服务器(建议2-8块GPU),当然一块也行,毕竟二分类,而且数据量不是很大
- CUDA 11.0及以上
2.2 软件依赖
pip install torch torchvision#这是主要的工具
2.3 数据预处理
可以参照之前的文章,深度学习之用CelebA_Spoof数据集搭建一个活体检测-数据处理,可以用数据集里面的json文件,也可以自己重新构建数据集。在这里我是重新构建了:
import os
from PIL import Image
from pathlib import Path
from PIL import ImageFile
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading
ImageFile.LOAD_TRUNCATED_IMAGES = True
print_lock = threading.Lock()
def load_bbox(bbox_file_path):
"""加载边界框文件"""
if not os.path.exists(bbox_file_path):
raise FileNotFoundError(f"Bounding box file not found: {
bbox_file_path}")
with open(bbox_file_path, 'r') as file:
return list(map(float, file.read().strip().split()))
def process_image(args):
"""处理单张图片(线程安全)"""
img_path, output_path, thread_id, remaining = args
try:
with print_lock:
print(f"[Thread-{
thread_id}] Processing {
Path(img_path).name} ({
remaining} remaining)")
bbox_file = os.path.join(os.path.dirname(img_path), f"{
Path(img_path).stem}_BB.txt")
bbox = load_bbox(bbox_file)
img = Image.open(img_path)
real_w, real_h = img.size
# 计算实际边界框
x1 = int(bbox[0] * (real_w / 224))
y1 = int(bbox[1] * (real_h / 224))
w1 = int(bbox[2] * (real_w / 224))
h1 = int(bbox[3] * (real_h / 224))
# 裁剪并保存
cropped_img = img.crop((x1, y1, x1 + w1, y1 + h1))
cropped_img.save(output_path)
#with print_lock:
# print(f"[Thread-{thread_id}] Finished {Path(img_path).name}")
return True
except Exception as e:
with print_lock:
print(f"[Thread-{
thread_id}] Error processing {
Path(img_path).name}: {
str(e)}")
return False
def process_directory(input_dir, output_dir, max_workers=8):
"""使用线程池处理目录中的所有图片"""
os.makedirs(output_dir, exist_ok=True)
tasks = []
# 获取所有图片文件
image_files = [f for f in os.listdir(input_dir) if f.endswith((".png", ".jpg", ".jpeg"))]
total = len(image_files)
for idx, img_file in enumerate(image_files):
img_path = os.path.join(input_dir, img_file)
output_file = Path(img_file).with_suffix(".png")
output_path = os.path.join(output_dir, output_file)
remaining = total - idx - 1
tasks.append((img_path, output_path, idx % max_workers, remaining))
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(process_image, task) for task in tasks]
for future in as_completed(futures):
future.result()
def process_data(root_dir, output_base_dir):
"""处理整个数据集目录"""
for split in ["test", "train"]:
split_dir = os.path.join(root_dir, split)
if not os.path.exists(split_dir):
print(f"Warning: {
split_dir} does not exist")
continue
# 创建对应的输出目录
output_split_dir = os.path.join(output_base_dir, split)
os.makedirs(os.path.join(output_split_dir, "live"), exist_ok=True)
os.makedirs(os.path.join(output_split_dir, "spoof"), exist_ok=True)
for sub_dir in os.listdir(split_dir):
sub_path = os.path.join(split_dir, sub_dir)
if not os.path.isdir(sub_path):
print(f"Skipping non-directory: {
sub_path}")
continue
for label in ["live", "spoof"]:
label_path = os.path.join(sub_path, label)
if not os.path.exists(label_path)

最低0.47元/天 解锁文章

被折叠的 条评论
为什么被折叠?



