from flask import Flask, render_template, request, jsonify, send_file
import tensorflow as tf
import cv2
import math
import numpy as np
import os
import tempfile
from typing import Tuple, Dict, List
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
import matplotlib.pyplot as plt
# -----------------------------------------------------------------------------
# 1. 基础配置(设备自动检测、临时文件路径)
# -----------------------------------------------------------------------------
app = Flask(__name__)
app.config['MAX_CONTENT_LENGTH'] = 50 * 1024 * 1024 # 最大上传大小50MB
UPLOAD_FOLDER = tempfile.mkdtemp() # 临时文件存储目录(自动清理)
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
# 自动检测GPU/CPU
def get_available_device() -> str:
return "/GPU:0" if tf.config.list_physical_devices("GPU") else "/CPU:0"
DEVICE = get_available_device()
print(f"使用计算设备: {DEVICE}")
# -----------------------------------------------------------------------------
# 2. FSIM配置类(参数集中管理)
# -----------------------------------------------------------------------------
class FSIMConfig:
def __init__(
self,
win_size: int = 11,
sigma: float = 1.5,
alpha: float = 0.84,
beta: float = 0.16,
target_size: Tuple[int, int] = (512, 512),
nscale: int = 4,
norient: int = 4,
min_wave_length: int = 6,
mult: float = 2.0,
sigma_onf: float = 0.55,
d_theta_on_sigma: float = 1.2
):
if win_size % 2 == 0:
raise ValueError("窗口大小必须为奇数")
self.win_size = win_size
self.sigma = sigma
self.alpha = alpha
self.beta = beta
self.target_size = target_size
self.nscale = nscale
self.norient = norient
self.min_wave_length = min_wave_length
self.mult = mult
self.sigma_onf = sigma_onf
self.d_theta_on_sigma = d_theta_on_sigma
self.half_win = win_size // 2
self.epsilon = tf.constant(1e-8, dtype=tf.float32)
# 全局配置实例
config = FSIMConfig()
# -----------------------------------------------------------------------------
# 3. 核心工具函数(图像预处理、特征提取、FSIM计算)
# -----------------------------------------------------------------------------
def preprocess_image(img_path: str) -> tf.Tensor:
"""单张图像预处理:读取→RGB→Resize→归一化"""
img = cv2.imread(img_path)
if img is None:
raise ValueError(f"图像读取失败: {img_path}")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
target_h, target_w = config.target_size
h, w = img.shape[:2]
# 保持比例Resize
scale = min(target_w / w, target_h / h)
new_w, new_h = int(w * scale), int(h * scale)
img_resized = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
# 补黑边
pad_left = (target_w - new_w) // 2
pad_right = target_w - new_w - pad_left
pad_top = (target_h - new_h) // 2
pad_bottom = target_h - new_h - pad_top
img_padded = cv2.copyMakeBorder(
img_resized, pad_top, pad_bottom, pad_left, pad_right,
cv2.BORDER_CONSTANT, value=(0, 0, 0)
)
return tf.cast(img_padded, tf.float32) / 255.0
def lowpassfilter(sze: Tuple[int, int], cutoff: float, n: int) -> tf.Tensor:
"""生成频域低通滤波器"""
rows, cols = sze
x_range = tf.range(-(cols-1)//2, (cols//2)+1, dtype=tf.float32) / (cols-1 if cols%2 else cols)
y_range = tf.range(-(rows-1)//2, (rows//2)+1, dtype=tf.float32) / (rows-1 if rows%2 else rows)
x_grid, y_grid = tf.meshgrid(x_range, y_range)
radius = tf.sqrt(x_grid**2 + y_grid**2)
filter_val = 1.0 / (1.0 + (radius / cutoff) ** (2 * n))
return tf.signal.ifftshift(filter_val)
def phasecong2(img: tf.Tensor) -> tf.Tensor:
"""单张图像相位一致性计算"""
H, W = tf.shape(img)[0], tf.shape(img)[1]
img_fft = tf.signal.fft2d(tf.cast(img, tf.complex128))
# 频域网格
x_range = tf.range(-(W-1)//2, (W//2)+1, dtype=tf.float32) / (W-1 if W%2 else W)
y_range = tf.range(-(H-1)//2, (H//2)+1, dtype=tf.float32) / (H-1 if H%2 else H)
x_grid, y_grid = tf.meshgrid(x_range, y_range)
radius = tf.sqrt(x_grid**2 + y_grid**2)
theta = tf.math.atan2(-y_grid, x_grid)
radius = tf.signal.ifftshift(radius)
theta = tf.signal.ifftshift(theta)
radius = tf.where(radius == 0, 1.0, radius)
sintheta = tf.math.sin(theta)
costheta = tf.math.cos(theta)
# 全局低通滤波
lp = lowpassfilter((H, W), 0.45, 15)
lp = tf.cast(lp, tf.float32)
# 多尺度对数Gabor
log_gabor_list = []
for s in range(config.nscale):
wavelength = config.min_wave_length * (config.mult ** s)
fo = 1.0 / wavelength
log_gabor = tf.exp(-(tf.math.log(radius / fo)) ** 2 / (2 * (tf.math.log(config.sigma_onf)) ** 2))
log_gabor *= lp
log_gabor = tf.where(radius == 1, 0.0, log_gabor)
log_gabor_list.append(log_gabor)
# 多方向扩散
theta_sigma = math.pi / config.norient / config.d_theta_on_sigma
spread_list = []
for o in range(config.norient):
angl = tf.cast(o * math.pi / config.norient, tf.float32)
ds = sintheta * tf.math.cos(angl) - costheta * tf.math.sin(angl)
dc = costheta * tf.math.cos(angl) + sintheta * tf.math.sin(angl)
dtheta = tf.abs(tf.math.atan2(ds, dc))
spread = tf.exp(-dtheta ** 2 / (2 * theta_sigma ** 2))
spread_list.append(spread)
# 相位一致性计算
energy_all = tf.zeros((H, W), dtype=tf.float32)
an_all = tf.zeros((H, W), dtype=tf.float32)
for o in range(config.norient):
spread = spread_list[o]
sum_e = tf.zeros((H, W), dtype=tf.float32)
sum_o = tf.zeros((H, W), dtype=tf.float32)
sum_an = tf.zeros((H, W), dtype=tf.float32)
energy = tf.zeros((H, W), dtype=tf.float32)
ifft_filters = []
for s in range(config.nscale):
log_gabor = log_gabor_list[s]
filter_freq = log_gabor * spread
ifft_filt = tf.math.real(tf.signal.ifft2d(tf.cast(filter_freq, tf.complex128)))
ifft_filt *= tf.math.sqrt(tf.cast(H * W, tf.float32))
ifft_filters.append(ifft_filt)
eo = tf.signal.ifft2d(img_fft * tf.cast(filter_freq, tf.complex128))
an = tf.abs(eo)
sum_an += an
sum_e += tf.math.real(eo)
sum_o += tf.math.imag(eo)
# 主相位归一化
x_energy = tf.sqrt(sum_e ** 2 + sum_o ** 2 + config.epsilon)
mean_e = sum_e / x_energy
mean_o = sum_o / x_energy
# 能量计算
for s in range(config.nscale):
log_gabor = log_gabor_list[s]
filter_freq = log_gabor * spread
eo = tf.signal.ifft2d(img_fft * tf.cast(filter_freq, tf.complex128))
e = tf.math.real(eo)
o = tf.math.imag(eo)
energy += e * mean_e + o * mean_o - tf.abs(e * mean_o - o * mean_e)
# 噪声过滤
log_gabor0 = log_gabor_list[0]
filter0_freq = log_gabor0 * spread
eo0 = tf.signal.ifft2d(img_fft * tf.cast(filter0_freq, tf.complex128))
an0 = tf.abs(eo0)
median_e2n = tf.math.reduce_median(tf.reshape(an0**2, (-1,)))
mean_e2n = -median_e2n / math.log(0.5)
em_n = tf.reduce_sum((log_gabor0 * spread) ** 2)
noise_power = mean_e2n / em_n
est_sum_an2 = tf.reduce_sum(tf.stack([f**2 for f in ifft_filters], axis=0), axis=0)
est_sum_ai_aj = tf.zeros((H, W), dtype=tf.float32)
for si in range(config.nscale):
for sj in range(si+1, config.nscale):
est_sum_ai_aj += ifft_filters[si] * ifft_filters[sj]
est_noise_energy2 = 2 * noise_power * est_sum_an2 + 4 * noise_power * est_sum_ai_aj
tau = tf.math.sqrt(est_noise_energy2 / 2)
est_noise_energy = tau * tf.math.sqrt(tf.cast(math.pi / 2, tf.float32))
t = est_noise_energy + 2.0 * tf.math.sqrt((2 - math.pi/2) * tau ** 2)
t /= 1.7
energy = tf.maximum(energy - t, 0.0)
energy_all += energy
an_all += sum_an
pc = energy_all / (an_all + config.epsilon)
return tf.clip_by_value(pc, 0.0, 1.0)
def extract_features(img: tf.Tensor) -> Dict[str, tf.Tensor]:
"""提取单张图像特征(PC+梯度+色度)"""
H, W = tf.shape(img)[0], tf.shape(img)[1]
r, g, b = img[..., 0], img[..., 1], img[..., 2]
# RGB→YIQ
y = 0.299 * r + 0.587 * g + 0.114 * b
i = 0.596 * r - 0.274 * g - 0.322 * b
q = 0.211 * r - 0.523 * g + 0.312 * b
# 下采样参数
min_dim = tf.minimum(H, W)
f = tf.maximum(1, tf.cast(tf.round(min_dim / 256), tf.int32))
win_size = f
ave_kernel = tf.ones((win_size, win_size, 1, 1), dtype=tf.float32) / (win_size ** 2)
# 下采样函数
def downsample(channel: tf.Tensor) -> tf.Tensor:
channel = tf.expand_dims(channel, axis=-1)
filtered = tf.nn.conv2d(channel, ave_kernel, strides=[1,1,1,1], padding='SAME')
return filtered[:, ::f, ::f, 0]
y_down = downsample(y)
i_down = downsample(i)
q_down = downsample(q)
# 相位一致性
pc_y = phasecong2(y_down)
# 梯度计算
dx = tf.constant([[3, 0, -3], [10, 0, -10], [3, 0, -3]], dtype=tf.float32) / 16
dy = tf.constant([[3, 10, 3], [0, 0, 0], [-3, -10, -3]], dtype=tf.float32) / 16
dx_kernel = tf.reshape(dx, (3, 3, 1, 1))
dy_kernel = tf.reshape(dy, (3, 3, 1, 1))
y_down_exp = tf.expand_dims(y_down, axis=-1)
ix = tf.nn.conv2d(y_down_exp, dx_kernel, strides=[1,1,1,1], padding='SAME')[..., 0]
iy = tf.nn.conv2d(y_down_exp, dy_kernel, strides=[1,1,1,1], padding='SAME')[..., 0]
g_y = tf.sqrt(ix**2 + iy**2 + config.epsilon)
return {
"pc_y": pc_y, "g_y": g_y, "i": i_down, "q": q_down,
"y_down": y_down, "img_original": img # 保留原始图像用于可视化
}
def calculate_fsim(ref_feats: Dict[str, tf.Tensor], dis_feats: Dict[str, tf.Tensor]) -> float:
"""计算单组FSIM分数"""
pc_ref, g_ref = ref_feats["pc_y"], ref_feats["g_y"]
pc_dis, g_dis = dis_feats["pc_y"], dis_feats["g_y"]
H, W = tf.shape(pc_ref)[0], tf.shape(pc_ref)[1]
# 高斯窗口
half_win = config.half_win
x = tf.range(-half_win, half_win + 1, dtype=tf.float32)
x_grid, y_grid = tf.meshgrid(x, x)
gaussian = tf.exp(-(x_grid**2 + y_grid**2) / (2 * config.sigma**2))
gaussian /= tf.reduce_sum(gaussian)
gaussian = tf.reshape(gaussian, (config.win_size, config.win_size, 1))
# 滑动窗口提取
def extract_patches(channel: tf.Tensor) -> tf.Tensor:
channel = tf.expand_dims(tf.expand_dims(channel, axis=0), axis=-1)
patches = tf.image.extract_patches(
images=channel,
sizes=[1, config.win_size, config.win_size, 1],
strides=[1, 1, 1, 1],
rates=[1, 1, 1, 1],
padding="SAME"
)
return tf.reshape(patches, (H * W, config.win_size, config.win_size, 1))
# 窗口加权
pc_ref_patches = extract_patches(pc_ref) * gaussian
pc_dis_patches = extract_patches(pc_dis) * gaussian
g_ref_patches = extract_patches(g_ref) * gaussian
g_dis_patches = extract_patches(g_dis) * gaussian
# 相似度计算
def window_sim(x_patches: tf.Tensor, y_patches: tf.Tensor) -> tf.Tensor:
x_mean = tf.reduce_mean(x_patches, axis=[1,2,3], keepdims=True)
y_mean = tf.reduce_mean(y_patches, axis=[1,2,3], keepdims=True)
cov = tf.reduce_mean((x_patches - x_mean) * (y_patches - y_mean), axis=[1,2,3])
var_x = tf.reduce_mean((x_patches - x_mean) ** 2, axis=[1,2,3])
var_y = tf.reduce_mean((y_patches - y_mean) ** 2, axis=[1,2,3])
sim = cov / (tf.sqrt(var_x * var_y) + config.epsilon)
return tf.maximum(sim, 0.0)
pc_sim = window_sim(pc_ref_patches, pc_dis_patches)
g_sim = window_sim(g_ref_patches, g_dis_patches)
# 加权平均
local_weight = (tf.reduce_mean(g_ref_patches, axis=[1,2,3]) + tf.reduce_mean(g_dis_patches, axis=[1,2,3])) / 2
total_sim = tf.reduce_sum(pc_sim * g_sim * local_weight)
total_weight = tf.reduce_sum(local_weight)
return tf.clip_by_value(total_sim / (total_weight + config.epsilon), 0.0, 1.0).numpy()
def calculate_fsimc(ref_feats: Dict[str, tf.Tensor], dis_feats: Dict[str, tf.Tensor]) -> float:
"""计算单组FSIMc分数"""
# FSIM分数
fsim = calculate_fsim(ref_feats, dis_feats)
# 色度相似度
i_ref, q_ref = ref_feats["i"], ref_feats["q"]
i_dis, q_dis = dis_feats["i"], dis_feats["q"]
H, W = tf.shape(i_ref)[0], tf.shape(i_ref)[1]
# 高斯窗口
half_win = config.half_win
x = tf.range(-half_win, half_win + 1, dtype=tf.float32)
x_grid, y_grid = tf.meshgrid(x, x)
gaussian = tf.exp(-(x_grid**2 + y_grid**2) / (2 * config.sigma**2))
gaussian /= tf.reduce_sum(gaussian)
gaussian = tf.reshape(gaussian, (config.win_size, config.win_size, 1))
# 窗口提取与加权
def chroma_sim(chroma_ref: tf.Tensor, chroma_dis: tf.Tensor) -> float:
def extract_patches(channel: tf.Tensor) -> tf.Tensor:
channel = tf.expand_dims(tf.expand_dims(channel, axis=0), axis=-1)
patches = tf.image.extract_patches(
images=channel,
sizes=[1, config.win_size, config.win_size, 1],
strides=[1, 1, 1, 1],
rates=[1, 1, 1, 1],
padding="SAME"
)
return tf.reshape(patches, (H * W, config.win_size, config.win_size, 1))
ref_patches = extract_patches(chroma_ref) * gaussian
dis_patches = extract_patches(chroma_dis) * gaussian
# SSIM色度公式
c2 = (0.08) ** 2
ref_mean = tf.reduce_mean(ref_patches, axis=[1,2,3])
dis_mean = tf.reduce_mean(dis_patches, axis=[1,2,3])
cov = tf.reduce_mean((ref_patches - tf.expand_dims(ref_mean, axis=[1,2,3])) *
(dis_patches - tf.expand_dims(dis_mean, axis=[1,2,3])), axis=[1,2,3])
var_ref = tf.reduce_mean((ref_patches - tf.expand_dims(ref_mean, axis=[1,2,3])) ** 2, axis=[1,2,3])
var_dis = tf.reduce_mean((dis_patches - tf.expand_dims(dis_mean, axis=[1,2,3])) ** 2, axis=[1,2,3])
numerator = (2 * ref_mean * dis_mean + c2) * (2 * cov + c2)
denominator = (ref_mean**2 + dis_mean**2 + c2) * (var_ref + var_dis + c2)
sim = numerator / (denominator + config.epsilon)
sim = tf.maximum(sim, 0.0)
return tf.reduce_mean(sim).numpy()
sc_i = chroma_sim(i_ref, i_dis)
sc_q = chroma_sim(q_ref, q_dis)
sc = (sc_i + sc_q) / 2
# FSIMc分数
fsimc = (fsim ** config.alpha) * (sc ** config.beta)
return np.clip(fsimc, 0.0, 1.0)
def generate_visualization(ref_feats: Dict[str, tf.Tensor], dis_feats: Dict[str, tf.Tensor], fsim: float, fsimc: float) -> str:
"""生成可视化图片,返回临时文件路径"""
# 转换为NumPy数组
ref_img = tf.cast(ref_feats["img_original"] * 255, tf.uint8).numpy()
dis_img = tf.cast(dis_feats["img_original"] * 255, tf.uint8).numpy()
pc_ref = ref_feats["pc_y"].numpy()
pc_dis = dis_feats["pc_y"].numpy()
g_ref = ref_feats["g_y"].numpy()
g_dis = dis_feats["g_y"].numpy()
# 梯度图归一化
g_ref_norm = (g_ref - g_ref.min()) / (g_ref.max() - g_ref.min() + 1e-8)
g_dis_norm = (g_dis - g_dis.min()) / (g_dis.max() - g_dis.min() + 1e-8)
grad_diff = g_ref_norm - g_dis_norm
# 创建画布
fig, axes = plt.subplots(2, 3, figsize=(18, 12))
fig.suptitle(f"FSIM: {fsim:.4f} | FSIMc: {fsimc:.4f}", fontsize=16, fontweight="bold")
# 填充子图
axes[0,0].imshow(ref_img)
axes[0,0].set_title("参考图像", fontsize=12)
axes[0,0].axis("off")
axes[0,1].imshow(dis_img)
axes[0,1].set_title("比较图像", fontsize=12)
axes[0,1].axis("off")
axes[0,2].text(0.5, 0.5, f"FSIM: {fsim:.4f}\nFSIMc: {fsimc:.4f}",
ha="center", va="center", fontsize=14, fontweight="bold")
axes[0,2].set_title("相似度分数", fontsize=12)
axes[0,2].axis("off")
im1 = axes[1,0].imshow(pc_ref, cmap="jet")
axes[1,0].set_title("参考图相位一致性", fontsize=12)
axes[1,0].axis("off")
plt.colorbar(im1, ax=axes[1,0], fraction=0.046, pad=0.04)
im2 = axes[1,1].imshow(pc_dis, cmap="jet")
axes[1,1].set_title("比较图相位一致性", fontsize=12)
axes[1,1].axis("off")
plt.colorbar(im2, ax=axes[1,1], fraction=0.046, pad=0.04)
im3 = axes[1,2].imshow(grad_diff, cmap="coolwarm", vmin=-1, vmax=1)
axes[1,2].set_title("梯度差异(参考-比较)", fontsize=12)
axes[1,2].axis("off")
plt.colorbar(im3, ax=axes[1,2], fraction=0.046, pad=0.04)
# 保存到临时文件
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=app.config['UPLOAD_FOLDER'])
canvas = FigureCanvas(fig)
canvas.print_png(temp_file)
temp_file.close()
return temp_file.name
# -----------------------------------------------------------------------------
# 4. Flask路由(前端页面+API接口)
# -----------------------------------------------------------------------------
@app.route('/')
def index():
"""主页面:图像上传与结果展示"""
return render_template('index.html')
@app.route('/upload', methods=['POST'])
def upload():
"""图像上传与计算接口"""
try:
# 1. 获取上传文件
if 'reference' not in request.files:
return jsonify({"status": "error", "msg": "未上传参考图像"})
ref_file = request.files['reference']
comp_files = request.files.getlist('comparison[]')
if not ref_file.filename:
return jsonify({"status": "error", "msg": "参考图像文件名不能为空"})
if not comp_files:
return jsonify({"status": "error", "msg": "未上传比较图像"})
# 2. 保存临时文件
ref_path = os.path.join(app.config['UPLOAD_FOLDER'], ref_file.filename)
ref_file.save(ref_path)
comp_paths = []
for file in comp_files:
if file.filename:
comp_path = os.path.join(app.config['UPLOAD_FOLDER'], file.filename)
file.save(comp_path)
comp_paths.append(comp_path)
# 3. 预处理与特征提取
with tf.device(DEVICE):
# 参考图处理
ref_img = preprocess_image(ref_path)
ref_feats = extract_features(ref_img)
# 比较图批量处理
results = []
for comp_path in comp_paths:
comp_img = preprocess_image(comp_path)
comp_feats = extract_features(comp_img)
# 计算分数
fsim = calculate_fsim(ref_feats, comp_feats)
fsimc = calculate_fsimc(ref_feats, comp_feats)
# 生成可视化图片
vis_path = generate_visualization(ref_feats, comp_feats, fsim, fsimc)
vis_filename = os.path.basename(vis_path)
results.append({
"filename": os.path.basename(comp_path),
"fsim": round(fsim, 4),
"fsimc": round(fsimc, 4),
"vis_filename": vis_filename
})
# 4. 返回结果
return jsonify({
"status": "success",
"ref_filename": os.path.basename(ref_path),
"results": results
})
except Exception as e:
return jsonify({"status": "error", "msg": str(e)})
@app.route('/visualization/<filename>')
def get_visualization(filename):
"""获取可视化图片"""
vis_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
if os.path.exists(vis_path):
return send_file(vis_path, mimetype='image/png')
return "图片不存在", 404
# -----------------------------------------------------------------------------
# 5. 启动服务
# -----------------------------------------------------------------------------
if __name__ == '__main__':
# 确保临时文件夹存在
if not os.path.exists(UPLOAD_FOLDER):
os.makedirs(UPLOAD_FOLDER)
# 启动Flask服务(debug=False用于生产环境,True用于开发)
app.run(host='0.0.0.0', port=5000, debug=False)
最新发布