import torch
import torch.nn.functional as F
from PIL import Image
import numpy as np
import argparse
import os
import torchvision.transforms as transforms
def process_image_frequency(image_path, output_path, source_freq=16, target_freq=14):
"""
将图像从源频率等级处理为目标频率等级
参数:
- image_path: 输入图像路径
- output_path: 输出图像保存路径
- source_freq: 源频率等级
- target_freq: 目标频率等级
"""
# 加载图像
img = Image.open(image_path).convert('RGB')
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # 归一化到[-1, 1]
])
img_tensor = transform(img).unsqueeze(0) # 添加批次维度 [1, C, H, W]
# 获取图像尺寸
_, _, H, W = img_tensor.shape
print(f"原始图像尺寸: {H}x{W}")
print(f"处理中: 从频率等级 {source_freq} 降低到 {target_freq}")
# 使用FAR的频率处理方法
# 1. 先降采样到目标频率对应的分辨率
target_size = (target_freq, target_freq)
downsampled = F.interpolate(img_tensor, size=target_size, mode='area')
# 2. 然后上采样回原始分辨率
processed = F.interpolate(downsampled, size=(H, W), mode='bicubic')
# 将张量转换回图像格式
# 从[-1,1]转回[0,1]再转到[0,255]
processed = processed * 0.5 + 0.5
processed = processed.clamp(0, 1)
processed = processed.squeeze(0).permute(1, 2, 0).cpu().numpy()
processed = (processed * 255).astype(np.uint8)
# 保存图像
output_img = Image.fromarray(processed)
output_img.save(output_path)
print(f"处理完成! 已保存到: {output_path}")
return processed
def main():
parser = argparse.ArgumentParser(description='图像频率处理工具')
parser.add_argument('--input', type=str, required=True, help='输入图像路径')
parser.add_argument('--output', type=str, help='输出图像路径')
parser.add_argument('--source_freq', type=int, default=16, help='源频率等级')
parser.add_argument('--target_freq', type=int, default=14, help='目标频率等级')
args = parser.parse_args()
input_path = args.input
if args.output:
output_path = args.output
else:
# 如果没有指定输出路径,则在原文件名后添加频率等级信息
filename, ext = os.path.splitext(input_path)
output_path = f"{filename}_freq{args.target_freq}{ext}"
process_image_frequency(
input_path,
output_path,
args.source_freq,
args.target_freq
)
if __name__ == "__main__":
main()
212345
最新推荐文章于 2025-04-29 11:53:09 发布