import os.path as osp
from collections import OrderedDict
import math
import copy
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.cuda.amp import GradScaler, autocast
from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.metrics import compute_accuracy
from dassl.utils import load_pretrained_weights, load_checkpoint
from dassl.optim import build_optimizer, build_lr_scheduler
from clip import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
try:
# Prefer relative import to avoid PYTHONPATH issues
from .capid_modules import (
CrossAttentionCoupler,
DiffusionPromptGenerator,
InteractiveGate,
save_debug_image,
)
except Exception: # fallback to absolute if needed
from trainers.capid_modules import (
CrossAttentionCoupler,
DiffusionPromptGenerator,
InteractiveGate,
save_debug_image,
)
_tokenizer = _Tokenizer()
class CrossAttentivePromptBridge(nn.Module):
"""Bridge deep text/vision prompts with bi-directional cross-attention.
- Projects text (512) and vision (768) prompts to a common dim (default 512).
- Runs two multi-head attentions: text<-vision and vision<-text.
- Residual fuse with small alpha, then project back to original dims.
- Expects lists of tensors per depth: [ (n_ctx, 512) ], [ (n_ctx, 768) ].
"""
def __init__(self, dim_text: int = 512, dim_vision: int = 768, dim_common: int = 512,
heads: int = 4, dropout: float = 0.0, alpha: float = 0.1):
super().__init__()
self.txt_to_common = nn.Linear(dim_text, dim_common, bias=False)
self.vis_to_common = nn.Linear(dim_vision, dim_common, bias=False)
self.common_to_txt = nn.Linear(dim_common, dim_text, bias=False)
self.common_to_vis = nn.Linear(dim_common, dim_vision, bias=False)
self.attn_tq = nn.MultiheadAttention(dim_common, heads, dropout=dropout, batch_first=True)
self.attn_vq = nn.MultiheadAttention(dim_common, heads, dropout=dropout, batch_first=True)
self.alpha = alpha
def forward(self, deep_txt_list, deep_vis_list, alpha: float = None):
if alpha is None:
alpha = self.alpha
alpha = float(max(0.0, min(1.0, alpha)))
# Stack to (L, n_ctx, C)
txt = torch.stack(deep_txt_list, dim=0) # (L, n_ctx, 512)
vis = torch.stack(deep_vis_list, dim=0) # (L, n_ctx, 768)
L, n_ctx_t, dt = txt.shape
L2, n_ctx_v, dv = vis.shape
assert L == L2 and n_ctx_t == n_ctx_v, "Text/Vision deep prompts must align in depth and n_ctx"
S = L * n_ctx_t
txt_seq = txt.reshape(S, dt)
vis_seq = vis.reshape(S, dv)
t = self.txt_to_common(txt_seq).unsqueeze(0) # (1, S, dc)
v = self.vis_to_common(vis_seq).unsqueeze(0) # (1, S, dc)
# bi-directional cross-attention
t2, _ = self.attn_tq(t, v, v)
v2, _ = self.attn_vq(v, t, t)
# stabilize and residual blend
t2 = F.layer_norm(t2, t2.shape[-1:])
v2 = F.layer_norm(v2, v2.shape[-1:])
t_out = (1.0 - alpha) * t + alpha * t2
v_out = (1.0 - alpha) * v + alpha * v2
# back to original dims and list form
t_out = self.common_to_txt(t_out.squeeze(0)).reshape(L, n_ctx_t, dt)
v_out = self.common_to_vis(v_out.squeeze(0)).reshape(L, n_ctx_t, dv)
out_txt_list = [t_out[i] for i in range(L)]
out_vis_list = [v_out[i] for i in range(L)]
return out_txt_list, out_vis_list
def load_clip_to_cpu(cfg):
backbone_name = cfg.MODEL.BACKBONE.NAME
url = clip._MODELS[backbone_name]
model_path = clip._download(url)
try:
# loading JIT archive
model = torch.jit.load(model_path, map_location="cpu").eval()
state_dict = None
except RuntimeError:
state_dict = torch.load(model_path, map_location="cpu")
design_details = {"trainer": 'MaPLe',
"vision_depth": 0,
"language_depth": 0, "vision_ctx": 0,
"language_ctx": 0,
"maple_length": cfg.TRAINER.MAPLE.N_CTX}
model = clip.build_model(state_dict or model.state_dict(), design_details)
return model
class TextEncoder(nn.Module):
def __init__(self, clip_model):
super().__init__()
self.transformer = clip_model.transformer
self.positional_embedding = clip_model.positional_embedding
self.ln_final = clip_model.ln_final
self.text_projection = clip_model.text_projection
self.dtype = clip_model.dtype
def forward(self, prompts, tokenized_prompts, compound_prompts_deeper_text):
x = prompts + self.positional_embedding.type(self.dtype)
x = x.permute(1, 0, 2) # NLD -> LND
# Pass as the list, as nn.sequential cannot process multiple arguments in the forward pass
combined = [x, compound_prompts_deeper_text, 0] # third argument is the counter which denotes depth of prompt
outputs = self.transformer(combined)
x = outputs[0] # extract the x back from here
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x).type(self.dtype)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
return x
class MultiModalPromptLearner(nn.Module):
def __init__(self, cfg, classnames, clip_model):
super().__init__()
n_cls = len(classnames)
n_ctx = cfg.TRAINER.MAPLE.N_CTX
ctx_init = cfg.TRAINER.MAPLE.CTX_INIT
dtype = clip_model.dtype
ctx_dim = clip_model.ln_final.weight.shape[0]
clip_imsize = clip_model.visual.input_resolution
cfg_imsize = cfg.INPUT.SIZE[0]
# Default is 1, which is compound shallow prompting
assert cfg.TRAINER.MAPLE.PROMPT_DEPTH >= 1, "For MaPLe, PROMPT_DEPTH should be >= 1"
self.compound_prompts_depth = cfg.TRAINER.MAPLE.PROMPT_DEPTH # max=12, but will create 11 such shared prompts
assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"
if ctx_init and (n_ctx) <= 4:
# use given words to initialize context vectors
ctx_init = ctx_init.replace("_", " ")
n_ctx = n_ctx
prompt = clip.tokenize(ctx_init)
with torch.no_grad():
embedding = clip_model.token_embedding(prompt).type(dtype)
ctx_vectors = embedding[0, 1: 1 + n_ctx, :]
prompt_prefix = ctx_init
else:
# random initialization
ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
nn.init.normal_(ctx_vectors, std=0.02)
prompt_prefix = " ".join(["X"] * n_ctx)
print('MaPLe design: Multi-modal Prompt Learning')
print(f'Initial context: "{prompt_prefix}"')
print(f"Number of MaPLe context words (tokens): {n_ctx}")
# These below, related to the shallow prompts
# Linear layer so that the tokens will project to 512 and will be initialized from 768
self.proj = nn.Linear(ctx_dim, 768)
self.proj.half()
self.ctx = nn.Parameter(ctx_vectors)
# These below parameters related to the shared prompts
# Define the compound prompts for the deeper layers
# Minimum can be 1, which defaults to shallow MaPLe
# compound prompts
self.compound_prompts_text = nn.ParameterList([nn.Parameter(torch.empty(n_ctx, 512))
for _ in range(self.compound_prompts_depth - 1)])
for single_para in self.compound_prompts_text:
nn.init.normal_(single_para, std=0.02)
# Also make corresponding projection layers, for each prompt
single_layer = nn.Linear(ctx_dim, 768)
self.compound_prompt_projections = _get_clones(single_layer, self.compound_prompts_depth - 1)
classnames = [name.replace("_", " ") for name in classnames]
name_lens = [len(_tokenizer.encode(name)) for name in classnames]
prompts = [prompt_prefix + " " + name + "." for name in classnames]
tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]) # (n_cls, n_tkn)
with torch.no_grad():
embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)
# These token vectors will be saved when in save_model(),
# but they should be ignored in load_model() as we want to use
# those computed using the current class names
self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS
self.register_buffer("token_suffix", embedding[:, 1 + n_ctx:, :]) # CLS, EOS
self.n_cls = n_cls
self.n_ctx = n_ctx
self.tokenized_prompts = tokenized_prompts # torch.Tensor
self.name_lens = name_lens
# --- Optional CAPID modules integrated at the prompt learner level ---
self._clip_model_ref = clip_model
capid = getattr(cfg.TRAINER, "CAPID", None)
self.capid_enabled = bool(getattr(capid, "ENABLE", False)) if capid is not None else False
if self.capid_enabled:
self.ca_enabled = bool(getattr(capid, "CA_ENABLE", False))
self.diff_enabled = bool(getattr(capid, "DIFF_ENABLE", False))
self.gate_enabled = bool(getattr(capid, "GATE_ENABLE", False))
# Conservative safety knobs (with robust defaults)
self.ca_alpha = float(getattr(capid, "CA_ALPHA", 0.1)) # residual blend factor for CA
self.diff_scale = float(getattr(capid, "DIFF_SCALE", 0.05)) # residual scale for DIFF
self.gate_max = float(getattr(capid, "GATE_MAX", 0.5)) # clamp gate strength upper bound
# CA mode: 'bridge' (default) couples deep prompts in CustomCLIP; 'shallow' applies here
self.ca_mode = str(getattr(capid, "CA_MODE", "bridge")).lower()
self.ca_shallow = bool(getattr(capid, "CA_SHALLOW", False))
if self.ca_enabled:
self.ca = CrossAttentionCoupler(
dim_text=512, dim_vision=768,
depth=int(getattr(capid, "CA_DEPTH", 1)),
heads=int(getattr(capid, "CA_HEADS", 4)),
dropout=float(getattr(capid, "CA_DROPOUT", 0.0)),
)
if self.diff_enabled:
self.diff_text = DiffusionPromptGenerator(channels=512, cond_channels=512)
self.diff_vision = DiffusionPromptGenerator(channels=768, cond_channels=768)
self.diff_steps = int(getattr(capid, "DIFF_STEPS", 2))
self.diff_noise = float(getattr(capid, "DIFF_NOISE_STD", 0.1))
self.cfg_scale = float(getattr(capid, "CFG_SCALE", 1.0))
if self.gate_enabled:
self.gate = InteractiveGate(
alpha=float(getattr(capid, "GATE_ALPHA", 1.0)),
beta=float(getattr(capid, "GATE_BETA", 0.0)),
)
# Instruction text for gating (optional)
self.instruction_text = str(getattr(capid, "INSTRUCTION", ""))
# Debugging
self.debug_save = bool(getattr(capid, "DEBUG_SAVE", False))
self.debug_freq = int(getattr(capid, "DEBUG_FREQ", 200))
self.debug_dir = str(getattr(capid, "DEBUG_DIR", "output/capid_debug"))
self._debug_step = 0
self.capid_applied = False
@torch.no_grad()
def _encode_instruction(self, text: str):
if text is None or len(text.strip()) == 0:
return None
try:
tokens = clip.tokenize([text]) # (1, L)
cm = self._clip_model_ref
emb = cm.token_embedding(tokens.to(cm.text_projection.device).type(cm.dtype))
x = emb + cm.positional_embedding.type(cm.dtype)
x = x.permute(1, 0, 2)
x = cm.transformer(x)
x = x.permute(1, 0, 2)
x = cm.ln_final(x).type(cm.dtype)
feat = x[torch.arange(x.shape[0]), tokens.argmax(dim=-1).to(x.device)] @ cm.text_projection
return feat
except Exception:
return None
def construct_prompts(self, ctx, prefix, suffix, label=None):
# dim0 is either batch_size (during training) or n_cls (during testing)
# ctx: context tokens, with shape of (dim0, n_ctx, ctx_dim)
# prefix: the sos token, with shape of (n_cls, 1, ctx_dim)
# suffix: remaining tokens, with shape of (n_cls, *, ctx_dim)
if label is not None:
prefix = prefix[label]
suffix = suffix[label]
prompts = torch.cat(
[
prefix, # (dim0, 1, dim)
ctx, # (dim0, n_ctx, dim)
suffix, # (dim0, *, dim)
],
dim=1,
)
return prompts
def forward(self):
ctx = self.ctx
if ctx.dim() == 2:
ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)
prefix = self.token_prefix
suffix = self.token_suffix
prompts = self.construct_prompts(ctx, prefix, suffix)
# Before returning, need to transform
# prompts to 768 for the visual side
visual_deep_prompts = []
for index, layer in enumerate(self.compound_prompt_projections):
visual_deep_prompts.append(layer(self.compound_prompts_text[index]))
# CAPID optional coupling/generation inside prompt learner
# Align projection dtype with context dtype to avoid Half/Float mismatch after loading checkpoints
if hasattr(self.proj, "weight") and self.proj.weight.dtype != self.ctx.dtype:
self.proj.to(self.ctx.dtype)
shared_ctx = self.proj(self.ctx) # (n_ctx, 768)
if getattr(self, "capid_enabled", False):
# Expand to per-class vision tokens
vis_tokens = shared_ctx.unsqueeze(0).expand(self.n_cls, -1, -1)
gate_strength = 1.0
if getattr(self, "gate_enabled", False):
inst_feat = self._encode_instruction(self.instruction_text)
# Use a lightweight text summary by averaging prompt tokens
try:
txt_feat = prompts.mean(dim=1) # (n_cls, 512)
except Exception:
txt_feat = None
g_tensor = self.gate(inst_feat, txt_feat, None)
gate_strength = max(0.0, min(self.gate_max, float(g_tensor.item())))
# Safe DIFF: only apply when truly non-zero effect
should_diff = (
getattr(self, "diff_enabled", False)
and (getattr(self, "cfg_scale", 0.0) > 0.0)
and (getattr(self, "diff_noise", 0.0) > 0.0)
and (getattr(self, "diff_steps", 0) > 0)
)
cond_txt_pl = prompts.mean(dim=1) # (n_cls, 512)
cond_vis_pl = vis_tokens.mean(dim=1) # (n_cls, 768)
delta_txt = self.diff_text.sample(prompts, cond=cond_txt_pl, steps=self.diff_steps,
noise_std=self.diff_noise)
delta_txt = F.layer_norm(delta_txt, delta_txt.shape[-1:])
prompts = prompts + self.diff_scale * self.cfg_scale * gate_strength * delta_txt
delta_vis = self.diff_vision.sample(vis_tokens, cond=cond_vis_pl, steps=self.diff_steps,
noise_std=self.diff_noise)
delta_vis = F.layer_norm(delta_vis, delta_vis.shape[-1:])
vis_tokens = vis_tokens + self.diff_scale * self.cfg_scale * gate_strength * delta_vis
attn_maps = None
# Only apply shallow CA here when explicitly enabled
if getattr(self, "ca_enabled", False) and getattr(self, "ca_shallow", False) and getattr(self, "ca_mode", "bridge") != "bridge":
# Residual CA with small alpha
p_in, v_in = prompts, vis_tokens
p_ca, v_ca, attn_maps = self.ca(p_in, v_in)
p_ca = F.layer_norm(p_ca, p_ca.shape[-1:])
v_ca = F.layer_norm(v_ca, v_ca.shape[-1:])
alpha = max(0.0, min(1.0, float(self.ca_alpha)))
prompts = (1.0 - alpha) * p_in + alpha * p_ca
vis_tokens = (1.0 - alpha) * v_in + alpha * v_ca
shared_ctx = vis_tokens.mean(dim=0)
# Debug saves
if getattr(self, "debug_save", False):
self._debug_step += 1
if self._debug_step % max(1, self.debug_freq) == 0:
try:
if attn_maps is not None and len(attn_maps) > 0:
a = attn_maps[0][0]
# Robust: handle 4D (B,H,Lq,Lk) and 3D (B,Lq,Lk)
if a.dim() == 4:
a_vis = a.mean(dim=1)[0]
elif a.dim() == 3:
a_vis = a[0]
else:
a_vis = a.flatten(1).unsqueeze(0)
out_path = osp.join(self.debug_dir, f"pl_attn_layer0_{self._debug_step:06d}.png")
save_debug_image(a_vis, out_path)
if getattr(self, "diff_enabled", False):
try:
dt = (delta_txt[0].norm(dim=-1, keepdim=False))
dv = (delta_vis[0].norm(dim=-1, keepdim=False))
save_debug_image(dt.unsqueeze(0), osp.join(self.debug_dir, f"pl_delta_txt_norm_{self._debug_step:06d}.png"))
save_debug_image(dv.unsqueeze(0), osp.join(self.debug_dir, f"pl_delta_vis_norm_{self._debug_step:06d}.png"))
except Exception:
pass
except Exception:
pass
self.capid_applied = True
# Now the other way around; return original as for visual 768 is required
return prompts, shared_ctx, self.compound_prompts_text, visual_deep_prompts
class CustomCLIP(nn.Module):
def __init__(self, cfg, classnames, clip_model):
super().__init__()
self.cfg = cfg
self.prompt_learner = MultiModalPromptLearner(cfg, classnames, clip_model)
self.tokenized_prompts = self.prompt_learner.tokenized_prompts
self.image_encoder = clip_model.visual
self.text_encoder = TextEncoder(clip_model)
self.logit_scale = clip_model.logit_scale
self.dtype = clip_model.dtype
# Keep a lightweight reference for encoding free-form instructions
self._clip_model_ref = clip_model
# CAPID modules (optional)
capid = cfg.TRAINER.CAPID
self.capid_enabled = bool(getattr(capid, "ENABLE", False))
if self.capid_enabled:
self.ca_enabled = bool(getattr(capid, "CA_ENABLE", False))
self.diff_enabled = bool(getattr(capid, "DIFF_ENABLE", False))
self.gate_enabled = bool(getattr(capid, "GATE_ENABLE", False))
self.diff_loss_weight = float(getattr(capid, "DIFF_LOSS_WEIGHT", 0.1))
# Conservative safety knobs (mirror prompt learner)
self.ca_alpha = float(getattr(capid, "CA_ALPHA", 0.1))
self.diff_scale = float(getattr(capid, "DIFF_SCALE", 0.05))
self.gate_max = float(getattr(capid, "GATE_MAX", 0.5))
self.ca_mode = str(getattr(capid, "CA_MODE", "bridge")).lower()
if self.ca_enabled:
self.ca = CrossAttentionCoupler(
dim_text=512, dim_vision=768,
depth=int(getattr(capid, "CA_DEPTH", 1)),
heads=int(getattr(capid, "CA_HEADS", 4)),
dropout=float(getattr(capid, "CA_DROPOUT", 0.0)),
)
# Bridge module for deep compound prompts (text 512 <-> vision 768)
if self.ca_mode == "bridge":
self.ca_bridge = CrossAttentivePromptBridge(
dim_text=512, dim_vision=768,
dim_common=512,
heads=int(getattr(capid, "CA_HEADS", 4)),
dropout=float(getattr(capid, "CA_DROPOUT", 0.0)),
alpha=float(getattr(capid, "CA_ALPHA", 0.1)),
)
if self.diff_enabled:
self.diff_text = DiffusionPromptGenerator(channels=512, cond_channels=512)
self.diff_vision = DiffusionPromptGenerator(channels=768, cond_channels=768)
self.diff_steps = int(getattr(capid, "DIFF_STEPS", 2))
self.diff_noise = float(getattr(capid, "DIFF_NOISE_STD", 0.1))
self.cfg_scale = float(getattr(capid, "CFG_SCALE", 1.0))
if self.gate_enabled:
self.gate = InteractiveGate(
alpha=float(getattr(capid, "GATE_ALPHA", 1.0)),
beta=float(getattr(capid, "GATE_BETA", 0.0)),
)
# Debug state
self.debug_save = bool(getattr(capid, "DEBUG_SAVE", False))
self.debug_freq = int(getattr(capid, "DEBUG_FREQ", 200))
self.debug_dir = str(getattr(capid, "DEBUG_DIR", "output/capid_debug"))
self._debug_step = 0
@torch.no_grad()
def _encode_instruction(self, text: str):
if text is None or len(text.strip()) == 0:
return None
# Lightweight proxy: mean-pooled token embeddings (no transformer), dtype/device-safe
try:
tokens = clip.tokenize([text]) # (1, L)
emb = self._clip_model_ref.token_embedding(tokens.to(self.logit_scale.device).type(self.dtype)) # (1,L,512)
feat = emb.mean(dim=1) # (1,512)
return feat
except Exception:
return None
def forward(self, image, label=None):
tokenized_prompts = self.tokenized_prompts
logit_scale = self.logit_scale.exp()
prompts, shared_ctx, deep_compound_prompts_text, deep_compound_prompts_vision = self.prompt_learner()
# Bridge deep prompts before encoders if enabled
if getattr(self, "capid_enabled", False) and getattr(self, "ca_enabled", False) and getattr(self, "ca_mode", "bridge") == "bridge":
try:
deep_compound_prompts_text, deep_compound_prompts_vision = self.ca_bridge(
deep_compound_prompts_text,
deep_compound_prompts_vision,
alpha=float(getattr(self, "ca_alpha", 0.1)),
)
except Exception:
# Fallback: keep original if any shape issue
pass
# CAPID optional pipeline
if getattr(self, "capid_enabled", False) and not getattr(self.prompt_learner, "capid_applied", False):
# Prepare per-class vision tokens from shared_ctx for coupling/diffusion
# shared_ctx: (n_ctx, 768) -> (n_cls, n_ctx, 768)
vis_tokens = shared_ctx.unsqueeze(0).expand(self.prompt_learner.n_cls, -1, -1)
gate_strength = 1.0
if getattr(self, "gate_enabled", False):
instruction = getattr(self.cfg.TRAINER.CAPID, "INSTRUCTION", "")
inst_feat = self._encode_instruction(instruction)
# Use text-only gating by default to avoid extra compute; img_feat kept None
# Compute a quick baseline text feature from current prompts (detached)
try:
txt_feat_base = self.text_encoder(prompts, tokenized_prompts, deep_compound_prompts_text).detach()
except Exception:
txt_feat_base = None
g_tensor = self.gate(inst_feat, txt_feat_base, None)
gate_strength = max(0.0, min(self.gate_max, float(g_tensor.item())))
# Safe DIFF
should_diff = (
getattr(self, "diff_enabled", False)
and (getattr(self, "cfg_scale", 0.0) > 0.0)
and (getattr(self, "diff_noise", 0.0) > 0.0)
and (getattr(self, "diff_steps", 0) > 0)
)
cond_txt = prompts.mean(dim=1) # (n_cls, 512)
cond_vis = vis_tokens.mean(dim=1) # (n_cls, 768)
delta_txt = self.diff_text.sample(prompts, cond=cond_txt, steps=self.diff_steps, noise_std=self.diff_noise)
delta_txt = F.layer_norm(delta_txt, delta_txt.shape[-1:])
prompts = prompts + self.diff_scale * self.cfg_scale * gate_strength * delta_txt
delta_vis = self.diff_vision.sample(vis_tokens, cond=cond_vis, steps=self.diff_steps,
noise_std=self.diff_noise)
delta_vis = F.layer_norm(delta_vis, delta_vis.shape[-1:])
vis_tokens = vis_tokens + self.diff_scale * self.cfg_scale * gate_strength * delta_vis
attn_maps = None
# If using bridge mode, skip shallow CA here
if getattr(self, "ca_enabled", False) and getattr(self, "ca_mode", "bridge") != "bridge":
p_in, v_in = prompts, vis_tokens
p_ca, v_ca, attn_maps = self.ca(p_in, v_in)
p_ca = F.layer_norm(p_ca, p_ca.shape[-1:])
v_ca = F.layer_norm(v_ca, v_ca.shape[-1:])
alpha = max(0.0, min(1.0, float(self.ca_alpha)))
prompts = (1.0 - alpha) * p_in + alpha * p_ca
vis_tokens = (1.0 - alpha) * v_in + alpha * v_ca
# Reduce back to shared_ctx shape expected by vision encoder
shared_ctx = vis_tokens.mean(dim=0) # (n_ctx, 768)
# Debug saves (very lightweight)
if getattr(self, "debug_save", False):
self._debug_step += 1
if self._debug_step % max(1, self.debug_freq) == 0:
try:
if attn_maps is not None and len(attn_maps) > 0:
a = attn_maps[0][0]
if a.dim() == 4:
a_vis = a.mean(dim=1)[0]
elif a.dim() == 3:
a_vis = a[0]
else:
a_vis = a.flatten(1).unsqueeze(0)
out_path = osp.join(self.debug_dir, f"attn_layer0_{self._debug_step:06d}.png")
save_debug_image(a_vis, out_path)
if getattr(self, "diff_enabled", False):
# Save magnitude heatmaps for first class' deltas
try:
dt = (delta_txt[0].norm(dim=-1, keepdim=False)) # (L_text,)
dv = (delta_vis[0].norm(dim=-1, keepdim=False)) # (L_vis,)
# Expand to 2D for visualization
save_debug_image(dt.unsqueeze(0), osp.join(self.debug_dir, f"delta_txt_norm_{self._debug_step:06d}.png"))
save_debug_image(dv.unsqueeze(0), osp.join(self.debug_dir, f"delta_vis_norm_{self._debug_step:06d}.png"))
except Exception:
pass
except Exception:
pass
text_features = self.text_encoder(prompts, tokenized_prompts, deep_compound_prompts_text)
image_features = self.image_encoder(image.type(self.dtype), shared_ctx, deep_compound_prompts_vision)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
logits = logit_scale * image_features @ text_features.t()
if self.prompt_learner.training:
loss = F.cross_entropy(logits, label)
if getattr(self, "capid_enabled", False) and getattr(self, "diff_enabled", False) and (getattr(self, "diff_loss_weight", 0.0) > 0):
n_cls = self.prompt_learner.n_cls
vis_tokens = shared_ctx.unsqueeze(0).expand(n_cls, -1, -1) # (n_cls, n_ctx, 768)
# 条件:文本用 prompts 的 token 平均;视觉用 shared_ctx 的 token 平均
cond_txt = prompts.mean(dim=1) # (n_cls, 512)
cond_vis = shared_ctx.mean(dim=0, keepdim=True).expand(n_cls, -1) # (n_cls, 768)
try:
l_txt = self.diff_text.diffusion_loss(prompts, cond_txt, noise_std=float(getattr(self, "diff_noise", 0.1)))
except Exception:
l_txt = torch.tensor(0.0, device=loss.device, dtype=loss.dtype)
try:
l_vis = self.diff_vision.diffusion_loss(vis_tokens, cond_vis, noise_std=float(getattr(self, "diff_noise", 0.1)))
except Exception:
l_vis = torch.tensor(0.0, device=loss.device, dtype=loss.dtype)
loss = loss + self.diff_loss_weight * (l_txt + l_vis)
return loss
return logits
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
@TRAINER_REGISTRY.register()
class MaPLe(TrainerX):
def check_cfg(self, cfg):
assert cfg.TRAINER.MAPLE.PREC in ["fp16", "fp32", "amp"]
def build_model(self):
cfg = self.cfg
classnames = self.dm.dataset.classnames
print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
clip_model = load_clip_to_cpu(cfg)
if cfg.TRAINER.MAPLE.PREC == "fp32" or cfg.TRAINER.MAPLE.PREC == "amp":
# CLIP's default precision is fp16
clip_model.float()
print("Building custom CLIP")
self.model = CustomCLIP(cfg, classnames, clip_model)
print("Turning off gradients in both the image and the text encoder")
# Default: only update prompt_learner (MaPLe). If CAPID enabled, also allow
# bridge/CA/DIFF/Gate small modules to learn.
capid_cfg = getattr(self.cfg.TRAINER, "CAPID", None)
capid_on = bool(getattr(capid_cfg, "ENABLE", False)) if capid_cfg is not None else False
capid_train_only = bool(getattr(capid_cfg, "TRAIN_ONLY_CAPID", False)) if capid_cfg is not None else False
# Freeze CLIP backbone under _clip_model_ref; only train open-track prompt subset + CAPID small modules
for name, param in self.model.named_parameters():
# hard block CLIP backbone
if "prompt_learner._clip_model_ref" in name:
param.requires_grad_(False)
continue
if capid_on and capid_train_only:
# train only CAPID modules
allow = (
("ca_bridge" in name)
or (".ca." in name or name.endswith(".ca"))
or ("diff_text" in name) or ("diff_vision" in name)
or ("gate" in name)
)
else:
# open-track prompt subset + CAPID modules (+VPT)
allow = (
(
name.startswith("prompt_learner.ctx")
or name.startswith("prompt_learner.proj")
or name.startswith("prompt_learner.compound_prompts_text.0")
or name.startswith("prompt_learner.compound_prompt_projections.0")
)
or (capid_on and (
("ca_bridge" in name)
or (".ca." in name or name.endswith(".ca"))
or ("diff_text" in name) or ("diff_vision" in name)
or ("gate" in name)
))
or ("VPT" in name)
)
param.requires_grad_(bool(allow))
# Double check
enabled = set()
for name, param in self.model.named_parameters():
if param.requires_grad:
enabled.add(name)
print(f"Parameters to be updated: {enabled}")
if cfg.MODEL.INIT_WEIGHTS:
load_pretrained_weights(self.model, cfg.MODEL.INIT_WEIGHTS)
self.model.to(self.device)
# NOTE: only give prompt_learner to the optimizer
self.optim = build_optimizer(self.model, cfg.OPTIM)
self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
self.register_model("MultiModalPromptLearner", self.model, self.optim, self.sched)
self.scaler = GradScaler() if cfg.TRAINER.MAPLE.PREC == "amp" else None
# Note that multi-gpu training could be slow because CLIP's size is
# big, which slows down the copy operation in DataParallel
device_count = torch.cuda.device_count()
if device_count > 1:
print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!")
self.model = nn.DataParallel(self.model)
def forward_backward(self, batch):
image, label = self.parse_batch_train(batch)
model = self.model
optim = self.optim
scaler = self.scaler
prec = self.cfg.TRAINER.MAPLE.PREC
if prec == "amp":
with autocast():
loss = model(image, label)
optim.zero_grad()
scaler.scale(loss).backward()
scaler.step(optim)
scaler.update()
else:
loss = model(image, label)
optim.zero_grad()
loss.backward()
optim.step()
loss_summary = {"loss": loss.item()}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def parse_batch_train(self, batch):
input = batch["img"]
label = batch["label"]
input = input.to(self.device)
label = label.to(self.device)
return input, label
def load_model(self, directory, epoch=None):
if not directory:
print("Note that load_model() is skipped as no pretrained model is given")
return
names = self.get_model_names()
# By default, the best model is loaded
model_file = "model-best.pth.tar"
if epoch is not None:
model_file = "model.pth.tar-" + str(epoch)
for name in names:
model_path = osp.join(directory, name, model_file)
if not osp.exists(model_path):
raise FileNotFoundError('Model not found at "{}"'.format(model_path))
checkpoint = load_checkpoint(model_path)
state_dict = checkpoint["state_dict"]
epoch = checkpoint["epoch"]
# Ignore fixed token vectors
if "prompt_learner.token_prefix" in state_dict:
del state_dict["prompt_learner.token_prefix"]
if "prompt_learner.token_suffix" in state_dict:
del state_dict["prompt_learner.token_suffix"]
print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch))
# set strict=False
self._models[name].load_state_dict(state_dict, strict=False)
要求:
在类 MultiModalPromptLearner.forward() 中,找到
if should_diff and (not self.training): 下面的两行采样:
delta_txt = self.diff_text.sample(prompts, cond=None, steps=self.diff_steps, noise_std=self.diff_noise)
delta_vis = self.diff_vision.sample(vis_tokens, cond=None, steps=self.diff_steps, noise_std=self.diff_noise)
将它们改为:
cond_txt_pl = prompts.mean(dim=1) # (n_cls, 512)
cond_vis_pl = vis_tokens.mean(dim=1) # (n_cls, 768)
delta_txt = self.diff_text.sample(prompts, cond=cond_txt_pl, steps=self.diff_steps, noise_std=self.diff_noise)
delta_vis = self.diff_vision.sample(vis_tokens, cond=cond_vis_pl, steps=self.diff_steps, noise_std=self.diff_noise)
2) 同一文件,在类 CustomCLIP.forward() 中,另一处
if should_diff and (not self.training): 下面的两行采样:
delta_txt = self.diff_text.sample(prompts, cond=None, steps=self.diff_steps, noise_std=self.diff_noise)
delta_vis = self.diff_vision.sample(vis_tokens, cond=None, steps=self.diff_steps, noise_std=self.diff_noise)
将它们改为:
cond_txt = prompts.mean(dim=1) # (n_cls, 512)
cond_vis = vis_tokens.mean(dim=1) # (n_cls, 768)
delta_txt = self.diff_text.sample(prompts, cond=cond_txt, steps=self.diff_steps, noise_std=self.diff_noise)
delta_vis = self.diff_vision.sample(vis_tokens, cond=cond_vis, steps=self.diff_steps, noise_std=self.diff_noise)
改完后发我修改后的完整代码
最新发布