s = torch.Tensor(2,3):fill(1) 关于fill指定数据填充的有关问题

博客内容展示了将代码改为使用torch.Tensor创建一个2行3列的张量,并将其元素全部填充为1的形式,即s = torch.Tensor(2,3).fill_(1),涉及到深度学习框架中张量的操作。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

需要改为如下的形式:

 

s = torch.Tensor(2,3).fill_(1)

 

这是源代码import cv2 import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from typing import Tuple, Dict import time from collections import deque import tkinter as tk from tkinter import ttk import os from lut3d import LUT3D class ImageEnhancer: def init(self): # 检查CUDA是否可用 self.device = torch.device(‘cuda’ if torch.cuda.is_available() else ‘cpu’) # 初始化增强参 self.brightness = 1.0 self.contrast = 1.0 self.sharpness = 1.0 self.saturation = 1.0 # 初始化摄像头 self.cap = cv2.VideoCapture(0) # 初始化3D LUT相关属性 self.lut = LUT3D(size=33, device=self.device) self.current_lut_name = 'identity' self.lut_strength = 1.0 self.available_luts = self._load_preset_luts() # 场景识别相关 self.auto_mode = False self.scene_presets = self._init_scene_presets() # 参历史记录 self.history = deque(maxlen=20) self.history_position = -1 self._save_history() # 保存初始状态 # GUI相关 self.gui = None self.parameter_widgets = {} def preprocess_frame(self, frame: np.ndarray) -> torch.Tensor: """将OpenCV图像帧转换为PyTorch张量 Args: frame: BGR格式的numpy组图像 Returns: 处理后的图像张量 (1, 3, H, W),值范围[0,1] """ # BGR转RGB frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # 转换为float32并归一化到[0,1] frame_norm = frame_rgb.astype(np.float32) / 255.0 # 转换为PyTorch张量并调整维度顺序 frame_tensor = torch.from_numpy(frame_norm).to(self.device) frame_tensor = frame_tensor.permute(2, 0, 1).unsqueeze(0) return frame_tensor def postprocess_frame(self, frame_tensor: torch.Tensor) -> np.ndarray: """将PyTorch张量转换回OpenCV图像格式 Args: frame_tensor: 值范围[0,1]的图像张量 (1, 3, H, W) Returns: BGR格式的numpy组图像 """ # 转换回numpy组并调整维度顺序 frame_np = frame_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy() # 转换值范围到[0,255]并转为uint8类型 frame_np = (frame_np * 255).clip(0, 255).astype(np.uint8) # RGB转BGR frame_bgr = cv2.cvtColor(frame_np, cv2.COLOR_RGB2BGR) return frame_bgr def _init_scene_presets(self) -> Dict[str, Dict[str, float]]: """初始化不同场景的预设参""" return { 'portrait': { 'brightness': 1.1, 'contrast': 1.2, 'saturation': 1.1, 'sharpness': 1.3, 'lut_strength': 0.8 }, 'landscape': { 'brightness': 1.0, 'contrast': 1.3, 'saturation': 1.4, 'sharpness': 1.5, 'lut_strength': 0.9 }, 'night': { 'brightness': 1.4, 'contrast': 1.1, 'saturation': 0.9, 'sharpness': 0.8, 'lut_strength': 0.7 } } def _load_preset_luts(self) -> dict: """加载预设的LUT""" presets = { 'identity': self.lut.get_identity_lut(), } # 添加预设的暖色和冷色LUT warm_lut = self.lut.get_identity_lut() self.lut.create_warm_lut(0.5) presets['warm'] = warm_lut cool_lut = self.lut.get_identity_lut() self.lut.create_cool_lut(0.5) presets['cool'] = cool_lut return presets def analyze_scene(self, frame_tensor: torch.Tensor) -> str: """基于简单规则分析场景类型""" # 计算基本图像统计信息 brightness = torch.mean(frame_tensor).item() saturation = torch.std(frame_tensor).item() # 简单场景判断规则 if brightness < 0.3: return 'night' elif saturation > 0.5: return 'landscape' else: return 'portrait' def get_current_params(self) -> Dict[str, float]: """获取当前参状态""" return { 'brightness': self.brightness, 'contrast': self.contrast, 'saturation': self.saturation, 'sharpness': self.sharpness, 'lut_strength': self.lut_strength } def apply_params(self, params: Dict[str, float]) -> None: """应用参设置""" self.brightness = params['brightness'] self.contrast = params['contrast'] self.saturation = params['saturation'] self.sharpness = params['sharpness'] self.lut_strength = params['lut_strength'] # 更新GUI显示(如果存在) if self.gui is not None: self._update_gui_values() def _save_history(self) -> None: """保存当前参状态到历史记录""" current_params = self.get_current_params() # 如果在历史中间位置进行了新的修改,删除后面的记录 if self.history_position < len(self.history) - 1: while len(self.history) > self.history_position + 1: self.history.pop() self.history.append(current_params) self.history_position = len(self.history) - 1 def undo(self) -> None: """撤销操作""" if self.history_position > 0: self.history_position -= 1 self.apply_params(self.history[self.history_position]) def redo(self) -> None: """重做操作""" if self.history_position < len(self.history) - 1: self.history_position += 1 self.apply_params(self.history[self.history_position]) def init_gui(self) -> None: """初始化图形界面""" self.gui = tk.Tk() self.gui.title("图像增强控制面板") # 创建参控制区 self._create_parameter_controls() # 创建按钮区 self._create_control_buttons() # 更新GUI显示 self._update_gui_values() def _create_parameter_controls(self) -> None: """创建参控制滑块""" params_frame = ttk.LabelFrame(self.gui, text="参调整") params_frame.pack(padx=5, pady=5, fill="x") parameters = { 'brightness': ('亮度', 0.0, 2.0), 'contrast': ('对比度', 0.0, 2.0), 'saturation': ('饱和度', 0.0, 2.0), 'sharpness': ('锐化', 0.0, 2.0), 'lut_strength': ('LUT强度', 0.0, 1.0) } for param, (label, min_val, max_val) in parameters.items(): frame = ttk.Frame(params_frame) frame.pack(fill="x", padx=5, pady=2) ttk.Label(frame, text=label).pack(side="left") var = tk.DoubleVar(value=getattr(self, param)) slider = ttk.Scale(frame, from_=min_val, to=max_val, variable=var, orient="horizontal") slider.pack(side="right", fill="x", expand=True) self.parameter_widgets[param] = (var, slider) var.trace_add("write", lambda *args, p=param: self._on_parameter_change(p)) def _create_control_buttons(self) -> None: """创建控制按钮""" button_frame = ttk.Frame(self.gui) button_frame.pack(padx=5, pady=5, fill="x") ttk.Button(button_frame, text="撤销", command=self.undo).pack(side="left", padx=2) ttk.Button(button_frame, text="重做", command=self.redo).pack(side="left", padx=2) auto_var = tk.BooleanVar(value=self.auto_mode) ttk.Checkbutton(button_frame, text="自动模式", variable=auto_var, command=lambda: setattr(self, 'auto_mode', auto_var.get()) ).pack(side="right", padx=2) def _update_gui_values(self) -> None: """更新GUI显示的参值""" if self.gui is None: return for param, (var, _) in self.parameter_widgets.items(): var.set(getattr(self, param)) def _on_parameter_change(self, parameter: str) -> None: """参变化回调""" if not self.auto_mode: value = self.parameter_widgets[parameter][0].get() setattr(self, parameter, value) self._save_history() def enhance_frame(self, frame_tensor: torch.Tensor) -> torch.Tensor: """增强图像帧""" if self.auto_mode: # 分析场景并应用相应的预设 scene_type = self.analyze_scene(frame_tensor) preset = self.scene_presets[scene_type] old_params = self.get_current_params() self.apply_params(preset) enhanced = self._enhance_frame_internal(frame_tensor) self.apply_params(old_params) return enhanced else: return self._enhance_frame_internal(frame_tensor) def _enhance_frame_internal(self, frame_tensor: torch.Tensor) -> torch.Tensor: """内部图像增强处理 Args: frame_tensor: 输入图像张量 (B, C, H, W),值范围[0,1] Returns: 处理后的图像张量 (B, C, H, W),值范围[0,1] """ # 确保输入张量在[0,1]范围内 frame_tensor = torch.clamp(frame_tensor, 0, 1) # 亮度调整 enhanced = frame_tensor * self.brightness # 对比度调整 mean = torch.mean(enhanced) enhanced = (enhanced - mean) * self.contrast + mean # 饱和度调整 gray = torch.mean(enhanced, dim=1, keepdim=True) enhanced = torch.lerp(gray, enhanced, self.saturation) # 锐化处理 if self.sharpness > 1.0: laplacian = torch.tensor([ [0, -1, 0], [-1, 4, -1], [0, -1, 0] ], dtype=torch.float32).to(self.device) laplacian = laplacian.view(1, 1, 3, 3).repeat(3, 1, 1, 1) sharp = F.conv2d(enhanced, laplacian, padding=1, groups=3) enhanced = enhanced + (self.sharpness - 1) * sharp # 应用3D LUT if self.lut_strength > 0: try: lut_enhanced = self.lut.apply(enhanced) enhanced = torch.lerp(enhanced, lut_enhanced, self.lut_strength) except Exception as e: print(f"应用LUT时出错: {e}") # 确保输出在[0,1]范围内 enhanced = torch.clamp(enhanced, 0, 1) return enhanced def run(self): # 初始化GUI self.init_gui() try: while True: ret, frame = self.cap.read() if not ret: break start_time = time.time() frame_tensor = self.preprocess_frame(frame) enhanced_tensor = self.enhance_frame(frame_tensor) enhanced_frame = self.postprocess_frame(enhanced_tensor) fps = 1.0 / (time.time() - start_time) # 显示信息 cv2.putText(enhanced_frame, f'LUT: {self.current_lut_name} ({self.lut_strength:.2f})', (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) cv2.putText(enhanced_frame, f'Mode: {"Auto" if self.auto_mode else "Manual"}', (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) cv2.imshow('Enhanced', enhanced_frame) cv2.imshow('Original', frame) # 处理键盘输入 key = cv2.waitKey(1) & 0xFF if key == ord('q'): break elif key == ord('['): self.set_lut_strength(self.lut_strength - 0.1) elif key == ord(']'): self.set_lut_strength(self.lut_strength + 0.1) elif key == ord('z'): # 撤销 self.undo() elif key == ord('y'): # 重做 self.redo() # 更新GUI self.gui.update() finally: self.cap.release() cv2.destroyAllWindows() if self.gui is not None: self.gui.destroy() if name == ‘main’: enhancer = ImageEnhancer() enhancer.run() 下面这是lut3d.py代码import numpy as np import torch import torch.nn.functional as F from typing import Optional, Tuple, Union import os class LUT3D: def __init__(self, size: int = 33, device: Optional[torch.device] = None): """初始化3D LUT Args: size: LUT的大小 (size x size x size) device: 运行设备 (CPU/GPU) """ self.size = size self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 创建标准网格点 self.grid = self._create_identity_lut() def _create_identity_lut(self) -> torch.Tensor: """创建单位LUT""" indices = torch.linspace(0, 1, self.size, device=self.device) r, g, b = torch.meshgrid([indices, indices, indices]) grid = torch.stack([r, g, b], dim=-1) return grid.unsqueeze(0) def get_identity_lut(self) -> torch.Tensor: """获取恒等映射LUT""" return self.grid.clone() def apply(self, image: torch.Tensor) -> torch.Tensor: """应用3D LUT到图像 Args: image: 输入图像张量 (B, C, H, W),值范围[0,1] Returns: 处理后的图像张量 (B, C, H, W) """ B, C, H, W = image.shape # 将图像重塑为 (B*H*W, C) image_flat = image.permute(0, 2, 3, 1).reshape(-1, 3) # 计算在LUT中的索引位置 scaled_input = image_flat * (self.size - 1) # 确保索引在有效范围内 scaled_input = torch.clamp(scaled_input, 0, self.size - 1) lower = scaled_input.floor().long() upper = (lower + 1).clamp(max=self.size-1) slope = scaled_input - lower # 获取每个通道的索引 lower_r, lower_g, lower_b = torch.chunk(lower, 3, dim=1) upper_r, upper_g, upper_b = torch.chunk(upper, 3, dim=1) slope_r, slope_g, slope_b = torch.chunk(slope, 3, dim=1) def get_lut_value(r, g, b): """获取LUT中的值 Args: r, g, b: 索引张量 Returns: 对应位置的LUT值 """ # 计算在展平的LUT中的索引 idx = (r * self.size * self.size + g * self.size + b).long() # 确保索引在有效范围内 idx = torch.clamp(idx, 0, self.size**3 - 1) return torch.index_select(self.grid, 0, idx.squeeze(-1)) # 获取8个相邻点的值 c000 = get_lut_value(lower_r, lower_g, lower_b) c001 = get_lut_value(lower_r, lower_g, upper_b) c010 = get_lut_value(lower_r, upper_g, lower_b) c011 = get_lut_value(lower_r, upper_g, upper_b) c100 = get_lut_value(upper_r, lower_g, lower_b) c101 = get_lut_value(upper_r, lower_g, upper_b) c110 = get_lut_value(upper_r, upper_g, lower_b) c111 = get_lut_value(upper_r, upper_g, upper_b) # 三线性插值 c00 = c000 * (1 - slope_b) + c001 * slope_b c01 = c010 * (1 - slope_b) + c011 * slope_b c10 = c100 * (1 - slope_b) + c101 * slope_b c11 = c110 * (1 - slope_b) + c111 * slope_b c0 = c00 * (1 - slope_g) + c01 * slope_g c1 = c10 * (1 - slope_g) + c11 * slope_g output = c0 * (1 - slope_r) + c1 * slope_r # 重塑回原始形状 return output.reshape(B, H, W, 3).permute(0, 3, 1, 2) def load_from_file(self, file_path: str) -> None: """从文件加载LUT数据 目前支持.npy格式的LUT文件 """ if not os.path.exists(file_path): raise FileNotFoundError(f"LUT文件不存在: {file_path}") if file_path.endswith('.npy'): lut_data = np.load(file_path) if lut_data.shape[-1] != 3: raise ValueError("LUT数据格式错误:最后一维必须是3(RGB)") self.grid = torch.from_numpy(lut_data).float().to(self.device) else: raise ValueError("不支持的文件格式") def save_to_file(self, file_path: str) -> None: """保存LUT数据到文件""" lut_data = self.grid.cpu().numpy() np.save(file_path, lut_data) def modify_lut(self, transform_func) -> None: """修改LUT数据 Args: transform_func: 接收和返回torch.Tensor的变换函 """ self.grid = transform_func(self.grid) def create_warm_lut(self, strength: float = 0.5) -> None: """创建暖色调LUT""" def warm_transform(x): # 增加红色分量,减少蓝色分量 x = x.clone() x[:, 0] = torch.clamp(x[:, 0] + strength * 0.2, 0, 1) # 红色 x[:, 2] = torch.clamp(x[:, 2] - strength * 0.1, 0, 1) # 蓝色 return x self.modify_lut(warm_transform) def create_cool_lut(self, strength: float = 0.5) -> None: """创建冷色调LUT""" def cool_transform(x): # 增加蓝色分量,减少红色分量 x = x.clone() x[:, 2] = torch.clamp(x[:, 2] + strength * 0.2, 0, 1) # 蓝色 x[:, 0] = torch.clamp(x[:, 0] - strength * 0.1, 0, 1) # 红色 return x self.modify_lut(cool_transform) 我要在不影响画质的情况下提高帧率使得实验能在使用中流畅运行
最新发布
05-31
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值