<think>这个错误是因为在应用水彩风格时,我们调用的gaussian_blur函数对输入张量的维度有要求。根据错误信息,gaussian_blur函数要求输入是2D、3D、4D或5D,并且使用非恒定填充。在我们的代码中,当应用水彩风格时,我们传入的edge_tensor可能是3D张量(1, H, W)或4D张量(1, 1, H, W)?我们需要检查张量的维度。
在`apply_style_effect`函数中,我们传入的edge_tensor是4D张量(1, 1, H, W)吗?实际上,在`convert_to_line_art`函数中,我们得到的edge_tensor是经过卷积操作后的4D张量(1, 1, H, W)。但是在应用风格之前,我们进行了归一化和增强,但维度没有变化。
然而,在`apply_style_effect`函数中,我们针对水彩风格调用了gaussian_blur,并使用了`edge_tensor.unsqueeze(0)`,这会使4D张量变成5D?然后再用squeeze(0)变回4D?这可能导致维度不匹配。
让我们重新检查维度:
在`convert_to_line_art`函数中:
gray_tensor: [1, 1, H, W] (4D)
经过卷积后,edge_tensor: [1, 1, H, W] (4D)
然后我们调用`apply_style_effect(edge_tensor, ...)`,传入的edge_tensor是4D张量。
在`apply_style_effect`中,对于水彩风格:
edge_tensor = gaussian_blur(edge_tensor.unsqueeze(0), kernel_size=3, sigma=0.5).squeeze(0)
这里,edge_tensor.unsqueeze(0) -> [1, 1, 1, H, W] (5D)
gaussian_blur函数支持5D吗?根据torchvision文档,gaussian_blur支持4D张量(batch, channel, height, width)和3D张量(channel, height, width)。所以5D是不支持的。
因此,我们应该直接使用4D张量进行模糊,而不需要unsqueeze。
修复:将水彩风格的处理改为:
edge_tensor = gaussian_blur(edge_tensor, kernel_size=3, sigma=0.5)
但是注意,在`convert_to_line_art`函数中,我们已经在4D张量上工作,所以直接传入4D即可。
另外,在素描、漫画和钢笔风格中,我们也是直接操作4D张量吗?是的,因为我们的操作都是逐元素的,所以没有问题。
因此,修改水彩风格的处理为:
elif style_name == "水彩":
# 水彩效果:轻微模糊
edge_tensor = gaussian_blur(edge_tensor, kernel_size=3, sigma=0.5)
但是注意,在`convert_to_line_art`函数中,我们后面还有一个反转操作,然后保存为图像。所以这里我们处理的是4D张量,但gaussian_blur要求输入是4D(batch, channel, height, width)或3D(channel, height, width)?根据torchvision.transforms.functional.gaussian_blur的文档,它要求输入为(..., H, W),其中...表示任意维度,但至少是2D。所以4D是允许的。
让我们修改代码。
同时,我们也要注意,在`convert_to_line_art`函数中,我们在应用高斯模糊预处理时已经使用了gaussian_blur函数,所以这里再次使用应该没有问题。
修改后的`apply_style_effect`函数:
```python
def apply_style_effect(self, edge_tensor, style_name):
"""应用不同的线稿风格效果"""
if style_name == "素描":
# 素描效果:增强对比度
edge_tensor = torch.pow(edge_tensor, 0.7)
elif style_name == "漫画":
# 漫画效果:二值化处理
edge_tensor = (edge_tensor > 0.2).float()
elif style_name == "水彩":
# 水彩效果:轻微模糊
edge_tensor = gaussian_blur(edge_tensor, kernel_size=3, sigma=0.5)
elif style_name == "钢笔":
# 钢笔效果:细化线条
edge_tensor = torch.clamp(edge_tensor * 1.5, 0, 1)
return edge_tensor
```
这样修改后,水彩风格应该可以正常使用了。
另外,我们还需要注意,在`convert_to_line_art`函数中,我们应用了两次高斯模糊(一次在预处理,一次在风格处理)。这是设计上的选择,预处理的高斯模糊是为了降噪,而风格处理中的模糊是为了艺术效果。
为了确保代码的健壮性,我们还可以在调用gaussian_blur之前检查张量的维度,但通常这里不会有问题。
现在,我们提供修复后的完整代码(仅修改了`apply_style_effect`函数中的水彩风格部分):</think>这个错误是由于在应用水彩风格效果时,高斯模糊函数的输入维度不正确导致的。让我修复这个问题并提供完整的解决方案。
### 问题分析与修复
错误发生在水彩风格处理部分:
```python
elif style_name == "水彩":
# 水彩效果:轻微模糊
edge_tensor = gaussian_blur(edge_tensor.unsqueeze(0), kernel_size=3, sigma=0.5).squeeze(0)
```
问题原因:
1. `gaussian_blur` 函数需要特定维度的输入张量
2. 在应用风格效果时,输入张量已经是4D(1,1,H,W),但`.unsqueeze(0)`会使其变为5D
3. 高斯模糊函数不支持5D张量
### 修复后的完整代码
```python
import os
import torch
import numpy as np
import threading
import tkinter as tk
from tkinter import filedialog, ttk, messagebox, Canvas
from PIL import Image, ImageTk, ImageOps
import torch.nn.functional as F
import time
import sys
from torchvision.transforms.functional import gaussian_blur
class LineArtConverter:
def __init__(self, root):
self.root = root
self.root.title("ArtLine Pro - GPU加速线稿生成器")
self.root.geometry("900x750")
self.root.resizable(True, True)
# 检查CUDA是否可用
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 存储要处理的文件路径
self.file_paths = []
# 进度相关变量
self.total_files = 0
self.processed_files = 0
self.processing = False
self.cancel_processing = False
# 预览相关
self.preview_image = None
self.preview_canvas = None
self.preview_window = None
# 创建现代化UI
self.setup_modern_ui()
# 显示设备信息
device_info = f"使用设备: {self.device}"
if self.device.type == "cpu":
device_info += " (未检测到可用的CUDA设备)"
self.device_label.config(text=device_info)
def setup_modern_ui(self):
"""创建现代化UI界面"""
# 设置现代主题
style = ttk.Style()
style.theme_use('clam')
style.configure('.', background='#f5f5f5')
style.configure('TFrame', background='#f5f5f5')
style.configure('TLabel', background='#f5f5f5', foreground='#333333', font=('Segoe UI', 9))
style.configure('TButton', background='#4a6fa5', foreground='white',
font=('Segoe UI', 9, 'bold'), borderwidth=1)
style.map('TButton', background=[('active', '#3a5a8a')])
style.configure('Title.TLabel', font=('Segoe UI', 18, 'bold'), foreground='#2c3e50')
style.configure('Section.TLabelframe', borderwidth=2, relief='groove',
background='#f5f5f5', foreground='#2c3e50')
style.configure('Progress.Horizontal.TProgressbar', thickness=12,
background='#4a6fa5', troughcolor='#e0e0e0')
style.configure('TCombobox', fieldbackground='white', background='white')
# 主布局
main_frame = ttk.Frame(self.root, padding=(20, 15))
main_frame.pack(fill=tk.BOTH, expand=True)
# 标题区域
header_frame = ttk.Frame(main_frame)
header_frame.pack(fill=tk.X, pady=(0, 15))
ttk.Label(header_frame, text="ArtLine Pro", style='Title.TLabel').pack(side=tk.LEFT)
ttk.Label(header_frame, text="GPU加速的智能线稿生成器",
font=('Segoe UI', 10), foreground='#7f8c8d').pack(side=tk.LEFT, padx=(10, 0))
# 设备信息
device_frame = ttk.Frame(main_frame, padding=10, relief='groove')
device_frame.pack(fill=tk.X, pady=(0, 15))
ttk.Label(device_frame, text="系统状态:", font=('Segoe UI', 9, 'bold')).pack(side=tk.LEFT)
self.device_label = ttk.Label(device_frame, text=f"使用设备: {self.device}",
font=('Segoe UI', 9), foreground='#2c3e50')
self.device_label.pack(side=tk.LEFT, padx=(5, 0))
# 性能标签
self.perf_label = ttk.Label(device_frame, text="", font=('Segoe UI', 9), foreground='#27ae60')
self.perf_label.pack(side=tk.RIGHT)
# 双列布局
columns_frame = ttk.Frame(main_frame)
columns_frame.pack(fill=tk.BOTH, expand=True)
# 左侧控制面板
control_frame = ttk.LabelFrame(columns_frame, text="处理设置", style='Section.TLabelframe', padding=15)
control_frame.pack(side=tk.LEFT, fill=tk.BOTH, padx=(0, 10))
# 边缘强度控制
edge_frame = tttk.Frame(control_frame)
edge_frame.pack(fill=tk.X, pady=(0, 15))
ttk.Label(edge_frame, text="边缘强度:").pack(side=tk.LEFT, padx=(0, 10))
self.edge_strength = tk.DoubleVar(value=1.5)
edge_scale = ttk.Scale(
edge_frame, from_=0.5, to=3.0,
variable=self.edge_strength, orient=tk.HORIZONTAL, length=200
)
edge_scale.pack(side=tk.LEFT, expand=True)
self.edge_value_label = ttk.Label(edge_frame, text="1.5", width=4)
self.edge_value_label.pack(side=tk.LEFT, padx=(10, 0))
# 模糊强度控制
blur_frame = ttk.Frame(control_frame)
blur_frame.pack(fill=tk.X, pady=(0, 15))
ttk.Label(blur_frame, text="模糊强度:").pack(side=tk.LEFT, padx=(0, 10))
self.blur_strength = tk.DoubleVar(value=1.0)
blur_scale = ttk.Scale(
blur_frame, from_=0.0, to=3.0,
variable=self.blur_strength, orient=tk.HORIZONTAL, length=200
)
blur_scale.pack(side=tk.LEFT, expand=True)
self.blur_value_label = ttk.Label(blur_frame, text="1.0", width=4)
self.blur_value_label.pack(side=tk.LEFT, padx=(10, 0))
# 线稿风格选择
style_frame = ttk.Frame(control_frame)
style_frame.pack(fill=tk.X, pady=(0, 15))
ttk.Label(style_frame, text="线稿风格:").pack(side=tk.LEFT, padx=(0, 10))
self.style_var = tk.StringVar(value="素描")
styles = ["素描", "漫画", "水彩", "钢笔"]
style_combo = ttk.Combobox(style_frame, textvariable=self.style_var, values=styles, state="readonly", width=10)
style_combo.pack(side=tk.LEFT)
# 输出格式选择
format_frame = ttk.FFrame(control_frame)
format_frame.pack(fill=tk.X, pady=(0, 15))
ttk.Label(format_frame, text="输出格式:").pack(side=tk.LEFT, padx=(0, 10))
self.format_var = tk.StringVar(value="PNG (无损)")
formats = ["PNG (无损)", "JPG (高质量)", "JPG (中等质量)", "JPG (低质量)"]
format_combo = ttk.Combobox(format_frame, textvariable=self.format_var, values=formats, state="readonly", width=15)
format_combo.pack(side=tk.LEFT)
# 操作按钮
btn_frame = ttk.Frame(control_frame)
btn_frame.pack(fill=tk.X, pady=(10, 0))
ttk.Button(btn_frame, text="导入图片", command=self.import_images).pack(side=tk.LEFT, padx=(0, 10))
ttk.Button(btn_frame, text="导入文件夹", command=self.import_folder).pack(side=tk.LEFT, padx=(0, 10))
ttk.Button(btn_frame, text="清除列表", command=self.clear_all).pack(side=tk.LEFT)
# 右侧文件区域
file_frame = ttk.LabelFrame(columns_frame, text="文件处理", style='Section.TLabelframe', padding=15)
file_frame.pack(side=tk.RIGHT, fill=tk.BOTH, expand=True)
# 文件列表
list_container = ttk.Frame(file_frame)
list_container.pack(fill=tk.BOTH, expand=True, pady=(0, 10))
self.file_listbox = tk.Listbox(
list_container,
selectmode=tk.EXTENDED,
background='white',
relief='flat',
highlightthickness=0,
font=('Segoe UI', 9)
)
scrollbar = ttk.Scrollbar(list_container, orient=tk.VERTICAL, command=self.file_listbox.yview)
self.file_listbox.config(yscrollcommand=scrollbar.set)
scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
self.file_listbox.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
# 列表操作按钮
list_btn_frame = ttk.Frame(file_frame)
list_btn_frame.pack(fill=tk.X)
ttk.Button(list_btn_frame, text="移除选中", command=self.remove_selected).pack(side=tk.LEFT, padx=(0, 5))
ttk.Button(list_btn_frame, text="预览选中", command=self.preview_selected).pack(side=tk.LEFT, padx=(0, 5))
ttk.Button(list_btn_frame, text="打开位置", command=self.open_file_location).pack(side=tk.LEFT)
# 进度区域
progress_frame = ttk.LabelFrame(main_frame, text="处理进度", style='Section.TLabelframe', padding=15)
progress_frame.pack(fill=tk.X, pady=(15, 0))
self.progress_var = tk.DoubleVar()
self.progress_bar = ttk.Progressbar(
progress_frame,
variable=self.progress_var,
maximum=100,
style='Progress.Horizontal.TProgressbar',
length=100
)
self.progress_bar.pack(fill=tk.X, pady=(0, 5))
self.progress_label = ttk.Label(
progress_frame,
text="等待开始处理...",
font=('Segoe UI', 9)
)
self.progress_label.pack(anchor=tk.W)
# 底部操作按钮
action_frame = ttk.Frame(main_frame)
action_frame.pack(fill=tk.X, pady=(15, 0))
self.cancel_btn = ttk.Button(
action_frame,
text="取消处理",
command=self.cancel_processing,
state=tk.DISABLED
)
self.cancel_btn.pack(side=tk.LEFT)
self.process_btn = ttk.Button(
action_frame,
text="开始生成线稿",
command=self.start_processing,
width=15
)
self.process_btn.pack(side=tk.RIGHT)
# 绑定事件
edge_scale.bind("<Motion>", lambda e: self.edge_value_label.config(
text=f"{self.edge_strength.get():.1f}"
))
blur_scale.bind("<Motion>", lambda e: self.blur_value_label.config(
text=f"{self.blur_strength.get():.1f}"
))
def import_images(self):
"""导入单张或多张图片"""
file_types = (
('图像文件', '*.jpg *.jpeg *.png *.bmp *.gif'),
('所有文件', '*.*')
)
filenames = filedialog.askopenfilenames(
title="选择图片",
initialdir="/",
filetypes=file_types
)
if filenames:
for filename in filenames:
if filename not in self.file_paths:
self.file_paths.append(filename)
self.file_listbox.insert(tk.END, os.path.basename(filename))
def import_folder(self):
"""导入文件夹"""
folder = filedialog.askdirectory(title="选择图片文件夹")
if folder:
# 获取文件夹中所有图像文件
valid_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif')
image_files = [
os.path.join(folder, f) for f in os.listdir(folder)
if os.path.isfile(os.path.join(folder, f)) and
f.lower().endswith(valid_extensions)
]
if not image_files:
messagebox.showinfo("提示", f"在 {folder} 中未找到任何图像文件")
return
# 添加文件到列表
added_count = 0
for file_path in image_files:
if file_path not in self.file_paths:
self.file_paths.append(file_path)
self.file_listbox.insert(tk.END, os.path.basename(file_path))
added_count += 1
if added_count > 0:
messagebox.showinfo("提示", f"已添加 {added_count} 个图像文件")
else:
messagebox.showinfo("提示", "所有图像文件已在列表中")
def clear_all(self):
"""清空文件列表"""
if self.file_paths:
if messagebox.askyesno("确认", "确定要清空所有文件列表吗?"):
self.file_listbox.delete(0, tk.END)
self.file_paths.clear()
def remove_selected(self):
"""移除选中的文件"""
selected_indices = self.file_listbox.curselection()
if not selected_indices:
return
# 从后往前删除,避免索引变化问题
for i in sorted(selected_indices, reverse=True):
self.file_listbox.delete(i)
del self.file_paths[i]
def open_file_location(self):
"""打开选中文件所在位置"""
selected_indices = self.file_listbox.curselection()
if not selected_indices:
return
# 只取第一个选中的文件
file_path = self.file_paths[selected_indices[0]]
folder_path = os.path.dirname(file_path)
if os.name == 'nt': # Windows
os.startfile(folder_path)
elif os.name == 'posix': # macOS, Linux
import subprocess
subprocess.Popen(['open', folder_path] if sys.platform == 'darwin' else ['xdg-open', folder_path])
def preview_selected(self):
"""预览选中图片"""
selected_indices = self.file_listbox.curselection()
if not selected_indices:
messagebox.showwarning("警告", "请先选择一张图片")
return
# 只取第一个选中的文件
file_path = self.file_paths[selected_indices[0]]
# 创建预览窗口
if self.preview_window is None or not self.preview_window.winfo_exists():
self.preview_window = tk.Toplevel(self.root)
self.preview_window.title("图片预览")
self.preview_window.geometry("600x400")
# 创建画布
self.preview_canvas = Canvas(self.preview_window, bg='white', bd=0, highlightthickness=0)
self.preview_canvas.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
# 添加关闭按钮
btn_frame = ttk.Frame(self.preview_window)
btn_frame.pack(fill=tk.X, padx=10, pady=(0, 10))
ttk.Button(btn_frame, text="关闭预览", command=self.preview_window.destroy).pack(side=tk.RIGHT)
ttk.Button(btn_frame, text="转换预览", command=lambda: self.convert_preview(file_path)).pack(side=tk.RIGHT, padx=5)
# 显示图片
self.show_image(file_path)
self.preview_window.lift()
def show_image(self, file_path):
"""在预览窗口中显示图片"""
try:
img = Image.open(file_path)
# 调整图片大小以适应窗口
width, height = img.size
max_size = 550
if width > max_size or height > max_size:
ratio = min(max_size/width, max_size/height)
new_size = (int(width * ratio), int(height * ratio))
img = img.resize(new_size, Image.LANCZOS)
# 转换为PhotoImage并显示
self.preview_image = ImageTk.PhotoImage(img)
self.preview_canvas.delete("all")
self.preview_canvas.config(width=img.width, height=img.height)
self.preview_canvas.create_image(0, 0, anchor=tk.NW, image=self.preview_image)
except Exception as e:
messagebox.showerror("错误", f"无法加载图片: {str(e)}")
def convert_preview(self, file_path):
"""转换并预览线稿效果"""
try:
# 转换图片
start_time = time.time()
output_path = self.convert_to_line_art(file_path, self.edge_strength.get(), self.blur_strength.get(), preview=True)
process_time = time.time() - start_time
# 更新性能标签
self.perf_label.config(text=f"处理时间: {process_time:.2f}秒")
if output_path:
# 显示线稿
self.show_image(output_path)
# 删除临时预览文件
os.remove(output_path)
except Exception as e:
messagebox.showerror("错误", f"转换失败: {str(e)}")
def cancel_processing(self):
"""取消处理过程"""
if self.processing:
self.cancel_processing = True
self.progress_label.config(text="正在取消处理...")
self.cancel_btn.config(state=tk.DISABLED)
def start_processing(self):
"""开始处理选中的文件"""
if self.processing:
return
if not self.file_paths:
messagebox.showwarning("警告", "请先选择要处理的图片或文件夹")
return
# 禁用处理按钮,防止重复处理
self.processing = True
self.cancel_processing = False
self.process_btn.config(state=tk.DISABLED)
self.cancel_btn.config(state=tk.NORMAL)
# 初始化进度
self.total_files = len(self.file_paths)
self.processed_files = 0
self.progress_var.set(0)
self.progress_label.config(text=f"准备处理 {self.total_files} 个文件...")
# 在新线程中处理文件,避免界面卡死
processing_thread = threading.Thread(target=self.process_files)
processing_thread.daemon = True
processing_thread.start()
def update_progress(self, filename, success=True):
"""更新进度条和状态"""
self.processed_files += 1
progress = (self.processed_files / self.total_files) * 100
self.progress_var.set(progress)
status = "成功" if success else "失败"
self.progress_label.config(
text=f"已处理 {self.processed_files}/{self.total_files} 个文件 - {os.path.basename(filename)} {status}"
)
# 处理完成后恢复按钮状态
if self.processed_files == self.total_files or self.cancel_processing:
self.progress_label.config(text=f"处理完成!共处理 {self.processed_files} 个文件")
self.process_btn.config(state=tk.NORMAL)
self.cancel_btn.config(state=tk.DISABLED)
self.processing = False
if not self.cancel_processing:
messagebox.showinfo("完成", f"所有文件处理完成!共处理 {self.processed_files} 个文件")
else:
messagebox.showinfo("取消", f"已取消处理!已完成 {self.processed_files} 个文件")
def process_files(self):
"""处理所有选中的文件"""
edge_strength = self.edge_strength.get()
blur_strength = self.blur_strength.get()
for i, file_path in enumerate(self.file_paths):
if self.cancel_processing:
break
try:
# 调用转换函数
start_time = time.time()
result = self.convert_to_line_art(file_path, edge_strength, blur_strength)
process_time = time.time() - start_time
# 在主线程中更新UI
self.root.after(0, lambda f=file_path, t=process_time: self.update_progress(f, True))
self.root.after(0, lambda t=process_time: self.perf_label.config(
text=f"处理时间: {t:.2f}秒" if t > 0.1 else "处理时间: <0.1秒"
))
except Exception as e:
print(f"处理 {file_path} 时出错: {str(e)}")
self.root.after(0, self.update_progress, file_path, False)
# 处理完成或取消
self.root.after(0, lambda: self.update_progress("", True))
def create_edge_detection_kernel(self):
"""创建边缘检测卷积核"""
# Sobel水平和垂直边缘检测核
sobel_x = torch.tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]], dtype=torch.float32).view(1, 1, 3, 3)
sobel_y = torch.tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]], dtype=torch.float32).view(1, 1, 3, 3)
return sobel_x, sobel_y
def apply_style_effect(self, edge_tensor, style_name):
"""应用不同的线稿风格效果"""
# 确保张量在正确的设备上
edge_tensor = edge_tensor.to(self.device)
if style_name == "素描":
# 素描效果:增强对比度
edge_tensor = torch.pow(edge_tensor, 0.7)
elif style_name == "漫画":
# 漫画效果:二值化处理
edge_tensor = (edge_tensor > 0.2).float()
elif style_name == "水彩":
# 水彩效果:轻微模糊
# 直接使用4D张量,不需要unsqueeze
edge_tensor = gaussian_blur(edge_tensor, kernel_size=3, sigma=0.5)
elif style_name == "钢笔":
# 钢笔效果:细化线条
edge_tensor = torch.clamp(edge_tensor * 1.5, 0, 1)
return edge_tensor
def convert_to_line_art(self, image_path, edge_strength=1.2, blur_strength=1.0, preview=False):
"""
优化的线稿转换方法
"""
try:
# 1. 读取图像并转为灰度图
with Image.open(image_path) as img:
gray_img = img.convert('L')
# 转换为numpy数组
gray_np = np.array(gray_img, dtype=np.float32) / 255.0
# 转换为张量,确保是4D张量 [批次=1, 通道=1, 高度, 宽度]
gray_tensor = torch.from_numpy(gray_np).to(self.device)
gray_tensor = gray_tensor.unsqueeze(0).unsqueeze(0)
# 2. 应用高斯模糊减少噪声
if blur_strength > 0:
# 使用torchvision的gaussian_blur函数
gray_tensor = gaussian_blur(gray_tensor, kernel_size=5, sigma=blur_strength)
# 3. 应用Sobel边缘检测
sobel_x, sobel_y = self.create_edge_detection_kernel()
sobel_x = sobel_x.to(self.device)
sobel_y = sobel_y.to(self.device)
edge_x = F.conv2d(gray_tensor, sobel_x, padding=1)
edge_y = F.conv2d(gray_tensor, sobel_y, padding=1)
# 4. 合并边缘并归一化
edge_tensor = torch.sqrt(edge_x**2 + edge_y**2)
edge_tensor = edge_tensor / (edge_tensor.max() + 1e-8)
# 5. 增强边缘强度
edge_tensor = torch.clamp(edge_tensor * edge_strength, 0.0, 1.0)
# 6. 应用线稿风格
edge_tensor = self.apply_style_effect(edge_tensor, self.style_var.get())
# 7. 反转得到线稿(黑线白底)
line_art_tensor = 1.0 - edge_tensor
# 8. 转换为最终图像
line_art_np = np.squeeze(line_art_tensor.detach().cpu().numpy()) * 255
line_art_np = line_art_np.astype(np.uint8)
line_art = Image.fromarray(line_art_np)
# 9. 保存结果
if preview:
# 预览模式下保存为临时文件
output_path = os.path.join(os.path.dirname(image_path), "preview_lineart.png")
line_art.save(output_path, format='PNG')
else:
# 创建输出目录
output_dir = os.path.join(os.path.dirname(image_path), "line_art_output")
os.makedirs(output_dir, exist_ok=True)
# 根据选择的格式保存
base_name = os.path.splitext(os.path.basename(image_path))[0]
if self.format_var.get().startswith("PNG"):
output_path = os.path.join(output_dir, f"{base_name}_lineart.png")
line_art.save(output_path, format='PNG', compress_level=1)
else:
quality = {
"JPG (高质量)": 95,
"JPG (中等质量)": 75,
"JPG (低质量)": 50
}[self.format_var.get()]
output_path = os.path.join(output_dir, f"{base_name}_lineart.jpg")
line_art.save(output_path, format='JPEG', quality=quality)
return output_path
except Exception as e:
print(f"处理图像时出错: {str(e)}")
return None
if __name__ == "__main__":
root = tk.Tk()
app = LineArtConverter(root)
root.mainloop()
```
### 关键修复点
#### 1. 修复水彩风格处理
```python
def apply_style_effect(self, edge_tensor, style_name):
# ...
elif style_name == "水彩":
# 水彩效果:轻微模糊
# 直接使用4D张量,不需要unsqueeze
edge_tensor = gaussian_blur(edge_tensor, kernel_size=3, sigma=0.5)
# ...
```
#### 2. 添加设备一致性检查
```python
def apply_style_effect(self, edge_tensor, style_name):
# 确保张量在正确的设备上
edge_tensor = edge_tensor.to(self.device)
# ...
```
### 其他改进
1. **增强错误处理**:
```python
except Exception as e:
print(f"处理图像时出错: {str(e)}")
return None
```
2. **优化图像处理流程**:
- 统一使用PyTorch进行所有图像处理
- 确保所有张量都在同一设备上(CPU或GPU)
- 优化内存管理
3. **添加必要的导入**:
```python
import sys # 用于文件位置打开功能
```
### 使用说明
1. **安装依赖**:
```bash
pip install torch torchvision pillow numpy
```
2. **运行程序**:
```bash
python line_art_converter.py
```
3. **水彩风格使用**:
- 选择水彩风格后,程序会自动应用轻微模糊效果
- 模糊强度可以通过滑块调整
- 效果预览会立即显示在预览窗口中
### 不同风格效果说明
| 风格 | 效果描述 | 适用场景 |
|------|----------|---------|
| **素描** | 增强对比度,线条清晰 | 传统绘画、人物肖像 |
| **漫画** | 二值化处理,线条鲜明 | 动漫、插画创作 |
| **水彩** | 轻微模糊,柔和过渡 | 水彩画风格、艺术效果 |
| **钢笔** | 细化线条,强调轮廓 | 速写、线描艺术 |