import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import DropPath, to_2tuple
from .slide import SlideAttention
from .dat_blocks import *
# from .nat import NeighborhoodAttention2D
from .qna import FusedKQnA
class LayerScale(nn.Module):
def __init__(self,
dim: int,
inplace: bool = False,
init_values: float = 1e-5):
super().__init__()
self.inplace = inplace
self.weight = nn.Parameter(torch.ones(dim) * init_values)
def forward(self, x):
if self.inplace:
return x.mul_(self.weight.view(-1, 1, 1))
else:
return x * self.weight.view(-1, 1, 1)
class TransformerStage(nn.Module):
def __init__(self, fmap_size, window_size, ns_per_pt,
dim_in, dim_embed, depths, stage_spec, n_groups,
use_pe, sr_ratio,
heads, heads_q, stride,
offset_range_factor,
dwc_pe, no_off, fixed_pe,
attn_drop, proj_drop, expansion, drop, drop_path_rate,
use_dwc_mlp, ksize, nat_ksize,
k_qna, nq_qna, qna_activation,
layer_scale_value, use_lpu, log_cpb):
super().__init__()
fmap_size = to_2tuple(fmap_size)
self.depths = depths
hc = dim_embed // heads
assert dim_embed == heads * hc
self.proj = nn.Conv2d(dim_in, dim_embed, 1, 1, 0) if dim_in != dim_embed else nn.Identity()
self.stage_spec = stage_spec
self.use_lpu = use_lpu
self.ln_cnvnxt = nn.ModuleDict(
{str(d): LayerNormProxy(dim_embed) for d in range(depths) if stage_spec[d] == 'X'}
)
self.layer_norms = nn.ModuleList(
[LayerNormProxy(dim_embed) if stage_spec[d // 2] != 'X' else nn.Identity() for d in range(2 * depths)]
)
mlp_fn = TransformerMLPWithConv if use_dwc_mlp else TransformerMLP
self.mlps = nn.ModuleList(
[
mlp_fn(dim_embed, expansion, drop) for _ in range(depths)
]
)
self.attns = nn.ModuleList()
self.drop_path = nn.ModuleList()
self.layer_scales = nn.ModuleList(
[
LayerScale(dim_embed, init_values=layer_scale_value) if layer_scale_value > 0.0 else nn.Identity()
for _ in range(2 * depths)
]
)
self.local_perception_units = nn.ModuleList(
[
nn.Conv2d(dim_embed, dim_embed, kernel_size=3, stride=1, padding=1,
groups=dim_embed) if use_lpu else nn.Identity()
for _ in range(depths)
]
)
for i in range(depths):
if stage_spec[i] == 'L': #W注意力机制
self.attns.append(
LocalAttention(dim_embed, heads, window_size, attn_drop, proj_drop)
)
# 在TransformerStage的__init__中修改'D'分支
elif stage_spec[i] == 'D': # 可变形注意力
self.attns.append(
DAttentionBaseline(fmap_size, fmap_size, heads,
hc, n_groups, attn_drop, proj_drop,
stride, offset_range_factor, use_pe, dwc_pe,
no_off, fixed_pe, ksize, log_cpb)
)
elif stage_spec[i] == 'S': #SW注意力机制
shift_size = math.ceil(window_size / 2)
self.attns.append(
ShiftWindowAttention(dim_embed, heads, window_size, attn_drop, proj_drop, shift_size, fmap_size)
)
# elif stage_spec[i] == 'N':
# self.attns.append(
# NeighborhoodAttention2D(dim_embed, nat_ksize, heads, attn_drop, proj_drop)
# )
elif stage_spec[i] == 'P': #金字塔注意力机制
self.attns.append(
PyramidAttention(dim_embed, heads, attn_drop, proj_drop, sr_ratio)
)
elif stage_spec[i] == 'Q':
self.attns.append(
FusedKQnA(nq_qna, dim_embed, heads_q, k_qna, 1, 0, qna_activation)
)
elif self.stage_spec[i] == 'X':
self.attns.append(
nn.Conv2d(dim_embed, dim_embed, kernel_size=window_size, padding=window_size // 2, groups=dim_embed)
)
elif self.stage_spec[i] == 'E':
self.attns.append(
SlideAttention(dim_embed, heads, 3)
)
else:
raise NotImplementedError(f'Spec: {stage_spec[i]} is not supported.')
self.drop_path.append(DropPath(drop_path_rate[i]) if drop_path_rate[i] > 0.0 else nn.Identity())
def forward(self, x):
x = self.proj(x)
for d in range(self.depths):
if self.use_lpu:
x0 = x
x = self.local_perception_units[d](x.contiguous())
x = x + x0
if self.stage_spec[d] == 'X':
x0 = x
x = self.attns[d](x)
x = self.mlps[d](self.ln_cnvnxt[str(d)](x))
x = self.drop_path[d](x) + x0
#else:
# x0 = x
# x, pos, ref = self.attns[d](self.layer_norms[2 * d](x))
# x = self.layer_scales[2 * d](x)
# x = self.drop_path[d](x) + x0
# x0 = x
# x = self.mlps[d](self.layer_norms[2 * d + 1](x))
# x = self.layer_scales[2 * d + 1](x)
# x = self.drop_path[d](x) + x0
else:
x0 = x
x = self.attns[d](self.layer_norms[2 * d](x)) # 仅接收x,不接收pos和ref
x = self.layer_scales[2 * d](x)
x = self.drop_path[d](x) + x0
x0 = x
x = self.mlps[d](self.layer_norms[2 * d + 1](x))
x = self.layer_scales[2 * d + 1](x)
x = self.drop_path[d](x) + x0
return x
from model.attention.EMA import EMA as EMAAttention
from model.attention.SKAttention import SKAttention
from model.attention.SpatialAttentionModule import SpatialAttentionModule
from model.attention.CBAM import ChannelAttention
class NEW_Att(nn.Module):
def __init__(self, in_channels, rate=4):
super(NEW_Att, self).__init__()
# 通道注意力子模块
self.channel_attention = nn.Sequential(
nn.Linear(in_channels, int(in_channels / rate)), # 线性层,将通道数缩减到1/rate
nn.ReLU(inplace=True), # ReLU激活函数
nn.Linear(int(in_channels / rate), in_channels) # 线性层,将通道数恢复到原始大小
)
# 空间注意力子模块
self.spatial_attention = nn.Sequential(
nn.Conv2d(in_channels, int(in_channels / rate), kernel_size=7, padding=3), # 7x7卷积,通道数缩减到1/rate
nn.BatchNorm2d(int(in_channels / rate)), # 批归一化
nn.ReLU(inplace=True), # ReLU激活函数
nn.Conv2d(int(in_channels / rate), in_channels, kernel_size=7, padding=3), # 7x7卷积,恢复到原始通道数
nn.BatchNorm2d(in_channels) # 批归一化
)
# 通道洗牌操作函数
def channel_shuffle(self, x, groups):
batchsize, num_channels, height, width = x.size()
channels_per_group = num_channels // groups
# 调整形状,分组
x = x.view(batchsize, groups, channels_per_group, height, width)
# 转置,打乱组内通道
x = torch.transpose(x, 1, 2).contiguous()
# 恢复原始形状
x = x.view(batchsize, -1, height, width)
return x
# 前向传播函数
def forward(self, x):
b, c, h, w = x.shape # 获取输入张量的形状
x_permute = x.permute(0, 2, 3, 1).view(b, -1, c) # 调整形状,便于通道注意力操作
x_att_permute = self.channel_attention(x_permute).view(b, h, w, c) # 应用通道注意力
x_channel_att = x_att_permute.permute(0, 3, 1, 2).sigmoid() # 调整回原始形状,并应用Sigmoid激活函数
x = x * x_channel_att # 将输入特征图与通道注意力图逐元素相乘
x = self.channel_shuffle(x, groups=4) # 添加通道洗牌操作[根据自己的任务设定组数,2也行,大于4也行,看效果选择]
x_spatial_att = self.spatial_attention(x).sigmoid() # 应用空间注意力,并应用Sigmoid激活函数
out = x * x_spatial_att # 将输入特征图与空间注意力图逐元素相乘
return out # 返回输出特征图
class ResidualBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.conv1 = nn.Conv2d(channels, channels,kernel_size=3,groups=1, stride=1, padding=2, dilation=2)
self.conv2 = nn.Conv2d(channels, channels,kernel_size=1,stride=1,padding=0)
self.bn = nn.BatchNorm2d(channels)
self.pool = nn.MaxPool2d(kernel_size=2)
self.conv = nn.Sequential(
nn.Conv2d(channels, channels, kernel_size=3, groups=channels, stride=1, padding=2, dilation=2),
nn.Conv2d(channels, channels,kernel_size=1,stride=1,padding=0,groups=1),
)
# self.attention1 = NEW_Att(channels)
def forward(self, x):
x1 = self.conv(x)
x2 = self.conv(x1)
#x3=x+x2
x3 = x+F.relu(self.bn(x2))
return x3
# 增加多尺度特征融合
class FusionNeck(nn.Module):
def __init__(self, dims):
super().__init__()
self.dims = dims
self.conv4_to_3 = nn.Conv2d(dims[3], dims[2], 1)
self.conv3_to_2 = nn.Conv2d(dims[2], dims[1], 1)
self.conv2_to_1 = nn.Conv2d(dims[1], dims[0], 1)
self.smooth = nn.Conv2d(dims[0], dims[0], 3, padding=1)
def forward(self, features):
x1, x2, x3, x4 = features
fused_3 = x3 + F.interpolate(self.conv4_to_3(x4), size=x3.shape[-2:], mode='bilinear', align_corners=False)
fused_2 = x2 + F.interpolate(self.conv3_to_2(fused_3), size=x2.shape[-2:], mode='bilinear', align_corners=False)
fused_1 = x1 + F.interpolate(self.conv2_to_1(fused_2), size=x1.shape[-2:], mode='bilinear', align_corners=False)
return self.smooth(fused_1)
class DAT(nn.Module):
def __init__(self, img_size, patch_size, num_classes, expansion,dim_stem, dims,
depths, heads, heads_q,window_sizes,drop_rate, attn_drop_rate,
drop_path_rate,strides,offset_range_factor,stage_spec,groups,use_pes,
dwc_pes,sr_ratios,lower_lr_kvs,fixed_pes,no_offs,ns_per_pts,use_dwc_mlps,
use_conv_patches,ksizes,ksize_qnas,nqs,qna_activation,nat_ksizes,
layer_scale_values,use_lpus,log_cpb,with_fusion_neck=False,**kwargs):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels=3,
out_channels=3,
kernel_size=7,
stride=2,
padding=3,
groups=3),
nn.Conv2d(in_channels=3,
out_channels=dim_stem,
kernel_size=1,
stride=1,
padding=0,
groups=1),
LayerNormProxy(dim_stem),
nn.MaxPool2d(kernel_size=2,stride=2),
ResidualBlock(dim_stem), #残差连接要求输入输出同维度,因此未降维
ResidualBlock(dim_stem),
NEW_Att(dim_stem)
)
# self.conv = nn.Sequential(
# nn.Conv2d(3, dim_stem // 2, 3, patch_size // 2, 1),
# LayerNormProxy(dim_stem // 2),
# nn.GELU(),
# nn.Conv2d(dim_stem // 2, dim_stem, 3, patch_size // 2, 1),
# LayerNormProxy(dim_stem)
# ) if use_conv_patches else nn.Sequential(
# nn.Conv2d(3, dim_stem, patch_size, patch_size, 0),
# LayerNormProxy(dim_stem),
# NEW_Att(dim_stem)
# )
img_size = img_size // patch_size
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
self.stages = nn.ModuleList()
for i in range(4):
dim1 = dim_stem if i == 0 else dims[i - 1] * 2
dim2 = dims[i]
self.stages.append(
TransformerStage(
img_size, window_sizes[i], ns_per_pts[i],
dim1, dim2, depths[i],
stage_spec[i], groups[i], use_pes[i],
sr_ratios[i], heads[i], heads_q[i], strides[i],
offset_range_factor[i],
dwc_pes[i], no_offs[i], fixed_pes[i],
attn_drop_rate, drop_rate, expansion, drop_rate,
dpr[sum(depths[:i]):sum(depths[:i + 1])], use_dwc_mlps[i],
ksizes[i], nat_ksizes[i], ksize_qnas[i], nqs[i], qna_activation,
layer_scale_values[i], use_lpus[i], log_cpb[i]
)
)
img_size = img_size // 2
self.down_projs = nn.ModuleList()
# for i in range(3):
# self.down_projs.append(
# nn.Sequential(
# nn.Conv2d(dims[i], dims[i + 1], 3, 2, 1, bias=False),
# LayerNormProxy(dims[i + 1])
# )
# if use_conv_patches else nn.Sequential(
# nn.Conv2d(dims[i], dims[i + 1], 2, 2, 0, bias=False),
# LayerNormProxy(dims[i + 1])
# )
# )
for i in range(3):
self.down_projs.append(
nn.Sequential(
nn.Conv2d(dims[i], dims[i], kernel_size=3, stride=2, padding=1, bias=False, groups=dims[i]),
nn.Conv2d(in_channels=dims[i], out_channels=dims[i+1], kernel_size=1, stride=1, padding=0, groups=1),
LayerNormProxy(dims[i + 1])
) if use_conv_patches else nn.Sequential(
nn.Conv2d(dims[i], dims[i], 2, 2, 0, bias=False,groups=dims[i]),
nn.Conv2d(in_channels=dims[i], out_channels=dims[i + 1], kernel_size=1, stride=1, padding=0, groups=1),
LayerNormProxy(dims[i + 1])
# nn.Conv2d(dims[i], dims[i + 1], 2, 2, 0, bias=False),
))
self.with_fusion_neck = with_fusion_neck
if with_fusion_neck:
self.fusion_neck = FusionNeck(dims=[dims[0], dims[1], dims[2], dims[3]])
# 新增分类头适配融合后的特征
self.mlp_head = nn.Sequential(
nn.Linear(dims[0] if with_fusion_neck else dims[-1] * 7 * 7, num_classes)
)
# self.cls_norm = LayerNormProxy(dims[-1])
# self.cls_head = nn.Linear(dims[-1], num_classes)
#
# self.lower_lr_kvs = lower_lr_kvs
#
# self.reset_parameters()
#
# def reset_parameters(self):
#
# for m in self.parameters():
# if isinstance(m, (nn.Linear, nn.Conv2d)):
# nn.init.kaiming_normal_(m.weight)
# nn.init.zeros_(m.bias)
#
# @torch.no_grad()
# def load_pretrained(self, state_dict, lookup_22k):
#
# new_state_dict = {}
# for state_key, state_value in state_dict.items():
# keys = state_key.split('.')
# m = self
# for key in keys:
# if key.isdigit():
# m = m[int(key)]
# else:
# m = getattr(m, key)
# if m.shape == state_value.shape:
# new_state_dict[state_key] = state_value
# else:
# # Ignore different shapes
# if 'relative_position_index' in keys:
# new_state_dict[state_key] = m.data
# if 'q_grid' in keys:
# new_state_dict[state_key] = m.data
# if 'reference' in keys:
# new_state_dict[state_key] = m.data
# # Bicubic Interpolation
# if 'relative_position_bias_table' in keys:
# n, c = state_value.size()
# l_side = int(math.sqrt(n))
# assert n == l_side ** 2
# L = int(math.sqrt(m.shape[0]))
# pre_interp = state_value.reshape(1, l_side, l_side, c).permute(0, 3, 1, 2)
# post_interp = F.interpolate(pre_interp, (L, L), mode='bicubic')
# new_state_dict[state_key] = post_interp.reshape(c, L ** 2).permute(1, 0)
# if 'rpe_table' in keys:
# c, h, w = state_value.size()
# C, H, W = m.data.size()
# pre_interp = state_value.unsqueeze(0)
# post_interp = F.interpolate(pre_interp, (H, W), mode='bicubic')
# new_state_dict[state_key] = post_interp.squeeze(0)
# if 'cls_head' in keys:
# new_state_dict[state_key] = state_value[lookup_22k]
#
# msg = self.load_state_dict(new_state_dict, strict=False)
# return msg
#
# @torch.jit.ignore
# def no_weight_decay(self):
# return {'absolute_pos_embed'}
# @torch.jit.ignore
# def no_weight_decay_keywords(self):
# return {'relative_position_bias_table', 'rpe_table'}
def forward(self, x):
stage_outputs = []
x = self.conv(x)
stage_outputs.append(x) # [56x56, dim_stem]
for i in range(4):
x = self.stages[i](x)
stage_outputs.append(x)
if i < 3:
x = self.down_projs[i](x)
stage_outputs.append(x)
if self.with_fusion_neck:
key_features = [
stage_outputs[0], # stage0输出 (1/4)
stage_outputs[3], # stage1输出 (1/8)
stage_outputs[6], # stage2输出 (1/16)
stage_outputs[9] # stage3输出 (1/32)
]
fused_feat = self.fusion_neck(key_features)
x = fused_feat.mean(dim=[2, 3]) # 对融合后的特征做GAP
else:
x = x.mean(dim=[2, 3]) # 原始GAP方式
return {
'final_output': self.mlp_head(x),
'stage_outputs': stage_outputs,
'fused_feature': fused_feat if self.with_fusion_neck else None
}
#def forward(self, x):
# x = self.conv(x)
# for i in range(4):
# x = self.stages[i](x)
# if i < 3:
#x = self.down_projs[i](x)
# x = self.cls_norm(x)
# x = F.adaptive_avg_pool2d(x, 1)
# x = torch.flatten(x, 1)
# x = self.cls_head(x)
# x = x.mean(dim=[2, 3])
#return self.mlp_head(x)
# return x
def dat(num_classes=10, dim_stem=96, has_logits: bool = True):
model = DAT(
img_size=224,
patch_size=4,
num_classes=num_classes,
expansion=4,
dim_stem=dim_stem,
dims=[96, 192, 384, 768],
heads_q=[3, 6, 12, 24],
depths=[2, 2, 6, 2], #这里是每个阶段的编码block的块数
stage_spec=[['L', 'S'], ['L', 'S'], ['S', 'D', 'S', 'D', 'S', 'D'], ['S', 'D']], #这里代表的每个stage选择的注意力
heads=[3, 6, 12, 24],
window_sizes=[7, 7, 7, 7],
groups=[-1, -1, 3, 6],
ns_per_pts=[4, 2, 2, 2],
use_dwc_mlps=[False, False, False, False],
use_conv_patches=False,
ksizes=[3, 3, 3, 3],
ksize_qnas=[3, 3, 3, 3],
nqs=[2, 2, 6, 2],
qna_activation='exp',
nat_ksizes=[3, 3, 3, 3],
layer_scale_values=[-1, -1, -1, -1],
use_lpus=[False, False, False, False],
log_cpb=[False, False, False, False],
use_pes=[False, False, True, True],
dwc_pes=[False, False, False, False],
strides=[1, 1, 1, 1],
sr_ratios=[-1, -1, -1, -1],
offset_range_factor=[-1, -1, 2, 2],
no_offs=[False, False, False, False],
fixed_pes=[False, False, False, False],
lower_lr_kvs={},
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.2,
representation_size=640 if has_logits else None,
with_fusion_neck=False,
)
return model仔细检查这段代码是否匹配这个训练代码import torch.nn as nn
from tqdm.auto import tqdm
from getCWTData import CWT #这里的getCWTData已经写好了小波变换
from utils import pad4 #这里的utils也写好了
from sklearn.metrics import confusion_matrix
import seaborn as sns
import gc, warnings, time, sys, torch, os
from torch import optim
import numpy as np
from torch.utils.data import DataLoader
import pandas as pd
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from sklearn.metrics import (f1_score, accuracy_score, precision_score, recall_score, )
from contextlib import contextmanager
from torchinfo import summary
from thop import profile
from DAT.dat1 import dat
#from CNN import cnn
#from transformer import swin_t
from torchvision.models import resnet18, resnet34, resnet50, resnet101, resnet152, resnext50_32x4d, resnext101_32x8d, wide_resnet50_2, wide_resnet101_2
from torchvision.models import alexnet,Inception3,inception_v3,squeezenet1_0
# from torchvision.models import ConvNeXt, convnext_tiny, convnext_small, convnext_base, convnext_large
from torchvision.models import DenseNet, densenet121, densenet161, densenet169, densenet201
from torchvision.models import GoogLeNet, googlenet
from torchvision.models import MNASNet, mnasnet1_0, mnasnet1_3, mnasnet0_5, mnasnet0_75
# from torchvision.models import regnet_x_800mf, regnet_x_8gf
from torchvision.models import VGG, vgg11, vgg13, vgg16, vgg19, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn, vgg11_bn, vgg13_bn
#from torchvision.models import vit_b_16, vit_h_14, vit_b_32, vit_l_16, vit_l_32, vision_transformer
#from torchvision.models import swin_v2_t, swin_t, swin_b, swin_s, swin_v2_b, swin_v2_s
# from torchvision.models import maxvit_t
gc.collect() #强制垃圾回收,释放 CPU 内存
torch.cuda.empty_cache() #清空 CUDA 缓存,释放 GPU 内存
warnings.filterwarnings("ignore") #忽略所有的警告信息
myseed = 6
np.random.seed(myseed)
# random.seed(myseed)
torch.manual_seed(myseed)
if torch.cuda.is_available():
torch.cuda.manual_seed(myseed)
torch.cuda.manual_seed_all(myseed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
# 得到数据
'''
--------------------------------------------------------
'''
HP="PU"
'''
--------------------------------------------------------
'''
path = "data/"+HP
train = CWT(path, 224, "train")
val = CWT(path, 224, "val")
test = CWT(path, 224, "test")
print(len(train))
print(len(val))
print(len(test))
# 定义超参数
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = dat(num_classes=10).to(device)
#model = alexnet(num_classes=10).to(device)
#model = vgg16(num_classes=10).to(device)
#model = inception_v3(num_classes=10,aux_logits=False).to(device)
#Inception V3 是一种非常强大和有效的 CNN 架构 如果要换模型就把inception_v3换了
batch_size = 4
lr = 1e-3
EPOCHS = 320
loss_F = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr, weight_decay=1e-5)
# 计算模型参数量
total_parameters = 0
for i in model.parameters():
total_parameters += i.numel()
print('模型总参数量: %.4f M' % (total_parameters / 1000000.0))
train_loader = DataLoader(train, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test, batch_size=batch_size, shuffle=True)
# 定义保存权重名称,并使用wandb记录训练过程
_weight_name = "cwt_swinT_" + str(myseed) + "_b" + str(batch_size) + "_lr" + str(lr)
# wandb.init(project='cwt_swinT_', name=time.strftime('%m%d%H%M%S-' + _weight_name))
best_acc = 0
# 定义训练和测试准确率和损失函数值,保存为csv文件
df = pd.DataFrame()
df_train_acc = []
df_train_loss = []
df_valid_acc = []
df_valid_loss = []
start_time = time.time()
for epoch in range(EPOCHS):
# ---------- Train ----------
model.train()
train_loss = []
train_accs = []
for batch in tqdm(train_loader):
imgs, labels = batch
logits = model(imgs.to(device))
# print("train_logits:{}".format(logits.shape))
# output=torch.tensor(logits)
# print(logits,type(logits))
# print(type(labels))
# print(type(imgs))
loss = loss_F(logits, labels.long().to(device))
optimizer.zero_grad()
loss.backward()
grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
optimizer.step()
# Compute the accuracy for current batch.
acc = (logits.argmax(dim=1) == labels.to(device)).float().mean()
train_loss.append(loss.item())
train_accs.append(acc)
# 一个epoch的平均loss和acc
train_loss = sum(train_loss) / len(train_loss)
train_acc = sum(train_accs) / len(train_accs)
# 保存每一轮的训练损失函数和训练准确率以供保存
df_train_loss.append(train_loss)
df_train_acc.append(train_acc.cpu().numpy())
print(f"[ Train | {epoch + 1:03d}/{EPOCHS:03d} ] loss = {train_loss:.5f}, acc = {train_acc:.5f}")
# lr_scheduler.step()
# ---------- Validation ----------
model.eval()
valid_loss = []
valid_accs = []
for batch in tqdm(val_loader):
imgs, labels = batch
with torch.no_grad():
logits = model(imgs.to(device))
loss = loss_F(logits, labels.long().to(device))
acc = (logits.argmax(dim=1) == labels.to(device)).float().mean()
valid_loss.append(loss.item())
valid_accs.append(acc)
valid_loss = sum(valid_loss) / len(valid_loss)
valid_acc = sum(valid_accs) / len(valid_accs)
# wandb.log({"valid_loss": valid_loss, "epoch": epoch + 1})
# wandb.log({"valid_acc": valid_acc, "epoch": epoch + 1})
df_valid_loss.append(valid_loss)
df_valid_acc.append(valid_acc.cpu().numpy())
print(f"[ Valid | {epoch + 1:03d}/{EPOCHS:03d} ] loss = {valid_loss:.5f}, acc = {valid_acc:.5f}")
end_time = time.time()
time_now = time.strftime("%m%d-%H%M", time.localtime())
# 获取当前路径
path_fig = 'fig/' +HP+'/'+ time_now + '/'
if not os.path.exists(path_fig):
os.makedirs(path_fig)
# 判断当前轮valid_acc是否是最好,如是打印,保存权值
if valid_acc > best_acc:
print(
f"[ Valid | {epoch + 1:03d}/{EPOCHS:03d} ] loss = {valid_loss:.5f}, acc = {valid_acc:.5f} " + "\033[1;31;40m -> best \033[0m ")
print(f"Best model found at epoch {epoch + 1}, saving model")
torch.save(model.state_dict(), f"{path_fig + _weight_name}_best.ckpt")
best_acc = valid_acc
print("--" * 30)
# 记录训练和验证集结果
# print(df_train_acc)
df["epoch"] = [pad4(i) for i in range(1, EPOCHS + 1)]
df["tarin_acc"] = df_train_acc
df["train_loss"] = df_train_loss
df["valid_acc"] = df_valid_acc
df["valid_loss"] = df_valid_loss
df.to_csv(path_fig + "训练集和验证集结果.csv", index=False)
# 测试
# 导入权重推理测试
# model.load_state_dict(torch.load(f"{path_fig+_weight_name}_best.ckpt"))
#
prediction = []
true = []
features = []
labels = []
model.eval()
with torch.no_grad():
for x, y in tqdm(test_loader):
test_pred = model(x.to(device))
# 预测的结果
test_label = np.argmax(test_pred.cpu().numpy(), axis=1)
prediction += test_label.tolist()
true += y.cpu().numpy().tolist()
for batch in tqdm(test_loader):
imgs, batch_labels = batch
batch_features = model(imgs.to(device)).cpu().numpy()
features.append(batch_features)
labels.append(batch_labels.numpy())
features = np.concatenate(features)
labels = np.concatenate(labels)
#
# # 保存测试集结果
df2 = pd.DataFrame()
df2["Id"] = [pad4(i) for i in range(1, len(test) + 1)]
df2["Prediction"] = prediction
df2["True"] = true
df2.to_csv(path_fig +"测试推理预测.csv", index=False)
accuracy = accuracy_score(df2["True"], df2["Prediction"])
F1 = f1_score(
df2["True"],
df2["Prediction"],
average="macro",
labels=np.unique(df2["Prediction"]),
)
recall = recall_score(
df2["True"],
df2["Prediction"],
average="macro",
labels=np.unique(df2["Prediction"]),
)
precision = precision_score(
df2["True"],
df2["Prediction"],
average="macro",
labels=np.unique(df2["Prediction"]),
)
print("accuracy:", accuracy)
print("F1:", F1)
print("recall:", recall)
print("precision:", precision)
# 获取程序结束时间
# 计算程序的运行时间
run_time = end_time - start_time
print("程序运行时间:", run_time, "s")
#torch.save(model, path_fig + "model.pth") #取消这个注释就会保存训练的权重文件
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei'] # 绘图显示中文
cm = confusion_matrix(df2["True"], df2["Prediction"], labels=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
cm = np.round(cm, 2)
plt.figure(figsize=(10, 7))
heatmap = sns.heatmap(cm, annot=True, cmap="Blues", annot_kws={"size": 16,"fontweight": "bold"}, cbar=True)
plt.xlabel("Predicted", fontsize=18)
plt.ylabel("True", fontsize=18)
plt.xticks(size=14)
plt.yticks(size=14)
# Set the font size of the colorbar label
cbar = heatmap.collections[0].colorbar
cbar.ax.tick_params(labelsize=15)
plt.savefig(path_fig+'混淆矩阵.jpg', dpi=600, bbox_inches='tight', pad_inches=0.1)
plt.show()
#plt.clf()
# 计算混淆矩阵
# 绘制混淆矩阵图(带百分比)
# 绘制混淆矩阵
# 曲线图
df = pd.DataFrame({
'train_loss': df_train_loss,
'train_acc': df_train_acc,
'valid_loss': df_valid_loss,
'valid_acc': df_valid_acc
})
# 绘制训练和验证准确率曲线
plt.plot(range(1, len(df_train_acc) + 1), df_train_acc, label='Train Accuracy')
plt.plot(range(1, len(df_valid_acc) + 1), df_valid_acc, label='Valid Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.savefig(path_fig + '曲线.jpg', dpi=600, bbox_inches='tight', pad_inches=0.1)
#plt.clf()
plt.show()
# '''t-SNE'''
# plt.rcParams['font.sans-serif'] = ['SimHei']
# plt.rcParams['axes.unicode_minus'] = False
# 使用TSNE进行降维
tsne = TSNE(n_components=2)
reduced_features = tsne.fit_transform(features)
# 绘制TSNE图像
plt.figure(figsize=(6, 5))
scatter = plt.scatter(reduced_features[:, 0], reduced_features[:, 1], c=labels, cmap='rainbow_r')
plt.legend(handles=scatter.legend_elements()[0], labels=range(len(set(labels))))
# plt.title("T-SNE Visualization")
# plt.xlim(0, 1) # 设置x轴范围为0到1
# plt.ylim(-1, 1)
plt.savefig(path_fig + 'tsne.jpg', dpi=600,bbox_inches='tight', pad_inches=0.1)
#plt.clf()
plt.show()
@contextmanager
def stdout_redirected(to=None):
"""
上下文管理器,用于临时将stdout重定向到文件或控制台。
使用方法:`with stdout_redirected(to='output.txt'):`
"""
if to is None:
yield
else:
sys.stdout.flush()
original_stdout = sys.stdout
with open(to, 'w') as file:
sys.stdout = file
try:
yield file
finally:
sys.stdout = original_stdout
dummy_input = torch.randn(1, 3, 224, 224)
flops, params = profile(model, (dummy_input.to(device),))
print('FLOPs: ', flops, 'params: ', params)
print('FLOPs: %.4f G, params: %.4f M' % (flops / 1000000000.0, params / 1000000.0))
F_time = format(run_time, '.4f') #运行时间取小数点后4位
output_file = path_fig + F_time + "_.txt" #设置打印模型结构的路径
with stdout_redirected(to=output_file): # 打印模型结构
summary(model, input_size=(1, 3, 224, 224), depth=9, device='cuda')
with open(output_file, "r+") as f:
old = f.read()
f.seek(0)
f.write('train:' + str(len(train)) +
' val: ' + str(len(val)) +
' test: ' + str(len(test))+'\n'+
'accuracy: '+ str(accuracy)+
' F1: '+ str(F1) + '\n' +
'recall: '+ str(recall) +
' precision: '+ str(precision) +'\n'
'FLOPs: '+str(int(flops))+ ' = '+str('%.4fM '% (flops / 1000000.0)) +
' = '+str('%.4fG'% (flops / 1000000000.0))+'\n')
f.write(old)
with open(output_file, 'a') as f:
f.write('\n'+
'*****************************'+
'\n\n'+
str(model))