深度学习之用CelebA_Spoof数据集搭建一个活体检测-模型搭建和训练

上一篇:深度学习之用CelebA_Spoof数据集搭建一个活体检测-数据处理,我们使用了CelebA_Spoof数据集进行了处理,目的是为了一个2D的活体检测。本文将详细介绍如何使用CelebA_Spoof数据集训练一个2D活体检测模型,包括模型架构设计、分布式训练实现和关键代码解析。

1. 项目概述

对于活体检测,个人认为一直是个伪命题,在人脸识别中,其实真正意义是不管这个人脸是真还是假,随着人脸识别技术的应用越来越广泛,人脸识别技术已经不再是一种技术,而是技术意外的一个比较重要的工具了,那么安全性就必须得到提高,活体检测自热而然就有了它的用武之地(本人是很不情愿做这个活体检测的)。如何设计一个活体检测的模型呢,当然现在这个领域已经有了很多的研究和很好的方法,在现有的实际应用中已经完全可以安全使用。但是作为入门者,还是得从简单的步骤做起来。
在这里,我暂且把2D的活体检测当作一个二分类的问题去解决(实际中,活体检测很复杂),考虑到要移植到嵌入式设备上,所以用一个小模型去作为二分类的主体,当然有人说为了模型的小,随便搭几个卷积层就行了,但是,凭什么你自己搭建的能有别人团队开源的好呢?所以拿来主义打败了自己。本项目采用知识蒸馏技术,使用ResNet18作为教师模型,SqueezeNet1.0作为学生模型,在CelebA_Spoof数据集上进行训练。

2. 环境配置

2.1 硬件要求

  • 多GPU服务器(建议2-8块GPU),当然一块也行,毕竟二分类,而且数据量不是很大
  • CUDA 11.0及以上

2.2 软件依赖

pip install torch torchvision#这是主要的工具

2.3 数据预处理

可以参照之前的文章,深度学习之用CelebA_Spoof数据集搭建一个活体检测-数据处理,可以用数据集里面的json文件,也可以自己重新构建数据集。在这里我是重新构建了:

import os
from PIL import Image
from pathlib import Path
from PIL import ImageFile
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading

ImageFile.LOAD_TRUNCATED_IMAGES = True
print_lock = threading.Lock()

def load_bbox(bbox_file_path):
    """加载边界框文件"""
    if not os.path.exists(bbox_file_path):
        raise FileNotFoundError(f"Bounding box file not found: {
     
     bbox_file_path}")
    with open(bbox_file_path, 'r') as file:
        return list(map(float, file.read().strip().split()))

def process_image(args):
    """处理单张图片(线程安全)"""
    img_path, output_path, thread_id, remaining = args
    try:
        with print_lock:
            print(f"[Thread-{
     
     thread_id}] Processing {
     
     Path(img_path).name} ({
     
     remaining} remaining)")
            
        bbox_file = os.path.join(os.path.dirname(img_path), f"{
     
     Path(img_path).stem}_BB.txt")
        bbox = load_bbox(bbox_file)
        img = Image.open(img_path)
        real_w, real_h = img.size
        
        # 计算实际边界框
        x1 = int(bbox[0] * (real_w / 224))
        y1 = int(bbox[1] * (real_h / 224))
        w1 = int(bbox[2] * (real_w / 224))
        h1 = int(bbox[3] * (real_h / 224))
        
        # 裁剪并保存
        cropped_img = img.crop((x1, y1, x1 + w1, y1 + h1))
        cropped_img.save(output_path)
        
        #with print_lock:
        #    print(f"[Thread-{thread_id}] Finished {Path(img_path).name}")
        return True
    except Exception as e:
        with print_lock:
            print(f"[Thread-{
     
     thread_id}] Error processing {
     
     Path(img_path).name}: {
     
     str(e)}")
        return False

def process_directory(input_dir, output_dir, max_workers=8):
    """使用线程池处理目录中的所有图片"""
    os.makedirs(output_dir, exist_ok=True)
    tasks = []
    
    # 获取所有图片文件
    image_files = [f for f in os.listdir(input_dir) if f.endswith((".png", ".jpg", ".jpeg"))]
    total = len(image_files)
    
    for idx, img_file in enumerate(image_files):
        img_path = os.path.join(input_dir, img_file)
        output_file = Path(img_file).with_suffix(".png")
        output_path = os.path.join(output_dir, output_file)
        remaining = total - idx - 1
        tasks.append((img_path, output_path, idx % max_workers, remaining))
    
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = [executor.submit(process_image, task) for task in tasks]
        for future in as_completed(futures):
            future.result()

def process_data(root_dir, output_base_dir):
    """处理整个数据集目录"""
    for split in ["test", "train"]:
        split_dir = os.path.join(root_dir, split)
        if not os.path.exists(split_dir):
            print(f"Warning: {
     
     split_dir} does not exist")
            continue
            
        # 创建对应的输出目录
        output_split_dir = os.path.join(output_base_dir, split)
        os.makedirs(os.path.join(output_split_dir, "live"), exist_ok=True)
        os.makedirs(os.path.join(output_split_dir, "spoof"), exist_ok=True)
            
        for sub_dir in os.listdir(split_dir):
            sub_path = os.path.join(split_dir, sub_dir)
            if not os.path.isdir(sub_path):
                print(f"Skipping non-directory: {
     
     sub_path}")
                continue
                
            for label in ["live", "spoof"]:
                label_path = os.path.join(sub_path, label)
                if not os.path.exists(label_path)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值