基于暗通道先验的图像去雾算法实现与分析

  1. 引言
    雾霾天气会严重影响图像质量,导致对比度降低、颜色失真等问题。从而导致下游任务不好开展,如目标检测、图像分类、语义分割等计算机视觉任务,构造深度学习去雾数据集耗时耗力,想快速的验证自己的idea的话传统的方法是一个不错的选择,本文介绍一种经典的单幅图像去雾算法——暗通道先验去雾法,并基于Python实现完整的处理流程。
    暗通道先验理论基础
    暗通道先验是何恺明博士提出的一种图像去雾理论,其核心假设是:
    暗通道原理:对于大多数无雾户外图像,在局部区域至少有一个颜色通道值非常小
    大气散射模型:雾霾图像形成遵循物理公式 I(x)=J(x)t(x)+A(1−t(x))I(x) = J(x)t(x) + A(1-t(x))I(x)=J(x)t(x)+A(1t(x))
    其中:
    I(x) 是观测到的有雾图像
    J(x) 是待恢复的无雾图像
    A 是全局大气光成分
    t(x) 是透射率图
  2. 核心算法实现
    暗通道计算 (DarkChannel 函数)
def DarkChannel(im, sz):
    b, g, r = cv2.split(im)
    dc = cv2.min(cv2.min(r, g), b)
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (sz, sz))
    dark = cv2.erode(dc, kernel)
    return dark

该函数首先分离RGB三个通道,然后取每个像素点三个通道中的最小值,最后通过形态学腐蚀操作获得暗通道图像。
大气光估计 (AtmLight 函数)

def AtmLight(im, dark):
    [h, w] = im.shape[:2]
    imsz = h * w
    numpx = int(max(math.floor(imsz / 1000), 1))
    darkvec = dark.reshape(imsz)
    imvec = im.reshape(imsz, 3)

    indices = darkvec.argsort()
    indices = indices[imsz - numpx::]

    atmsum = np.zeros([1, 3])
    for ind in range(1, numpx):
        atmsum = atmsum + imvec[indices[ind]]

    A = atmsum / numpx
    return A

根据暗通道图找出最亮的0.1%像素点,以其在原图中的像素值作为大气光 A 的估计值。
透射率估算 (TransmissionEstimate 函数)

def TransmissionEstimate(im, A, sz):
    omega = 0.90
    im3 = np.empty(im.shape, im.dtype)

    for ind in range(0, 3):
        im3[:, :, ind] = im[:, :, ind] / A[0, ind]

    transmission = 1 - omega * DarkChannel(im3, sz)
    return transmission

利用大气散射模型反推粗略的透射率图,其中 omega 为去雾程度参数。
透射率优化 (TransmissionRefine 函数)

def TransmissionRefine(im, et):
    gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
    gray = np.float64(gray) / 255
    r = 60
    eps = 0.0001
    t = Guidedfilter(gray, et, r, eps)
    return t

最终根据大气散射模型恢复清晰图像,并限制像素值范围在[0,1]之间
完整处理流程

def process_image(input_path, output_path):
    src = cv2.imread(input_path)
    I = src.astype('float32') / 255
    
    dark = DarkChannel(I, 15)
    A = AtmLight(I, dark)
    te = TransmissionEstimate(I, A, 15)
    t = TransmissionRefine(src, te)
    J = Recover(I, t, A, 0.1)
    
    result = (J * 255).astype(np.uint8)
    cv2.imwrite(output_path, result)

批量处理功能

def batch_process(input_folder, output_folder):
    os.makedirs(output_folder, exist_ok=True)
    extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tif']
    
    image_files = []
    for ext in extensions:
        image_files.extend(glob.glob(os.path.join(input_folder, ext)))
        image_files.extend(glob.glob(os.path.join(input_folder, ext.upper())))
    
    for img_path in image_files:
        filename = os.path.basename(img_path)
        output_path = os.path.join(output_folder, filename)
        process_image(img_path, output_path)

总结
本文实现了完整的暗通道先验去雾算法,具有以下特点:
能有效去除雾霾影响,提升图像对比度和清晰度
支持多种图像格式批量处理
自动优化输出文件大小,避免过度膨胀
可调节参数控制去雾强度和效果
该算法适用于监控视频增强、遥感图像处理、自动驾驶视觉等多个领域,是一种实用性强的经典图像处理方法。
完整代码

import cv2
import math
import numpy as np
import os
import glob


def DarkChannel(im, sz):
    b, g, r = cv2.split(im)
    dc = cv2.min(cv2.min(r, g), b)
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (sz, sz))
    dark = cv2.erode(dc, kernel)
    return dark


def AtmLight(im, dark):
    [h, w] = im.shape[:2]
    imsz = h * w
    numpx = int(max(math.floor(imsz / 1000), 1))
    darkvec = dark.reshape(imsz)
    imvec = im.reshape(imsz, 3)

    indices = darkvec.argsort()
    indices = indices[imsz - numpx::]

    atmsum = np.zeros([1, 3])
    for ind in range(1, numpx):
        atmsum = atmsum + imvec[indices[ind]]

    A = atmsum / numpx
    return A


def TransmissionEstimate(im, A, sz):
    omega = 0.90
    im3 = np.empty(im.shape, im.dtype)

    for ind in range(0, 3):
        im3[:, :, ind] = im[:, :, ind] / A[0, ind]

    transmission = 1 - omega * DarkChannel(im3, sz)
    return transmission


def Guidedfilter(im, p, r, eps):
    mean_I = cv2.boxFilter(im, cv2.CV_64F, (r, r))
    mean_p = cv2.boxFilter(p, cv2.CV_64F, (r, r))
    mean_Ip = cv2.boxFilter(im * p, cv2.CV_64F, (r, r))
    cov_Ip = mean_Ip - mean_I * mean_p

    mean_II = cv2.boxFilter(im * im, cv2.CV_64F, (r, r))
    var_I = mean_II - mean_I * mean_I

    a = cov_Ip / (var_I + eps)
    b = mean_p - a * mean_I

    mean_a = cv2.boxFilter(a, cv2.CV_64F, (r, r))
    mean_b = cv2.boxFilter(b, cv2.CV_64F, (r, r))

    q = mean_a * im + mean_b
    return q


def TransmissionRefine(im, et):
    gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
    gray = np.float64(gray) / 255
    r = 60
    eps = 0.0001
    t = Guidedfilter(gray, et, r, eps)

    return t


def Recover(im, t, A, tx=0.1):
    res = np.empty(im.shape, im.dtype)
    t = cv2.max(t, tx)

    for ind in range(0, 3):
        res[:, :, ind] = (im[:, :, ind] - A[0, ind]) / t + A[0, ind]

    res = np.clip(res, 0, 1)
    return res


def process_image(input_path, output_path):
    """Process a single image"""
    src = cv2.imread(input_path)

    if src is None:
        print(f"Could not read image: {input_path}")
        return

    # 获取原始图像尺寸和文件大小
    original_size = os.path.getsize(input_path)

    # Ensure input is 3-channel
    if len(src.shape) == 2:  # Grayscale image
        src = cv2.cvtColor(src, cv2.COLOR_GRAY2BGR)

    I = src.astype('float32') / 255

    dark = DarkChannel(I, 15)
    A = AtmLight(I, dark)
    te = TransmissionEstimate(I, A, 15)
    t = TransmissionRefine(src, te)
    J = Recover(I, t, A, 0.1)

    # Ensure output is 3-channel before saving
    result = (J * 255).astype(np.uint8)
    if len(result.shape) == 2:  # Prevent accidentally becoming grayscale
        result = cv2.cvtColor(result, cv2.COLOR_GRAY2BGR)

    # Save with default quality first
    cv2.imwrite(output_path, result)

    # Check if output file is larger than input file
    output_size = os.path.getsize(output_path)

    if output_size > original_size:
        # Handle TIFF files with compression
        if output_path.lower().endswith(('.tif', '.tiff')):
            # Remove the original uncompressed TIFF
            os.remove(output_path)
            # Try saving with LZW compression
            compressed_tiff_path = output_path
            cv2.imwrite(compressed_tiff_path, result, [int(cv2.IMWRITE_TIFF_COMPRESSION), 5])

            # # If still too large, convert to JPEG
            # if os.path.getsize(compressed_tiff_path) > original_size:
            #     base_name = os.path.splitext(output_path)[0]
            #     jpg_output_path = base_name + '.jpg'
            #     cv2.imwrite(jpg_output_path, result, [int(cv2.IMWRITE_JPEG_QUALITY), 90])
            #     print(f"Converted to JPEG due to size: {input_path} -> {jpg_output_path}")
            # else:
            #     print(f"Compressed TIFF: {input_path} -> {compressed_tiff_path}")

        # Handle other formats by reducing quality
        else:
            base_name = os.path.splitext(output_path)[0]
            jpg_output_path = base_name + '.jpg'

            # Try saving as JPEG with quality 90
            cv2.imwrite(jpg_output_path, result, [int(cv2.IMWRITE_JPEG_QUALITY), 90])

            # If still too large, reduce quality further
            if os.path.getsize(jpg_output_path) > original_size:
                cv2.imwrite(jpg_output_path, result, [int(cv2.IMWRITE_JPEG_QUALITY), 80])

            # Remove original output file and keep JPEG version
            if os.path.exists(output_path):
                os.remove(output_path)

            print(f"Processed: {input_path} -> {jpg_output_path}")
    else:
        print(f"Processed: {input_path} -> {output_path}")


def batch_process(input_folder, output_folder):
    """Process all images in a folder"""
    # Create output folder if it doesn't exist
    os.makedirs(output_folder, exist_ok=True)

    # Supported image extensions
    extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.tif']

    # Get all image files
    image_files = []
    for ext in extensions:
        image_files.extend(glob.glob(os.path.join(input_folder, ext)))
        image_files.extend(glob.glob(os.path.join(input_folder, ext.upper())))

    # print(f"Found {len(image_files)} images to process")

    # Process each image
    for img_path in image_files:
        filename = os.path.basename(img_path)
        output_path = os.path.join(output_folder, filename)
        process_image(img_path, output_path)


if __name__ == '__main__':
    import sys

    if len(sys.argv) >= 2:
        input_folder = sys.argv[1]
        output_folder = sys.argv[2] if len(sys.argv) > 2 else "./output"
        batch_process(input_folder, output_folder)
    else:
        # Default folders if no arguments provided
        input_folder = r"E:\Desktop\ZunYi_Building\SINGLECLASSDataset\images\val"
        output_folder = r"E:\Desktop\ZunYi_Building\SINGLECLASSDataset\images\val_dehaze"
        batch_process(input_folder, output_folder)

带GUI界面去雾脚本
为了直观的观察去雾前后的图像,在上述脚本的基础上对其进行修改。

import cv2
import math
import numpy as np
import tkinter as tk
from tkinter import filedialog, messagebox
from PIL import Image, ImageTk
import os


# 原有的去雾函数保持不变
def DarkChannel(im, sz):
    b, g, r = cv2.split(im)
    dc = cv2.min(cv2.min(r, g), b)
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (sz, sz))
    dark = cv2.erode(dc, kernel)
    return dark


def AtmLight(im, dark):
    [h, w] = im.shape[:2]
    imsz = h * w
    numpx = int(max(math.floor(imsz / 1000), 1))
    darkvec = dark.reshape(imsz)
    imvec = im.reshape(imsz, 3)

    indices = darkvec.argsort()
    indices = indices[imsz - numpx:]

    atmsum = np.zeros([1, 3])
    for ind in range(1, numpx):
        atmsum = atmsum + imvec[indices[ind]]

    A = atmsum / numpx
    return A


def TransmissionEstimate(im, A, sz):
    omega = 0.95
    im3 = np.empty(im.shape, im.dtype)

    for ind in range(0, 3):
        im3[:, :, ind] = im[:, :, ind] / A[0, ind]

    transmission = 1 - omega * DarkChannel(im3, sz)
    return transmission


def Guidedfilter(im, p, r, eps):
    mean_I = cv2.boxFilter(im, cv2.CV_64F, (r, r))
    mean_p = cv2.boxFilter(p, cv2.CV_64F, (r, r))
    mean_Ip = cv2.boxFilter(im * p, cv2.CV_64F, (r, r))
    cov_Ip = mean_Ip - mean_I * mean_p

    mean_II = cv2.boxFilter(im * im, cv2.CV_64F, (r, r))
    var_I = mean_II - mean_I * mean_I

    a = cov_Ip / (var_I + eps)
    b = mean_p - a * mean_I

    mean_a = cv2.boxFilter(a, cv2.CV_64F, (r, r))
    mean_b = cv2.boxFilter(b, cv2.CV_64F, (r, r))

    q = mean_a * im + mean_b
    return q


def TransmissionRefine(im, et):
    gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
    gray = np.float64(gray) / 255
    r = 60
    eps = 0.0001
    t = Guidedfilter(gray, et, r, eps)
    #t=cv2.GaussianBlur(t, (1,1), 0)
    return t


def Recover(im, t, A, tx=0.1):
    res = np.empty(im.shape, im.dtype)
    t = cv2.max(t, tx)

    for ind in range(0, 3):
        res[:, :, ind] = (im[:, :, ind] - A[0, ind]) / t + A[0, ind]

        res = np.clip(res, 0, 1)

    return res


class DehazeApp:
    def __init__(self, root):
        self.root = root
        self.root.title("图像去雾工具")
        self.root.geometry("1200x700")

        # 初始化变量
        self.folder_path = ""
        self.image_files = []
        self.current_index = 0

        # 创建界面
        self.create_widgets()

    def create_widgets(self):
        # 文件夹选择区域和控制按钮都放在顶部
        top_frame = tk.Frame(self.root)
        top_frame.pack(pady=10)

        # 文件夹选择区域
        folder_frame = tk.Frame(top_frame)
        folder_frame.pack()

        tk.Button(folder_frame, text="选择文件夹", command=self.select_folder).pack(side=tk.LEFT, padx=5)
        self.folder_label = tk.Label(folder_frame, text="未选择文件夹")
        self.folder_label.pack(side=tk.LEFT, padx=5)

        # 控制按钮区域
        control_frame = tk.Frame(top_frame)
        control_frame.pack(pady=5)

        tk.Button(control_frame, text="上一张", command=self.prev_image).pack(side=tk.LEFT, padx=5)
        tk.Button(control_frame, text="下一张", command=self.next_image).pack(side=tk.LEFT, padx=5)
        tk.Button(control_frame, text="执行去雾", command=self.process_image).pack(side=tk.LEFT, padx=5)
        tk.Button(control_frame, text="保存结果", command=self.save_result).pack(side=tk.LEFT, padx=5)

        # 图像显示区域
        image_frame = tk.Frame(self.root)
        image_frame.pack(expand=True, fill=tk.BOTH, padx=20, pady=10)

        # 原图显示
        original_frame = tk.LabelFrame(image_frame, text="原图")
        original_frame.pack(side=tk.LEFT, expand=True, fill=tk.BOTH, padx=10)

        self.original_canvas = tk.Canvas(original_frame, bg="lightgray")
        self.original_canvas.pack(expand=True, fill=tk.BOTH)

        # 结果图显示
        result_frame = tk.LabelFrame(image_frame, text="去雾结果")
        result_frame.pack(side=tk.RIGHT, expand=True, fill=tk.BOTH, padx=10)

        self.result_canvas = tk.Canvas(result_frame, bg="lightgray")
        self.result_canvas.pack(expand=True, fill=tk.BOTH)

        # 状态栏
        self.status_bar = tk.Label(self.root, text="就绪", bd=1, relief=tk.SUNKEN, anchor=tk.W)
        self.status_bar.pack(side=tk.BOTTOM, fill=tk.X)

    def select_folder(self):
        folder_selected = filedialog.askdirectory()
        if folder_selected:
            self.folder_path = folder_selected
            self.folder_label.config(text=folder_selected)

            # 获取文件夹中的图像文件
            self.image_files = [
                f for f in os.listdir(self.folder_path)
                if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff'))
            ]

            if not self.image_files:
                messagebox.showwarning("警告", "所选文件夹中没有图像文件")
                return

            self.current_index = 0
            self.show_image()

    def show_image(self):
        if not self.image_files:
            return

        # 显示当前图像名状态
        self.status_bar.config(
            text=f"图像 {self.current_index + 1}/{len(self.image_files)}: {self.image_files[self.current_index]}"
        )

        # 加载并显示原图
        img_path = os.path.join(self.folder_path, self.image_files[self.current_index])
        original_img = Image.open(img_path)

        # 调整图像大小以适应显示区域
        max_width, max_height = 1024, 1024
        original_img.thumbnail((max_width, max_height), Image.LANCZOS)

        self.original_photo = ImageTk.PhotoImage(original_img)
        self.original_canvas.delete("all")
        self.original_canvas.create_image(
            max_width // 2, max_height // 2,
            image=self.original_photo,
            anchor=tk.CENTER
        )

    def process_image(self):
        if not self.image_files:
            messagebox.showwarning("警告", "请先选择包含图像的文件夹")
            return

        # 加载当前图像
        img_path = os.path.join(self.folder_path, self.image_files[self.current_index])
        src = cv2.imread(img_path)

        if src is None:
            messagebox.showerror("错误", "无法加载图像文件")
            return

        # 执行去雾算法
        I = src.astype('float64') / 255
        dark = DarkChannel(I, 15)
        A = AtmLight(I, dark)
        te = TransmissionEstimate(I, A, 15)
        t = TransmissionRefine(src, te)
        J = Recover(I, t, A, 0.1)

        # 将结果转换为可显示格式
        result_img = (J * 255).astype(np.uint8)
        result_img = cv2.cvtColor(result_img, cv2.COLOR_BGR2RGB)
        result_pil = Image.fromarray(result_img)

        # 调整图像大小以适应显示区域
        max_width, max_height = 1024, 1024
        result_pil.thumbnail((max_width, max_height), Image.LANCZOS)

        self.result_photo = ImageTk.PhotoImage(result_pil)
        self.result_canvas.delete("all")
        self.result_canvas.create_image(
            max_width // 2, max_height // 2,
            image=self.result_photo,
            anchor=tk.CENTER
        )

    def prev_image(self):
        if not self.image_files:
            return

        self.current_index = (self.current_index - 1) % len(self.image_files)
        self.show_image()
        # 清空结果图
        self.result_canvas.delete("all")

    def next_image(self):
        if not self.image_files:
            return

        self.current_index = (self.current_index + 1) % len(self.image_files)
        self.show_image()
        # 清空结果图
        self.result_canvas.delete("all")
        
    def save_result(self):
        if not hasattr(self, 'result_photo') or not self.result_photo:
            messagebox.showwarning("警告", "请先执行去雾操作")
            return
        
        if not self.image_files:
            messagebox.showwarning("警告", "请先选择包含图像的文件夹")
            return
        
        # 构造保存路径
        current_image_name = self.image_files[self.current_index]
        name_without_ext = os.path.splitext(current_image_name)[0]
        save_filename = f"{name_without_ext}_dehazed.png"
        save_path = os.path.join(self.folder_path, save_filename)
        
        # 保存图像
        try:
            # 重新处理获取原始分辨率的结果
            img_path = os.path.join(self.folder_path, self.image_files[self.current_index])
            src = cv2.imread(img_path)
            
            if src is None:
                messagebox.showerror("错误", "无法加载图像文件")
                return
                
            # 执行去雾算法
            I = src.astype('float64') / 255
            dark = DarkChannel(I, 15)
            A = AtmLight(I, dark)
            te = TransmissionEstimate(I, A, 15)
            t = TransmissionRefine(src, te)
            J = Recover(I, t, A, 0.1)
            
            # 转换为可保存格式
            result_img = (J * 255).astype(np.uint8)
            result_img = cv2.cvtColor(result_img, cv2.COLOR_BGR2RGB)
            
            # 保存图像
            success = cv2.imwrite(save_path, cv2.cvtColor(result_img, cv2.COLOR_RGB2BGR))
            
            if success:
                messagebox.showinfo("成功", f"去雾结果已保存至:\n{save_path}")
            else:
                messagebox.showerror("错误", "保存失败")
                
        except Exception as e:
            messagebox.showerror("错误", f"保存过程中发生错误:\n{str(e)}")


if __name__ == "__main__":
    root = tk.Tk()
    app = DehazeApp(root)
    root.mainloop()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

VisionX Lab

你的鼓励将是我更新的动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值