问题4要求建立一个物理模型
J. Mar. Sci. Eng. 2024, 12, 1790. https://doi.org/10.3390/jmse12101790
研究提出了一种新的水下图像增强方法,该方法结合了高效融合边缘检测(EFED)和多尺度彩色并行频分注意力模块(MCPFA),以解决水下图像中常见的局部色偏和细节模糊的问题,实现了水下图像的有效增强。EFED模块通过快速提取图像边缘信息,为后续特征提取提供关键指导。MCPFA模块通过在多个颜色空间中并行处理,有效地校正了水下图像中的颜色失真。同时,注意力机制的引入使网络能够自适应地聚焦于图像中的感兴趣区域,进一步提高了增强效果。
正在复现中。。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import torch.nn.functional as F
import os
# 数据集类,用于加载和预处理图像
class UnderwaterDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.images = [f for f in os.listdir(root_dir) if os.path.isfile(os.path.join(root_dir, f))]
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_name = self.images[idx]
img_path = os.path.join(self.root_dir, img_name)
image = Image.open(img_path).convert('RGB')
if self.transform:
image = self.transform(image)
return image
# 数据预处理步骤
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 创建数据集和数据加载器
dataset = UnderwaterDataset(root_dir=r"C:\Users\ASUS\Desktop\附件一", transform=transform)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
class EFED(nn.Module):
def __init__(self):
super(EFED, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(128, 1, kernel_size=3, padding=1)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = self.conv3(x)
return x
class MCPFA(nn.Module):
def __init__(self):
super(MCPFA, self).__init__()
# 假设 x 的通道数是 4(因为 edge_map 的通道数是 1,与原图合并后是 4)
self.attention = nn.MultiheadAttention(embed_dim=4, num_heads=1)
def forward(self, x):
# x 的形状应该是 (N, C, H, W)
N, C, H, W = x.size()
# 将 x 调整为 (L, N, E) 形状,其中 L = H * W,N = N,E = C
x = x.view(N, C, -1).transpose(1, 2) # 调整形状为 (L, N, E)
x, _ = self.attention(x, x, x)
x = x.transpose(1, 2).view(N, C, H, W) # 恢复原始形状
return x
class UnderwaterImageEnhancement(nn.Module):
def __init__(self):
super(UnderwaterImageEnhancement, self).__init__()
self.efed = EFED()
self.mcpfa = MCPFA()
self.conv_final = nn.Conv2d(4, 3, kernel_size=1) # 定义conv_final
def forward(self, x):
edge_map = self.efed(x) # 获取边缘图
x = torch.cat([x, edge_map], dim=1) # 将边缘图与原图合并
x = self.mcpfa(x) # 应用注意力机制
x = F.interpolate(x, size=(256, 256), mode='bilinear', align_corners=False) # 上采样
x = self.conv_final(x) # 调整通道数以匹配目标图像的形状
return x
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UnderwaterImageEnhancement().to(device)
if not list(model.parameters()):
print("模型没有参数")
else:
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss() # 定义损失函数
num_epochs = 10
for epoch in range(num_epochs):
for images in dataloader:
images = images.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, images) # 计算损失
loss.backward() # 反向传播
optimizer.step() # 更新权重
print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')
torch.save(model.state_dict(), 'underwater_image_enhancement_model.pth')