这是源代码import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Dict
import time
from collections import deque
import tkinter as tk
from tkinter import ttk
import os
from lut3d import LUT3D
class ImageEnhancer:
def init(self):
# 检查CUDA是否可用
self.device = torch.device(‘cuda’ if torch.cuda.is_available() else ‘cpu’)
# 初始化增强参数
self.brightness = 1.0
self.contrast = 1.0
self.sharpness = 1.0
self.saturation = 1.0
# 初始化摄像头
self.cap = cv2.VideoCapture(0)
# 初始化3D LUT相关属性
self.lut = LUT3D(size=33, device=self.device)
self.current_lut_name = 'identity'
self.lut_strength = 1.0
self.available_luts = self._load_preset_luts()
# 场景识别相关
self.auto_mode = False
self.scene_presets = self._init_scene_presets()
# 参数历史记录
self.history = deque(maxlen=20)
self.history_position = -1
self._save_history() # 保存初始状态
# GUI相关
self.gui = None
self.parameter_widgets = {}
def preprocess_frame(self, frame: np.ndarray) -> torch.Tensor:
"""将OpenCV图像帧转换为PyTorch张量
Args:
frame: BGR格式的numpy数组图像
Returns:
处理后的图像张量 (1, 3, H, W),值范围[0,1]
"""
# BGR转RGB
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# 转换为float32并归一化到[0,1]
frame_norm = frame_rgb.astype(np.float32) / 255.0
# 转换为PyTorch张量并调整维度顺序
frame_tensor = torch.from_numpy(frame_norm).to(self.device)
frame_tensor = frame_tensor.permute(2, 0, 1).unsqueeze(0)
return frame_tensor
def postprocess_frame(self, frame_tensor: torch.Tensor) -> np.ndarray:
"""将PyTorch张量转换回OpenCV图像格式
Args:
frame_tensor: 值范围[0,1]的图像张量 (1, 3, H, W)
Returns:
BGR格式的numpy数组图像
"""
# 转换回numpy数组并调整维度顺序
frame_np = frame_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy()
# 转换值范围到[0,255]并转为uint8类型
frame_np = (frame_np * 255).clip(0, 255).astype(np.uint8)
# RGB转BGR
frame_bgr = cv2.cvtColor(frame_np, cv2.COLOR_RGB2BGR)
return frame_bgr
def _init_scene_presets(self) -> Dict[str, Dict[str, float]]:
"""初始化不同场景的预设参数"""
return {
'portrait': {
'brightness': 1.1,
'contrast': 1.2,
'saturation': 1.1,
'sharpness': 1.3,
'lut_strength': 0.8
},
'landscape': {
'brightness': 1.0,
'contrast': 1.3,
'saturation': 1.4,
'sharpness': 1.5,
'lut_strength': 0.9
},
'night': {
'brightness': 1.4,
'contrast': 1.1,
'saturation': 0.9,
'sharpness': 0.8,
'lut_strength': 0.7
}
}
def _load_preset_luts(self) -> dict:
"""加载预设的LUT"""
presets = {
'identity': self.lut.get_identity_lut(),
}
# 添加预设的暖色和冷色LUT
warm_lut = self.lut.get_identity_lut()
self.lut.create_warm_lut(0.5)
presets['warm'] = warm_lut
cool_lut = self.lut.get_identity_lut()
self.lut.create_cool_lut(0.5)
presets['cool'] = cool_lut
return presets
def analyze_scene(self, frame_tensor: torch.Tensor) -> str:
"""基于简单规则分析场景类型"""
# 计算基本图像统计信息
brightness = torch.mean(frame_tensor).item()
saturation = torch.std(frame_tensor).item()
# 简单场景判断规则
if brightness < 0.3:
return 'night'
elif saturation > 0.5:
return 'landscape'
else:
return 'portrait'
def get_current_params(self) -> Dict[str, float]:
"""获取当前参数状态"""
return {
'brightness': self.brightness,
'contrast': self.contrast,
'saturation': self.saturation,
'sharpness': self.sharpness,
'lut_strength': self.lut_strength
}
def apply_params(self, params: Dict[str, float]) -> None:
"""应用参数设置"""
self.brightness = params['brightness']
self.contrast = params['contrast']
self.saturation = params['saturation']
self.sharpness = params['sharpness']
self.lut_strength = params['lut_strength']
# 更新GUI显示(如果存在)
if self.gui is not None:
self._update_gui_values()
def _save_history(self) -> None:
"""保存当前参数状态到历史记录"""
current_params = self.get_current_params()
# 如果在历史中间位置进行了新的修改,删除后面的记录
if self.history_position < len(self.history) - 1:
while len(self.history) > self.history_position + 1:
self.history.pop()
self.history.append(current_params)
self.history_position = len(self.history) - 1
def undo(self) -> None:
"""撤销操作"""
if self.history_position > 0:
self.history_position -= 1
self.apply_params(self.history[self.history_position])
def redo(self) -> None:
"""重做操作"""
if self.history_position < len(self.history) - 1:
self.history_position += 1
self.apply_params(self.history[self.history_position])
def init_gui(self) -> None:
"""初始化图形界面"""
self.gui = tk.Tk()
self.gui.title("图像增强控制面板")
# 创建参数控制区
self._create_parameter_controls()
# 创建按钮区
self._create_control_buttons()
# 更新GUI显示
self._update_gui_values()
def _create_parameter_controls(self) -> None:
"""创建参数控制滑块"""
params_frame = ttk.LabelFrame(self.gui, text="参数调整")
params_frame.pack(padx=5, pady=5, fill="x")
parameters = {
'brightness': ('亮度', 0.0, 2.0),
'contrast': ('对比度', 0.0, 2.0),
'saturation': ('饱和度', 0.0, 2.0),
'sharpness': ('锐化', 0.0, 2.0),
'lut_strength': ('LUT强度', 0.0, 1.0)
}
for param, (label, min_val, max_val) in parameters.items():
frame = ttk.Frame(params_frame)
frame.pack(fill="x", padx=5, pady=2)
ttk.Label(frame, text=label).pack(side="left")
var = tk.DoubleVar(value=getattr(self, param))
slider = ttk.Scale(frame, from_=min_val, to=max_val,
variable=var, orient="horizontal")
slider.pack(side="right", fill="x", expand=True)
self.parameter_widgets[param] = (var, slider)
var.trace_add("write", lambda *args, p=param: self._on_parameter_change(p))
def _create_control_buttons(self) -> None:
"""创建控制按钮"""
button_frame = ttk.Frame(self.gui)
button_frame.pack(padx=5, pady=5, fill="x")
ttk.Button(button_frame, text="撤销", command=self.undo).pack(side="left", padx=2)
ttk.Button(button_frame, text="重做", command=self.redo).pack(side="left", padx=2)
auto_var = tk.BooleanVar(value=self.auto_mode)
ttk.Checkbutton(button_frame, text="自动模式",
variable=auto_var,
command=lambda: setattr(self, 'auto_mode', auto_var.get())
).pack(side="right", padx=2)
def _update_gui_values(self) -> None:
"""更新GUI显示的参数值"""
if self.gui is None:
return
for param, (var, _) in self.parameter_widgets.items():
var.set(getattr(self, param))
def _on_parameter_change(self, parameter: str) -> None:
"""参数变化回调"""
if not self.auto_mode:
value = self.parameter_widgets[parameter][0].get()
setattr(self, parameter, value)
self._save_history()
def enhance_frame(self, frame_tensor: torch.Tensor) -> torch.Tensor:
"""增强图像帧"""
if self.auto_mode:
# 分析场景并应用相应的预设
scene_type = self.analyze_scene(frame_tensor)
preset = self.scene_presets[scene_type]
old_params = self.get_current_params()
self.apply_params(preset)
enhanced = self._enhance_frame_internal(frame_tensor)
self.apply_params(old_params)
return enhanced
else:
return self._enhance_frame_internal(frame_tensor)
def _enhance_frame_internal(self, frame_tensor: torch.Tensor) -> torch.Tensor:
"""内部图像增强处理
Args:
frame_tensor: 输入图像张量 (B, C, H, W),值范围[0,1]
Returns:
处理后的图像张量 (B, C, H, W),值范围[0,1]
"""
# 确保输入张量在[0,1]范围内
frame_tensor = torch.clamp(frame_tensor, 0, 1)
# 亮度调整
enhanced = frame_tensor * self.brightness
# 对比度调整
mean = torch.mean(enhanced)
enhanced = (enhanced - mean) * self.contrast + mean
# 饱和度调整
gray = torch.mean(enhanced, dim=1, keepdim=True)
enhanced = torch.lerp(gray, enhanced, self.saturation)
# 锐化处理
if self.sharpness > 1.0:
laplacian = torch.tensor([
[0, -1, 0],
[-1, 4, -1],
[0, -1, 0]
], dtype=torch.float32).to(self.device)
laplacian = laplacian.view(1, 1, 3, 3).repeat(3, 1, 1, 1)
sharp = F.conv2d(enhanced, laplacian, padding=1, groups=3)
enhanced = enhanced + (self.sharpness - 1) * sharp
# 应用3D LUT
if self.lut_strength > 0:
try:
lut_enhanced = self.lut.apply(enhanced)
enhanced = torch.lerp(enhanced, lut_enhanced, self.lut_strength)
except Exception as e:
print(f"应用LUT时出错: {e}")
# 确保输出在[0,1]范围内
enhanced = torch.clamp(enhanced, 0, 1)
return enhanced
def run(self):
# 初始化GUI
self.init_gui()
try:
while True:
ret, frame = self.cap.read()
if not ret:
break
start_time = time.time()
frame_tensor = self.preprocess_frame(frame)
enhanced_tensor = self.enhance_frame(frame_tensor)
enhanced_frame = self.postprocess_frame(enhanced_tensor)
fps = 1.0 / (time.time() - start_time)
# 显示信息
cv2.putText(enhanced_frame,
f'LUT: {self.current_lut_name} ({self.lut_strength:.2f})',
(10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
cv2.putText(enhanced_frame,
f'Mode: {"Auto" if self.auto_mode else "Manual"}',
(10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
cv2.imshow('Enhanced', enhanced_frame)
cv2.imshow('Original', frame)
# 处理键盘输入
key = cv2.waitKey(1) & 0xFF
if key == ord('q'):
break
elif key == ord('['):
self.set_lut_strength(self.lut_strength - 0.1)
elif key == ord(']'):
self.set_lut_strength(self.lut_strength + 0.1)
elif key == ord('z'): # 撤销
self.undo()
elif key == ord('y'): # 重做
self.redo()
# 更新GUI
self.gui.update()
finally:
self.cap.release()
cv2.destroyAllWindows()
if self.gui is not None:
self.gui.destroy()
if name == ‘main’:
enhancer = ImageEnhancer()
enhancer.run()
下面这是lut3d.py代码import numpy as np
import torch
import torch.nn.functional as F
from typing import Optional, Tuple, Union
import os
class LUT3D:
def __init__(self, size: int = 33, device: Optional[torch.device] = None):
"""初始化3D LUT
Args:
size: LUT的大小 (size x size x size)
device: 运行设备 (CPU/GPU)
"""
self.size = size
self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 创建标准网格点
self.grid = self._create_identity_lut()
def _create_identity_lut(self) -> torch.Tensor:
"""创建单位LUT"""
indices = torch.linspace(0, 1, self.size, device=self.device)
r, g, b = torch.meshgrid([indices, indices, indices])
grid = torch.stack([r, g, b], dim=-1)
return grid.unsqueeze(0)
def get_identity_lut(self) -> torch.Tensor:
"""获取恒等映射LUT"""
return self.grid.clone()
def apply(self, image: torch.Tensor) -> torch.Tensor:
"""应用3D LUT到图像
Args:
image: 输入图像张量 (B, C, H, W),值范围[0,1]
Returns:
处理后的图像张量 (B, C, H, W)
"""
B, C, H, W = image.shape
# 将图像重塑为 (B*H*W, C)
image_flat = image.permute(0, 2, 3, 1).reshape(-1, 3)
# 计算在LUT中的索引位置
scaled_input = image_flat * (self.size - 1)
# 确保索引在有效范围内
scaled_input = torch.clamp(scaled_input, 0, self.size - 1)
lower = scaled_input.floor().long()
upper = (lower + 1).clamp(max=self.size-1)
slope = scaled_input - lower
# 获取每个通道的索引
lower_r, lower_g, lower_b = torch.chunk(lower, 3, dim=1)
upper_r, upper_g, upper_b = torch.chunk(upper, 3, dim=1)
slope_r, slope_g, slope_b = torch.chunk(slope, 3, dim=1)
def get_lut_value(r, g, b):
"""获取LUT中的值
Args:
r, g, b: 索引张量
Returns:
对应位置的LUT值
"""
# 计算在展平的LUT中的索引
idx = (r * self.size * self.size + g * self.size + b).long()
# 确保索引在有效范围内
idx = torch.clamp(idx, 0, self.size**3 - 1)
return torch.index_select(self.grid, 0, idx.squeeze(-1))
# 获取8个相邻点的值
c000 = get_lut_value(lower_r, lower_g, lower_b)
c001 = get_lut_value(lower_r, lower_g, upper_b)
c010 = get_lut_value(lower_r, upper_g, lower_b)
c011 = get_lut_value(lower_r, upper_g, upper_b)
c100 = get_lut_value(upper_r, lower_g, lower_b)
c101 = get_lut_value(upper_r, lower_g, upper_b)
c110 = get_lut_value(upper_r, upper_g, lower_b)
c111 = get_lut_value(upper_r, upper_g, upper_b)
# 三线性插值
c00 = c000 * (1 - slope_b) + c001 * slope_b
c01 = c010 * (1 - slope_b) + c011 * slope_b
c10 = c100 * (1 - slope_b) + c101 * slope_b
c11 = c110 * (1 - slope_b) + c111 * slope_b
c0 = c00 * (1 - slope_g) + c01 * slope_g
c1 = c10 * (1 - slope_g) + c11 * slope_g
output = c0 * (1 - slope_r) + c1 * slope_r
# 重塑回原始形状
return output.reshape(B, H, W, 3).permute(0, 3, 1, 2)
def load_from_file(self, file_path: str) -> None:
"""从文件加载LUT数据
目前支持.npy格式的LUT文件
"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"LUT文件不存在: {file_path}")
if file_path.endswith('.npy'):
lut_data = np.load(file_path)
if lut_data.shape[-1] != 3:
raise ValueError("LUT数据格式错误:最后一维必须是3(RGB)")
self.grid = torch.from_numpy(lut_data).float().to(self.device)
else:
raise ValueError("不支持的文件格式")
def save_to_file(self, file_path: str) -> None:
"""保存LUT数据到文件"""
lut_data = self.grid.cpu().numpy()
np.save(file_path, lut_data)
def modify_lut(self, transform_func) -> None:
"""修改LUT数据
Args:
transform_func: 接收和返回torch.Tensor的变换函数
"""
self.grid = transform_func(self.grid)
def create_warm_lut(self, strength: float = 0.5) -> None:
"""创建暖色调LUT"""
def warm_transform(x):
# 增加红色分量,减少蓝色分量
x = x.clone()
x[:, 0] = torch.clamp(x[:, 0] + strength * 0.2, 0, 1) # 红色
x[:, 2] = torch.clamp(x[:, 2] - strength * 0.1, 0, 1) # 蓝色
return x
self.modify_lut(warm_transform)
def create_cool_lut(self, strength: float = 0.5) -> None:
"""创建冷色调LUT"""
def cool_transform(x):
# 增加蓝色分量,减少红色分量
x = x.clone()
x[:, 2] = torch.clamp(x[:, 2] + strength * 0.2, 0, 1) # 蓝色
x[:, 0] = torch.clamp(x[:, 0] - strength * 0.1, 0, 1) # 红色
return x
self.modify_lut(cool_transform) 我要在不影响画质的情况下提高帧率使得实验能在使用中流畅运行
最新发布