import os
import time
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, Model, optimizers
import cv2
from pathlib import Path
import argparse
import json
class ResnetBlock(layers.Layer):
def __init__(self, filters, kernel_size, name=None):
super(ResnetBlock, self).__init__(name=name)
self.filters = filters
self.kernel_size = kernel_size
self.conv1 = layers.Conv2D(filters, kernel_size, padding='same',
kernel_initializer='glorot_uniform',
bias_initializer='zeros')
self.conv2 = layers.Conv2D(filters, kernel_size, padding='same',
kernel_initializer='glorot_uniform',
bias_initializer='zeros',
activation=None)
self.relu = layers.ReLU()
def call(self, inputs):
x = self.relu(inputs)
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
return inputs + x
class BasicConvLSTMCell(layers.Layer):
def __init__(self, shape, kernel_size, filters, name=None):
super(BasicConvLSTMCell, self).__init__(name=name)
self.shape = shape
self.kernel_size = kernel_size
self.filters = filters
# 输入门
self.conv_i = layers.Conv2D(filters, kernel_size, padding='same', activation='sigmoid')
# 遗忘门
self.conv_f = layers.Conv2D(filters, kernel_size, padding='same', activation='sigmoid')
# 输出门
self.conv_o = layers.Conv2D(filters, kernel_size, padding='same', activation='sigmoid')
# 候选值
self.conv_c = layers.Conv2D(filters, kernel_size, padding='same', activation='tanh')
def zero_state(self, batch_size):
h, w = self.shape
return (tf.zeros([batch_size, h, w, self.filters]),
tf.zeros([batch_size, h, w, self.filters]))
def call(self, inputs, state):
h_prev, c_prev = state
# 连接输入和隐藏状态
combined = tf.concat([inputs, h_prev], axis=-1)
# 计算门控
i = self.conv_i(combined)
f = self.conv_f(combined)
o = self.conv_o(combined)
c_tilde = self.conv_c(combined)
# 更新细胞状态和隐藏状态
c = f * c_prev + i * c_tilde
h = o * tf.nn.tanh(c)
return h, (h, c)
class DeblurGenerator(Model):
def __init__(self, n_levels=2, scale=0.5, chns=3, model_type='color'):
super(DeblurGenerator, self).__init__()
self.n_levels = n_levels
self.scale = scale
self.chns = chns
self.model_type = model_type
if model_type == 'lstm':
self.lstm_cell = BasicConvLSTMCell([32, 32], [3, 3], 64)
# 为每个级别创建层 - 匹配GTX 1650优化版本
self.conv_layers = {}
self.resnet_blocks = {}
self.deconv_layers = {}
for i in range(n_levels):
# 编码器层(减少通道数)
self.conv_layers[f'enc1_1_{i}'] = layers.Conv2D(16, [5, 5], padding='same', activation='relu',
kernel_initializer='glorot_uniform')
self.resnet_blocks[f'enc1_2_{i}'] = ResnetBlock(16, 5)
self.resnet_blocks[f'enc1_3_{i}'] = ResnetBlock(16, 5)
self.conv_layers[f'enc2_1_{i}'] = layers.Conv2D(32, [5, 5], strides=2, padding='same', activation='relu',
kernel_initializer='glorot_uniform')
self.resnet_blocks[f'enc2_2_{i}'] = ResnetBlock(32, 5)
self.resnet_blocks[f'enc2_3_{i}'] = ResnetBlock(32, 5)
self.conv_layers[f'enc3_1_{i}'] = layers.Conv2D(64, [5, 5], strides=2, padding='same', activation='relu',
kernel_initializer='glorot_uniform')
self.resnet_blocks[f'enc3_2_{i}'] = ResnetBlock(64, 5)
# 解码器层
self.resnet_blocks[f'dec3_1_{i}'] = ResnetBlock(64, 5)
self.deconv_layers[f'dec2_4_{i}'] = layers.Conv2DTranspose(32, [4, 4], strides=2, padding='same',
kernel_initializer='glorot_uniform')
self.resnet_blocks[f'dec2_3_{i}'] = ResnetBlock(32, 5)
self.resnet_blocks[f'dec2_2_{i}'] = ResnetBlock(32, 5)
self.deconv_layers[f'dec1_4_{i}'] = layers.Conv2DTranspose(16, [4, 4], strides=2, padding='same',
kernel_initializer='glorot_uniform')
self.resnet_blocks[f'dec1_3_{i}'] = ResnetBlock(16, 5)
self.resnet_blocks[f'dec1_2_{i}'] = ResnetBlock(16, 5)
self.conv_layers[f'dec1_0_{i}'] = layers.Conv2D(chns, [5, 5], padding='same',
kernel_initializer='glorot_uniform')
def call(self, inputs, training=None):
batch_size = tf.shape(inputs)[0]
h = tf.shape(inputs)[1]
w = tf.shape(inputs)[2]
if self.model_type == 'lstm':
rnn_state = self.lstm_cell.zero_state(batch_size)
x_unwrap = []
inp_pred = inputs
for i in range(self.n_levels):
scale = self.scale ** (self.n_levels - i - 1)
hi = tf.cast(tf.round(tf.cast(h, tf.float32) * scale), tf.int32)
wi = tf.cast(tf.round(tf.cast(w, tf.float32) * scale), tf.int32)
# 调整输入尺寸
inp_blur = tf.image.resize(inputs, [hi, wi], method='bilinear')
inp_pred = tf.stop_gradient(tf.image.resize(inp_pred, [hi, wi], method='bilinear'))
inp_all = tf.concat([inp_blur, inp_pred], axis=3)
if self.model_type == 'lstm':
rnn_state = (tf.image.resize(rnn_state[0], [hi // 4, wi // 4], method='bilinear'),
tf.image.resize(rnn_state[1], [hi // 4, wi // 4], method='bilinear'))
# 编码器(简化)
conv1_1 = self.conv_layers[f'enc1_1_{i}'](inp_all)
conv1_2 = self.resnet_blocks[f'enc1_2_{i}'](conv1_1)
conv1_3 = self.resnet_blocks[f'enc1_3_{i}'](conv1_2)
conv2_1 = self.conv_layers[f'enc2_1_{i}'](conv1_3)
conv2_2 = self.resnet_blocks[f'enc2_2_{i}'](conv2_1)
conv2_3 = self.resnet_blocks[f'enc2_3_{i}'](conv2_2)
conv3_1 = self.conv_layers[f'enc3_1_{i}'](conv2_3)
conv3_2 = self.resnet_blocks[f'enc3_2_{i}'](conv3_1)
if self.model_type == 'lstm':
deconv3_2, rnn_state = self.lstm_cell(conv3_2, rnn_state)
else:
deconv3_2 = conv3_2
# 解码器(简化)
deconv3_1 = self.resnet_blocks[f'dec3_1_{i}'](deconv3_2)
deconv2_4 = self.deconv_layers[f'dec2_4_{i}'](deconv3_1)
cat2 = deconv2_4 + conv2_3
deconv2_3 = self.resnet_blocks[f'dec2_3_{i}'](cat2)
deconv2_2 = self.resnet_blocks[f'dec2_2_{i}'](deconv2_3)
deconv1_4 = self.deconv_layers[f'dec1_4_{i}'](deconv2_2)
cat1 = deconv1_4 + conv1_3
deconv1_3 = self.resnet_blocks[f'dec1_3_{i}'](cat1)
deconv1_2 = self.resnet_blocks[f'dec1_2_{i}'](deconv1_3)
inp_pred = self.conv_layers[f'dec1_0_{i}'](deconv1_2)
x_unwrap.append(inp_pred)
return x_unwrap
def calculate_psnr(img1, img2):
"""计算PSNR"""
mse = np.mean((img1 - img2) ** 2)
if mse == 0:
return 100
max_pixel = 255.0
psnr = 20 * np.log10(max_pixel / np.sqrt(mse))
return psnr
def calculate_ssim(img1, img2):
"""简化版SSIM计算"""
# 转换为灰度图像
if len(img1.shape) == 3:
img1 = cv2.cvtColor(img1, cv2.COLOR_RGB2GRAY)
if len(img2.shape) == 3:
img2 = cv2.cvtColor(img2, cv2.COLOR_RGB2GRAY)
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
# 常数
c1 = 6.5025
c2 = 58.5225
# 计算均值
mu1 = cv2.GaussianBlur(img1, (11, 11), 1.5)
mu2 = cv2.GaussianBlur(img2, (11, 11), 1.5)
mu1_sq = mu1 * mu1
mu2_sq = mu2 * mu2
mu1_mu2 = mu1 * mu2
# 计算方差和协方差
sigma1_sq = cv2.GaussianBlur(img1 * img1, (11, 11), 1.5) - mu1_sq
sigma2_sq = cv2.GaussianBlur(img2 * img2, (11, 11), 1.5) - mu2_sq
sigma12 = cv2.GaussianBlur(img1 * img2, (11, 11), 1.5) - mu1_mu2
# 计算SSIM
numerator = (2 * mu1_mu2 + c1) * (2 * sigma12 + c2)
denominator = (mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2)
ssim_map = numerator / denominator
return np.mean(ssim_map)
class DeblurEvaluator:
def __init__(self, args):
self.args = args
self.chns = 3 if args.model == 'color' else 1
self.train_dir = os.path.join('./checkpoints', args.model)
# 构建模型
self.build_generator()
# 设置优化器(为了加载检查点)
self.optimizer = optimizers.Adam(learning_rate=1e-4)
if args.mixed_precision:
self.optimizer = tf.keras.mixed_precision.LossScaleOptimizer(self.optimizer)
# 设置检查点
self.checkpoint = tf.train.Checkpoint(
optimizer=self.optimizer,
generator=self.generator
)
self.checkpoint_manager = tf.train.CheckpointManager(
self.checkpoint, self.train_dir, max_to_keep=3
)
# 加载模型
self.load_model()
def build_generator(self):
"""构建生成器模型"""
self.generator = DeblurGenerator(
n_levels=2, # GTX 1650优化版本
scale=0.5,
chns=self.chns,
model_type=self.args.model
)
# 构建模型 - 使用较小的128尺寸
dummy_input = tf.zeros([1, 128, 128, self.chns])
_ = self.generator(dummy_input)
print("生成器模型构建完成(GTX 1650优化版)")
def load_model(self):
"""加载训练好的模型"""
if self.checkpoint_manager.latest_checkpoint:
# 使用expect_partial()来忽略不匹配的检查点部分
status = self.checkpoint.restore(self.checkpoint_manager.latest_checkpoint)
status.expect_partial() # 忽略优化器状态等不匹配的部分
print(f"成功加载模型: {self.checkpoint_manager.latest_checkpoint}")
print("注意: 使用GTX 1650优化版本架构")
else:
print("未找到训练好的模型!请先训练模型。")
print(f"查找路径: {self.train_dir}")
# 列出checkpoint目录内容
if os.path.exists(self.train_dir):
print(f"目录内容: {os.listdir(self.train_dir)}")
exit(1)
def preprocess_image(self, image_path, target_size=None):
"""预处理单张图像"""
img = cv2.imread(image_path)
if img is None:
print(f"无法读取图像: {image_path}")
return None, None, None
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
original_h, original_w = img.shape[:2]
# 转换为float32并归一化
img_input = img.astype(np.float32) / 255.0
if self.args.model != 'color':
img_input = cv2.cvtColor(img_input, cv2.COLOR_RGB2GRAY)
img_input = np.expand_dims(img_input, -1)
# 如果指定了目标尺寸,调整图像
if target_size:
target_h, target_w = target_size
if original_h > target_h or original_w > target_w:
scale = min(target_h / original_h, target_w / original_w)
new_h, new_w = int(original_h * scale), int(original_w * scale)
img_input = cv2.resize(img_input, (new_w, new_h))
# 填充到目标尺寸
pad_h = target_h - new_h
pad_w = target_w - new_w
img_input = np.pad(img_input, ((0, pad_h), (0, pad_w), (0, 0)), 'edge')
return img_input, (new_h, new_w), (original_h, original_w)
else:
new_h, new_w = original_h, original_w
pad_h = target_h - original_h
pad_w = target_w - original_w
img_input = np.pad(img_input, ((0, pad_h), (0, pad_w), (0, 0)), 'edge')
return img_input, (new_h, new_w), (original_h, original_w)
else:
return img_input, (original_h, original_w), (original_h, original_w)
def postprocess_image(self, result, processed_size, original_size):
"""后处理图像到原始尺寸"""
processed_h, processed_w = processed_size
original_h, original_w = original_size
# 裁剪到处理后的尺寸
result = result[:processed_h, :processed_w]
# 如果尺寸不同,调整回原始尺寸
if (processed_h, processed_w) != (original_h, original_w):
result = cv2.resize(result, (original_w, original_h))
# 后处理
result = np.clip(result * 255, 0, 255).astype(np.uint8)
if self.args.model != 'color':
result = cv2.cvtColor(result, cv2.COLOR_GRAY2RGB)
return result
def deblur_image(self, image_path, target_size=(128, 128)): # 使用更小的尺寸
"""对单张图像进行去模糊处理"""
# 预处理
img_input, processed_size, original_size = self.preprocess_image(image_path, target_size)
if img_input is None:
return None
# 添加batch维度
img_batch = np.expand_dims(img_input, 0)
# 推理
start_time = time.time()
predictions = self.generator(img_batch, training=False)
duration = time.time() - start_time
# 获取最后一个尺度的输出
result = predictions[-1][0].numpy()
# 后处理
result = self.postprocess_image(result, processed_size, original_size)
return result, duration
def evaluate_dataset(self, blur_path, clear_path, output_path):
"""评估整个数据集"""
blur_path = Path(blur_path)
clear_path = Path(clear_path)
output_path = Path(output_path)
# 创建输出目录
output_path.mkdir(parents=True, exist_ok=True)
# 获取所有模糊图像
blur_images = sorted(list(blur_path.glob("*.jpg")) + list(blur_path.glob("*.png")))
if len(blur_images) == 0:
print(f"在 {blur_path} 中未找到图像文件")
return
print(f"找到 {len(blur_images)} 张测试图像")
results = []
total_psnr = 0
total_ssim = 0
total_time = 0
processed_count = 0
for i, blur_img_path in enumerate(blur_images):
print(f"\n处理图像 {i + 1}/{len(blur_images)}: {blur_img_path.name}")
# 查找对应的清晰图像
clear_img_path = clear_path / blur_img_path.name
if not clear_img_path.exists():
print(f"警告: 未找到对应的清晰图像 {clear_img_path}")
# 如果没有清晰图像,仍然进行去模糊但不计算指标
deblurred_result, process_time = self.deblur_image(str(blur_img_path))
if deblurred_result is not None:
output_file = output_path / f"deblurred_{blur_img_path.name}"
result_bgr = cv2.cvtColor(deblurred_result, cv2.COLOR_RGB2BGR)
cv2.imwrite(str(output_file), result_bgr)
print(f"已保存去模糊图像(无指标计算): {output_file}")
continue
# 去模糊处理
deblurred_result, process_time = self.deblur_image(str(blur_img_path))
if deblurred_result is None:
continue
# 读取清晰图像
clear_img = cv2.imread(str(clear_img_path))
if clear_img is None:
print(f"无法读取清晰图像: {clear_img_path}")
continue
clear_img = cv2.cvtColor(clear_img, cv2.COLOR_BGR2RGB)
# 确保尺寸一致
h, w = clear_img.shape[:2]
deblurred_result = cv2.resize(deblurred_result, (w, h))
# 计算指标
psnr_value = calculate_psnr(clear_img, deblurred_result)
ssim_value = calculate_ssim(clear_img, deblurred_result)
# 保存结果图像
output_file = output_path / f"deblurred_{blur_img_path.name}"
result_bgr = cv2.cvtColor(deblurred_result, cv2.COLOR_RGB2BGR)
cv2.imwrite(str(output_file), result_bgr)
# 记录结果
result_data = {
'filename': blur_img_path.name,
'psnr': float(psnr_value),
'ssim': float(ssim_value),
'time': float(process_time)
}
results.append(result_data)
total_psnr += psnr_value
total_ssim += ssim_value
total_time += process_time
processed_count += 1
print(f"PSNR: {psnr_value:.2f} dB, SSIM: {ssim_value:.4f}, 时间: {process_time:.3f}s")
if processed_count == 0:
print("没有成功处理任何图像用于指标计算")
return
# 计算平均值
avg_psnr = total_psnr / processed_count
avg_ssim = total_ssim / processed_count
avg_time = total_time / processed_count
# 打印总结
print(f"\n=== 评估结果总结 ===")
print(f"处理图像数量: {processed_count}")
print(f"平均PSNR: {avg_psnr:.2f} dB")
print(f"平均SSIM: {avg_ssim:.4f}")
print(f"平均处理时间: {avg_time:.3f}s")
# 保存详细结果到JSON文件
summary_results = {
'summary': {
'total_images': processed_count,
'average_psnr': float(avg_psnr),
'average_ssim': float(avg_ssim),
'average_time': float(avg_time)
},
'detailed_results': results
}
json_path = output_path / "evaluation_results.json"
with open(json_path, 'w', encoding='utf-8') as f:
json.dump(summary_results, f, indent=2, ensure_ascii=False)
print(f"详细结果已保存到: {json_path}")
# 保存简单的文本报告
txt_path = output_path / "evaluation_report.txt"
with open(txt_path, 'w', encoding='utf-8') as f:
f.write("=== 图像去模糊评估报告 ===\n\n")
f.write(f"处理图像数量: {processed_count}\n")
f.write(f"平均PSNR: {avg_psnr:.2f} dB\n")
f.write(f"平均SSIM: {avg_ssim:.4f}\n")
f.write(f"平均处理时间: {avg_time:.3f}s\n\n")
f.write("详细结果:\n")
f.write("-" * 60 + "\n")
f.write(f"{'文件名':<30} {'PSNR':<10} {'SSIM':<10} {'时间':<10}\n")
f.write("-" * 60 + "\n")
for result in results:
f.write(
f"{result['filename']:<30} {result['psnr']:<10.2f} {result['ssim']:<10.4f} {result['time']:<10.3f}\n")
print(f"文本报告已保存到: {txt_path}")
return avg_psnr, avg_ssim, avg_time
def parse_args():
parser = argparse.ArgumentParser(description='图像去模糊批量测试和评估(GTX 1650优化版)')
parser.add_argument('--model', type=str, default='color', help='model type: [lstm | gray | color]')
parser.add_argument('--gpu_id', type=str, default='0', help='use gpu or cpu')
parser.add_argument('--mixed_precision', action='store_true', help='enable mixed precision')
parser.add_argument('--blur_path', type=str, default=r'D:\混凝土裂缝数据集\val_blurred',
help='path to blurred images')
parser.add_argument('--clear_path', type=str, default=r'D:\混凝土裂缝数据集\val',
help='path to clear images (ground truth)')
parser.add_argument('--output_path', type=str, default='./evaluation_results',
help='output path for deblurred images and results')
return parser.parse_args()
def main():
args = parse_args()
# 设置GPU内存限制
if int(args.gpu_id) >= 0:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
try:
tf.config.experimental.set_visible_devices(gpus[int(args.gpu_id)], 'GPU')
tf.config.experimental.set_memory_growth(gpus[int(args.gpu_id)], True)
# 为GTX 1650设置更严格的内存限制
tf.config.experimental.set_virtual_device_configuration(
gpus[int(args.gpu_id)],
[tf.config.experimental.VirtualDeviceConfiguration(memory_limit=2048)] # 仅使用2GB
)
print(f"使用GPU: {args.gpu_id},内存限制: 2GB")
except RuntimeError as e:
print(f"GPU设置错误: {e}")
else:
print("未找到GPU,使用CPU")
else:
tf.config.set_visible_devices([], 'GPU')
print("使用CPU")
# 启用混合精度(推荐)
if args.mixed_precision:
policy = tf.keras.mixed_precision.Policy('mixed_float16')
tf.keras.mixed_precision.set_global_policy(policy)
print("启用混合精度")
# 创建评估器
evaluator = DeblurEvaluator(args)
# 开始评估
print("开始批量去模糊处理和评估...")
avg_psnr, avg_ssim, avg_time = evaluator.evaluate_dataset(
args.blur_path,
args.clear_path,
args.output_path
)
print("\n评估完成!")
if __name__ == '__main__':
main()