# -*- coding: utf-8 -*-
# 重新增加了然门控变得更快得方式:1.beta_l0更大;2.log_alpha的学习率变为2.0;3.添加熵正则化。
from __future__ import annotations
import math
import os
import random
import time
from collections import deque
from pathlib import Path
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms
from sklearn.cluster import KMeans
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import (
silhouette_score, silhouette_samples,
calinski_harabasz_score, davies_bouldin_score,
)
from sklearn.manifold import TSNE
try:
import umap # 只有 umap-learn 才带 UMAP 类
HAS_UMAP = hasattr(umap, "UMAP") or hasattr(umap, "umap_")
except ImportError:
HAS_UMAP = False
from datetime import datetime
from matplotlib.patches import Rectangle
import warnings
# -------------------------- Global configuration -------------------------- #
class CFG:
# Paths
data_root: str = r"D:\dataset\TILDA_8class_73"
save_root: str = r"D:\SCI_exp\7_29\exp_file"
# Dataset & DL
batch_size: int = 128
num_workers: int = 0 # tune to your CPU
img_size: int = 224 # F2013 images are 48×48; we upscale for ResNet‐18
# Model dimensions (§3.5.1)
d_backbone: int = 512
d_proj: int = 128
K_max: int = 3
mem_size: int = 4096
# Optimisation (§3.5.1)
lr_warmup: float = 1e-3
lr_joint: float = 3e-4
lr_ft: float = 1e-4
weight_decay: float = 5e-4
n_epochs_warmup: int = 15#5
n_epochs_joint: int = 150 #20
n_epochs_ft: int = 25 #15
# Loss hyper‑params
lambda1: float = 0.5 # push–pull
alpha_proto: float = 0.1
scale_ce: float = 30.0
gamma_se: float = 20 # 自表示权重 0.5
# ---------- Hard-Concrete ----------
tau0_hc: float = 1.5 # 初始温度
tau_min_hc: float = 0.15 # 最低温度
anneal_epochs_hc: int = 5
gamma_hc: float = -0.1 # stretch 下界
zeta_hc: float = 1.1 # stretch 上界
beta_l0: float = 5e-2 # L0 正则系数 5e-2
hc_threshold: float = 0.35
# Misc
seed: int = 42
device: str = "cuda" if torch.cuda.is_available() else "cpu"
# ---------- datetime ---------- #
def get_timestamp():
"""获取当前时间戳,格式:YYYYMMDD_HHMMSS"""
return datetime.now().strftime("%Y%m%d_%H%M%S")
# ---------- diagnostics ---------- #
MAX_SAMPLED = 5_000 # None → 全量
timestamp = get_timestamp() # 获取当前时间戳
DIAG_DIR = Path(CFG.save_root) / f"diagnostics_{timestamp}" # 文件夹名包含时间戳
DIAG_DIR.mkdir(parents=True, exist_ok=True)
# -------------------------- Reproducibility -------------------------- #
torch.manual_seed(CFG.seed)
random.seed(CFG.seed)
# -------------------------- Utility functions -------------------------- #
def L2_normalise(t: torch.Tensor, dim: int = 1, eps: float = 1e-12) -> torch.Tensor:
return F.normalize(t, p=2, dim=dim, eps=eps)
def pairwise_cosine(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Compute cosine similarity between all pairs in *x* and *y*."""
x = L2_normalise(x)
y = L2_normalise(y)
return x @ y.T # (N, M)
# -------------------------- Memory bank (FIFO queue) -------------------------- #
class MemoryBank:
"""Fixed‑size FIFO queue storing (p, q, y_c). All tensors are detached."""
def __init__(self, dim: int, size: int):
self.size = size
self.dim = dim
self.ptr = 0
self.is_full = False
# pre‑allocate
self.p_bank = torch.zeros(size, dim, device=CFG.device)
self.q_bank = torch.zeros_like(self.p_bank)
self.y_bank = torch.zeros(size, dtype=torch.long, device=CFG.device)
@torch.no_grad()
def enqueue(self, p: torch.Tensor, q: torch.Tensor, y: torch.Tensor):
b = p.size(0)
if b > self.size:
p, q, y = p[-self.size:], q[-self.size:], y[-self.size:]
b = self.size
idx = (torch.arange(b, device=CFG.device) + self.ptr) % self.size
self.p_bank[idx] = p.detach()
self.q_bank[idx] = q.detach()
self.y_bank[idx] = y.detach()
self.ptr = (self.ptr + b) % self.size
if self.ptr == 0:
self.is_full = True
def get(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
valid = self.size if self.is_full else self.ptr
return (
self.p_bank[:valid].detach(),
self.q_bank[:valid].detach(),
self.y_bank[:valid].detach(),
)
# -------------------------- Projection heads -------------------------- #
class MLPHead(nn.Module):
def __init__(self, in_dim: int, out_dim: int):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(in_dim, out_dim//2, bias=False),
nn.BatchNorm1d(out_dim//2),
nn.ReLU(inplace=True),
nn.Linear(out_dim//2, out_dim, bias=True),
)
def forward(self, x: torch.Tensor):
return self.mlp(x)
# -------------------------- Cosine classifier -------------------------- #
class CosineLinear(nn.Module):
"""Cosine classifier with fixed scale *s* (Eq. CE)."""
def __init__(self, in_dim: int, n_classes: int, s: float = CFG.scale_ce):
super().__init__()
self.s = s
self.weight = nn.Parameter(torch.randn(n_classes, in_dim))
nn.init.xavier_uniform_(self.weight)
def forward(self, x: torch.Tensor): # x ∈ ℝ^{B×d_p}
x = L2_normalise(x)
w = L2_normalise(self.weight)
# logits = s * cos(θ)
return self.s * (x @ w.T)
# -------------------------- BaPSTO model -------------------------- #
class BaPSTO(nn.Module):
"""Backbone + DASSER heads + BPGSNet prototypes & gates."""
def __init__(self, n_classes: int):
super().__init__()
# --- Backbone (ResNet‑18) ------------------------------------------------
resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
pretrained_path = Path(CFG.save_root) / "resnet18_best_TILDA_8class_73_7446.pth"
if pretrained_path.exists():
print(f"Loading pretrained weights from {pretrained_path}")
pretrained = torch.load(pretrained_path, map_location=CFG.device, weights_only=True)
# 创建临时模型来获取预训练权重的正确映射
temp_model = models.resnet18()
temp_model.fc = nn.Linear(temp_model.fc.in_features, n_classes)
temp_model.load_state_dict(pretrained["state_dict"], strict=False)
# 复制预训练权重到我们的模型中(除了fc层)
resnet_dict = resnet.state_dict()
pretrained_dict = {k: v for k, v in temp_model.state_dict().items()
if k in resnet_dict and 'fc' not in k}
resnet_dict.update(pretrained_dict)
resnet.load_state_dict(resnet_dict)
print("✓ Successfully loaded pretrained backbone weights!")
else:
print(f"⚠️ Pretrained weights not found at {pretrained_path}. Using ImageNet weights.")
# --- Backbone ------------------------------------------------
in_feat = resnet.fc.in_features # 512
resnet.fc = nn.Identity()
self.backbone = resnet
# project to d_backbone (512-64-128)
#self.fc_backbone = nn.Linear(in_feat, CFG.d_backbone, bias=False)
#nn.init.xavier_uniform_(self.fc_backbone.weight) # 这一句的
# --- Projection heads ---------------------------------------------------
self.g_SA = MLPHead(CFG.d_backbone, CFG.d_proj)
self.g_FV = MLPHead(CFG.d_backbone, CFG.d_proj)
# Cosine classifier (coarse level)
self.classifier = CosineLinear(CFG.d_proj, n_classes)
# --- BPGSNet prototypes & gate logits -----------------------------------
self.prototypes = nn.Parameter(
torch.randn(n_classes, CFG.K_max, CFG.d_proj)
) # (K_C, K_max, d_p)
nn.init.xavier_uniform_(self.prototypes)
self.log_alpha = nn.Parameter(
torch.randn(n_classes, CFG.K_max) * 0.01 # 随机初始化
) # (K_C, K_max)
self.register_buffer("global_step", torch.tensor(0, dtype=torch.long))
# ---------------- Forward pass ---------------- #
def forward(self, x: torch.Tensor, y_c: torch.Tensor, mem_bank: MemoryBank, use_bpgs: bool = True
) -> tuple[torch.Tensor, dict[str, float], torch.Tensor, torch.Tensor]:
"""Return full loss components (Section §3.3 & §3.4)."""
B = x.size(0)
# --- Backbone & projections -------------------------------------------
z = self.backbone(x) # (B, 512)
p = L2_normalise(self.g_SA(z)) # (B, d_p)
q = L2_normalise(self.g_FV(z)) # (B, d_p)
bank_p, bank_q, bank_y = mem_bank.get()
# ---------------- DASSER losses ---------------- #
# L_SA, L_ortho, L_ce_dasser = self._dasser_losses(
# p, q, y_c, bank_p, bank_q, bank_y
# )
# total_loss = L_SA + L_ortho + L_ce_dasser
# stats = {
# "loss": total_loss.item(),
# "L_SA": L_SA.item(),
# "L_ortho": L_ortho.item(),
# "L_ce_dasser": L_ce_dasser.item(),
# }
L_SA, L_ortho, L_ce_dasser, L_se = self._dasser_losses(
p, q, y_c, bank_p, bank_q, bank_y
)
total_loss = (
L_SA
+ L_ortho
+ L_ce_dasser
+ CFG.gamma_se * L_se # NEW
)
stats = {
"loss": total_loss.item(),
"L_SA": L_SA.item(),
"L_ortho": L_ortho.item(),
"L_ce_dasser": L_ce_dasser.item(),
"L_se": L_se.item(), # NEW
}
# ---------------- BPGSNet (conditional) -------- #
if use_bpgs:
L_ce_bpgs, L_proto, L_gate, coarse_logits = self._bpgs_losses(q, y_c)
total_loss = total_loss + L_ce_bpgs + L_proto + L_gate
stats.update({
"L_ce_bpgs": L_ce_bpgs.item(),
"L_proto": L_proto.item(),
"L_gate": L_gate.item(),
})
else:
coarse_logits = None
return total_loss, stats, p.detach(), q.detach()
# ---------------------- Internal helpers ---------------------- #
def _dasser_losses(
self,
p: torch.Tensor,
q: torch.Tensor,
y_c: torch.Tensor,
bank_p: torch.Tensor,
bank_q: torch.Tensor,
bank_y: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
DASSER 损失:
• 语义对齐 L_SA
• 正交 L_ortho
• 粗粒度 CE L_ce
• 自表示 L_se (NEW)
"""
# ---------- 拼 batch + memory ---------- #
p_all = torch.cat([p, bank_p], dim=0) if bank_p.numel() > 0 else p
q_all = torch.cat([q, bank_q], dim=0) if bank_q.numel() > 0 else q
y_all = torch.cat([y_c, bank_y], dim=0) if bank_y.numel() > 0 else y_c
# ---------- 1) 语义对齐 (原有) ---------- #
G = pairwise_cosine(p_all, p_all) # (N,N) :contentReference[oaicite:2]{index=2}
G.fill_diagonal_(0.0)
same = y_all.unsqueeze(0) == y_all.unsqueeze(1)
diff = ~same
L_SA = ((same * (1 - G)).sum() +
CFG.lambda1 * (diff * G.clamp_min(0)).sum()) / (p_all.size(0) ** 2)
# ---------- 2) 正交 (原有) --------------- #
L_ortho = (1.0 / CFG.d_proj) * (p_all @ q_all.T).pow(2).sum()
# ---------- 3) 自表示 (NEW) -------------- #
C_logits = pairwise_cosine(p_all, p_all) # 再算一次以免受上一步改动
C_logits.fill_diagonal_(-1e4) # 置 −∞ → softmax≈0
C = F.softmax(C_logits, dim=1) # 行归一化 :contentReference[oaicite:3]{index=3}
Q_recon = C @ q_all # 线性重构
L_se = F.mse_loss(Q_recon, q_all) # :contentReference[oaicite:4]{index=4}
# ---------- 4) 粗粒度 CE (原有) ---------- #
logits_coarse = self.classifier(p)
L_ce = F.cross_entropy(logits_coarse, y_c)
return L_SA, L_ortho, L_ce, L_se
# ---------------------- 放到 BaPSTO 类里,直接替换原函数 ---------------------- #
def _bpgs_losses(
self, q: torch.Tensor, y_c: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
计算 BPGSNet 损失(正确的 log-sum-exp 版)
"""
B = q.size(0) # q是batch*128的矩阵,获得批次大小
K_C, K_M = self.prototypes.size(0), self.prototypes.size(1)
# K_C 是类别数,K_M 是每个类别的原型数
# (1) 欧氏距离
d = ((q.unsqueeze(1).unsqueeze(2) - self.prototypes.unsqueeze(0)) ** 2).sum(-1) # (B,K_C,K_M)
s = 30.0
# ===== (2) 退火温度 τ =====
# τ 线性退火
epoch = self.global_step.item() / self.steps_per_epoch
tau = max(CFG.tau_min_hc,
CFG.tau0_hc - (CFG.tau0_hc - CFG.tau_min_hc) *
min(1., epoch / CFG.anneal_epochs_hc))
# ----- (3) Hard- -----
log_alpha = self.log_alpha # (C,K)
z, _s = self._sample_hardConcrete(log_alpha, tau) # z: (C,K)
g = z.unsqueeze(0) # (1,C,K) 广播到 batch # (1,C,K)
# ----- (4) coarse logits -----
mask_logits = -d * s + torch.log(g + 1e-12) # (B,C,K)
coarse_logits = torch.logsumexp(mask_logits, dim=2) # (B,C)
# ----- (5) losses -----
L_ce = F.cross_entropy(coarse_logits, y_c)
y_hat = torch.softmax(mask_logits.detach(), dim=2) # stop-grad
L_proto = CFG.alpha_proto * (y_hat * d).mean()
# ---------- Hard-Concrete 的 L0 正则 ----------
temp = (log_alpha - tau * math.log(-CFG.gamma_hc / CFG.zeta_hc)) # (C,K)
p_active = torch.sigmoid(temp) # 激活概率 p_active 是解析期望 pa(z大于0)
# 新增加得loss
pa = torch.sigmoid(log_alpha)
entropy_penalty = 0.05 * (pa * torch.log(pa + 1e-8) + (1-pa) * torch.log(1-pa + 1e-8)).mean()
# 新增加得loss,控制全局稀疏率
L_gate = CFG.beta_l0 * p_active.mean() - entropy_penalty # L0 正则 beta_l0 控控制全局稀疏率
return L_ce, L_proto, L_gate, coarse_logits
def _sample_hardConcrete(self, log_alpha, tau):
"""return z ~ HardConcrete, and its stretched unclipped \tilde z"""
u = torch.rand_like(log_alpha).clamp_(1e-6, 1-1e-6)
s = torch.sigmoid((log_alpha + torch.log(u) - torch.log(1-u)) / tau)
s = s * (CFG.zeta_hc - CFG.gamma_hc) + CFG.gamma_hc # stretch
z_hard = s.clamp(0.0, 1.0)
z = z_hard + (s - s.detach()) # ST estimator,让梯度穿过
return z, s # z用于前向, s用于梯度
# -------------------------- K-means++ initialisation -------------------------- #
@torch.no_grad()
def kmeans_init(model: BaPSTO, loader: DataLoader):
"""Use q‑features to initialise prototypes with K‑means++ (§3.4.1)."""
print("[Init] Running K‑means++ for prototype initialisation...")
model.eval()
all_q, all_y = [], []
for x, y in loader:
x = x.to(CFG.device)
z = L2_normalise(model.g_FV(model.backbone(x)))
all_q.append(z.cpu())
all_y.append(y)
all_q = torch.cat(all_q) # (N, d_p)
all_y = torch.cat(all_y) # (N,)
for c in range(model.prototypes.size(0)):
feats = all_q[all_y == c]
kmeans = KMeans(
n_clusters=CFG.K_max,
init="k-means++",
n_init=10,
max_iter=100,
random_state=CFG.seed,
).fit(feats.numpy())
centroids = torch.from_numpy(kmeans.cluster_centers_).to(CFG.device)
centroids = L2_normalise(centroids) # (K_max, d_p)
model.prototypes.data[c] = centroids
print("[Init] Prototype initialisation done.")
# -------------------------- Training utilities -------------------------- #
def accuracy(output: torch.Tensor, target: torch.Tensor) -> float:
"""Compute top‑1 accuracy (coarse)."""
with torch.no_grad():
pred = output.argmax(dim=1)
correct = pred.eq(target).sum().item()
return correct / target.size(0)
@torch.no_grad()
def _collect_Q_labels(model: BaPSTO, loader: DataLoader):
"""遍历 *loader*,返回 (Q features, coarse-ID, proto-ID);采样上限 MAX_SAMPLED."""
model.eval()
qs, cls, subs = [], [], []
for x, y in loader:
x = x.to(CFG.device)
q = L2_normalise(model.g_FV(model.backbone(x))) # (B,d)
# —— 预测最近原型 idx —— #
d = ((q.unsqueeze(1).unsqueeze(2) - model.prototypes.unsqueeze(0))**2).sum(-1) # (B,C,K)
proto_id = d.view(d.size(0), -1).argmin(dim=1) # flatten idx = C*K + k
qs.append(q.cpu())
cls.append(y)
subs.append(proto_id.cpu())
if MAX_SAMPLED and (sum(len(t) for t in qs) >= MAX_SAMPLED):
break
Q = torch.cat(qs)[:MAX_SAMPLED] # (N,d)
Yc = torch.cat(cls)[:MAX_SAMPLED] # coarse
Ysub = torch.cat(subs)[:MAX_SAMPLED] # pseudo-fine
return Q.numpy(), Yc.numpy(), Ysub.numpy()
def _plot_heatmap(mat: np.ndarray,
title: str,
path: Path,
boxes: list[tuple[int,int]] | None = None):
"""
mat : 排好序的相似度矩阵
boxes : [(row_start,row_end), ...];坐标在排序后的索引系中
"""
plt.figure(figsize=(6, 5))
ax = plt.gca()
im = ax.imshow(mat, cmap="viridis", aspect="auto")
plt.colorbar(im)
if boxes: # 逐个 coarse-class 画框
for s, e in boxes:
w = e - s
rect = Rectangle((s - .5, s - .5), w, w,
linewidth=1.5, edgecolor="white",
facecolor="none")
ax.add_patch(rect)
plt.title(title)
plt.tight_layout()
plt.savefig(path, dpi=300)
plt.close()
def compute_and_save_diagnostics(model: BaPSTO, loader: DataLoader, tag: str):
"""
• 计算三个内部指标并保存 csv
• 绘制五张图 (C heatmap, t-SNE / UMAP, Laplacian spectrum, Silhouette bars, Gate heatmap(opt))
"""
print(f"[Diag] computing metrics ({tag}) ...")
timestamp = get_timestamp()
Q, Yc, Ysub = _collect_Q_labels(model, loader)
# ========== 1) 聚类指标 ========== #
sil = silhouette_score(Q, Ysub, metric="cosine")
ch = calinski_harabasz_score(Q, Ysub)
db = davies_bouldin_score(Q, Ysub)
pd.DataFrame(
{"tag":[tag], "silhouette":[sil], "calinski":[ch], "davies":[db]}
).to_csv(DIAG_DIR / f"cluster_metrics_{tag}_{timestamp}.csv", index=False)
# ========== 2) C heatmap & Laplacian ========== #
GRAPH_LEVEL = 'coarse' # ← 这里换 'sub' 就看细粒度---------------------------------------------------
# ① —— 相似度矩阵(始终基于所有样本,用来画热力图) —— #
P_all = Q @ Q.T / np.linalg.norm(Q, axis=1, keepdims=True) / np.linalg.norm(Q, axis=1)[:, None]
np.fill_diagonal(P_all, -1e4) # 取消自环
C_heat = torch.softmax(torch.tensor(P_all), dim=1).cpu().numpy()
# —— 画热力图:完全沿用旧逻辑,不受 GRAPH_LEVEL 影响 —— #
order = np.lexsort((Ysub, Yc)) # 先 coarse 再 sub
#order = np.argsort(Yc) # 只按粗类别拍平----------------------
# —— 计算每个 coarse-class 的起止行(列) —— #
coarse_sorted = Yc[order]
bounds = [] # [(start,end),...]
start = 0
for i in range(1, len(coarse_sorted)):
if coarse_sorted[i] != coarse_sorted[i-1]:
bounds.append((start, i)) # [start, end)
start = i
bounds.append((start, len(coarse_sorted)))
# —— 绘图,并把边界传给 boxes 参数 —— #
_plot_heatmap(C_heat[order][:, order],
f"C heatmap ({tag})",
DIAG_DIR / f"C_heatmap_{tag}_{timestamp}.png",
boxes=bounds)
# ② —— 针对 Laplacian 的图,可选按 coarse/sub 屏蔽 —— #
P_graph = P_all.copy() # 从全局矩阵复制一份
if GRAPH_LEVEL == 'coarse':
P_graph[Yc[:, None] != Yc[None, :]] = -1e4 # 只留同 coarse 的边
elif GRAPH_LEVEL == 'sub':
P_graph[Ysub[:, None] != Ysub[None, :]] = -1e4 # 只留同子簇的边
C_graph = torch.softmax(torch.tensor(P_graph), dim=1).cpu().numpy()
D = np.diag(C_graph.sum(1))
L = D - (C_graph + C_graph.T) / 2
eigs = np.sort(np.linalg.eigvalsh(L))[:30]
plt.figure(); plt.plot(eigs, marker='o')
plt.title(f"Laplacian spectrum ({GRAPH_LEVEL or 'global'} | {tag})")
plt.tight_layout()
plt.savefig(DIAG_DIR / f"laplacian_{tag}_{timestamp}.png", dpi=300); plt.close()
# ========== 3) t-SNE / UMAP (带图例 & 色彩 ≤20) ========== #
warnings.filterwarnings("ignore", message="n_jobs value 1")
focus_cls = 1#None # ← 若只看 coarse ID=3,把它改成 3
sel = slice(None) if focus_cls is None else (Yc == focus_cls)
Q_sel, Ysub_sel = Q[sel], Ysub[sel]
# -- 选 UMAP 或 t-SNE --
if HAS_UMAP: # :contentReference[oaicite:2]{index=2}
reducer_cls = umap.UMAP if hasattr(umap, "UMAP") else umap.umap_.UMAP
reducer = reducer_cls(n_neighbors=30, min_dist=0.1, random_state=CFG.seed)
method = "UMAP"
else:
reducer = TSNE(perplexity=30, init="pca", random_state=CFG.seed)
method = "t-SNE"
emb = reducer.fit_transform(Q_sel) # (N,2)
# ---------- scatter ---------- #
unique_sub = np.unique(Ysub_sel)
try: # 新版 Matplotlib (≥3.7)
cmap = plt.get_cmap("tab20", min(len(unique_sub), 20))
except TypeError: # 旧版 Matplotlib (<3.7)
cmap = plt.cm.get_cmap("tab20", min(len(unique_sub), 20))
plt.figure(figsize=(5, 5))
for i, s_id in enumerate(unique_sub):
pts = Ysub_sel == s_id
plt.scatter(emb[pts, 0], emb[pts, 1],
color=cmap(i % 20), s=6, alpha=0.7,
label=str(s_id) if len(unique_sub) <= 20 else None)
if len(unique_sub) <= 20:
plt.legend(markerscale=2, bbox_to_anchor=(1.02, 1), borderaxespad=0.)
title = f"{method} ({tag})" if focus_cls is None else f"{method} cls={focus_cls} ({tag})"
plt.title(title)
plt.tight_layout()
plt.savefig(DIAG_DIR / f"embed_{tag}_{timestamp}.png", dpi=300)
plt.close()
# ========== 4) Silhouette bars ========== #
sil_samples = silhouette_samples(Q, Ysub, metric="cosine")
order = np.argsort(Ysub)
plt.figure(figsize=(6,4))
plt.barh(np.arange(len(sil_samples)), sil_samples[order], color="steelblue")
plt.title(f"Silhouette per sample ({tag})"); plt.xlabel("coefficient")
plt.tight_layout(); plt.savefig(DIAG_DIR / f"silhouette_bar_{tag}_{timestamp}.png", dpi=300); plt.close()
print(f"[Diag] saved to {DIAG_DIR}")
def create_dataloaders() -> Tuple[DataLoader, DataLoader, int]:
"""Load train/val as ImageFolder and return dataloaders + K_C."""
train_dir = Path(CFG.data_root) / "train"
val_dir = Path(CFG.data_root) / "test"
classes = sorted([d.name for d in train_dir.iterdir() if d.is_dir()])
K_C = len(classes)
transform_train = transforms.Compose(
[
transforms.Grayscale(num_output_channels=3),
transforms.Resize((CFG.img_size, CFG.img_size)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.RandomResizedCrop(CFG.img_size, scale=(0.8, 1.0)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5] * 3, std=[0.5] * 3),
]
)
transform_val = transforms.Compose(
[
transforms.Grayscale(num_output_channels=3),
transforms.Resize((CFG.img_size, CFG.img_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5] * 3, std=[0.5] * 3),
]
)
train_ds = datasets.ImageFolder(str(train_dir), transform=transform_train)
val_ds = datasets.ImageFolder(str(val_dir), transform=transform_val)
train_loader = DataLoader(
train_ds,
batch_size=CFG.batch_size,
shuffle=True,
num_workers=CFG.num_workers,
pin_memory=True,
drop_last=True,
)
val_loader = DataLoader(
val_ds,
batch_size=CFG.batch_size,
shuffle=False,
num_workers=CFG.num_workers,
pin_memory=True,
)
return train_loader, val_loader, K_C
# -------------------------- Main training routine -------------------------- #
def train():
best_ckpt_path = None # 记录最佳 joint 权重的完整文件名
best_acc = 0.0
best_epoch = -1
train_loader, val_loader, K_C = create_dataloaders()
model = BaPSTO(K_C).to(CFG.device)
model.steps_per_epoch = len(train_loader)
#print(model)
mb = MemoryBank(dim=CFG.d_proj, size=CFG.mem_size)
warmup_weights_path = Path(CFG.save_root) / "bapsto_warmup_complete.pth"
# 检查是否存在预保存的warm-up权重
if warmup_weights_path.exists():
print(f"找到预训练的warm-up权重,正在加载: {warmup_weights_path}")
checkpoint = torch.load(warmup_weights_path, map_location=CFG.device,weights_only=True)
model.load_state_dict(checkpoint["state_dict"])
print("✓ 成功加载warm-up权重,跳过warm-up阶段!")
else:
# ---------- Phase 1: DASSER warm‑up (backbone frozen) ---------- #
print("\n==== Phase 1 | DASSER warm‑up ====")
for p in model.backbone.parameters():
p.requires_grad = False
# —— 冻结 prototypes 和 gate_logits —— #
model.prototypes.requires_grad = False
model.log_alpha.requires_grad = False
# —— 冻结 prototypes 和 gate_logits —— #
optimizer = optim.AdamW(
filter(lambda p: p.requires_grad, model.parameters()),
lr=CFG.lr_warmup,
weight_decay=CFG.weight_decay,
betas=(0.9, 0.95),
)
scheduler = CosineAnnealingLR(optimizer, T_max=len(train_loader) * CFG.n_epochs_warmup)
for epoch in range(CFG.n_epochs_warmup):
run_epoch(train_loader, model, mb, optimizer, scheduler, epoch, phase="warmup")
# 保存warm-up完成后的权重
torch.save(
{"epoch": CFG.n_epochs_warmup, "state_dict": model.state_dict()},
warmup_weights_path
)
print(f"✓ Warm-up完成,模型权重已保存至: {warmup_weights_path}")
# after warm‑up loop, before Phase 2 header
kmeans_init(model, train_loader) # <─ 新增
print("K‑means initialisation done. Prototypes are now ready.")
compute_and_save_diagnostics(model, train_loader, tag="after_kmeans")
# ---------- Phase 2: Joint optimisation (all params trainable) ---------- #
print("\n==== Phase 2 | Joint optimisation ====")
for p in model.backbone.parameters():
p.requires_grad = True
# —— 解冻 prototypes 和 gate logits —— #
model.prototypes.requires_grad = True
model.log_alpha.requires_grad = True
# —— 解冻 prototypes 和 gate logits —— #
param_groups = [
{"params": [p for n,p in model.named_parameters() if n!='log_alpha'],
"lr": CFG.lr_joint},
{"params": [model.log_alpha], "lr": CFG.lr_joint * 2.0}
]
optimizer = optim.AdamW(
param_groups,
weight_decay=CFG.weight_decay,
betas=(0.9, 0.95),
)
scheduler = CosineAnnealingLR(optimizer, T_max=len(train_loader) * CFG.n_epochs_joint)
best_acc = 0.0
best_epoch = -1
epochs_no_improve = 0
for epoch in range(CFG.n_epochs_joint):
stats = run_epoch(train_loader, model, mb, optimizer, scheduler, epoch, phase="joint")
# ───────────────────────────────────────────
if (epoch + 1) % 1 == 0: # 每个 epoch 都跑验证
# —— 每 5 个 epoch 额外保存 Gate & 聚类诊断 —— #
if (epoch + 1) % 5 == 0:
timestamp = get_timestamp()
gate_prob = torch.sigmoid(model.log_alpha.detach().cpu())
_plot_heatmap(
gate_prob,
f"Gate prob (ep{epoch+1})",
DIAG_DIR / f"gate_ep{epoch+1}_{timestamp}.png",
)
compute_and_save_diagnostics(
model, train_loader, tag=f"joint_ep{epoch+1}"
)
# ---------- 统计指标 ----------
val_loss, val_acc, per_cls_acc, auc = metrics_on_loader(val_loader, model)
train_acc = metrics_on_loader (train_loader, model)[1] # 只取整体训练准确率
print(f"[Val] ep {epoch+1:02d} | loss {val_loss:.3f} | "
f"acc {val_acc:.3f} | train-acc {train_acc:.3f} |\n"
f" per-cls-acc {np.round(per_cls_acc, 2)} |\n"
f" AUC {np.round(auc, 2)}")
# —— checkpoint —— #
if val_acc > best_acc:
best_acc = val_acc
best_epoch = epoch
epochs_no_improve = 0
best_ckpt_path = save_ckpt(model, epoch, tag="best_joint",
acc=val_acc,
optimizer=optimizer,
scheduler=scheduler) # ← 传进去
else:
epochs_no_improve += 1
# —— gate 修剪 —— #
if epoch+1 >= 10: # 先训练 10 个 epoch 再剪
prune_gates(model, threshold=0.25, min_keep=1, hc_threshold=CFG.hc_threshold)
# —— early stopping —— #
if epochs_no_improve >= 50:
print("Early stopping triggered in joint phase.")
break
# ───────────────────────────────────────────
model.global_step += 1
print(model.prototypes.grad.norm()) # 非零即可证明 L_proto 对原型确实有更新压力
model.global_step.zero_()
# Joint训练结束后,重命名最佳模型文件,添加准确率
best_acc_int = round(best_acc * 1e4) # 将0.7068转换为7068
joint_ckpt_path = Path(CFG.save_root) / "bapsto_best_joint.pth"
renamed_path = Path(CFG.save_root) / f"bapsto_best_joint_{best_acc_int}.pth"
if joint_ckpt_path.exists():
joint_ckpt_path.rename(renamed_path)
best_ckpt_path = renamed_path # ★ 同步路径,供 fine-tune 使用
print(f"✓ 最优联合训练模型已重命名: {renamed_path.name} "
f"(epoch {best_epoch+1}, ACC: {best_acc:.4f})")
# ---------- Phase 3: Fine‑tune (prototypes & gates frozen) ---------- #
print("\n==== Phase 3 | Fine‑tuning ====")
best_ft_acc = 0.0
best_ft_epoch = -1
# 若有最佳 joint 权重则加载
if best_ckpt_path is not None and Path(best_ckpt_path).exists():
ckpt = torch.load(best_ckpt_path, map_location=CFG.device, weights_only=True)
model.load_state_dict(ckpt["state_dict"])
epoch_loaded = ckpt["epoch"] + 1 # 以 1 为起点的人类可读轮次
acc_loaded = ckpt.get("acc", -1) # 若早期代码没存 acc,给个占位
print(f"✓ loaded best joint ckpt (epoch {epoch_loaded}, ACC {acc_loaded:.4f})")
else:
print("⚠️ best_ckpt_path 未找到,继续沿用上一轮权重。")
for param in [model.prototypes, model.log_alpha]:
param.requires_grad = False
for p in model.parameters():
if p.requires_grad:
p.grad = None # clear any stale gradients
optimizer = optim.AdamW(
filter(lambda p: p.requires_grad, model.parameters()),
lr=CFG.lr_ft,
weight_decay=CFG.weight_decay,
betas=(0.9, 0.95),
)
scheduler = CosineAnnealingLR(optimizer, T_max=len(train_loader) * CFG.n_epochs_ft)
for epoch in range(CFG.n_epochs_ft):
run_epoch(train_loader, model, mb, optimizer, scheduler, epoch, phase="finetune")
if (epoch + 1) % 1 == 0: # 每个 epoch 都评估
val_acc = evaluate(val_loader, model)
print(f"[FT] ep {epoch+1:02d} | acc {val_acc:.4f}")
# ① 按 epoch 保存快照(可选)
save_ckpt(model, epoch, tag="ft")
# ② 维护 “fine-tune 最佳”
if val_acc > best_ft_acc:
best_ft_acc = val_acc
best_ft_epoch = epoch
best_ft_acc_int = round(best_ft_acc * 1e4) # 将0.7068转换为7068
best_ft_ckpt_path = Path(CFG.save_root) / f"bapsto_best_ft_{best_ft_acc_int}.pth"
save_ckpt(model, epoch, tag="best_ft", acc=val_acc) # 只保留一个最新 best_ft
# 重命名保存文件
if best_ft_ckpt_path.exists():
best_ft_ckpt_path.rename(best_ft_ckpt_path)
print(f"✓ Fine-tune最佳模型已重命名: {best_ft_ckpt_path.name} (epoch {best_ft_epoch+1}, ACC: {best_ft_acc:.4f})")
print(f"Training completed. Best FT ACC {best_ft_acc:.4f}")
# -------------------------- Helper functions -------------------------- #
def run_epoch(loader, model, mem_bank: MemoryBank, optimizer, scheduler, epoch, phase:str):
model.train()
running = {"loss": 0.0}
use_bpgs = (phase != "warmup")
for step, (x, y) in enumerate(loader):
x, y = x.to(CFG.device), y.to(CFG.device)
optimizer.zero_grad()
loss, stats, p_det, q_det = model(x, y, mem_bank, use_bpgs=use_bpgs)
loss.backward()
optimizer.step()
scheduler.step()
mem_bank.enqueue(p_det, q_det, y.detach())
# accumulate
for k, v in stats.items():
running[k] = running.get(k, 0.0) + v
# ★★★★★ Hard-Concrete 梯度健康检查 ★★★★★
if phase == "joint" and step % 100 == 0:
# ─── Hard-Concrete 监控 ───
tau_now = max(
CFG.tau_min_hc,
CFG.tau0_hc - (CFG.tau0_hc - CFG.tau_min_hc) *
min(1.0, model.global_step.item() /
(model.steps_per_epoch * CFG.anneal_epochs_hc))
)
pa = torch.sigmoid(model.log_alpha) # (C,K)
p_act = pa.mean().item()
alive = (pa > 0.4).float().sum().item() # 0.4 与 prune 阈值一致
total = pa.numel() # = C × K
grad_nm = (model.log_alpha.grad.detach().norm().item()
if model.log_alpha.grad is not None else 0.0)
pa = torch.sigmoid(model.log_alpha)
print(f"[DBG] τ={tau_now:.3f} p̄={pa.mean():.3f} "
f"min={pa.min():.2f} max={pa.max():.2f} "
f"alive={(pa>0.25).sum().item()}/{pa.numel()} "
f"‖∇α‖={grad_nm:.2e}")
# ★★★★★ 监控段结束 ★★★★★
if (step + 1) % 50 == 0:
avg_loss = running["loss"] / (step + 1)
print(
f"Epoch[{phase} {epoch+1}] Step {step+1}/{len(loader)} | "
f"loss: {avg_loss:.4f}",
end="\r",
)
# epoch summary
print(f"Epoch [{phase} {epoch+1}]: " + ', '.join(f"{k}: {running[k]:.4f}" for k in running))
return running
@torch.no_grad()
def evaluate(loader, model):
model.eval()
total_correct, total_samples = 0, 0
K_C, K_M = model.prototypes.size(0), model.prototypes.size(1)
gate_hard = (model.log_alpha > 0).float() # (K_C,K_M)
for x, y in loader:
x, y = x.to(CFG.device), y.to(CFG.device)
b = x.size(0)
# --- 特征 & 距离 ---
q = L2_normalise(model.g_FV(model.backbone(x))) # (b,d_p)
d = ((q.unsqueeze(1).unsqueeze(2) - model.prototypes.unsqueeze(0))**2).sum(-1) # (b,K_C,K_M)
s = 30.0 # scale for logits
# --- 子簇 logit & 粗 logit ---
mask_logits = -d * s + torch.log(gate_hard + 1e-12) # (b,K_C,K_M) # 这里由于是log,所以二者相加
coarse_logits = torch.logsumexp(mask_logits, dim=2) # (b,K_C)
# --- 统计准确率 ---
total_correct += coarse_logits.argmax(1).eq(y).sum().item()
total_samples += b
return total_correct / total_samples
@torch.no_grad()
def metrics_on_loader(loader, model):
"""
返回:
loss_avg – 均值交叉熵
acc – overall top-1
per_cls_acc (C,) – 每个 coarse 类别准确率
auc (C,) – 每类 one-vs-rest ROC-AUC
"""
model.eval()
n_cls = model.prototypes.size(0)
total_loss, total_correct, total_samples = 0., 0, 0
# —— 用来存储全量 logits / labels —— #
logits_all, labels_all = [], []
ce_fn = nn.CrossEntropyLoss(reduction="sum") # 累加再除
for x, y in loader:
x, y = x.to(CFG.device), y.to(CFG.device)
# 前向
with torch.no_grad():
q = L2_normalise(model.g_FV(model.backbone(x)))
d = ((q.unsqueeze(1).unsqueeze(2) - model.prototypes.unsqueeze(0))**2).sum(-1)
logits = torch.logsumexp(-d*30 + torch.log((model.log_alpha>0).float()+1e-12), dim=2)
total_loss += ce_fn(logits, y).item()
total_correct += logits.argmax(1).eq(y).sum().item()
total_samples += y.size(0)
logits_all.append(logits.cpu())
labels_all.append(y.cpu())
# —— overall —— #
loss_avg = total_loss / total_samples
acc = total_correct / total_samples
# —— 拼接 & 转 numpy —— #
logits_all = torch.cat(logits_all).numpy()
labels_all = torch.cat(labels_all).numpy()
# —— per-class ACC —— #
per_cls_acc = np.zeros(n_cls)
for c in range(n_cls):
mask = labels_all == c
if mask.any():
per_cls_acc[c] = (logits_all[mask].argmax(1) == c).mean()
# —— per-class AUC —— #
try:
from sklearn.metrics import roc_auc_score
prob = torch.softmax(torch.from_numpy(logits_all), dim=1).numpy()
auc = roc_auc_score(labels_all, prob, multi_class="ovr", average=None)
except Exception: # 组数太少或只有 1 类样本时会报错
auc = np.full(n_cls, np.nan)
return loss_avg, acc, per_cls_acc, auc
def save_ckpt(model,
epoch:int,
tag:str,
acc:float|None=None,
optimizer=None,
scheduler=None):
"""
通用保存函数
• 返回 ckpt 文件完整路径,方便上层记录
• 可选把 opt / sched state_dict 一起存进去,便于 resume
"""
save_dir = Path(CFG.save_root)
save_dir.mkdir(parents=True, exist_ok=True)
# -------- 路径策略 -------- #
if tag == "best_joint": # 只保留一个最新最优 joint
ckpt_path = save_dir / "bapsto_best_joint.pth"
else: # 其他阶段带时间戳
ckpt_path = save_dir / f"bapsto_{tag}_epoch{epoch+1}_{get_timestamp()}.pth"
# -------- 组装 payload -------- #
# • vars(CFG) 可以拿到用户自己在 CFG 里写的字段
# • 再过滤掉 __ 开头的内部键、防止把 Python meta-data 也 dump 进去
cfg_dict = {k: v for k, v in vars(CFG).items()
if not k.startswith("__")}
payload = {
"epoch": epoch,
"state_dict": model.state_dict(),
"cfg": cfg_dict, # ← 改在这里
}
if acc is not None:
payload["acc"] = acc
if optimizer is not None:
payload["optimizer"] = optimizer.state_dict()
if scheduler is not None:
payload["scheduler"] = scheduler.state_dict()
torch.save(payload, ckpt_path)
print(f"✓ checkpoint saved to {ckpt_path}")
return ckpt_path
@torch.no_grad()
def prune_gates(model: BaPSTO, threshold=0.05, min_keep=2, hc_threshold=0.35):
"""
Disable sub-clusters whose mean gate probability < threshold.
After setting them to -10, we do another **row normalization**:
Each coarse class row is subtracted by the max logit of that row,
ensuring the maximum logit for active clusters is 0 and inactive clusters ≈ -10 → softmax(-10) ≈ 0.
Also check for Hard-Concrete (HC) weights below a threshold (e.g., 0.35) to disable sub-clusters.
"""
# softmax probabilities (K_C, K_max)
p_active = torch.sigmoid(model.log_alpha) # Activation probability
mask = (p_active < threshold)
# Check HC thresholds and disable low weight clusters
low_weight_mask = (p_active < hc_threshold) # Find sub-clusters with low HC weight
mask = mask | low_weight_mask # Combine with existing mask
# Ensure at least `min_keep` sub-clusters are kept per coarse class
keep_mask = (mask.cumsum(1) >= (CFG.K_max - min_keep))
mask = mask & ~keep_mask
pruned = mask.sum().item()
if pruned == 0:
return
model.log_alpha.data[mask] = -10.0 # Set log_alpha of pruned sub-clusters to a very low value
print(f"Pruned {pruned} sub-clusters (ḡ<{threshold}, keep≥{min_keep}/class)")
# Reassign samples from pruned sub-clusters to active sub-clusters
if pruned > 0:
# Find the indices of the pruned sub-clusters
pruned_clusters = mask.sum(dim=1) > 0 # (K_C,)
for c in range(model.prototypes.size(0)): # Loop through each coarse class
if pruned_clusters[c]:
pruned_indices = mask[c] # Get indices of pruned sub-clusters for class `c`
active_indices = ~pruned_indices # Get indices of active sub-clusters
active_prototypes = model.prototypes[c][active_indices] # Get active prototypes
q = model.q # Get features
# Reassign samples from pruned clusters to active clusters
d_active = pairwise_cosine(q, active_prototypes) # Compute distance to active prototypes
best_active = d_active.argmin(dim=1) # Assign samples to the nearest active sub-cluster
# Update the model with reallocated samples (you can implement reallocation logic here)
print(f"Reassigning samples from pruned sub-clusters of class {c} to active clusters.")
# -------------------------- Entrypoint -------------------------- #
if __name__ == "__main__":
os.makedirs(CFG.save_root, exist_ok=True)
start = time.time()
train()
print(f"Total runtime: {(time.time() - start) / 3600:.2f} h")
我是纯新手,逐行解释代码
最新发布