2024亚太杯A题第四问,复杂水下图像的处理

问题4要求建立一个物理模型

J. Mar. Sci. Eng. 2024, 12, 1790. https://doi.org/10.3390/jmse12101790

研究提出了一种新的水下图像增强方法,该方法结合了高效融合边缘检测(EFED)和多尺度彩色并行频分注意力模块(MCPFA),以解决水下图像中常见的局部色偏和细节模糊的问题,实现了水下图像的有效增强。EFED模块通过快速提取图像边缘信息,为后续特征提取提供关键指导。MCPFA模块通过在多个颜色空间中并行处理,有效地校正了水下图像中的颜色失真。同时,注意力机制的引入使网络能够自适应地聚焦于图像中的感兴趣区域,进一步提高了增强效果。

正在复现中。。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import torch.nn.functional as F
import os

# 数据集类,用于加载和预处理图像
class UnderwaterDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = [f for f in os.listdir(root_dir) if os.path.isfile(os.path.join(root_dir, f))]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.root_dir, img_name)
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image

# 数据预处理步骤
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 创建数据集和数据加载器
dataset = UnderwaterDataset(root_dir=r"C:\Users\ASUS\Desktop\附件一", transform=transform)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

class EFED(nn.Module):
    def __init__(self):
        super(EFED, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(128, 1, kernel_size=3, padding=1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.conv3(x)
        return x

class MCPFA(nn.Module):
    def __init__(self):
        super(MCPFA, self).__init__()
        # 假设 x 的通道数是 4(因为 edge_map 的通道数是 1,与原图合并后是 4)
        self.attention = nn.MultiheadAttention(embed_dim=4, num_heads=1)

    def forward(self, x):
        # x 的形状应该是 (N, C, H, W)
        N, C, H, W = x.size()
        # 将 x 调整为 (L, N, E) 形状,其中 L = H * W,N = N,E = C
        x = x.view(N, C, -1).transpose(1, 2)  # 调整形状为 (L, N, E)
        x, _ = self.attention(x, x, x)
        x = x.transpose(1, 2).view(N, C, H, W)  # 恢复原始形状
        return x

class UnderwaterImageEnhancement(nn.Module):
    def __init__(self):
        super(UnderwaterImageEnhancement, self).__init__()
        self.efed = EFED()
        self.mcpfa = MCPFA()
        self.conv_final = nn.Conv2d(4, 3, kernel_size=1)  # 定义conv_final

    def forward(self, x):
        edge_map = self.efed(x)  # 获取边缘图
        x = torch.cat([x, edge_map], dim=1)  # 将边缘图与原图合并
        x = self.mcpfa(x)  # 应用注意力机制
        x = F.interpolate(x, size=(256, 256), mode='bilinear', align_corners=False)  # 上采样
        x = self.conv_final(x)  # 调整通道数以匹配目标图像的形状
        return x

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UnderwaterImageEnhancement().to(device)

if not list(model.parameters()):
    print("模型没有参数")
else:
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.MSELoss()  # 定义损失函数

num_epochs = 10
for epoch in range(num_epochs):
    for images in dataloader:
        images = images.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, images)  # 计算损失
        loss.backward()  # 反向传播
        optimizer.step()  # 更新权重
    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')

torch.save(model.state_dict(), 'underwater_image_enhancement_model.pth')

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值