python工具——图像数据清洗&图像质量检查

功能介绍

随着目前采集的数据集中的图像越来越多,出现了数据格式十分杂乱、质量不统一、部分图像存在损坏等各种问题。本程序提供图像数据清洗&质量检查等功能,防止后续处理在加载数据时出现的各种异常。

  1. 使用PIL、opencv等方式读取图像,检查图像是否损坏,支持多种图像格式,包括(‘.png’, ‘.jpg’, ‘.jpeg’, ‘.bmp’, ‘.gif’, ‘.tiff’)不区分大小写
  2. 读取图像exif旋转信息,防止后续标注异常,TODO(可选)先将源文件移出去,然后根据exit的值调整图像重新保存到原路径。
  3. 记录图像属性信息:
    ① 图像元信息:编码格式、分辨率、通道数、文件大小等,便于判断图像其他属性
    ② MD5,phash16等值,后续可以作为判断图像是否存在重复的依据
  4. 判断图像模糊质量的常用方法:
    ① Laplacian 变换(边缘检测):一种常见的边缘检测算法,通过计算图像中像素值的变化率来评估图像的清晰度。
    ② 傅里叶变换(频域分析):将图像从空间域转换到频率域,高频部分代表图像中的细节和边缘,低频部分代表平滑区域。(推荐)
    ③ Tenengrad 方法:基于图像梯度的平方和来衡量图像的清晰度,图像越清晰,梯度变化越大。(推荐)
    ④ 熵(Entropy):可以用来衡量图像的信息量,高熵图像通常包含更多的细节和信息,但不一定直接反映图像的清晰度。
    ⑤ 结构相似性(SSIM):主要用于比较两幅图像之间的相似性,但也可以用于评估单张图像的清晰度,特别是与参考图像进行比较时。
    ⑥ 峰值信噪比(PSNR):基于图像的峰值信噪比来判断图像的质量。SSIM 和 PSNR 更适合于比较两幅图像之间的差异,特别是当有参考图像时。
  5. 记录检查报告并输出。
  6. 使用 concurrent.futures.ProcessPoolExecutor 和 ThreadPoolExecutor 多线程和多进程加速处理。

可执行程序

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# 功能:图像数据清洗&图像质量检查
# 作者:AYangSN
# 时间:2025-03-12 19:23:22
# 版本:v1.0
# here is important imports
import time
import csv
import os
import shutil
import cv2
import hashlib
import imagehash
import numpy as np
from tqdm import tqdm
from PIL import Image, ImageOps, ExifTags
from skimage.metrics import structural_similarity as ssim
from scipy.stats import entropy
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed


def check_image_with_pil(filepath):
    '''使用PIL检查图像是否损坏'''
    try:
        img = Image.open(filepath)
        img.verify()  # 验证图像完整性
        img = Image.open(filepath)  # 再次打开以确保图像可以正常加载
        return True, img
    except Exception as e:
        return False, str(e)

def check_image_with_opencv(filepath):
    '''使用OpenCV检查图像是否损坏'''
    try:
        img = cv2.imread(filepath)
        if img is None or img.size == 0:
            return False, 'OpenCV无法加载图像'
        return True, img
    except Exception as e:
        return False, str(e)

def check_file_header(filepath):
    '''通过读取文件头信息检查图像格式是否正确'''
    valid_headers = {
        'JPEG': b'\xff\xd8\xff',
        'PNG': b'\x89\x50\x4e\x47\x0d\x0a\x1a\x0a',
        'GIF87a': b'GIF87a',
        'GIF89a': b'GIF89a',
        'BMP': b'BM'
    }
    with open(filepath, 'rb') as f:
        header = f.read(8)  # 读取前8个字节以覆盖所有格式
        for format, magic in valid_headers.items():
            if header.startswith(magic):
                return True, None
    return False, '未知的文件头'


def get_exif_orientation(img):
    '''获取图像方向信息'''
    try:
        exif = img._getexif()
    except AttributeError:
        exif = None
    if exif is None:
        return None
    exif = {
        ExifTags.TAGS[k]: v
        for k, v in exif.items()
        if k in ExifTags.TAGS
    }
    
    orientation = exif.get('Orientation', None)
    return orientation


def exif_update_image_files(img, orientation, output_dir):
    '''根据参数旋转图片'''
    if orientation == 2:
        # left-to-right mirror
        img = ImageOps.mirror(img)
    elif orientation == 3:
        # rotate 180
        img = img.transpose(Image.ROTATE_180)
    elif orientation == 4:
        # top-to-bottom mirror
        img = ImageOps.flip(img)
    elif orientation == 5:
        # top-to-left mirror
        img = ImageOps.mirror(img.transpose(Image.ROTATE_270))
    elif orientation == 6:
        # rotate 270
        img = img.transpose(Image.ROTATE_270)
    elif orientation == 7:
        # top-to-right mirror
        img =  ImageOps.mirror(img.transpose(Image.ROTATE_90))
    elif orientation == 8:
        # rotate 90
        img = img.transpose(Image.ROTATE_90)
    else:
        pass

    # 使用opencv读取,去除exif信息
    img = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
   
    # 重新保存图片
    cv2.imwrite(output_dir, img)


def compute_md5(filepath):
    '''计算文件的MD5值'''
    hash_md5 = hashlib.md5()
    with open(filepath, 'rb') as f:
        for chunk in iter(lambda: f.read(4096), b''):
            hash_md5.update(chunk)
    return hash_md5.hexdigest()


def compute_phash(img, hash_size=16):
    '''计算图像的phash值'''
    phash = imagehash.phash(img, hash_size=hash_size, highfreq_factor=4)
    hex_string = str(phash)
    return hex_string


def diff_phash(p1, p2, hash_size = 8):
    '''计算两个phash值之间的相似度差异'''
    return (p1 - p2) / hash_size ** 2


def check_blur(img, ref_image=None):
    '''综合评估图像的模糊质量'''
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

    # Laplacian 方差
    laplacian_var = cv2.Laplacian(gray, cv2.CV_64F).var()

    # 傅里叶变换
    f = np.fft.fft2(gray)
    fshift = np.fft.fftshift(f)
    magnitude_spectrum = 20 * np.log(np.abs(fshift) + 1e-10)  # 加上一个小值以避免log(0)
    # 计算总能量
    total_energy = np.sum(magnitude_spectrum)
    # 计算高频部分的能量占比
    high_freq_energy = np.sum(magnitude_spectrum[magnitude_spectrum > np.percentile(magnitude_spectrum, 90)])
    high_freq_ratio = high_freq_energy / total_energy

    # Tenengrad 方法
    gradient_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
    gradient_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
    gradient_magnitude = np.sqrt(gradient_x**2 + gradient_y**2)
    tenengrad_value = np.mean(gradient_magnitude)

    # 熵
    hist = cv2.calcHist([gray], [0], None, [256], [0, 256])
    hist_norm = hist.ravel() / hist.max()
    entropy_value = entropy(hist_norm, base=2)

    # 计算SSIM & PSNR(如果有参考图像)
    psnr = None
    ssim_score = None
    if ref_image is not None:
        gray_ref = cv2.cvtColor(ref_image, cv2.COLOR_BGR2GRAY)
        ssim_score, _ = ssim(gray, gray_ref, full=True)
        psnr = cv2.PSNR(gray, gray_ref)

    return laplacian_var, high_freq_ratio, tenengrad_value, entropy_value, ssim_score, psnr


def process_images(filepath, output_dir, save_file=True):
    '''处理单张图像'''
    log_entry = {}
    broken, exif, blur = None, None, None
    try:
        filename = os.path.basename(filepath)
        # file_extension = os.path.splitext(filepath)[1].lower()

        # 检查图像是否损坏
        pil_result, img_pil = check_image_with_pil(filepath)
        opencv_result, img_opencv = check_image_with_opencv(filepath)
        header_result, header_error = check_file_header(filepath)

        # 如果图像没有损坏,则继续处理
        if pil_result and opencv_result and header_result:
            
            # 获取文件大小  字节(bytes)
            file_size = os.path.getsize(filepath)

            # 计算MD5校验码
            md5_checksum = compute_md5(filepath)

            # 获取分辨率
            width, height = img_pil.size
            
            # 获取颜色模式
            color_mode = img_pil.mode

            # 获取通道数
            channels = len(color_mode) if isinstance(color_mode, str) else None

            # 获取位深度
            bit_depth = img_pil.bits if hasattr(img_pil, 'bits') else None
            
            # 获取压缩类型
            compression = img_pil.info.get('compression', 'Unknown')
            
            # 获取EXIF数据
            orientation = get_exif_orientation(img_pil)
            
            # 根据旋转信息更新图像
            if not (orientation is None or orientation==1):
                exif = filepath
                if save_file:
                    os.makedirs(os.path.join(output_dir,'exif_ori', str(orientation)), exist_ok=True)             # 创建新路径
                    shutil.copyfile(filepath, os.path.join(output_dir,'exif_ori', str(orientation), filename))    # 复制一份原始图像到新路径
                    exif_update_image_files(img_pil, orientation, filepath)                     # 更新原始图像

            # 计算phash16校验码
            hex_string = compute_phash(img_pil, hash_size=16)

            # # 获取直方图
            # hist = img_pil.histogram()

            # 计算模糊度
            laplacian_var, high_freq_ratio, tenengrad_value, entropy_value, ssim_score, psnr = check_blur(img_opencv)
            ### if laplacian_var < 100 or high_freq_ratio < 0.1 or tenengrad_value < 1000 or entropy_value < 5 or ssim_score < 0.9 or psnr < 30:  # 模糊度阈值
            ### if laplacian_var > 300 or high_freq_ratio > 0.5 or tenengrad_value > 5000 or entropy_value > 8 or ssim_score > 1.0 or psnr < 50: # 清晰度阈值
            if high_freq_ratio < 0.1:
                blur = filepath
                if save_file:
                    os.makedirs(os.path.join(output_dir,'blur'), exist_ok=True)
                    shutil.copyfile(filepath, os.path.join(output_dir,'blur',filename))

            log_entry = {
                'filename': filepath,
                'pil_check': pil_result,
                'opencv_check': opencv_result,
                'header_check': header_result,
                'header_error': header_error,
                'file_size': file_size,
                'resolution': (width, height),
                'color_mode': color_mode,
                'bit_depth': bit_depth,
                'channels': channels,
                'compression': compression,
                'exif_data': orientation,
                'md5_checksum': md5_checksum,
                'phash16_checksum': hex_string,
                'laplacian_var': laplacian_var,
                'high_freq_ratio': high_freq_ratio,
                'tenengrad_value': tenengrad_value,
                'entropy_value': entropy_value,
                'ssim_score': ssim_score,
                'psnr': psnr,
            }
            
        else:
            log_entry = {
                'filename': filepath,
                'pil_check': pil_result,
                'opencv_check': opencv_result,
                'header_check': header_result,
                'header_error': header_error,
            }
            broken = filepath
            if save_file:
                os.makedirs(os.path.join(output_dir,'broken'), exist_ok=True)
                shutil.copyfile(filepath, os.path.join(output_dir,'broken',filename))
        
    except Exception as e:
        print("Error processing image:{} {}".format(filepath,e))

    return log_entry, broken, exif, blur

def main_one(input_dir, output_dir):
    """单进程版本"""
    os.makedirs(output_dir, exist_ok=True)

    output_broken_txt_path = os.path.join(output_dir, 'image_broken_lists_{}.txt'.format(time.strftime("%Y%m%d%H%M%S%H%M%S", time.localtime())))
    output_exif_txt_path = os.path.join(output_dir, 'image_exif_lists_{}.txt'.format(time.strftime("%Y%m%d%H%M%S", time.localtime())))
    output_blur_txt_path = os.path.join(output_dir, 'image_blur_lists_{}.txt'.format(time.strftime("%Y%m%d%H%M%S", time.localtime())))
    output_csv_path = os.path.join(output_dir, 'image_integrity_report_{}.csv'.format(time.strftime("%Y%m%d%H%M%S", time.localtime())))

    filepaths = []
    # 遍历输入目录下的所有文件,包括子目录
    for root, dir, fs in os.walk(input_dir):
        filepaths.extend([os.path.join(root, f) for f in fs if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif', '.tiff'))])
    print(f'Found {len(filepaths)} images to process.')

    fieldnames = [
        'filename', 'pil_check', 'opencv_check', 'header_check', 'header_error', \
        'file_size', 'resolution', 'color_mode', 'bit_depth','channels', 'compression', 'exif_data', 'md5_checksum', 'phash16_checksum', \
        'laplacian_var', 'high_freq_ratio', 'tenengrad_value', 'entropy_value', 'ssim_score', 'psnr'
    ]

    for fp in tqdm(filepaths, desc="Processing images"):

        log_entry, broken, exif, blur = process_images(fp, output_dir)

        mode = 'a' if os.path.exists(output_csv_path) else 'w'
        with open(output_csv_path, mode, newline='', encoding='utf-8-sig') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            if mode == 'w':
                writer.writeheader()
            writer.writerow(log_entry)

        if broken is not None:
            mode = 'a' if os.path.exists(output_broken_txt_path) else 'w'
            with open(output_broken_txt_path, mode, newline='', encoding='utf-8-sig') as f_broken:
                f_broken.writelines(broken + '\n')

        if exif is not None:
            mode = 'a' if os.path.exists(output_exif_txt_path) else 'w'
            with open(output_exif_txt_path, mode, newline='', encoding='utf-8-sig') as f_exif:
                f_exif.writelines(exif + '\n')

        if blur is not None:
            mode = 'a' if os.path.exists(output_blur_txt_path) else 'w'
            with open(output_blur_txt_path, mode, newline='', encoding='utf-8-sig') as f_blur:
                f_blur.writelines(blur + '\n')

    print('Done.')


def write_to_file(log_entry, broken, exif, blur, output_dir, current_time):
    """
    将处理结果写入相应的文件。
    
    :param log_entry: 日志条目
    :param broken: 损坏的文件列表
    :param exif: 含有EXIF数据的文件列表
    :param blur: 模糊的文件列表
    :param output_dir: 输出目录
    """
    output_broken_txt_path = os.path.join(output_dir, 'image_broken_lists_{}.txt'.format(current_time))
    output_exif_txt_path = os.path.join(output_dir, 'image_exif_lists_{}.txt'.format(current_time))
    output_blur_txt_path = os.path.join(output_dir, 'image_blur_lists_{}.txt'.format(current_time))
    output_csv_path = os.path.join(output_dir, 'image_integrity_report_{}.csv'.format(current_time))

    fieldnames = [
        'filename', 'pil_check', 'opencv_check', 'header_check', 'header_error',
        'file_size', 'resolution', 'color_mode', 'bit_depth', 'channels', 'compression', 'exif_data',
        'md5_checksum', 'phash16_checksum', 'laplacian_var', 'high_freq_ratio', 'tenengrad_value',
        'entropy_value', 'ssim_score', 'psnr'
    ]

    def write_csv():
        mode = 'a' if os.path.exists(output_csv_path) else 'w'
        with open(output_csv_path, mode, newline='', encoding='utf-8-sig') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            if mode == 'w':
                writer.writeheader()
            writer.writerow(log_entry)

    def write_txt(file_path, content):
        mode = 'a' if os.path.exists(file_path) else 'w'
        with open(file_path, mode, newline='', encoding='utf-8-sig') as f:
            f.writelines(content + '\n')

    # 使用线程池处理文件写入
    with ThreadPoolExecutor(max_workers=4) as executor:
        futures = []
        futures.append(executor.submit(write_csv))
        if broken is not None:
            futures.append(executor.submit(write_txt, output_broken_txt_path, broken))
        if exif is not None:
            futures.append(executor.submit(write_txt, output_exif_txt_path, exif))
        if blur is not None:
            futures.append(executor.submit(write_txt, output_blur_txt_path, blur))

        for future in futures:
            future.result()


def main(input_dir, output_dir):
    """
    多进程+多线程主函数,用于处理指定目录下所有的图像文件。
    
    Args:
        input_dir: 输入目录路径,包含要处理的图像文件。
        output_dir: 输出目录路径,用于保存处理后的结果。
    """
    current_time = time.strftime("%Y%m%d%H%M%S", time.localtime())

    os.makedirs(output_dir, exist_ok=True)
    
    filepaths = []
    # 遍历输入目录下的所有文件,包括子目录
    for root, dir, fs in os.walk(input_dir):
        filepaths.extend([os.path.join(root, f) for f in fs if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif', '.tiff'))])
    print(f'Found {len(filepaths)} images to process.')

    # 使用进程池处理图片处理
    with ProcessPoolExecutor() as executor:
        futures = [executor.submit(process_images, fp, output_dir) for fp in filepaths]

        for future in tqdm(as_completed(futures), total=len(filepaths), desc="Processing images"):
            log_entry, broken, exif, blur = future.result()
            if log_entry:
                write_to_file(log_entry, broken, exif, blur, output_dir, current_time)

    print('Done.')


if __name__ == '__main__':
    # 示例用法
    input_directory = 'your_inputpath'
    output_directory = 'your_outputpath'
    os.makedirs(output_directory, exist_ok=True)
    main(input_directory, output_directory)    
    # main_one(input_directory, output_directory)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值