import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
import argparse
import trimesh
from pathlib import Path
def normalize_disparity_map(disparity_map):
"""归一化视差图用于可视化"""
disp = np.maximum(disparity_map, 0.0)
if disp.max() == 0:
return np.zeros_like(disp, dtype=np.uint8)
return (disp / disp.max() * 255).astype(np.uint8)
def visualize_disparity_map(disparity_map, gt_map=None, save_path=None):
"""可视化或保存视差图(支持无 GT 模式)"""
disp_vis = normalize_disparity_map(disparity_map)
if gt_map is not None:
gt_vis = normalize_disparity_map(gt_map)
concat_map = np.concatenate([disp_vis, gt_vis], axis=1)
labels = ['Disparity', 'Ground Truth']
else:
concat_map = disp_vis
labels = ['Disparity']
if save_path is None:
plt.figure(figsize=(10, 4))
plt.imshow(concat_map, cmap='gray')
plt.title(' | '.join(labels))
plt.axis('off')
plt.show()
else:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
plt.imsave(save_path, concat_map, cmap='gray', vmin=0, vmax=255)
print(f"📊 视差图已保存至: {save_path}")
def task1_compute_disparity_map_simple(
ref_img: np.ndarray,
sec_img: np.ndarray,
window_size: int,
disparity_range: tuple,
matching_function: str
):
"""简单滑动窗口立体匹配"""
H, W = ref_img.shape
min_disp, max_disp = disparity_range
pad = window_size // 2
disparity_map = np.zeros((H, W), dtype=np.float32)
# 边界填充
ref_pad = np.pad(ref_img, ((pad, pad), (pad, pad)), mode='constant')
sec_pad = np.pad(sec_img, ((pad, pad), (pad, pad)), mode='constant')
for h in range(pad, H + pad):
for w in range(pad, W + pad):
ref_win = ref_pad[h - pad:h + pad + 1, w - pad:w + pad + 1]
costs = []
for d in range(min_disp, max_disp):
sec_col = w - d
if sec_col < pad or sec_col >= W + pad:
cost = np.inf
else:
sec_win = sec_pad[h - pad:h + pad + 1, sec_col - pad:sec_col + pad + 1]
if matching_function == 'SSD':
cost = np.sum((ref_win - sec_win) ** 2)
elif matching_function == 'SAD':
cost = np.sum(np.abs(ref_win - sec_win))
elif matching_function == 'normalized_correlation':
rw_mean = ref_win - ref_win.mean()
sw_mean = sec_win - sec_win.mean()
numerator = np.sum(rw_mean * sw_mean)
denominator = np.sqrt(np.sum(rw_mean**2) * np.sum(sw_mean**2)) + 1e-8
cost = -numerator / denominator # 越大越好 → 取负号变最小化
else:
raise ValueError(f"Unknown function: {matching_function}")
costs.append(cost)
best_d = np.argmin(costs) + min_disp
disparity_map[h - pad, w - pad] = float(best_d)
return disparity_map
def task1_simple_disparity(ref_img, sec_img, gt_map, img_name='tsukuba'):
"""执行 Task1:多种参数组合测试(适配Tsukuba视差范围)"""
window_sizes = [5, 9]
disparity_range = (0, 16) # 修正:Tsukuba实际视差范围是0-16
matching_functions = ['SSD', 'SAD']
results = []
for ws in window_sizes:
for mf in matching_functions:
print(f"🔧 计算中: ws={ws}, func={mf}")
disp_map = task1_compute_disparity_map_simple(ref_img, sec_img, ws, disparity_range, mf)
results.append((disp_map, ws, mf, disparity_range))
# 保存结果
dmin, dmax = disparity_range
visualize_disparity_map(
disp_map, gt_map,
save_path=f"output/task1_{img_name}_{ws}_{dmin}_{dmax}_{mf}.png"
)
return results
def task2_compute_depth_map(disparity_map, baseline=0.1, focal_length=350.0):
"""由视差图计算深度图 z = fB / d(适配Tsukuba实际参数)"""
with np.errstate(divide='ignore', invalid='ignore'):
depth_map = (focal_length * baseline) / (disparity_map + 1e-6)
depth_map[disparity_map <= 0] = 0
depth_map[depth_map > 50] = 0 # 修正:Tsukuba场景深度不超过50m
return depth_map
def task2_visualize_pointcloud(
ref_img: np.ndarray,
disparity_map: np.ndarray,
save_path: str = 'output/task2_tsukuba.ply',
baseline: float = 0.1, # 修正:Tsukuba基线0.1m(10cm)
focal_length: float = 350.0 # 修正:Tsukuba像素焦距350
):
"""从视差图生成带颜色的 3D 点云并导出 .ply 文件"""
depth_map = task2_compute_depth_map(disparity_map, baseline, focal_length)
H, W = ref_img.shape[:2]
y_coords, x_coords = np.mgrid[0:H, 0:W]
# 构建 3D 坐标(适配Tsukuba相机内参)
points = np.stack([
(x_coords - W / 2) * depth_map / focal_length,
(y_coords - H / 2) * depth_map / focal_length,
depth_map
], axis=-1).reshape(-1, 3)
colors = cv2.cvtColor(ref_img, cv2.COLOR_BGR2RGB).reshape(-1, 3)
# 过滤有效点(放宽条件,保留更多点)
valid_mask = (
~np.isnan(points[:, 2]) &
~np.isinf(points[:, 2]) &
(points[:, 2] > 0.5) & # 去除过近点
(points[:, 2] < 50) # 去除过远点
)
points = points[valid_mask]
colors = colors[valid_mask]
# 创建点云并保存
os.makedirs(os.path.dirname(save_path), exist_ok=True)
pointcloud = trimesh.PointCloud(vertices=points, colors=colors)
pointcloud.export(save_path, file_type='ply')
print(f"✅ 点云已保存至: {save_path}(共 {len(points)} 个点)")
def task3_compute_disparity_map_dp(ref_img, sec_img):
"""动态规划行级立体匹配(修复窗口索引+适配Tsukuba)"""
H, W = ref_img.shape
window_size = 5
pad = window_size // 2
max_disp = 16 # 修正:匹配Tsukuba视差范围
lambda_dp = 8
# 边界填充(与Task1一致)
ref_pad = np.pad(ref_img, ((pad, pad), (pad, pad)), mode='constant')
sec_pad = np.pad(sec_img, ((pad, pad), (pad, pad)), mode='constant')
disparity_map = np.zeros((H, W), dtype=np.float32)
for h in range(pad, H + pad): # 修正:从pad开始遍历,匹配填充后的窗口
costs = np.full((W, max_disp), np.inf, dtype=np.float32)
# 预计算 SSD 成本(修复窗口索引)
for w in range(pad, W + pad):
ref_win = ref_pad[h - pad:h + pad + 1, w - pad:w + pad + 1]
for d in range(max_disp):
sec_w = w - d
if sec_w < pad or sec_w >= W + pad:
continue
sec_win = sec_pad[h - pad:h + pad + 1, sec_w - pad:sec_w + pad + 1]
costs[w - pad, d] = np.sum((ref_win - sec_win) ** 2)
# 动态规划(Viterbi 算法)
V = np.zeros((W, max_disp)) # 最小累计成本
parent = np.zeros((W, max_disp), dtype=int)
V[0, :] = costs[0, :]
for w in range(1, W):
for d in range(max_disp):
prev_cost = V[w - 1, :] + lambda_dp * np.abs(np.arange(max_disp) - d)
best_prev = np.argmin(prev_cost)
V[w, d] = costs[w, d] + prev_cost[best_prev]
parent[w, d] = best_prev
# 回溯最优路径
cur_d = np.argmin(V[-1, :])
path = []
for w in range(W - 1, -1, -1):
path.append(cur_d)
cur_d = parent[w, cur_d]
path.reverse()
disparity_map[h - pad, :] = path
return disparity_map
def main(tasks):
script_dir = Path(__file__).resolve().parent
data_dir = script_dir / "data"
output_dir = script_dir / "output"
output_dir.mkdir(exist_ok=True)
def load_image(filename, grayscale=False):
filepath = data_dir / filename
if not filepath.exists():
print(f"❌ 文件不存在: {filepath}")
return None
flag = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR
img = cv2.imread(str(filepath), flag)
if img is None:
print(f"⚠️ 图像无法解码: {filepath}")
return img
print(f"📁 正在从 {data_dir} 加载数据...")
# --- 仅尝试加载 tsukuba 数据(修复GT加载为灰度图)---
tsukuba_img1 = load_image("tsukuba1.jpg")
tsukuba_img2 = load_image("tsukuba2.jpg")
tsukuba_gt = load_image("tsukuba_gt.jpg", grayscale=True) # 修正:GT应加载为灰度图
if tsukuba_img1 is None or tsukuba_img2 is None:
print("🛑 至少缺少 tsukuba1.jpg 或 tsukuba2.jpg,请检查 data/ 目录!")
return
tsukuba_img1_gray = cv2.cvtColor(tsukuba_img1, cv2.COLOR_BGR2GRAY).astype(np.float32)
tsukuba_img2_gray = cv2.cvtColor(tsukuba_img2, cv2.COLOR_BGR2GRAY).astype(np.float32)
tsukuba_gt_gray = tsukuba_gt.astype(np.float32) if tsukuba_gt is not None else None
if tsukuba_gt_gray is not None:
print(f"✅ 成功加载图像: tsukuba1.jpg, tsukuba2.jpg, tsukuba_gt.jpg")
else:
print(f"✅ 成功加载图像: tsukuba1.jpg, tsukuba2.jpg(未找到 ground truth)")
# -------------------------------
# Task 0: OpenCV StereoBM(推荐使用)
# -------------------------------
if '0' in tasks:
print('🔧 Running task0: OpenCV StereoBM...')
stereo = cv2.StereoBM.create(numDisparities=16, blockSize=15) # 修正:视差范围匹配Tsukuba
disp_raw = stereo.compute(tsukuba_img1_gray, tsukuba_img2_gray).astype(np.float32)
disparity_cv2 = disp_raw / 16.0 # 解码 fixed-point
disparity_cv2[disparity_cv2 <= 0] = 0
visualize_disparity_map(
disparity_cv2, tsukuba_gt_gray,
save_path="output/task0_tsukuba_cv2.png"
)
if '2' in tasks:
task2_visualize_pointcloud(
tsukuba_img1, disparity_cv2,
save_path='output/task2_tsukuba_cv2.ply'
)
# -------------------------------
# Task 1: 简单匹配算法
# -------------------------------
if '1' in tasks:
print('🔍 Running task1: Simple Matching...')
maps = task1_simple_disparity(tsukuba_img1_gray, tsukuba_img2_gray, tsukuba_gt_gray, 'tsukuba')
if '2' in tasks:
for m, ws, mf, dr in maps:
dmin, dmax = dr
task2_visualize_pointcloud(
tsukuba_img1, m,
save_path=f'output/task2_tsukuba_{ws}_{dmin}_{dmax}_{mf}.ply'
)
# -------------------------------
# Task 3: 动态规划
# -------------------------------
if '3' in tasks:
print('🔄 Running task3: Dynamic Programming...') # 修正:规范打印信息
disparity_dp = task3_compute_disparity_map_dp(tsukuba_img1_gray, tsukuba_img2_gray)
disparity_dp[disparity_dp <= 0] = 0
visualize_disparity_map(
disparity_dp, tsukuba_gt_gray,
save_path='output/task3_tsukuba_dp.png'
)
if '2' in tasks:
task2_visualize_pointcloud(
tsukuba_img1, disparity_dp,
save_path='output/task2_tsukuba_dp.ply'
)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Homework 4: Stereo Matching (Tsukuba Only)')
parser.add_argument('--tasks', type=str, default='02', help='要运行的任务,例如 0, 02, 123')
args = parser.parse_args()
main(args.tasks)。
最新发布