import os
import cv2
import subprocess
import shutil
import tempfile
from PIL import Image
from tqdm import tqdm
import numpy as np
import queue
import threading
import onnxruntime as ort
import platform
import re
import sys
# 设置ONNX Runtime日志级别为警告(减少输出)
os.environ['ORT_LOG_LEVEL'] = '3'
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
def get_cuda_version():
"""获取安装的CUDA版本"""
try:
# 检查环境变量
cuda_path = os.environ.get('CUDA_PATH')
if cuda_path:
version_file = os.path.join(cuda_path, "version.txt")
if os.path.exists(version_file):
with open(version_file, 'r') as f:
return f.read().strip()
# 检查系统路径中的dll
if platform.system() == "Windows":
# 检查常见的CUDA安装路径
possible_paths = [
os.environ.get('CUDA_PATH', ''),
r'C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA',
r'C:\CUDA'
]
for base_path in possible_paths:
if os.path.exists(base_path):
for version_dir in os.listdir(base_path):
if re.match(r'v\d+\.\d+', version_dir):
version_file = os.path.join(base_path, version_dir, "version.txt")
if os.path.exists(version_file):
with open(version_file, 'r') as f:
return f.read().strip()
# 检查PATH中的dll
try:
result = subprocess.run(['where', 'cudart64_*.dll'], capture_output=True, text=True)
if result.returncode == 0:
dll_path = result.stdout.splitlines()[0].strip()
version_match = re.search(r'cudart64_(\d+)\.dll', dll_path)
if version_match:
version_num = version_match.group(1)
return f"CUDA {version_num[0]}.{version_num[1:]}"
except:
pass
return "无法确定CUDA版本(请检查CUDA_PATH环境变量)"
except Exception as e:
return f"获取CUDA版本时出错: {str(e)}"
def get_cudnn_version():
"""获取安装的cuDNN版本"""
try:
cuda_path = os.environ.get('CUDA_PATH')
if not cuda_path:
return "未找到CUDA_PATH环境变量"
# 检查头文件
header_path = os.path.join(cuda_path, "include", "cudnn_version.h")
if not os.path.exists(header_path):
header_path = os.path.join(cuda_path, "include", "cudnn.h")
if os.path.exists(header_path):
with open(header_path, 'r') as f:
content = f.read()
major = re.search(r"#define CUDNN_MAJOR\s+(\d+)", content)
minor = re.search(r"#define CUDNN_MINOR\s+(\d+)", content)
patch = re.search(r"#define CUDNN_PATCHLEVEL\s+(\d+)", content)
if major and minor and patch:
return f"cuDNN {major.group(1)}.{minor.group(1)}.{patch.group(1)}"
# 检查库文件
lib_dir = os.path.join(cuda_path, "lib", "x64")
if os.path.exists(lib_dir):
for file in os.listdir(lib_dir):
if file.startswith("cudnn64_") and file.endswith(".dll"):
version = re.search(r"cudnn64_(\d+)\.dll", file)
if version:
version_num = version.group(1)
return f"cuDNN {version_num[0]}.{version_num[1:]}"
return "无法确定cuDNN版本(请检查cudnn64_*.dll文件)"
except Exception as e:
return f"获取cuDNN版本时出错: {str(e)}"
def check_onnxruntime_gpu():
"""检查ONNX Runtime是否支持GPU"""
try:
# 检查可用提供程序
available_providers = ort.get_available_providers()
gpu_supported = 'CUDAExecutionProvider' in available_providers
# 如果支持GPU但初始化失败,尝试创建简单会话
if gpu_supported:
try:
# 创建简单的ONNX模型
import onnx
from onnx import helper, TensorProto
# 创建一个简单的模型
input_name = 'input'
output_name = 'output'
shape = [1, 3, 224, 224]
node_def = helper.make_node(
'Identity',
inputs=[input_name],
outputs=[output_name]
)
graph_def = helper.make_graph(
[node_def],
'test-model',
[helper.make_tensor_value_info(input_name, TensorProto.FLOAT, shape)],
[helper.make_tensor_value_info(output_name, TensorProto.FLOAT, shape)]
)
model_def = helper.make_model(graph_def, producer_name='onnx-example')
# 创建临时模型文件
with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as temp_file:
onnx.save(model_def, temp_file.name)
session = ort.InferenceSession(temp_file.name, providers=['CUDAExecutionProvider'])
return f"支持GPU: {gpu_supported}, 当前会话使用: {session.get_providers()}"
except Exception as e:
return f"支持GPU: {gpu_supported}, 但初始化失败: {str(e)}"
return f"支持GPU: {gpu_supported}, 可用提供程序: {available_providers}"
except Exception as e:
return f"检查ONNX Runtime GPU支持时出错: {str(e)}"
def has_ffmpeg():
"""检查系统是否安装了ffmpeg"""
try:
subprocess.run(['ffmpeg', '-version'], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
return True
except (FileNotFoundError, subprocess.CalledProcessError):
# 检查常见安装路径
if platform.system() == "Windows":
possible_paths = [
r'C:\Program Files\ffmpeg\bin\ffmpeg.exe',
r'C:\ffmpeg\bin\ffmpeg.exe',
r'C:\Program Files (x86)\ffmpeg\bin\ffmpeg.exe'
]
for path in possible_paths:
if os.path.exists(path):
return True
return False
def add_audio_to_video(video_path, original_video_path, output_path, fps):
"""
使用ffmpeg为生成的视频添加原始音频
参数:
video_path: 无声视频路径
original_video_path: 原始视频路径(包含音频)
output_path: 最终输出路径
fps: 视频帧率
"""
if not has_ffmpeg():
print("警告: ffmpeg未安装,无法添加音频。请安装ffmpeg以支持音频处理。")
shutil.copyfile(video_path, output_path)
return output_path
# 使用临时文件避免覆盖问题
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as temp_file:
temp_path = temp_file.name
try:
# 检查原始视频是否有音频流
check_audio_cmd = [
'ffprobe',
'-v', 'error',
'-select_streams', 'a',
'-show_entries', 'stream=codec_type',
'-of', 'default=noprint_wrappers=1:nokey=1',
original_video_path
]
result = subprocess.run(check_audio_cmd, capture_output=True, text=True)
has_audio = 'audio' in result.stdout
if not has_audio:
print(f"警告: 原始视频 '{os.path.basename(original_video_path)}' 没有音频轨道")
shutil.copyfile(video_path, output_path)
return output_path
# 尝试使用软件编码
cmd = [
'ffmpeg',
'-y', # 覆盖输出文件
'-r', str(fps), # 添加帧率
'-i', video_path, # 无声视频
'-i', original_video_path, # 原始视频(包含音频)
'-c:v', 'libx264', # 使用H.264编码
'-preset', 'fast', # 编码速度预设
'-crf', '23', # 质量控制
'-pix_fmt', 'yuv420p', # 视频要求的像素格式
'-c:a', 'aac', # 音频编码
'-b:a', '128k', # 音频比特率
'-map', '0:v:0', # 选择第一个视频流
'-map', '1:a:0', # 选择第二个文件的音频流
'-shortest', # 以最短流结束
'-movflags', '+faststart', # 流媒体优化
temp_path
]
# 执行命令
try:
subprocess.run(cmd, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
except subprocess.CalledProcessError as e:
print(f"音频合并失败: {e}")
print("将使用无声视频")
shutil.copyfile(video_path, output_path)
return output_path
# 移动临时文件到最终位置
shutil.move(temp_path, output_path)
return output_path
except Exception as e:
print(f"音频合并过程中出错: {str(e)}")
print("将使用无声视频")
shutil.copyfile(video_path, output_path)
return output_path
finally:
# 清理临时文件
if os.path.exists(temp_path):
os.remove(temp_path)
class Videocap:
def __init__(self, video, model_name, limit=1280):
self.model_name = model_name
vid = cv2.VideoCapture(video)
if not vid.isOpened():
raise ValueError(f"无法打开视频文件: {video}")
width = int(vid.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(vid.get(cv2.CAP_PROP_FRAME_HEIGHT))
self.total = int(vid.get(cv2.CAP_PROP_FRAME_COUNT))
self.fps = vid.get(cv2.CAP_PROP_FPS)
self.ori_width, self.ori_height = width, height
max_edge = max(width, height)
scale_factor = limit / max_edge if max_edge > limit else 1.
height = int(round(height * scale_factor))
width = int(round(width * scale_factor))
self.width, self.height = self.to_16s(width), self.to_16s(height) # 修改为16的倍数
self.count = 0
self.cap = vid
self.ret, frame = self.cap.read()
self.cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
self.q = queue.Queue(maxsize=100)
t = threading.Thread(target=self._reader)
t.daemon = True
t.start()
def _reader(self):
while True:
self.ret, frame = self.cap.read()
if not self.ret:
break
frame = np.asarray(self.process_frame(frame, self.width, self.height))
self.q.put(frame)
self.count += 1
self.cap.release()
def read(self):
if self.q.empty() and not self.ret:
return None
f = self.q.get()
self.q.task_done()
return f
def to_16s(self, x):
"""确保分辨率是16的倍数(H.264兼容性要求)"""
if x < 256:
return 256
# 确保宽度和高度都能被16整除
return x - x % 16
def process_frame(self, img, width, height):
img = Image.fromarray(img[:, :, ::-1]).resize((width, height), Image.Resampling.BILINEAR)
img = np.array(img).astype(np.float32) / 127.5 - 1.0
return np.expand_dims(img, axis=0)
class Cartoonizer():
def __init__(self, model_path, device="gpu"):
self.model_path = model_path
self.device = device
self.name = os.path.basename(model_path).rsplit('.', 1)[0]
# 打印环境信息
print("\n" + "="*50)
print("环境检查:")
print(f"CUDA 版本: {get_cuda_version()}")
print(f"cuDNN 版本: {get_cudnn_version()}")
print(f"ONNX Runtime GPU 支持: {check_onnxruntime_gpu()}")
print("="*50 + "\n")
# 配置提供程序 - 强制使用CPU
providers = ['CPUExecutionProvider']
print("强制使用CPUExecutionProvider")
# 创建ONNX会话
session_options = ort.SessionOptions()
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
try:
self.sess_land = ort.InferenceSession(model_path, sess_options=session_options, providers=providers)
print(f"ONNX Runtime 使用的提供程序: {self.sess_land.get_providers()}")
except Exception as e:
print(f"创建ONNX会话失败: {str(e)}")
print("尝试使用默认提供程序...")
self.sess_land = ort.InferenceSession(model_path, sess_options=session_options)
# 打印输入输出信息
print("模型输入信息:")
for input in self.sess_land.get_inputs():
print(f" Name: {input.name}, Shape: {input.shape}, Type: {input.type}")
print("模型输出信息:")
for output in self.sess_land.get_outputs():
print(f" Name: {output.name}, Shape: {output.shape}, Type: {output.type}")
def post_precess(self, img, wh):
img = (img.squeeze() + 1.) / 2 * 255
img = img.clip(0, 255).astype(np.uint8)
img = Image.fromarray(img).resize((wh[0], wh[1]), Image.Resampling.BILINEAR)
img = np.array(img).astype(np.uint8)
return img
def process_video(self, video_path, output_path):
"""处理视频并添加原始音频"""
# 创建临时无声视频文件
temp_dir = tempfile.mkdtemp()
temp_video_path = os.path.join(temp_dir, "temp_no_audio.mp4")
try:
# 处理视频(无声)
vid = Videocap(video_path, self.name)
# 使用FFmpeg直接编码视频(更高效)
if has_ffmpeg():
print("使用FFmpeg进行视频编码")
return self.process_video_with_ffmpeg(vid, video_path, output_path, temp_video_path)
# 回退到OpenCV编码
print("使用OpenCV进行视频编码")
return self.process_video_with_opencv(vid, video_path, output_path, temp_video_path)
except Exception as e:
print(f"视频处理失败: {str(e)}")
# 尝试直接复制原始视频作为回退
shutil.copyfile(video_path, output_path)
print(f"已回退到原始视频: {output_path}")
return output_path
finally:
# 清理临时文件
shutil.rmtree(temp_dir, ignore_errors=True)
def process_video_with_ffmpeg(self, vid, original_video_path, output_path, temp_video_path):
"""使用FFmpeg处理视频"""
try:
# 创建FFmpeg管道
ffmpeg_cmd = [
'ffmpeg',
'-y',
'-f', 'rawvideo',
'-vcodec', 'rawvideo',
'-s', f'{vid.width}x{vid.height}',
'-pix_fmt', 'rgb24',
'-r', str(vid.fps),
'-i', '-',
'-an', # 无音频
'-vcodec', 'libx264',
'-preset', 'fast',
'-crf', '23',
'-pix_fmt', 'yuv420p',
temp_video_path
]
ffmpeg_process = subprocess.Popen(
ffmpeg_cmd,
stdin=subprocess.PIPE,
stdout=subprocess.DEVNULL,
stderr=subprocess.PIPE
)
num = vid.total
pbar = tqdm(total=vid.total)
pbar.set_description(f"处理视频: {os.path.basename(original_video_path)}")
try:
while num > 0:
frame = vid.read()
if frame is None:
print("警告: 读取到空帧,提前结束")
break
fake_img = self.sess_land.run(None, {self.sess_land.get_inputs()[0].name: frame})[0]
result_img = self.post_precess(fake_img, (vid.width, vid.height))
# 转换为RGB格式并写入管道
rgb_img = cv2.cvtColor(result_img, cv2.COLOR_BGR2RGB)
ffmpeg_process.stdin.write(rgb_img.tobytes())
pbar.update(1)
num -= 1
except Exception as e:
print(f"处理过程中出错: {str(e)}")
finally:
pbar.close()
ffmpeg_process.stdin.close()
_, stderr = ffmpeg_process.communicate()
if ffmpeg_process.returncode != 0:
print(f"FFmpeg编码错误: {stderr.decode('utf-8')}")
raise RuntimeError("FFmpeg编码失败")
# 添加原始音频
return add_audio_to_video(temp_video_path, original_video_path, output_path, vid.fps)
except Exception as e:
print(f"FFmpeg处理失败: {str(e)}")
print("回退到OpenCV编码")
return self.process_video_with_opencv(vid, original_video_path, output_path, temp_video_path)
def process_video_with_opencv(self, vid, original_video_path, output_path, temp_video_path):
"""使用OpenCV处理视频(回退方法)"""
# 尝试不同的编码器
codec_options = [
('mp4v', 'MPEG-4'), # 最可靠的备选
('avc1', 'H.264/AVC'),
('h264', 'H.264/AVC'),
('x264', 'H.264/AVC'),
('vp09', 'VP9'),
('vp80', 'VP8')
]
video_out = None
selected_codec = None
for codec, codec_name in codec_options:
try:
fourcc = cv2.VideoWriter_fourcc(*codec)
video_out = cv2.VideoWriter(
temp_video_path,
fourcc,
vid.fps,
(vid.ori_width, vid.ori_height)
)
if video_out.isOpened():
selected_codec = (codec, codec_name)
print(f"使用编码器: {codec_name} ({codec})")
break
else:
video_out.release()
except Exception as e:
print(f"编码器 {codec} 初始化失败: {str(e)}")
video_out = None
# 如果所有编码器都失败,使用默认编码器
if video_out is None or not video_out.isOpened():
print("警告: 所有编码器初始化失败,使用默认MPEG-4编码器")
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
video_out = cv2.VideoWriter(
temp_video_path,
fourcc,
vid.fps,
(vid.ori_width, vid.ori_height)
)
selected_codec = ('mp4v', 'MPEG-4 (默认)')
num = vid.total
pbar = tqdm(total=vid.total)
pbar.set_description(f"处理视频: {os.path.basename(original_video_path)}")
try:
while num > 0:
frame = vid.read()
if frame is None:
print("警告: 读取到空帧,提前结束")
break
fake_img = self.sess_land.run(None, {self.sess_land.get_inputs()[0].name: frame})[0]
fake_img = self.post_precess(fake_img, (vid.ori_width, vid.ori_height))
video_out.write(fake_img[:, :, ::-1])
pbar.update(1)
num -= 1
except Exception as e:
print(f"处理过程中出错: {str(e)}")
finally:
pbar.close()
video_out.release()
# 添加原始音频
return add_audio_to_video(temp_video_path, original_video_path, output_path, vid.fps)
def process_image(self, image_path, output_path):
try:
img = cv2.imread(image_path)
if img is None:
print(f"错误: 无法读取图片: {image_path}")
return None
ori_height, ori_width = img.shape[:2]
# 计算目标尺寸
max_edge = max(ori_width, ori_height)
scale_factor = 1280 / max_edge if max_edge > 1280 else 1.
height = int(round(ori_height * scale_factor))
width = int(round(ori_width * scale_factor))
# 确保分辨率是16的倍数
width = width - width % 16
height = height - height % 16
# 预处理图片
img_rgb = Image.fromarray(img[:, :, ::-1]).resize((width, height), Image.Resampling.BILINEAR)
img_np = np.array(img_rgb).astype(np.float32) / 127.5 - 1.0
input_data = np.expand_dims(img_np, axis=0)
# 运行模型
fake_img = self.sess_land.run(None, {self.sess_land.get_inputs()[0].name: input_data})[0]
# 后处理
result_img = self.post_precess(fake_img, (ori_width, ori_height))
# 保存结果
cv2.imwrite(output_path, result_img[:, :, ::-1])
return output_path
except Exception as e:
print(f"图片处理失败: {str(e)}")
# 尝试直接复制原始图片作为回退
shutil.copyfile(image_path, output_path)
print(f"已回退到原始图片: {output_path}")
return output_path
def videopic_to_new(params):
"""
将输入目录中的图片和视频转换为动漫风格
参数:
params (dict): 包含以下键的字典:
"video_dir": 输入目录路径
"model": ONNX模型文件路径
"output_dir": 输出目录路径
"device": (可选) 运行设备 ("cpu" 或 "gpu"),默认为 "gpu"
"""
# 从参数中提取值
input_dir = params["video_dir"]
model_path = params["model"]
output_dir = params["output_dir"]
device = params.get("device", "gpu")
# 确保输出目录存在
os.makedirs(output_dir, exist_ok=True)
# 检查ffmpeg是否可用
if has_ffmpeg():
print("ffmpeg已安装,将自动为视频添加音频")
else:
print("警告: ffmpeg未安装,生成的视频将没有声音")
# 初始化卡通化器
cartoonizer = Cartoonizer(model_path, device)
model_name = os.path.basename(model_path).rsplit('.', 1)[0]
# 支持的媒体格式
image_exts = ['.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.webp']
video_exts = ['.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv']
# 处理所有文件
processed_files = []
for filename in os.listdir(input_dir):
filepath = os.path.join(input_dir, filename)
if not os.path.isfile(filepath):
continue
ext = os.path.splitext(filename)[1].lower()
output_name = f"{os.path.splitext(filename)[0]}_{model_name}{ext}"
output_path = os.path.join(output_dir, output_name)
try:
if ext in image_exts:
print(f"\n处理图片: {filename}")
result = cartoonizer.process_image(filepath, output_path)
if result:
processed_files.append(result)
print(f"保存为: {output_name}")
elif ext in video_exts:
print(f"\n处理视频: {filename}")
result = cartoonizer.process_video(filepath, output_path)
if result:
processed_files.append(result)
print(f"保存为: {output_name}")
except Exception as e:
print(f"处理 {filename} 时出错: {str(e)}")
print("\n处理完成。")
print(f"共处理文件: {len(processed_files)}")
return processed_files
def check_video_format(path):
"""检查视频格式是否符合要求"""
if not has_ffmpeg():
print("无法检查视频格式: ffmpeg未安装")
return
cmd = ['ffprobe', '-v', 'error', '-select_streams', 'v:0',
'-show_entries', 'stream=codec_name,pix_fmt', '-of',
'csv=p=0', path]
try:
result = subprocess.run(cmd, capture_output=True, text=True)
return result.stdout.strip()
except Exception as e:
return f"检查失败: {str(e)}"
if __name__ == "__main__":
# 构建参数字典
params = {
"video_dir": r"E:\软件视频类型测试\1带货测试\成品\切片法区测试",
"model": r"E:\python成品\15视频转绘\AnimeGANv3-1.1.0\AnimeGANv3-1.1.0\deploy\AnimeGANv3_Hayao_36.onnx",
"output_dir": r"E:\软件视频类型测试\1带货测试\成品\成品",
"device": "cpu" # 强制使用CPU
}
try:
# 执行转换
results = videopic_to_new(params)
# 打印结果
print("\n处理后的文件:")
for res in results:
print(f" - {res}")
# 检查视频格式
if res.lower().endswith(('.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv')):
print(f" 视频格式: {check_video_format(res)}")
except Exception as e:
print(f"\n发生严重错误: {str(e)}")
import traceback
traceback.print_exc()
问题是我cuda是11.8为什么读出来是6.5