def blotless_update(D, X, Y, sparsity_pattern, block_size=4):
num_blocks = num_atoms // block_size
for block_idx in range(num_blocks):
# 1. 确定当前要更新的小组
T = range(block_idx*block_size, (block_idx+1)*block_size) # 小组索引
# 2. 计算残差(去掉这个小组的贡献)
Y_residual = Y - D[:, ~T] @ X[~T, :] # ~T表示除了T之外的部分
# 3. 提取这个小组的当前原子和系数
D_T = D[:, T]
X_T = X[T, :]
# 4. 应用IterTLS算法更新这个小组
D_T_new, X_T_new = itertls_update(D_T, X_T, Y_residual, sparsity_pattern[T])
# 5. 更新字典和系数
D[:, T] = D_T_new
X[T, :] = X_T_new
return D, X
def itertls_update(D_T, X_T, Y_residual, sparsity_pattern_T):
# IterTLS迭代算法
X_hat = X_T.copy() # 当前估计
for iter in range(max_iters):
# a. 构造线性最小二乘问题
# Y_residual^T ≈ X_T^T × D_T^T
# 等价于 A × Z ≈ B 形式
# b. 用总体最小二乘法(TLS)求解
Z = solve_tls_problem(Y_residual, X_hat)
# c. 从Z中恢复D_T和新的X_hat
D_T_new, X_new = extract_from_Z(Z)
# d. 保持稀疏模式:非零位置不变,零位置强制为0
X_hat = project_to_sparsity(X_new, sparsity_pattern_T)
# e. 检查收敛
if converged:
break
return D_T_new, X_hat
这是个伪代码
,请帮我编辑一个import os
import cv2
import numpy as np
from skimage import io, transform, filters
from skimage.util import random_noise
import xml.etree.ElementTree as ET # 新增XML解析库
from parameter import IMAGE_SIZE, PATCH_SIZE, OVERLAP_RATIO, DATA_ROOT, SAVE_PATH, TRAIN_RATIO
def load_sar_dataset(data_root):
"""
加载SAR-aircraft-1.0数据集,读取图像与XML标注信息
返回:images(list) - 图像列表,annotations(list) - 标注列表(边界框、散射强度)
"""
images = []
annotations = []
img_dir = os.path.join(data_root, "images")
anno_dir = os.path.join(data_root, "annotations")
for img_name in os.listdir(img_dir):
if img_name.endswith((".png", ".jpg")):
# 读取图像并转为灰度图
img_path = os.path.join(img_dir, img_name)
img = io.imread(img_path, as_gray=True)
img = transform.resize(img, (IMAGE_SIZE, IMAGE_SIZE)) # 统一尺寸
images.append(img)
# 读取对应XML标注
anno_name = img_name.replace(".png", ".xml").replace(".jpg", ".xml")
anno_path = os.path.join(anno_dir, anno_name)
tree = ET.parse(anno_path)
root = tree.getroot()
# 解析XML中的目标边界框(适配常见VOC格式,可根据实际XML结构调整)
obj_list = []
for obj in root.findall("object"):
bbox = obj.find("bndbox")
x1 = int(float(bbox.find("xmin").text))
y1 = int(float(bbox.find("ymin").text))
x2 = int(float(bbox.find("xmax").text))
y2 = int(float(bbox.find("ymax").text))
# 若XML包含散射强度信息,可在此处添加解析(示例默认保留字段)
obj_list.append({"bbox": [x1, y1, x2, y2], "scatter_intensity": 1.0})
annotations.append({"objects": obj_list})
return images, annotations
def adaptive_wavelet_denoise(img, wavelet_level=3):
"""
自适应小波阈值去噪,对应报告3.1.2节
输入:img - 原始SAR图像,wavelet_level - 小波分解层数
输出:denoised_img - 去噪后图像
"""
noise_std = np.sqrt(np.var(img) / 2) # 简化估计,可根据数据集优化
denoised_img = filters.threshold_local(img, block_size=5, offset=noise_std * 1.5)
denoised_img = np.where(img > denoised_img, img, denoised_img)
return denoised_img
def data_augmentation(img):
"""
数据增强:旋转、缩放、加噪,对应报告3.1.1节
输入:img - 原始图像
输出:aug_imgs - 增强后的图像列表
"""
aug_imgs = [img]
# 旋转增强(0°, 90°, 180°, 270°)
for angle in [90, 180, 270]:
rotated = transform.rotate(img, angle, preserve_range=True)
aug_imgs.append(rotated)
# 缩放增强(0.8-1.2倍)
for scale in [0.8, 1.0, 1.2]:
scaled = transform.rescale(img, scale, preserve_range=True)
scaled = transform.resize(scaled, (IMAGE_SIZE, IMAGE_SIZE))
aug_imgs.append(scaled)
# 加噪增强(SNR 10dB-20dB)
snr_list = [10, 15, 20]
for snr in snr_list:
noisy = random_noise(img, var=1 / (10 ** (snr / 10))) # 按SNR计算噪声方差
aug_imgs.append(noisy)
return aug_imgs
def generate_patches(images, annotations):
obj_patches, clutter_patches = [], []
step = PATCH_SIZE # ← 改①:无重叠
for img, anno in zip(images, annotations):
bboxes = [obj["bbox"] for obj in anno["objects"]]
for y in range(0, img.shape[0] - PATCH_SIZE + 1, step):
for x in range(0, img.shape[1] - PATCH_SIZE + 1, step):
patch = img[y:y+PATCH_SIZE, x:x+PATCH_SIZE]
patch = (patch - np.min(patch)) / (np.max(patch) - np.min(patch) + 1e-9)
cx, cy = x + PATCH_SIZE//2, y + PATCH_SIZE//2
if any(x1 < cx < x2 and y1 < cy < y2 for x1,y1,x2,y2 in bboxes):
obj_patches.append(patch.ravel())
else:
clutter_patches.append(patch.ravel())
# ← 改②:硬上限
obj_patches = np.array(obj_patches[:10000], dtype=np.float64).T
clutter_patches = np.array(clutter_patches[:10000], dtype=np.float64).T
return obj_patches, clutter_patches
# -------------------------- 预处理执行入口 --------------------------
def preprocess_pipeline(max_images=None):
# 1. 加载原始数据
print("正在加载SAR-aircraft-1.0数据集...")
images, annotations = load_sar_dataset(DATA_ROOT)
if max_images is not None:
images = images[:max_images]
annotations = annotations[:max_images]
# 2. 数据增强与去噪
print("正在进行数据增强与自适应去噪...")
aug_images = []
aug_annotations = [] # 标注随图像增强同步复制
for img, anno in zip(images, annotations):
denoised_img = adaptive_wavelet_denoise(img)
aug_imgs = data_augmentation(denoised_img)
aug_images.extend(aug_imgs)
aug_annotations.extend([anno] * len(aug_imgs))
# 3. 划分训练集与测试集
print("正在划分训练集与测试集...")
num_total = len(aug_images)
num_train = int(num_total * TRAIN_RATIO)
train_indices = np.random.choice(num_total, num_train, replace=False)
test_indices = [i for i in range(num_total) if i not in train_indices]
train_images = [aug_images[i] for i in train_indices]
train_annotations = [aug_annotations[i] for i in train_indices]
test_images = [aug_images[i] for i in test_indices]
test_annotations = [aug_annotations[i] for i in test_indices]
# 保存测试图像和标注(供后续检测使用)
np.save(os.path.join(SAVE_PATH, "test_images.npy"), test_images)
np.save(os.path.join(SAVE_PATH, "test_annotations.npy"), test_annotations)
# 4. 生成图像块样本
print("正在生成目标块与杂波块...")
train_obj, train_clutter = generate_patches(train_images, train_annotations)
test_obj, test_clutter = generate_patches(test_images, test_annotations)
# 保存预处理结果(避免重复计算)
np.save(os.path.join(SAVE_PATH, "train_obj.npy"), train_obj)
np.save(os.path.join(SAVE_PATH, "train_clutter.npy"), train_clutter)
np.save(os.path.join(SAVE_PATH, "test_obj.npy"), test_obj)
np.save(os.path.join(SAVE_PATH, "test_clutter.npy"), test_clutter)
print(f"预处理完成!样本保存至{os.path.join(SAVE_PATH)}")
print(f"训练集:目标块{train_obj.shape[1]}个,杂波块{train_clutter.shape[1]}个")
print(f"测试集:目标块{test_obj.shape[1]}个,杂波块{test_clutter.shape[1]}个")
return train_obj, train_clutter, test_obj, test_clutter
# 执行预处理(首次运行需执行,后续可直接加载npy文件)
train_obj, train_clutter, test_obj, test_clutter = preprocess_pipeline(max_images=500)预处理结果结束后的SAR图像目标检测代码