理论学习:outputs_cls.detach()的什么意思

文章讲述了PyTorch中的.detach()方法如何从计算图中分离变量,防止其影响梯度计算。在知识蒸馏中,outputs_cls.detach()用于固定目标值,不影响模型参数更新。此方法常用于停止梯度计算和特定正则化策略中。
部署运行你感兴趣的模型镜像

在PyTorch中,.detach()方法的作用是将一个变量从当前计算图中分离出来,返回一个新的变量,这个新变量不会要求梯度(requires_grad=False)。这意味着使用.detach()方法得到的变量不会在反向传播中被计算梯度,也就是说,对这个变量的任何操作都不会影响到梯度的计算和模型的参数更新。

在上下文outputs_cls.detach()中的具体意义是:

  • outputs_cls是模型对输入数据的一部分(例如,批次数据的后一半)的输出。默认情况下,这些输出会与模型参数通过计算图连接起来,使得对输出的操作(比如计算损失)能够影响到模型参数的梯度。

  • 通过调用outputs_cls.detach(),我们得到了一个与原始outputs_cls内容相同但已从计算图中分离的版本。这样做的目的是在计算知识蒸馏损失时使用这些输出作为“静态”的目标值(或教师信号),而不是让这些输出参与梯度的计算。换句话说,我们希望这些输出作为固定的目标来指导另一部分数据(例如,批次数据的前一半)的训练,但不希望在反向传播时调整生成这些输出的模型参数。

使用.detach()的场景通常包括:

  • 当需要停止某些变量的梯度计算时,比如在知识蒸馏或使用生成的样本进行训练时,需要将生成的数据看作是固定的输入而不是要优化的参数。

  • 在实施某些特定的正则化策略或自定义损失函数时,需要对部分数据或中间结果进行操作,而这些操作不应影响到模型参数的优化过程。

总之,outputs_cls.detach()用于确保outputs_cls中的数据在后续的操作中不会影响到梯度计算和模型参数的更新,从而可以安全地用作损失计算中的固定目标值

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

import torch import torch.nn as nn from torchvision.datasets import VOCDetection from torchvision import transforms from torch.utils.data import Dataset import albumentations as A from albumentations.pytorch import ToTensorV2 import numpy as np #基础组件 class Conv(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=None, groups=1): super().__init__() if padding is None: padding = (kernel_size - 1) // 2 self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, groups=groups, bias=False) self.bn = nn.BatchNorm2d(out_channels) self.act = nn.SiLU() # YOLOv8 使用 SiLU 激活函数 def forward(self, x): return self.act(self.bn(self.conv(x))) #基础组件 class C2f(nn.Module): def __init__(self, in_channels, out_channels, num_blocks=2, shortcut=False): super().__init__() self.out_channels = out_channels hidden_channels = int(out_channels * 0.5) self.conv1 = Conv(in_channels, hidden_channels * 2, 1) self.conv2 = Conv((hidden_channels * 2 + hidden_channels * num_blocks), out_channels, 1) self.blocks = nn.ModuleList() for _ in range(num_blocks): self.blocks.append(Conv(hidden_channels, hidden_channels, 3)) self.shortcut = shortcut def forward(self, x): y = list(self.conv1(x).chunk(2, 1)) # Split into two halves for block in self.blocks: y.append(block(y[-1])) return self.conv2(torch.cat(y, dim=1)) #基础组件 class SPPF(nn.Module): def __init__(self, in_channels, out_channels, k=5): super().__init__() hidden_channels = in_channels // 2 self.conv1 = Conv(in_channels, hidden_channels, 1) self.pool = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2) self.conv2 = Conv(hidden_channels * 4, out_channels, 1) def forward(self, x): x = self.conv1(x) pool1 = self.pool(x) pool2 = self.pool(pool1) pool3 = self.pool(pool2) return self.conv2(torch.cat([x, pool1, pool2, pool3], dim=1)) #主干网络 class Backbone(nn.Module): def __init__(self): super().__init__() self.stage1 = nn.Sequential( Conv(3, 64, 3, 2), Conv(64, 128, 3, 2), C2f(128, 128, 3) ) self.stage2 = nn.Sequential( Conv(128, 256, 3, 2), C2f(256, 256, 6) ) self.stage3 = nn.Sequential( Conv(256, 512, 3, 2), C2f(512, 512, 6) ) self.stage4 = nn.Sequential( Conv(512, 1024, 3, 2), C2f(1024, 1024, 3), SPPF(1024, 1024) ) def forward(self, x): x = self.stage1(x) x = self.stage2(x) c3 = self.stage3(x) c4 = self.stage4(c3) return c3, c4 # 输出两个尺度用于 Neck #特征融合 class Neck(nn.Module): def __init__(self): super().__init__() self.conv1 = Conv(1024, 512, 1) self.upsample = nn.Upsample(scale_factor=2, mode='nearest') self.c2f1 = C2f(512 + 512, 512, 3) self.conv2 = Conv(512, 256, 1) self.c2f2 = C2f(256 + 256, 256, 3) self.conv3 = Conv(256, 256, 3, 2) self.c2f3 = C2f(256 + 256, 512, 3) self.conv4 = Conv(512, 512, 3, 2) self.c2f4 = C2f(512 + 512, 1024, 3) def forward(self, c3, c4): # 自顶向下:上采样 + 融合 p4 = self.conv1(c4) p4_up = self.upsample(p4) p4_cat = torch.cat([p4_up, c3], dim=1) p3 = self.c2f1(p4_cat) # 更高分辨率输出 p3_out = self.conv2(p3) p3_down = self.conv3(p3) p3_down_cat = torch.cat([p3_down, p4], dim=1) p4_out = self.c2f3(p3_down_cat) # 最深层输出 p4_down = self.conv4(p4_out) p4_down_cat = torch.cat([p4_down, c4], dim=1) p5_out = self.c2f4(p4_down_cat) return p3_out, p4_out, p5_out #解耦检测头 class DecoupledHead(nn.Module): def __init__(self, in_channels, num_classes=80): super().__init__() # 分离的 3×3 卷积分支 self.cls_conv = nn.Conv2d(in_channels, in_channels, 3, padding=1) self.reg_conv = nn.Conv2d(in_channels, in_channels, 3, padding=1) # 预测层 self.cls_pred = nn.Conv2d(in_channels, num_classes, 1) self.reg_pred = nn.Conv2d(in_channels, 4, 1) # tx, ty, tw, th self.obj_pred = nn.Conv2d(in_channels, 1, 1) # objectness self.act = nn.SiLU() def forward(self, x): c = self.act(self.cls_conv(x)) r = self.act(self.reg_conv(x)) cls = self.cls_pred(c) reg = self.reg_pred(r) obj = self.obj_pred(r) return torch.cat([reg, obj, cls], dim=1) #多尺度检测头 class Detect(nn.Module): def __init__(self, num_classes=80): super().__init__() self.num_classes = num_classes self.strides = [8, 16, 32] # 对应输出尺度 # 为每个尺度创建一个解耦头 self.head_small = DecoupledHead(256, num_classes) self.head_medium = DecoupledHead(512, num_classes) self.head_large = DecoupledHead(1024, num_classes) def forward(self, x): p3, p4, p5 = x pred_small = self.head_small(p3) # shape: (B, 4 + 1 + num_classes, H/8, W/8) pred_medium = self.head_medium(p4) # (B, ..., H/16, W/16) pred_large = self.head_large(p5) # (B, ..., H/32, W/32) return [pred_small, pred_medium, pred_large] #YOLO-V8整体模型 class YOLOv8(nn.Module): def __init__(self, num_classes=80): super().__init__() self.backbone = Backbone() self.neck = Neck() self.detect = Detect(num_classes) def forward(self, x): c3, c4 = self.backbone(x) features = self.neck(c3, c4) predictions = self.detect(features) return predictions #----------------------------------------------------------------------------------------------------------------------# # 下载并解压PASCAL VOC 2012 dataset_dir = "./VOCdevkit" voc_train = VOCDetection( root=dataset_dir, year='2012', image_set='train', download=True # 如果没有会尝试下载 ) voc_val = VOCDetection( root=dataset_dir, year='2012', image_set='val', download=True ) # 图像变换 transform = transforms.Compose([ transforms.Resize((640, 640)), # 调整为网络输入大小 transforms.ToTensor(), # 转为 [0,1] 归一化张量 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet 标准化 ]) #定义Dataset类 class VOCDataset(Dataset): def __init__(self, data_list, img_size=640): self.data_list = data_list self.img_size = img_size self.class_to_idx = { 'aeroplane': 0, 'bicycle': 1, 'bird': 2, 'boat': 3, 'bottle': 4, 'bus': 5, 'car': 6, 'cat': 7, 'chair': 8, 'cow': 9, 'diningtable': 10, 'dog': 11, 'horse': 12, 'motorbike': 13, 'person': 14, 'pottedplant': 15, 'sheep': 16, 'sofa': 17, 'train': 18, 'tvmonitor': 19 } # 定义增强 pipeline(包含 resize, flip, color jitter, normalize, to_tensor) self.transform = A.Compose([ A.Resize(height=img_size, width=img_size), A.HorizontalFlip(p=0.5), A.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1, p=0.5), A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ToTensorV2() # 转为 [C,H,W] tensor ], bbox_params=A.BboxParams( format='pascal_voc', label_fields=['class_labels'], min_visibility=0.1 )) def __len__(self): return len(self.data_list) def __getitem__(self, idx): image, ann = self.data_list[idx] img = np.array(image.convert("RGB")) # 转为 numpy array boxes = [] labels = [] for obj in ann['annotation']['object']: cls_name = obj['name'] if cls_name not in self.class_to_idx: continue label = self.class_to_idx[cls_name] bbox = obj['bndbox'] xmin = float(bbox['xmin']) ymin = float(bbox['ymin']) xmax = float(bbox['xmax']) ymax = float(bbox['ymax']) # 确保坐标合法 if xmax > xmin and ymax > ymin: boxes.append([xmin, ymin, xmax, ymax]) labels.append(label) # 若无有效标注,添加虚拟目标避免崩溃 if len(boxes) == 0: boxes = [[0, 0, 10, 10]] labels = [0] # 应用增强(自动完成 resize + flip + normalize + to_tensor) try: transformed = self.transform(image=img, bboxes=boxes, class_labels=labels) except Exception as e: print(f"Augmentation error at index {idx}: {e}") # 回退到最小变换 transformed = { "image": ToTensorV2()(image=transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] )(transforms.ToTensor()(img)))["image"], "bboxes": torch.tensor(boxes, dtype=torch.float32), "class_labels": torch.tensor(labels, dtype=torch.long) } img_tensor = transformed["image"] # 已经是 [3,640,640] 归一化张量 boxes_tensor = torch.tensor(transformed["bboxes"], dtype=torch.float32) labels_tensor = torch.tensor(transformed["class_labels"], dtype=torch.long) target = { "boxes": boxes_tensor, "labels": labels_tensor, "image_id": torch.tensor([idx]), } return img_tensor, target # 创建数据集实例 train_dataset = VOCDataset(voc_train, img_size=640) val_dataset = VOCDataset(voc_val, img_size=640) # 使用 DataLoader 加载批次 train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=8, shuffle=True, num_workers=4, pin_memory=True ) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=8, shuffle=False, num_workers=4, pin_memory=True ) #损失函数 def compute_loss(outputs, targets, strides=[8, 16, 32], num_classes=20, alpha=0.5, gamma=0.5): """ Compute YOLOv8-style loss with simple positive sample assignment. outputs: List[Tensor] -> [P3, P4, P5], each shape (B, C, H, W) targets: List[dict] -> [{'boxes': (N,4), 'labels': (N,)}] """ device = outputs[0].device criterion_cls = nn.BCEWithLogitsLoss(reduction='none') criterion_obj = nn.BCEWithLogitsLoss(reduction='none') total_loss_cls = torch.zeros(1, device=device) total_loss_obj = torch.zeros(1, device=device) total_loss_reg = torch.zeros(1, device=device) num_positive = 0 # 特征图尺寸对应关系 img_size = 640 feature_sizes = [img_size // s for s in strides] # [80, 40, 20] for i, pred in enumerate(outputs): H, W = feature_sizes[i], feature_sizes[i] stride = strides[i] bs = pred.shape[0] # Reshape & split predictions: [B, C, H, W] -> [B, H*W, *] pred = pred.permute(0, 2, 3, 1).reshape(bs, -1, 4 + 1 + num_classes) reg_pred = pred[..., :4] # tx, ty, tw, th obj_pred = pred[..., 4] # objectness (flattened) cls_pred = pred[..., 5:] # class logits # Generate grid centers (center of each grid cell in original image space) yv, xv = torch.meshgrid([torch.arange(H), torch.arange(W)]) grid_xy = torch.stack((xv, yv), dim=2).float().to(device) # (H, W, 2) grid_xy = grid_xy.reshape(-1, 2) # (H*W, 2) grid_xy = grid_xy.unsqueeze(0).expand(bs, -1, -1) # (B, H*W, 2) anchor_points = (grid_xy + 0.5) * stride # (cx, cy) on original image # Decode predicted boxes (in xywh format) pred_xy = anchor_points + reg_pred[..., :2].sigmoid() * stride - 0.5 * stride pred_wh = torch.exp(reg_pred[..., 2:]) * stride pred_boxes = torch.cat([pred_xy, pred_wh], dim=-1) # (B, H*W, 4) # Prepare targets obj_target = torch.zeros_like(obj_pred) cls_target = torch.zeros_like(cls_pred) reg_target = torch.zeros_like(reg_pred) fg_mask = torch.zeros_like(obj_pred, dtype=torch.bool) for b in range(bs): tbox = targets[b]['boxes'] # (N, 4), xyxy format tlabel = targets[b]['labels'] # (N,) if len(tbox) == 0: continue # Convert tbox from xyxy to xywh tbox_xyxy = tbox tbox_xywh = torch.cat([ (tbox_xyxy[:, :2] + tbox_xyxy[:, 2:]) / 2, tbox_xyxy[:, 2:] - tbox_xyxy[:, :2] ], dim=1) # (N, 4) # Match: find best overlap between gt centers and anchor points gt_centers = tbox_xywh[:, :2] # (N, 2) distances = (anchor_points[b].unsqueeze(1) - gt_centers.unsqueeze(0)).pow(2).sum(dim=-1) # (H*W, N) _, closest_grid_idx = distances.min(dim=0) # each gt → nearest grid _, closest_gt_idx = distances.min(dim=1) # each grid → nearest gt # Positive samples: grids whose closest gt is itself pos_mask = torch.zeros(H * W, dtype=torch.bool, device=device) for gt_i in range(len(tbox)): grid_i = closest_grid_idx[gt_i] if closest_gt_idx[grid_i] == gt_i: # mutual match pos_mask[grid_i] = True fg_mask[b][pos_mask] = True obj_target[b][pos_mask] = 1.0 cls_target[b][pos_mask] = nn.functional.one_hot(tlabel.long(), num_classes=num_classes).float() # Regression target for positive samples matched_gt = tbox_xywh[pos_mask.sum(dim=0).nonzero(as_tuple=True)[0]] # tricky fix; assume one-to-one if len(matched_gt) > 0: reg_target[b][pos_mask] = torch.cat([ (matched_gt[:, :2] - anchor_points[b][pos_mask]) / stride, torch.log(matched_gt[:, 2:] / stride + 1e-8) ], dim=1) # Only compute loss on positive samples if fg_mask.any(): loss_obj = criterion_obj(obj_pred, obj_target) total_loss_obj += loss_obj.mean() pos_obj = obj_target == 1 if pos_obj.any(): total_loss_cls += criterion_cls(cls_pred[pos_obj], cls_target[pos_obj]).mean() total_loss_reg += torch.abs(reg_pred[pos_obj] - reg_target[pos_pos]).mean() num_positive += pos_obj.sum().item() else: total_loss_obj += obj_pred.sum() * 0 total_loss_cls += cls_pred.sum() * 0 total_loss_reg += reg_pred.sum() * 0 # Normalize losses if num_positive > 0: total_loss_reg /= num_positive total_loss_cls /= num_positive total_loss_obj /= len(outputs) total_loss = total_loss_obj + total_loss_cls + total_loss_reg * 5.0 # weight reg more return total_loss, total_loss_obj.detach(), total_loss_cls.detach(), total_loss_reg.detach() # 模型训练(真实损失) model = YOLOv8(num_classes=20).train().cuda() optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) for epoch in range(10): print(f"\nEpoch {epoch + 1}/10") for i, (images, targets) in enumerate(train_loader): images = images.cuda(non_blocking=True) # 前向传播 outputs = model(images) # List[Tensor]: [B, 25, H/8, W/8], ... # 计算真实损失 loss, loss_obj, loss_cls, loss_reg = compute_loss(outputs, targets, num_classes=20) # 反向传播 optimizer.zero_grad() loss.backward() optimizer.step() if i % 10 == 0: print(f"Iter {i}, Loss: {loss.item():.4f} " f"(obj={loss_obj.item():.4f}, cls={loss_cls.item():.4f}, reg={loss_reg.item():.4f})")
最新发布
11-27
from typing import List, Tuple, Optional from mmpretrain.registry import MODELS from mmpretrain.models.classifiers.image import ImageClassifier from mmpretrain.structures import DataSample import torch import torch.nn.functional as F @MODELS.register_module() class MultiHeadPartialClassifier(ImageClassifier): def __init__(self, backbone:dict, neck: Optional[dict]=None, head:Optional[dict]=None, head_color:Optional[dict]=None, pretrained:Optional[str]=None, train_cfg:Optional[dict]=None, data_preprocessor:Optional[dict]=None, init_cfg:Optional[dict]=None): super().__init__(backbone=backbone, neck=neck, head=head, pretrained=pretrained, train_cfg=train_cfg, data_preprocessor=data_preprocessor, init_cfg=init_cfg) self.head_color = MODELS.build(head_color) def forward(self, inputs, data_samples=None, mode='tensor'): # backbone feats extract feats = self.extract_feat(inputs) if mode == 'tensor': # feats = self.extract_feat(inputs) out_type = self.head(feats) out_color = self.head_color(feats) return {'type': out_type, 'color': out_color} elif mode == 'loss': losses = dict() gt_type = torch.stack([ s.get('gt_type', -1) if isinstance(s.get('gt_type', -1), torch.Tensor) else torch.tensor(s.get('gt_type', -1), dtype=torch.long) for s in data_samples ]).to(feats[0].device, dtype=torch.long) gt_color = torch.stack([ s.get('gt_color', -1) if isinstance(s.get('gt_color', -1), torch.Tensor) else torch.tensor(s.get('gt_color', -1), dtype=torch.long) for s in data_samples ]).to(feats[0].device, dtype=torch.long) # gt_type = torch.stack([s.get('gt_type', torch.tensor(-1)) for s in data_samples]).to(out_type.device) # gt_color = torch.stack([s.get('gt_color', torch.tensor(-1)) for s in data_samples]).to(out_color.device) #validate of labels to loss valid_mask_type = gt_type != -1 valid_mask_color = gt_color != -1 loss_list = [] if valid_mask_type.any(): # loss_type = self.head.loss(feats[valid_mask_type], gt_type[valid_mask_type]) # loss_type = self.head.loss(feats, gt_type, weight=weight_type) loss_type = self.head.loss(feats, data_samples, gt_type=gt_type, mask=valid_mask_type) loss_list.append(loss_type['loss']) if valid_mask_color.any(): # loss_color = self.head_color.loss(feats[valid_mask_color], gt_color[valid_mask_color]) loss_color = self.head_color.loss(feats, data_samples, gt_color=gt_color,mask=valid_mask_color) loss_list.append(loss_color['loss']) if loss_list: loss_total = sum(loss_list) else: loss_total = torch.tensor(0.0, device=feats[0].device) return dict(loss=loss_total) elif mode == 'predict': pred_type = self.head.predict(feats, data_samples) pred_color = self.head_color.predict(feats, data_samples) return {'pred_type': pred_type, 'pred_color': pred_color} def predict( self, feats, data_samples: Optional[List[Optional[DataSample]]] = None, **kwargs) -> List[DataSample]: cls_score = self(feats) predictions = self._get_predictions(cls_score, data_samples) return predictions def _get_predictions(self, cls_score, data_samples): pred_scores = F.softmax(cls_score, dim=1) pred_labels = pred_scores.argmax(dim=1, keepdim=True).detach() out_data_samples = [] if data_samples is None: data_samples = [None for _ in range(pred_scores.size(0))] for data_sample, score, label in zip(data_samples, pred_scores, pred_labels): if data_sample is None: data_sample = DataSample() data_sample.set_pred_score(score).set_pred_label(label) out_data_samples.append(data_sample) return out_data_samples 修改在validation时,pred问题
10-21
import os.path as osp from collections import OrderedDict import math import copy import torch import torch.nn as nn from torch.nn import functional as F from torch.cuda.amp import GradScaler, autocast from dassl.engine import TRAINER_REGISTRY, TrainerX from dassl.metrics import compute_accuracy from dassl.utils import load_pretrained_weights, load_checkpoint from dassl.optim import build_optimizer, build_lr_scheduler from clip import clip from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer try: # Prefer relative import to avoid PYTHONPATH issues from .capid_modules import ( CrossAttentionCoupler, DiffusionPromptGenerator, InteractiveGate, save_debug_image, ) except Exception: # fallback to absolute if needed from trainers.capid_modules import ( CrossAttentionCoupler, DiffusionPromptGenerator, InteractiveGate, save_debug_image, ) _tokenizer = _Tokenizer() class CrossAttentivePromptBridge(nn.Module): """Bridge deep text/vision prompts with bi-directional cross-attention. - Projects text (512) and vision (768) prompts to a common dim (default 512). - Runs two multi-head attentions: text<-vision and vision<-text. - Residual fuse with small alpha, then project back to original dims. - Expects lists of tensors per depth: [ (n_ctx, 512) ], [ (n_ctx, 768) ]. """ def __init__(self, dim_text: int = 512, dim_vision: int = 768, dim_common: int = 512, heads: int = 4, dropout: float = 0.0, alpha: float = 0.1): super().__init__() self.txt_to_common = nn.Linear(dim_text, dim_common, bias=False) self.vis_to_common = nn.Linear(dim_vision, dim_common, bias=False) self.common_to_txt = nn.Linear(dim_common, dim_text, bias=False) self.common_to_vis = nn.Linear(dim_common, dim_vision, bias=False) self.attn_tq = nn.MultiheadAttention(dim_common, heads, dropout=dropout, batch_first=True) self.attn_vq = nn.MultiheadAttention(dim_common, heads, dropout=dropout, batch_first=True) self.alpha = alpha def forward(self, deep_txt_list, deep_vis_list, alpha: float = None): if alpha is None: alpha = self.alpha alpha = float(max(0.0, min(1.0, alpha))) # Stack to (L, n_ctx, C) txt = torch.stack(deep_txt_list, dim=0) # (L, n_ctx, 512) vis = torch.stack(deep_vis_list, dim=0) # (L, n_ctx, 768) L, n_ctx_t, dt = txt.shape L2, n_ctx_v, dv = vis.shape assert L == L2 and n_ctx_t == n_ctx_v, "Text/Vision deep prompts must align in depth and n_ctx" S = L * n_ctx_t txt_seq = txt.reshape(S, dt) vis_seq = vis.reshape(S, dv) t = self.txt_to_common(txt_seq).unsqueeze(0) # (1, S, dc) v = self.vis_to_common(vis_seq).unsqueeze(0) # (1, S, dc) # bi-directional cross-attention t2, _ = self.attn_tq(t, v, v) v2, _ = self.attn_vq(v, t, t) # stabilize and residual blend t2 = F.layer_norm(t2, t2.shape[-1:]) v2 = F.layer_norm(v2, v2.shape[-1:]) t_out = (1.0 - alpha) * t + alpha * t2 v_out = (1.0 - alpha) * v + alpha * v2 # back to original dims and list form t_out = self.common_to_txt(t_out.squeeze(0)).reshape(L, n_ctx_t, dt) v_out = self.common_to_vis(v_out.squeeze(0)).reshape(L, n_ctx_t, dv) out_txt_list = [t_out[i] for i in range(L)] out_vis_list = [v_out[i] for i in range(L)] return out_txt_list, out_vis_list def load_clip_to_cpu(cfg): backbone_name = cfg.MODEL.BACKBONE.NAME url = clip._MODELS[backbone_name] model_path = clip._download(url) try: # loading JIT archive model = torch.jit.load(model_path, map_location="cpu").eval() state_dict = None except RuntimeError: state_dict = torch.load(model_path, map_location="cpu") design_details = {"trainer": 'MaPLe', "vision_depth": 0, "language_depth": 0, "vision_ctx": 0, "language_ctx": 0, "maple_length": cfg.TRAINER.MAPLE.N_CTX} model = clip.build_model(state_dict or model.state_dict(), design_details) return model class TextEncoder(nn.Module): def __init__(self, clip_model): super().__init__() self.transformer = clip_model.transformer self.positional_embedding = clip_model.positional_embedding self.ln_final = clip_model.ln_final self.text_projection = clip_model.text_projection self.dtype = clip_model.dtype def forward(self, prompts, tokenized_prompts, compound_prompts_deeper_text): x = prompts + self.positional_embedding.type(self.dtype) x = x.permute(1, 0, 2) # NLD -> LND # Pass as the list, as nn.sequential cannot process multiple arguments in the forward pass combined = [x, compound_prompts_deeper_text, 0] # third argument is the counter which denotes depth of prompt outputs = self.transformer(combined) x = outputs[0] # extract the x back from here x = x.permute(1, 0, 2) # LND -> NLD x = self.ln_final(x).type(self.dtype) # x.shape = [batch_size, n_ctx, transformer.width] # take features from the eot embedding (eot_token is the highest number in each sequence) x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection return x class MultiModalPromptLearner(nn.Module): def __init__(self, cfg, classnames, clip_model): super().__init__() n_cls = len(classnames) n_ctx = cfg.TRAINER.MAPLE.N_CTX ctx_init = cfg.TRAINER.MAPLE.CTX_INIT dtype = clip_model.dtype ctx_dim = clip_model.ln_final.weight.shape[0] clip_imsize = clip_model.visual.input_resolution cfg_imsize = cfg.INPUT.SIZE[0] # Default is 1, which is compound shallow prompting assert cfg.TRAINER.MAPLE.PROMPT_DEPTH >= 1, "For MaPLe, PROMPT_DEPTH should be >= 1" self.compound_prompts_depth = cfg.TRAINER.MAPLE.PROMPT_DEPTH # max=12, but will create 11 such shared prompts assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})" if ctx_init and (n_ctx) <= 4: # use given words to initialize context vectors ctx_init = ctx_init.replace("_", " ") n_ctx = n_ctx prompt = clip.tokenize(ctx_init) with torch.no_grad(): embedding = clip_model.token_embedding(prompt).type(dtype) ctx_vectors = embedding[0, 1: 1 + n_ctx, :] prompt_prefix = ctx_init else: # random initialization ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype) nn.init.normal_(ctx_vectors, std=0.02) prompt_prefix = " ".join(["X"] * n_ctx) print('MaPLe design: Multi-modal Prompt Learning') print(f'Initial context: "{prompt_prefix}"') print(f"Number of MaPLe context words (tokens): {n_ctx}") # These below, related to the shallow prompts # Linear layer so that the tokens will project to 512 and will be initialized from 768 self.proj = nn.Linear(ctx_dim, 768) self.proj.half() self.ctx = nn.Parameter(ctx_vectors) # These below parameters related to the shared prompts # Define the compound prompts for the deeper layers # Minimum can be 1, which defaults to shallow MaPLe # compound prompts self.compound_prompts_text = nn.ParameterList([nn.Parameter(torch.empty(n_ctx, 512)) for _ in range(self.compound_prompts_depth - 1)]) for single_para in self.compound_prompts_text: nn.init.normal_(single_para, std=0.02) # Also make corresponding projection layers, for each prompt single_layer = nn.Linear(ctx_dim, 768) self.compound_prompt_projections = _get_clones(single_layer, self.compound_prompts_depth - 1) classnames = [name.replace("_", " ") for name in classnames] name_lens = [len(_tokenizer.encode(name)) for name in classnames] prompts = [prompt_prefix + " " + name + "." for name in classnames] tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]) # (n_cls, n_tkn) with torch.no_grad(): embedding = clip_model.token_embedding(tokenized_prompts).type(dtype) # These token vectors will be saved when in save_model(), # but they should be ignored in load_model() as we want to use # those computed using the current class names self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS self.register_buffer("token_suffix", embedding[:, 1 + n_ctx:, :]) # CLS, EOS self.n_cls = n_cls self.n_ctx = n_ctx self.tokenized_prompts = tokenized_prompts # torch.Tensor self.name_lens = name_lens # --- Optional CAPID modules integrated at the prompt learner level --- self._clip_model_ref = clip_model capid = getattr(cfg.TRAINER, "CAPID", None) self.capid_enabled = bool(getattr(capid, "ENABLE", False)) if capid is not None else False if self.capid_enabled: self.ca_enabled = bool(getattr(capid, "CA_ENABLE", False)) self.diff_enabled = bool(getattr(capid, "DIFF_ENABLE", False)) self.gate_enabled = bool(getattr(capid, "GATE_ENABLE", False)) # Conservative safety knobs (with robust defaults) self.ca_alpha = float(getattr(capid, "CA_ALPHA", 0.1)) # residual blend factor for CA self.diff_scale = float(getattr(capid, "DIFF_SCALE", 0.05)) # residual scale for DIFF self.gate_max = float(getattr(capid, "GATE_MAX", 0.5)) # clamp gate strength upper bound # CA mode: 'bridge' (default) couples deep prompts in CustomCLIP; 'shallow' applies here self.ca_mode = str(getattr(capid, "CA_MODE", "bridge")).lower() self.ca_shallow = bool(getattr(capid, "CA_SHALLOW", False)) if self.ca_enabled: self.ca = CrossAttentionCoupler( dim_text=512, dim_vision=768, depth=int(getattr(capid, "CA_DEPTH", 1)), heads=int(getattr(capid, "CA_HEADS", 4)), dropout=float(getattr(capid, "CA_DROPOUT", 0.0)), ) if self.diff_enabled: self.diff_text = DiffusionPromptGenerator(channels=512, cond_channels=512) self.diff_vision = DiffusionPromptGenerator(channels=768, cond_channels=768) self.diff_steps = int(getattr(capid, "DIFF_STEPS", 2)) self.diff_noise = float(getattr(capid, "DIFF_NOISE_STD", 0.1)) self.cfg_scale = float(getattr(capid, "CFG_SCALE", 1.0)) if self.gate_enabled: self.gate = InteractiveGate( alpha=float(getattr(capid, "GATE_ALPHA", 1.0)), beta=float(getattr(capid, "GATE_BETA", 0.0)), ) # Instruction text for gating (optional) self.instruction_text = str(getattr(capid, "INSTRUCTION", "")) # Debugging self.debug_save = bool(getattr(capid, "DEBUG_SAVE", False)) self.debug_freq = int(getattr(capid, "DEBUG_FREQ", 200)) self.debug_dir = str(getattr(capid, "DEBUG_DIR", "output/capid_debug")) self._debug_step = 0 self.capid_applied = False @torch.no_grad() def _encode_instruction(self, text: str): if text is None or len(text.strip()) == 0: return None try: tokens = clip.tokenize([text]) # (1, L) cm = self._clip_model_ref emb = cm.token_embedding(tokens.to(cm.text_projection.device).type(cm.dtype)) x = emb + cm.positional_embedding.type(cm.dtype) x = x.permute(1, 0, 2) x = cm.transformer(x) x = x.permute(1, 0, 2) x = cm.ln_final(x).type(cm.dtype) feat = x[torch.arange(x.shape[0]), tokens.argmax(dim=-1).to(x.device)] @ cm.text_projection return feat except Exception: return None def construct_prompts(self, ctx, prefix, suffix, label=None): # dim0 is either batch_size (during training) or n_cls (during testing) # ctx: context tokens, with shape of (dim0, n_ctx, ctx_dim) # prefix: the sos token, with shape of (n_cls, 1, ctx_dim) # suffix: remaining tokens, with shape of (n_cls, *, ctx_dim) if label is not None: prefix = prefix[label] suffix = suffix[label] prompts = torch.cat( [ prefix, # (dim0, 1, dim) ctx, # (dim0, n_ctx, dim) suffix, # (dim0, *, dim) ], dim=1, ) return prompts def forward(self): ctx = self.ctx if ctx.dim() == 2: ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1) prefix = self.token_prefix suffix = self.token_suffix prompts = self.construct_prompts(ctx, prefix, suffix) # Before returning, need to transform # prompts to 768 for the visual side visual_deep_prompts = [] for index, layer in enumerate(self.compound_prompt_projections): visual_deep_prompts.append(layer(self.compound_prompts_text[index])) # CAPID optional coupling/generation inside prompt learner # Align projection dtype with context dtype to avoid Half/Float mismatch after loading checkpoints if hasattr(self.proj, "weight") and self.proj.weight.dtype != self.ctx.dtype: self.proj.to(self.ctx.dtype) shared_ctx = self.proj(self.ctx) # (n_ctx, 768) if getattr(self, "capid_enabled", False): # Expand to per-class vision tokens vis_tokens = shared_ctx.unsqueeze(0).expand(self.n_cls, -1, -1) gate_strength = 1.0 if getattr(self, "gate_enabled", False): inst_feat = self._encode_instruction(self.instruction_text) # Use a lightweight text summary by averaging prompt tokens try: txt_feat = prompts.mean(dim=1) # (n_cls, 512) except Exception: txt_feat = None g_tensor = self.gate(inst_feat, txt_feat, None) gate_strength = max(0.0, min(self.gate_max, float(g_tensor.item()))) # Safe DIFF: only apply when truly non-zero effect should_diff = ( getattr(self, "diff_enabled", False) and (getattr(self, "cfg_scale", 0.0) > 0.0) and (getattr(self, "diff_noise", 0.0) > 0.0) and (getattr(self, "diff_steps", 0) > 0) ) cond_txt_pl = prompts.mean(dim=1) # (n_cls, 512) cond_vis_pl = vis_tokens.mean(dim=1) # (n_cls, 768) delta_txt = self.diff_text.sample(prompts, cond=cond_txt_pl, steps=self.diff_steps, noise_std=self.diff_noise) delta_txt = F.layer_norm(delta_txt, delta_txt.shape[-1:]) prompts = prompts + self.diff_scale * self.cfg_scale * gate_strength * delta_txt delta_vis = self.diff_vision.sample(vis_tokens, cond=cond_vis_pl, steps=self.diff_steps, noise_std=self.diff_noise) delta_vis = F.layer_norm(delta_vis, delta_vis.shape[-1:]) vis_tokens = vis_tokens + self.diff_scale * self.cfg_scale * gate_strength * delta_vis attn_maps = None # Only apply shallow CA here when explicitly enabled if getattr(self, "ca_enabled", False) and getattr(self, "ca_shallow", False) and getattr(self, "ca_mode", "bridge") != "bridge": # Residual CA with small alpha p_in, v_in = prompts, vis_tokens p_ca, v_ca, attn_maps = self.ca(p_in, v_in) p_ca = F.layer_norm(p_ca, p_ca.shape[-1:]) v_ca = F.layer_norm(v_ca, v_ca.shape[-1:]) alpha = max(0.0, min(1.0, float(self.ca_alpha))) prompts = (1.0 - alpha) * p_in + alpha * p_ca vis_tokens = (1.0 - alpha) * v_in + alpha * v_ca shared_ctx = vis_tokens.mean(dim=0) # Debug saves if getattr(self, "debug_save", False): self._debug_step += 1 if self._debug_step % max(1, self.debug_freq) == 0: try: if attn_maps is not None and len(attn_maps) > 0: a = attn_maps[0][0] # Robust: handle 4D (B,H,Lq,Lk) and 3D (B,Lq,Lk) if a.dim() == 4: a_vis = a.mean(dim=1)[0] elif a.dim() == 3: a_vis = a[0] else: a_vis = a.flatten(1).unsqueeze(0) out_path = osp.join(self.debug_dir, f"pl_attn_layer0_{self._debug_step:06d}.png") save_debug_image(a_vis, out_path) if getattr(self, "diff_enabled", False): try: dt = (delta_txt[0].norm(dim=-1, keepdim=False)) dv = (delta_vis[0].norm(dim=-1, keepdim=False)) save_debug_image(dt.unsqueeze(0), osp.join(self.debug_dir, f"pl_delta_txt_norm_{self._debug_step:06d}.png")) save_debug_image(dv.unsqueeze(0), osp.join(self.debug_dir, f"pl_delta_vis_norm_{self._debug_step:06d}.png")) except Exception: pass except Exception: pass self.capid_applied = True # Now the other way around; return original as for visual 768 is required return prompts, shared_ctx, self.compound_prompts_text, visual_deep_prompts class CustomCLIP(nn.Module): def __init__(self, cfg, classnames, clip_model): super().__init__() self.cfg = cfg self.prompt_learner = MultiModalPromptLearner(cfg, classnames, clip_model) self.tokenized_prompts = self.prompt_learner.tokenized_prompts self.image_encoder = clip_model.visual self.text_encoder = TextEncoder(clip_model) self.logit_scale = clip_model.logit_scale self.dtype = clip_model.dtype # Keep a lightweight reference for encoding free-form instructions self._clip_model_ref = clip_model # CAPID modules (optional) capid = cfg.TRAINER.CAPID self.capid_enabled = bool(getattr(capid, "ENABLE", False)) if self.capid_enabled: self.ca_enabled = bool(getattr(capid, "CA_ENABLE", False)) self.diff_enabled = bool(getattr(capid, "DIFF_ENABLE", False)) self.gate_enabled = bool(getattr(capid, "GATE_ENABLE", False)) self.diff_loss_weight = float(getattr(capid, "DIFF_LOSS_WEIGHT", 0.1)) # Conservative safety knobs (mirror prompt learner) self.ca_alpha = float(getattr(capid, "CA_ALPHA", 0.1)) self.diff_scale = float(getattr(capid, "DIFF_SCALE", 0.05)) self.gate_max = float(getattr(capid, "GATE_MAX", 0.5)) self.ca_mode = str(getattr(capid, "CA_MODE", "bridge")).lower() if self.ca_enabled: self.ca = CrossAttentionCoupler( dim_text=512, dim_vision=768, depth=int(getattr(capid, "CA_DEPTH", 1)), heads=int(getattr(capid, "CA_HEADS", 4)), dropout=float(getattr(capid, "CA_DROPOUT", 0.0)), ) # Bridge module for deep compound prompts (text 512 <-> vision 768) if self.ca_mode == "bridge": self.ca_bridge = CrossAttentivePromptBridge( dim_text=512, dim_vision=768, dim_common=512, heads=int(getattr(capid, "CA_HEADS", 4)), dropout=float(getattr(capid, "CA_DROPOUT", 0.0)), alpha=float(getattr(capid, "CA_ALPHA", 0.1)), ) if self.diff_enabled: self.diff_text = DiffusionPromptGenerator(channels=512, cond_channels=512) self.diff_vision = DiffusionPromptGenerator(channels=768, cond_channels=768) self.diff_steps = int(getattr(capid, "DIFF_STEPS", 2)) self.diff_noise = float(getattr(capid, "DIFF_NOISE_STD", 0.1)) self.cfg_scale = float(getattr(capid, "CFG_SCALE", 1.0)) if self.gate_enabled: self.gate = InteractiveGate( alpha=float(getattr(capid, "GATE_ALPHA", 1.0)), beta=float(getattr(capid, "GATE_BETA", 0.0)), ) # Debug state self.debug_save = bool(getattr(capid, "DEBUG_SAVE", False)) self.debug_freq = int(getattr(capid, "DEBUG_FREQ", 200)) self.debug_dir = str(getattr(capid, "DEBUG_DIR", "output/capid_debug")) self._debug_step = 0 @torch.no_grad() def _encode_instruction(self, text: str): if text is None or len(text.strip()) == 0: return None # Lightweight proxy: mean-pooled token embeddings (no transformer), dtype/device-safe try: tokens = clip.tokenize([text]) # (1, L) emb = self._clip_model_ref.token_embedding(tokens.to(self.logit_scale.device).type(self.dtype)) # (1,L,512) feat = emb.mean(dim=1) # (1,512) return feat except Exception: return None def forward(self, image, label=None): tokenized_prompts = self.tokenized_prompts logit_scale = self.logit_scale.exp() prompts, shared_ctx, deep_compound_prompts_text, deep_compound_prompts_vision = self.prompt_learner() # Bridge deep prompts before encoders if enabled if getattr(self, "capid_enabled", False) and getattr(self, "ca_enabled", False) and getattr(self, "ca_mode", "bridge") == "bridge": try: deep_compound_prompts_text, deep_compound_prompts_vision = self.ca_bridge( deep_compound_prompts_text, deep_compound_prompts_vision, alpha=float(getattr(self, "ca_alpha", 0.1)), ) except Exception: # Fallback: keep original if any shape issue pass # CAPID optional pipeline if getattr(self, "capid_enabled", False) and not getattr(self.prompt_learner, "capid_applied", False): # Prepare per-class vision tokens from shared_ctx for coupling/diffusion # shared_ctx: (n_ctx, 768) -> (n_cls, n_ctx, 768) vis_tokens = shared_ctx.unsqueeze(0).expand(self.prompt_learner.n_cls, -1, -1) gate_strength = 1.0 if getattr(self, "gate_enabled", False): instruction = getattr(self.cfg.TRAINER.CAPID, "INSTRUCTION", "") inst_feat = self._encode_instruction(instruction) # Use text-only gating by default to avoid extra compute; img_feat kept None # Compute a quick baseline text feature from current prompts (detached) try: txt_feat_base = self.text_encoder(prompts, tokenized_prompts, deep_compound_prompts_text).detach() except Exception: txt_feat_base = None g_tensor = self.gate(inst_feat, txt_feat_base, None) gate_strength = max(0.0, min(self.gate_max, float(g_tensor.item()))) # Safe DIFF should_diff = ( getattr(self, "diff_enabled", False) and (getattr(self, "cfg_scale", 0.0) > 0.0) and (getattr(self, "diff_noise", 0.0) > 0.0) and (getattr(self, "diff_steps", 0) > 0) ) cond_txt = prompts.mean(dim=1) # (n_cls, 512) cond_vis = vis_tokens.mean(dim=1) # (n_cls, 768) delta_txt = self.diff_text.sample(prompts, cond=cond_txt, steps=self.diff_steps, noise_std=self.diff_noise) delta_txt = F.layer_norm(delta_txt, delta_txt.shape[-1:]) prompts = prompts + self.diff_scale * self.cfg_scale * gate_strength * delta_txt delta_vis = self.diff_vision.sample(vis_tokens, cond=cond_vis, steps=self.diff_steps, noise_std=self.diff_noise) delta_vis = F.layer_norm(delta_vis, delta_vis.shape[-1:]) vis_tokens = vis_tokens + self.diff_scale * self.cfg_scale * gate_strength * delta_vis attn_maps = None # If using bridge mode, skip shallow CA here if getattr(self, "ca_enabled", False) and getattr(self, "ca_mode", "bridge") != "bridge": p_in, v_in = prompts, vis_tokens p_ca, v_ca, attn_maps = self.ca(p_in, v_in) p_ca = F.layer_norm(p_ca, p_ca.shape[-1:]) v_ca = F.layer_norm(v_ca, v_ca.shape[-1:]) alpha = max(0.0, min(1.0, float(self.ca_alpha))) prompts = (1.0 - alpha) * p_in + alpha * p_ca vis_tokens = (1.0 - alpha) * v_in + alpha * v_ca # Reduce back to shared_ctx shape expected by vision encoder shared_ctx = vis_tokens.mean(dim=0) # (n_ctx, 768) # Debug saves (very lightweight) if getattr(self, "debug_save", False): self._debug_step += 1 if self._debug_step % max(1, self.debug_freq) == 0: try: if attn_maps is not None and len(attn_maps) > 0: a = attn_maps[0][0] if a.dim() == 4: a_vis = a.mean(dim=1)[0] elif a.dim() == 3: a_vis = a[0] else: a_vis = a.flatten(1).unsqueeze(0) out_path = osp.join(self.debug_dir, f"attn_layer0_{self._debug_step:06d}.png") save_debug_image(a_vis, out_path) if getattr(self, "diff_enabled", False): # Save magnitude heatmaps for first class' deltas try: dt = (delta_txt[0].norm(dim=-1, keepdim=False)) # (L_text,) dv = (delta_vis[0].norm(dim=-1, keepdim=False)) # (L_vis,) # Expand to 2D for visualization save_debug_image(dt.unsqueeze(0), osp.join(self.debug_dir, f"delta_txt_norm_{self._debug_step:06d}.png")) save_debug_image(dv.unsqueeze(0), osp.join(self.debug_dir, f"delta_vis_norm_{self._debug_step:06d}.png")) except Exception: pass except Exception: pass text_features = self.text_encoder(prompts, tokenized_prompts, deep_compound_prompts_text) image_features = self.image_encoder(image.type(self.dtype), shared_ctx, deep_compound_prompts_vision) image_features = image_features / image_features.norm(dim=-1, keepdim=True) text_features = text_features / text_features.norm(dim=-1, keepdim=True) logits = logit_scale * image_features @ text_features.t() if self.prompt_learner.training: loss = F.cross_entropy(logits, label) if getattr(self, "capid_enabled", False) and getattr(self, "diff_enabled", False) and (getattr(self, "diff_loss_weight", 0.0) > 0): n_cls = self.prompt_learner.n_cls vis_tokens = shared_ctx.unsqueeze(0).expand(n_cls, -1, -1) # (n_cls, n_ctx, 768) # 条件:文本用 prompts 的 token 平均;视觉用 shared_ctx 的 token 平均 cond_txt = prompts.mean(dim=1) # (n_cls, 512) cond_vis = shared_ctx.mean(dim=0, keepdim=True).expand(n_cls, -1) # (n_cls, 768) try: l_txt = self.diff_text.diffusion_loss(prompts, cond_txt, noise_std=float(getattr(self, "diff_noise", 0.1))) except Exception: l_txt = torch.tensor(0.0, device=loss.device, dtype=loss.dtype) try: l_vis = self.diff_vision.diffusion_loss(vis_tokens, cond_vis, noise_std=float(getattr(self, "diff_noise", 0.1))) except Exception: l_vis = torch.tensor(0.0, device=loss.device, dtype=loss.dtype) loss = loss + self.diff_loss_weight * (l_txt + l_vis) return loss return logits def _get_clones(module, N): return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) @TRAINER_REGISTRY.register() class MaPLe(TrainerX): def check_cfg(self, cfg): assert cfg.TRAINER.MAPLE.PREC in ["fp16", "fp32", "amp"] def build_model(self): cfg = self.cfg classnames = self.dm.dataset.classnames print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") clip_model = load_clip_to_cpu(cfg) if cfg.TRAINER.MAPLE.PREC == "fp32" or cfg.TRAINER.MAPLE.PREC == "amp": # CLIP's default precision is fp16 clip_model.float() print("Building custom CLIP") self.model = CustomCLIP(cfg, classnames, clip_model) print("Turning off gradients in both the image and the text encoder") # Default: only update prompt_learner (MaPLe). If CAPID enabled, also allow # bridge/CA/DIFF/Gate small modules to learn. capid_cfg = getattr(self.cfg.TRAINER, "CAPID", None) capid_on = bool(getattr(capid_cfg, "ENABLE", False)) if capid_cfg is not None else False capid_train_only = bool(getattr(capid_cfg, "TRAIN_ONLY_CAPID", False)) if capid_cfg is not None else False # Freeze CLIP backbone under _clip_model_ref; only train open-track prompt subset + CAPID small modules for name, param in self.model.named_parameters(): # hard block CLIP backbone if "prompt_learner._clip_model_ref" in name: param.requires_grad_(False) continue if capid_on and capid_train_only: # train only CAPID modules allow = ( ("ca_bridge" in name) or (".ca." in name or name.endswith(".ca")) or ("diff_text" in name) or ("diff_vision" in name) or ("gate" in name) ) else: # open-track prompt subset + CAPID modules (+VPT) allow = ( ( name.startswith("prompt_learner.ctx") or name.startswith("prompt_learner.proj") or name.startswith("prompt_learner.compound_prompts_text.0") or name.startswith("prompt_learner.compound_prompt_projections.0") ) or (capid_on and ( ("ca_bridge" in name) or (".ca." in name or name.endswith(".ca")) or ("diff_text" in name) or ("diff_vision" in name) or ("gate" in name) )) or ("VPT" in name) ) param.requires_grad_(bool(allow)) # Double check enabled = set() for name, param in self.model.named_parameters(): if param.requires_grad: enabled.add(name) print(f"Parameters to be updated: {enabled}") if cfg.MODEL.INIT_WEIGHTS: load_pretrained_weights(self.model, cfg.MODEL.INIT_WEIGHTS) self.model.to(self.device) # NOTE: only give prompt_learner to the optimizer self.optim = build_optimizer(self.model, cfg.OPTIM) self.sched = build_lr_scheduler(self.optim, cfg.OPTIM) self.register_model("MultiModalPromptLearner", self.model, self.optim, self.sched) self.scaler = GradScaler() if cfg.TRAINER.MAPLE.PREC == "amp" else None # Note that multi-gpu training could be slow because CLIP's size is # big, which slows down the copy operation in DataParallel device_count = torch.cuda.device_count() if device_count > 1: print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!") self.model = nn.DataParallel(self.model) def forward_backward(self, batch): image, label = self.parse_batch_train(batch) model = self.model optim = self.optim scaler = self.scaler prec = self.cfg.TRAINER.MAPLE.PREC if prec == "amp": with autocast(): loss = model(image, label) optim.zero_grad() scaler.scale(loss).backward() scaler.step(optim) scaler.update() else: loss = model(image, label) optim.zero_grad() loss.backward() optim.step() loss_summary = {"loss": loss.item()} if (self.batch_idx + 1) == self.num_batches: self.update_lr() return loss_summary def parse_batch_train(self, batch): input = batch["img"] label = batch["label"] input = input.to(self.device) label = label.to(self.device) return input, label def load_model(self, directory, epoch=None): if not directory: print("Note that load_model() is skipped as no pretrained model is given") return names = self.get_model_names() # By default, the best model is loaded model_file = "model-best.pth.tar" if epoch is not None: model_file = "model.pth.tar-" + str(epoch) for name in names: model_path = osp.join(directory, name, model_file) if not osp.exists(model_path): raise FileNotFoundError('Model not found at "{}"'.format(model_path)) checkpoint = load_checkpoint(model_path) state_dict = checkpoint["state_dict"] epoch = checkpoint["epoch"] # Ignore fixed token vectors if "prompt_learner.token_prefix" in state_dict: del state_dict["prompt_learner.token_prefix"] if "prompt_learner.token_suffix" in state_dict: del state_dict["prompt_learner.token_suffix"] print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch)) # set strict=False self._models[name].load_state_dict(state_dict, strict=False) 要求: 在类 MultiModalPromptLearner.forward() 中,找到 if should_diff and (not self.training): 下面的两行采样: delta_txt = self.diff_text.sample(prompts, cond=None, steps=self.diff_steps, noise_std=self.diff_noise) delta_vis = self.diff_vision.sample(vis_tokens, cond=None, steps=self.diff_steps, noise_std=self.diff_noise) 将它们改为: cond_txt_pl = prompts.mean(dim=1) # (n_cls, 512) cond_vis_pl = vis_tokens.mean(dim=1) # (n_cls, 768) delta_txt = self.diff_text.sample(prompts, cond=cond_txt_pl, steps=self.diff_steps, noise_std=self.diff_noise) delta_vis = self.diff_vision.sample(vis_tokens, cond=cond_vis_pl, steps=self.diff_steps, noise_std=self.diff_noise) 2) 同一文件,在类 CustomCLIP.forward() 中,另一处 if should_diff and (not self.training): 下面的两行采样: delta_txt = self.diff_text.sample(prompts, cond=None, steps=self.diff_steps, noise_std=self.diff_noise) delta_vis = self.diff_vision.sample(vis_tokens, cond=None, steps=self.diff_steps, noise_std=self.diff_noise) 将它们改为: cond_txt = prompts.mean(dim=1) # (n_cls, 512) cond_vis = vis_tokens.mean(dim=1) # (n_cls, 768) delta_txt = self.diff_text.sample(prompts, cond=cond_txt, steps=self.diff_steps, noise_std=self.diff_noise) delta_vis = self.diff_vision.sample(vis_tokens, cond=cond_vis, steps=self.diff_steps, noise_std=self.diff_noise) 改完后发我修改后的完整代码
11-13
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值