分析代码:import torch
import torch.nn as nn
import torch.nn.functional as F
# --- Attention Mechanisms ---
class ChannelAttention(nn.Module):
def __init__(self, in_planes, ratio=16):
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
mid_ch = max(1, in_planes // ratio)
self.fc1 = nn.Conv2d(in_planes, mid_ch, 1, bias=False)
self.relu1 = nn.ReLU()
self.fc2 = nn.Conv2d(mid_ch, in_planes, 1, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
out = avg_out + max_out
return self.sigmoid(out)
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=3):
super(SpatialAttention, self).__init__()
padding = kernel_size // 2
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
x_cat = torch.cat([avg_out, max_out], dim=1)
x_out = self.conv1(x_cat)
return self.sigmoid(x_out)
class CBAM(nn.Module):
def __init__(self, in_planes, ratio=16, kernel_size=3):
super(CBAM, self).__init__()
self.ca = ChannelAttention(in_planes, ratio)
self.sa = SpatialAttention(kernel_size)
def forward(self, x):
x = x * self.ca(x)
x = x * self.sa(x)
return x
# NEW: Coordinate Attention (CVPR 2021)
class CoordinateAttention(nn.Module):
"""
Coordinate Attention for Efficient Mobile Network Design (CVPR 2021)
More stable and effective than complex cross-attention for small feature maps
"""
def __init__(self, in_channels, reduction=8):
super(CoordinateAttention, self).__init__()
self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
self.pool_w = nn.AdaptiveAvgPool2d((1, None))
mid_channels = max(8, in_channels // reduction)
self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=1, stride=1, padding=0)
self.bn1 = nn.BatchNorm2d(mid_channels)
self.act = nn.ReLU(inplace=True)
self.conv_h = nn.Conv2d(mid_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.conv_w = nn.Conv2d(mid_channels, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
identity = x
n, c, h, w = x.size()
# Coordinate information embedding
x_h = self.pool_h(x) # (n, c, h, 1)
x_w = self.pool_w(x).permute(0, 1, 3, 2) # (n, c, w, 1) -> (n, c, 1, w)
y = torch.cat([x_h, x_w], dim=2) # (n, c, h+w, 1)
y = self.conv1(y)
y = self.bn1(y)
y = self.act(y)
x_h, x_w = torch.split(y, [h, w], dim=2)
x_w = x_w.permute(0, 1, 3, 2)
a_h = self.conv_h(x_h).sigmoid()
a_w = self.conv_w(x_w).sigmoid()
out = identity * a_h * a_w
return out
# NEW: Efficient Channel Attention (ECA-Net, CVPR 2020)
class EfficientChannelAttention(nn.Module):
"""
ECA-Net: Efficient Channel Attention (CVPR 2020)
Parameter-free, very efficient for small networks
"""
def __init__(self, channels, gamma=2, b=1):
super(EfficientChannelAttention, self).__init__()
kernel_size = int(abs((torch.log2(torch.tensor(channels, dtype=torch.float32)) + b) / gamma))
kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
y = self.avg_pool(x)
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
y = self.sigmoid(y)
return x * y.expand_as(x)
class SelfAttentionBlock(nn.Module):
def __init__(self, in_channels, heads=4, dim_head=32):
super(SelfAttentionBlock, self).__init__()
inner_dim = heads * dim_head
self.heads = heads
self.scale = dim_head ** -0.5
self.norm = nn.LayerNorm(in_channels)
self.to_qkv = nn.Conv2d(in_channels, inner_dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(inner_dim, in_channels, 1)
self.ffn = nn.Sequential(
nn.Conv2d(in_channels, in_channels * 2, 1),
nn.GELU(),
nn.Conv2d(in_channels * 2, in_channels, 1)
)
def forward(self, x):
B, C, H, W = x.shape
x_norm = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
qkv = self.to_qkv(x_norm).chunk(3, dim=1)
q, k, v = map(lambda t: t.reshape(B, self.heads, -1, H * W).permute(0, 1, 3, 2), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = dots.softmax(dim=-1)
out = torch.matmul(attn, v)
out = out.permute(0, 1, 3, 2).reshape(B, -1, H, W)
out = self.to_out(out)
x = x + out
x = x + self.ffn(x)
return x
# --- Convolutional Blocks ---
class conv_block1(nn.Module):
def __init__(self, in_ch, out_ch, padding=0):
super(conv_block1, self).__init__()
self.net = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=1, padding=padding),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.net(x)
class conv_block2(nn.Module):
def __init__(self, in_ch, out_ch, padding=0, use_cbam=True):
super(conv_block2, self).__init__()
self.use_cbam = use_cbam
self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=1, padding=padding)
self.bn1 = nn.BatchNorm2d(out_ch)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=1, stride=1, padding=0)
self.bn2 = nn.BatchNorm2d(out_ch)
if self.use_cbam:
self.attention = CBAM(out_ch)
else:
self.attention = ChannelAttention(out_ch)
if in_ch != out_ch or padding != 0:
self.shortcut = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=1, padding=padding),
nn.BatchNorm2d(out_ch)
)
else:
self.shortcut = nn.Identity()
def forward(self, x):
residual = self.shortcut(x)
y = self.relu(self.bn1(self.conv1(x)))
y = self.bn2(self.conv2(y))
if self.use_cbam:
y = self.attention(y)
else:
y = y * self.attention(y)
y += residual
y = self.relu(y)
return y
class conv_block3(nn.Module):
def __init__(self, in_ch, out_ch, padding=0):
super(conv_block3, self).__init__()
self.net = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=padding),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.net(x)
class conv11_block(nn.Module):
def __init__(self, in_ch):
super(conv11_block, self).__init__()
self.net = nn.Sequential(
nn.Conv2d(in_ch, 2 * in_ch, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(2 * in_ch),
nn.ReLU(inplace=True),
nn.Conv2d(2 * in_ch, in_ch, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(in_ch),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.net(x)
# NEW: Coordinate Attention Fusion Block (Replaces Cross-Attention)
class CoordinateAttentionFusionBlock(nn.Module):
"""
使用Coordinate Attention替代复杂的Cross-Attention
更稳定,参数更少,适合小特征图
"""
def __init__(self, in_ch, out_ch, padding=0):
super(CoordinateAttentionFusionBlock, self).__init__()
# Main feature extraction
self.net_main = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=1, padding=padding),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
# Coordinate Attention for spatial-channel modeling
self.coord_attn = CoordinateAttention(out_ch, reduction=8)
# Feature fusion with guidance
self.fusion_conv = nn.Sequential(
nn.Conv2d(out_ch * 2, out_ch, kernel_size=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
# Shortcut
self.shortcut = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=1, padding=padding),
nn.BatchNorm2d(out_ch)
)
self.relu = nn.ReLU(inplace=True)
# Side output for alignment
side_ch_out = max(16, out_ch // 16)
self.net_side_out = nn.Sequential(
nn.Conv2d(out_ch, side_ch_out, kernel_size=3, stride=1, padding=0),
nn.BatchNorm2d(side_ch_out)
)
def forward(self, x_main, x_side):
residual = self.shortcut(x_main)
# Extract main features
y_main = self.net_main(x_main)
# Apply coordinate attention
y_attn = self.coord_attn(y_main)
# Adaptive fusion with guidance from x_side
# Resize x_side to match y_attn if needed
if y_attn.shape[2:] != x_side.shape[2:]:
x_side_resized = F.interpolate(x_side, size=y_attn.shape[2:], mode='bilinear', align_corners=False)
else:
x_side_resized = x_side
# Concatenate and fuse
fused = torch.cat([y_attn, x_side_resized], dim=1)
y = self.fusion_conv(fused)
# Residual connection
y = self.relu(y + residual)
# Side output
y_side = self.net_side_out(y)
y_side = torch.softmax(y_side, dim=1)
return y, y_side
# NEW: Dynamic Feature Alignment Module (替代GRL)
class DynamicFeatureAlignment(nn.Module):
"""
Dynamic Feature Alignment using learnable weights
替代GRL,使用可学习的特征对齐,更稳定
基于FDA (Fourier Domain Adaptation) 的思想简化版
"""
def __init__(self, feature_dim):
super(DynamicFeatureAlignment, self).__init__()
# Style transfer parameters (learnable)
self.alpha = nn.Parameter(torch.tensor(0.1)) # 初始值较小
# Feature statistics alignment
self.align_conv = nn.Sequential(
nn.Conv2d(feature_dim, feature_dim, kernel_size=1),
nn.BatchNorm2d(feature_dim),
nn.ReLU(inplace=True)
)
def forward(self, x):
"""
使用特征统计信息进行软对齐,而不是对抗训练
"""
# Calculate feature statistics
mean = x.mean(dim=[2, 3], keepdim=True)
var = x.var(dim=[2, 3], keepdim=True)
# Normalize
x_norm = (x - mean) / (var + 1e-5).sqrt()
# Apply learnable alignment
x_aligned = self.align_conv(x_norm)
# Soft blending with original features
x_out = self.alpha * x_aligned + (1 - self.alpha) * x
return x_out
# NEW: Multi-Scale Feature Aggregation (增强多尺度特征)
class MultiScaleFusionBlock(nn.Module):
"""
Multi-scale feature aggregation with efficient attention
用于增强特征表达能力
"""
def __init__(self, in_channels):
super(MultiScaleFusionBlock, self).__init__()
# Multi-scale convolutions
self.branch1 = nn.Sequential(
nn.Conv2d(in_channels, in_channels // 4, kernel_size=1),
nn.BatchNorm2d(in_channels // 4),
nn.ReLU(inplace=True)
)
self.branch2 = nn.Sequential(
nn.Conv2d(in_channels, in_channels // 4, kernel_size=3, padding=1),
nn.BatchNorm2d(in_channels // 4),
nn.ReLU(inplace=True)
)
self.branch3 = nn.Sequential(
nn.Conv2d(in_channels, in_channels // 4, kernel_size=5, padding=2),
nn.BatchNorm2d(in_channels // 4),
nn.ReLU(inplace=True)
)
self.branch4 = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, in_channels // 4, kernel_size=1),
nn.BatchNorm2d(in_channels // 4),
nn.ReLU(inplace=True)
)
# Fusion
self.fusion = nn.Sequential(
nn.Conv2d(in_channels, in_channels, kernel_size=1),
nn.BatchNorm2d(in_channels),
nn.ReLU(inplace=True)
)
# ECA attention for refinement
self.eca = EfficientChannelAttention(in_channels)
def forward(self, x):
b1 = self.branch1(x)
b2 = self.branch2(x)
b3 = self.branch3(x)
b4 = F.interpolate(self.branch4(x), size=x.shape[2:], mode='bilinear', align_corners=False)
concat = torch.cat([b1, b2, b3, b4], dim=1)
out = self.fusion(concat)
out = self.eca(out)
return out + x # Residual
# Keep original conv_block4 for compatibility
class conv_block4(nn.Module):
def __init__(self, in_ch, out_ch, padding=0):
super(conv_block4, self).__init__()
self.net_main = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=1, padding=padding),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, 2 * out_ch, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(2 * out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(2 * out_ch, out_ch, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(out_ch)
)
self.net_side = nn.Sequential(
nn.Conv2d(out_ch, max(1, int(out_ch / 15)), kernel_size=3, stride=1, padding=0),
nn.BatchNorm2d(max(1, int(out_ch / 15)))
)
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=1, padding=padding),
nn.BatchNorm2d(out_ch)
)
self.relu = nn.ReLU(inplace=True)
def forward(self, x_main, x_side):
y_main = self.net_main(x_main)
y = y_main + self.conv(x_main) - x_side
y_side = self.net_side(y_main - x_side)
y_side = torch.softmax(y_side, dim=1)
return y, y_side