copy.deepcopy(train_model)时报错:Only Tensors created explicitly by the user support the deepcopy

在PyTorch模型训练过程中,使用`copy.deepcopy()`对模型进行拷贝时遇到RuntimeError,原因是不支持非用户显式创建的Tensors。问题定位到模型子模块返回了self.features,修改为返回临时变量features解决了问题。修改前后的代码示例展示了这一变化。
部署运行你感兴趣的模型镜像

错误信息:

RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment

可能的原因:

模型训练过程中常需边训练边做validation,通常使用copy.deepcopy()直接深度拷贝训练中的model用来做validation是比较简洁的写法,如在我的validation.py中,会用到:

 val_model = copy.deepcopy(train_model)

但是由于copy.deepcopy()的限制,调用copy.deepcopy(model)时可能就会遇到这个错误:Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment,详细错误信息如下:

  File "/home/users/xinxin.li/HAT-dev-toolchain/hat/engine/ddp_trainer.py", line 359, in _with_exception
    fn(*args)
  File "/home/users/xinxin.li/HAT-dev-toolchain/tools/train.py", line 186, in train_entrance
    trainer.fit()
  File "/home/users/xinxin.li/HAT-dev-toolchain/hat/engine/loop_base.py", line 523, in fit
    storage=self.storage,
  File "/home/users/xinxin.li/HAT-dev-toolchain/hat/engine/loop_base.py", line 73, in on_epoch_end
    cb.on_epoch_end(**kwargs)
  File "/home/users/xinxin.li/HAT-dev-toolchain/hat/callbacks/validation.py", line 207, in on_epoch_end
    self._do_val(epoch_id, model, ema_model, device, val_metrics)
  File "/home/users/xinxin.li/HAT-dev-toolchain/hat/callbacks/validation.py", line 163, in _do_val
    val_model = self._select_and_init_val_model(train_model=eval_model)
  File "/home/users/xinxin.li/HAT-dev-toolchain/hat/callbacks/validation.py", line 147, in _select_and_init_val_model
    val_model = copy.deepcopy(train_model)
  File "/home/users/xinxin.li/anaconda3/envs/python36/lib/python3.6/copy.py", line 180, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/home/users/xinxin.li/anaconda3/envs/python36/lib/python3.6/copy.py", line 280, in _reconstruct
    state = deepcopy(state, memo)
  File "/home/users/xinxin.li/anaconda3/envs/python36/lib/python3.6/copy.py", line 150, in deepcopy
    y = copier(x, memo)
  File "/home/users/xinxin.li/anaconda3/envs/python36/lib/python3.6/copy.py", line 240, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/home/users/xinxin.li/anaconda3/envs/python36/lib/python3.6/copy.py", line 180, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/home/users/xinxin.li/anaconda3/envs/python36/lib/python3.6/copy.py", line 306, in _reconstruct
    value = deepcopy(value, memo)
  File "/home/users/xinxin.li/anaconda3/envs/python36/lib/python3.6/copy.py", line 180, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/home/users/xinxin.li/anaconda3/envs/python36/lib/python3.6/copy.py", line 280, in _reconstruct
    state = deepcopy(state, memo)
  File "/home/users/xinxin.li/anaconda3/envs/python36/lib/python3.6/copy.py", line 150, in deepcopy
    y = copier(x, memo)
  File "/home/users/xinxin.li/anaconda3/envs/python36/lib/python3.6/copy.py", line 240, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/home/users/xinxin.li/anaconda3/envs/python36/lib/python3.6/copy.py", line 180, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/home/users/xinxin.li/anaconda3/envs/python36/lib/python3.6/copy.py", line 306, in _reconstruct
    value = deepcopy(value, memo)
  File "/home/users/xinxin.li/anaconda3/envs/python36/lib/python3.6/copy.py", line 180, in deepcopy
    y = _reconstruct(x, memo, *rv)
  File "/home/users/xinxin.li/anaconda3/envs/python36/lib/python3.6/copy.py", line 280, in _reconstruct
    state = deepcopy(state, memo)
  File "/home/users/xinxin.li/anaconda3/envs/python36/lib/python3.6/copy.py", line 150, in deepcopy
    y = copier(x, memo)
  File "/home/users/xinxin.li/anaconda3/envs/python36/lib/python3.6/copy.py", line 240, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/home/users/xinxin.li/anaconda3/envs/python36/lib/python3.6/copy.py", line 150, in deepcopy
    y = copier(x, memo)
  File "/home/users/xinxin.li/anaconda3/envs/python36/lib/python3.6/copy.py", line 240, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
  File "/home/users/xinxin.li/anaconda3/envs/python36/lib/python3.6/copy.py", line 161, in deepcopy
    y = copier(memo)
  File "/home/users/xinxin.li/anaconda3/envs/python36/lib/python3.6/site-packages/torch/_tensor.py", line 85, in __deepcopy__
    raise RuntimeError("Only Tensors created explicitly by the user "
RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment

如何排查:

1. 进入 /home/users/xinxin.li/anaconda3/envs/python36/lib/python3.6/copy.py ,给下面位置打断点,并输出对应的 key 和 value

 2. 重新运行程序,定位报错的前一行的网络对应原模型的哪一行,找到你网络结构对应的位置,就是这个地方的报错

 我的问题定位:

因为我的模型子模块在构建时返回了 self.features,导致了这个错误,我修改返回临时变量后,这个错误解决了。

修改前的代码:

    def forward(self, input_image):
        self.features = []
        x = (input_image - 0.45) / 0.225
        x = self.encoder.conv1(x)
        x = self.encoder.bn1(x)
        self.features.append(self.encoder.relu(x))
        self.features.append(self.encoder.layer1(self.encoder.maxpool(self.features[-1])))
        self.features.append(self.encoder.layer2(self.features[-1]))
        self.features.append(self.encoder.layer3(self.features[-1]))
        self.features.append(self.encoder.layer4(self.features[-1]))

        return self.features

 修改后的代码:

    def forward(self, input_image):
        features = []
        x = (input_image - 0.45) / 0.225
        x = self.encoder.conv1(x)
        x = self.encoder.bn1(x)
        features.append(self.encoder.relu(x))
        features.append(self.encoder.layer1(self.encoder.maxpool(features[-1])))
        features.append(self.encoder.layer2(features[-1]))
        features.append(self.encoder.layer3(features[-1]))
        features.append(self.encoder.layer4(features[-1]))

        return features

参考链接:(138条消息) 解决使用copy.deepcopy()拷贝Tensor或model时报错只支持用户显式创建的Tensor问题_Arnold-FY-Chen的博客-优快云博客_copy tensor

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

#!/usr/bin/env python3 import numpy as np if not hasattr(np, 'float'): np.float = np.float32 from isaacgym import gymtorch from typing import Optional, Dict, Union, Mapping, Tuple, List, Any, Iterable from dataclasses import dataclass, InitVar, replace from pathlib import Path from copy import deepcopy from gym import spaces import tempfile import multiprocessing as mp import json import cv2 from functools import partial import numpy as np import torch as th import einops from torch.utils.tensorboard import SummaryWriter from pkm.models.common import (transfer, map_struct) # from pkm.models.rl.v4.rppo import ( # RecurrentPPO as PPO) from pkm.models.rl.v6.ppo import PPO from pkm.models.rl.generic_state_encoder import ( MLPStateEncoder) from pkm.models.rl.nets import ( VNet, PiNet, CategoricalPiNet, MLPFwdBwdDynLossNet ) # env + general wrappers # FIXME: ArmEnv _looks_ like a class, but it's # actually PushEnv + wrapper. from pkm.env.arm_env import (ArmEnv, ArmEnvConfig, OBS_BOUND_MAP, _identity_bound) from pkm.env.env.wrap.base import WrapperEnv from pkm.env.env.wrap.normalize_env import NormalizeEnv from pkm.env.env.wrap.monitor_env import MonitorEnv from pkm.env.env.wrap.adaptive_domain_tuner import MultiplyScalarAdaptiveDomainTuner from pkm.env.env.wrap.nvdr_camera_wrapper import NvdrCameraWrapper from pkm.env.env.wrap.popdict import PopDict from pkm.env.env.wrap.mvp_wrapper import MvpWrapper from pkm.env.env.wrap.normalize_img import NormalizeImg from pkm.env.env.wrap.tabulate_action import TabulateAction from pkm.env.env.wrap.nvdr_record_viewer import NvdrRecordViewer from pkm.env.env.wrap.nvdr_record_episode import NvdrRecordEpisode from pkm.env.util import set_seed from pkm.util.config import (ConfigBase, recursive_replace_map) from pkm.util.hydra_cli import hydra_cli from pkm.util.path import RunPath, ensure_directory from pkm.train.ckpt import last_ckpt, step_from_ckpt from pkm.train.hf_hub import (upload_ckpt, HfConfig, GroupConfig) from pkm.train.wandb import with_wandb, WandbConfig from pkm.train.util import ( assert_committed ) # domain-specific wrappers from envs.push_env_wrappers import ( AddGoalThreshFromPushTask ) from envs.cube_env_wrappers import ( AddObjectMass, AddPhysParams, AddPrevArmWrench, AddPrevAction, AddWrenchPenalty, AddObjectEmbedding, AddObjectKeypoint, AddObjectFullCloud, AddFingerFullCloud, AddApproxTouchFlag, AddTouchCount, AddSuccessAsObs, AddTrackingReward, QuatToDCM, QuatTo6D, RelGoal, Phase2Training, P2VembObs, ICPEmbObs, PNEmbObs ) # == drawing/debugging wrappers == from pkm.env.env.wrap.draw_bbox_kpt import DrawGoalBBoxKeypoint, DrawObjectBBoxKeypoint from pkm.env.env.wrap.draw_inertia_box import DrawInertiaBox from pkm.env.env.wrap.draw_clouds import DrawClouds from pkm.env.env.wrap.draw_patch_attn import DrawPatchAttention from envs.cube_env_wrappers import (DrawGoalPose, DrawObjPose, DrawTargetPose, DrawPosBound, DrawDebugLines) import nvtx from icecream import ic def to_pod(x: np.ndarray) -> List[float]: return [float(e) for e in x] @dataclass class PolicyConfig(ConfigBase): """ Actor-Critic policy configuration. """ actor: PiNet.Config = PiNet.Config() value: VNet.Config = VNet.Config() dim_state: InitVar[Optional[int]] = None dim_act: InitVar[Optional[int]] = None def __post_init__(self, dim_state: Optional[int] = None, dim_act: Optional[int] = None): if dim_state is not None: self.actor = replace(self.actor, dim_feat=dim_state) self.value = replace(self.value, dim_feat=dim_state) if dim_act is not None: self.actor = replace(self.actor, dim_act=dim_act) @dataclass class NetworkConfig(ConfigBase): """ Overall network configuration. """ state: MLPStateEncoder.Config = MLPStateEncoder.Config() policy: PolicyConfig = PolicyConfig() obs_space: InitVar[Union[int, Dict[str, int], None]] = None act_space: InitVar[Optional[int]] = None def __post_init__(self, obs_space=None, act_space=None): self.state = replace(self.state, obs_space=obs_space, act_space=act_space) try: if isinstance(act_space, Iterable) and len(act_space) == 1: act_space = act_space[0] policy = replace(self.policy, dim_state=self.state.state.dim_out, dim_act=act_space) self.policy = policy except AttributeError: pass @dataclass class Config(WandbConfig, HfConfig, GroupConfig, ConfigBase): # WandbConfig parts project: str = 'arm-ppo' use_wandb: bool = True # HfConfig (huggingface) parts hf_repo_id: Optional[str] = 'corn/corn-/arm' use_hfhub: bool = True # General experiment / logging force_commit: bool = False description: str = '' path: RunPath.Config = RunPath.Config(root='/tmp/pkm/ppo-arm/') env: ArmEnvConfig = ArmEnvConfig(which_robot='franka') agent: PPO.Config = PPO.Config() # State/Policy network configurations net: NetworkConfig = NetworkConfig() # Loading / continuing from prevous runs load_ckpt: Optional[str] = None transfer_ckpt: Optional[str] = None freeze_transferred: bool = True global_device: Optional[str] = None # VISION CONFIG use_camera: bool = False camera: NvdrCameraWrapper.Config = NvdrCameraWrapper.Config( use_depth=True, use_col=True, ctx_type='cuda', # == D435 config(?) == # aspect=8.0 / 5.0, # img_size=(480,848) # z_near for the physical camera # is actually pretty large! # z_near=0.195 # Horizontal Field of View 69.4 91.2 # Vertical Field of View 42.5 65.5 ) # Convert img into MVP-pretrained embeddings use_mvp: bool = False remove_state: bool = False remove_robot_state: bool = False remove_all_state: bool = False # Determines which inputs, even if they remain # in the observation dict, are not processed # by the state representation network. state_net_blocklist: Optional[List[str]] = None # FIXME: remove `hide_action`: # legacy config from train_ppo_hand.py hide_action: Optional[bool] = True add_object_mass: bool = False add_object_embedding: bool = False add_phys_params: bool = False add_keypoint: bool = False add_object_full_cloud: bool = False add_goal_full_cloud: bool = False add_finger_full_cloud: bool = False add_prev_wrench: bool = True add_prev_action: bool = True zero_out_prev_action: bool = False add_goal_thresh: bool = False add_wrench_penalty: bool = False wrench_penalty_coef: float = 1e-4 add_touch_flag: bool = False add_touch_count: bool = False min_touch_force: float = 5e-2 min_touch_speed: float = 1e-3 add_success: bool = False add_tracking_reward: bool = False # ==<CURRICULUM>== use_tune_init_pos: bool = False tune_init_pos_scale: MultiplyScalarAdaptiveDomainTuner.Config = MultiplyScalarAdaptiveDomainTuner.Config( step=1.05, easy=0.1, hard=1.0) use_tune_goal_radius: bool = False tune_goal_radius: MultiplyScalarAdaptiveDomainTuner.Config = MultiplyScalarAdaptiveDomainTuner.Config( step=0.95, easy=0.5, hard=0.05) use_tune_goal_speed: bool = False tune_goal_speed: MultiplyScalarAdaptiveDomainTuner.Config = MultiplyScalarAdaptiveDomainTuner.Config( step=0.95, easy=4.0, hard=0.1) use_tune_goal_angle: bool = False tune_goal_angle: MultiplyScalarAdaptiveDomainTuner.Config = MultiplyScalarAdaptiveDomainTuner.Config( step=0.95, easy=1.57, hard=0.05) use_tune_pot_gamma: bool = False tune_pot_gamma: MultiplyScalarAdaptiveDomainTuner.Config = MultiplyScalarAdaptiveDomainTuner.Config( step=0.999, easy=1.00, hard=0.99, step_down=1.001, metric='return', target_lower=0.0, target_upper=0.0) force_vel: Optional[float] = None force_rad: Optional[float] = None force_ang: Optional[float] = None # ==</CURRICULUM>== use_tabulate: bool = False tabulate: TabulateAction.Config = TabulateAction.Config( num_bin=3 ) use_norm: bool = True normalizer: NormalizeEnv.Config = NormalizeEnv.Config() # Convert some observations into # alternative forms... use_dcm: bool = False use_rel_goal: bool = False use_6d_rel_goal: bool = False use_monitor: bool = True monitor: MonitorEnv.Config = MonitorEnv.Config() # == camera config == use_nvdr_record_episode: bool = False nvdr_record_episode: NvdrRecordEpisode.Config = NvdrRecordEpisode.Config() use_nvdr_record_viewer: bool = False nvdr_record_viewer: NvdrRecordViewer.Config = NvdrRecordViewer.Config( img_size=(128, 128) ) normalize_img: bool = True img_mean: float = 0.4 img_std: float = 0.2 cloud_only: bool = False multiple_cameras: bool = False camera_eyes: Tuple[Any] = ( (-0.238, 0.388, 0.694), (-0.408, -0.328, 0.706) ) # == "special" training configs # add auxiliary dynamics netweork+loss add_dyn_aux: bool = False # automatic mixed-precision(FP16) training use_amp: bool = False # DataParallel training across multiple devices parallel: Optional[Tuple[int, ...]] = None # == periodic validation configs == sample_action: bool = False eval_period: int = -1 eval_step: int = 256 eval_num_env: int = 16 eval_record: bool = True eval_device: str = 'cuda:0' eval_track_per_obj_suc_rate: bool = False draw_debug_lines: bool = False draw_patch_attn: bool = False finalize: bool = False parallel: Optional[Tuple[int, ...]] = None is_phase2: bool = False phase2: Phase2Training.Config = Phase2Training.Config() use_p2v: bool = False use_icp_obs: bool = False use_pn_obs: bool = False p2v: P2VembObs.Config = P2VembObs.Config() icp_obs: ICPEmbObs.Config = ICPEmbObs.Config() pn_obs: PNEmbObs.Config = PNEmbObs.Config() def __post_init__(self): self.group = F'{self.machine}-{self.env_name}-{self.model_name}-{self.tag}' self.name = F'{self.group}-{self.env.seed:06d}' if not self.finalize: return # WARNING: VERY HAZARDOUS use_dr_on_setup = self.env.single_object_scene.use_dr_on_setup | self.is_phase2 use_dr = self.env.single_object_scene.use_dr | self.is_phase2 self.env = recursive_replace_map( self.env, {'franka.compute_wrench': self.add_prev_wrench, 'franka.add_control_noise': self.is_phase2, 'single_object_scene.use_dr_on_setup': use_dr_on_setup, 'single_object_scene.use_dr': use_dr, }) if self.global_device is not None: dev_id: int = int(str(self.global_device).split(':')[-1]) self.env = recursive_replace_map(self.env, { 'graphics_device_id': (dev_id if self.env.use_viewer else -1), 'compute_device_id': dev_id, 'th_device': self.global_device, }) self.agent = recursive_replace_map(self.agent, { 'device': self.global_device}) if self.force_vel is not None: self.use_tune_goal_speed = False self.env.task.max_speed = self.force_vel if self.force_rad is not None: self.use_tune_goal_radius = False self.env.task.goal_radius = self.force_rad if self.force_ang is not None: self.use_tune_goal_angle = False self.env.task.goal_angle = self.force_ang def setup(cfg: Config): # Maybe it's related to jit if cfg.global_device is not None: th.cuda.set_device(cfg.global_device) th.backends.cudnn.benchmark = True commit_hash = assert_committed(force_commit=cfg.force_commit) path = RunPath(cfg.path) print(F'run = {path.dir}') return path class AddTensorboardWriter(WrapperEnv): def __init__(self, env): super().__init__(env) self._writer = None def set_writer(self, w): self._writer = w @property def writer(self): return self._writer def load_env(cfg: Config, path, freeze_env: bool = False, **kwds): env = ArmEnv(cfg.env) env.setup() env.gym.prepare_sim(env.sim) env.refresh_tensors() env.reset() env = AddTensorboardWriter(env) obs_bound = None if cfg.use_norm: obs_bound = {} # Populate `obs_bound` with defaults # from `ArmEnv`. obs_bound['goal'] = OBS_BOUND_MAP.get(cfg.env.goal_type) obs_bound['object_state'] = OBS_BOUND_MAP.get( cfg.env.object_state_type) obs_bound['hand_state'] = OBS_BOUND_MAP.get(cfg.env.hand_state_type) obs_bound['robot_state'] = OBS_BOUND_MAP.get(cfg.env.robot_state_type) if cfg.normalizer.norm.stats is not None: obs_bound.update(deepcopy(cfg.normalizer.norm.stats)) print(obs_bound) def __update_obs_bound(key, value, obs_bound, overwrite: bool = True): if not cfg.use_norm: return if value is None: obs_bound.pop(key, None) if key in obs_bound: if overwrite: print(F'\t WARN: key = {key} already in obs_bound !') else: raise ValueError(F'key = {key} already in obs_bound !') obs_bound[key] = value update_obs_bound = partial(__update_obs_bound, obs_bound=obs_bound) if cfg.env.task.use_pose_goal: if cfg.add_goal_full_cloud: update_obs_bound('goal', OBS_BOUND_MAP.get('cloud')) else: update_obs_bound('goal', OBS_BOUND_MAP.get(cfg.env.goal_type)) # Crude check for mutual exclusion # Determines what type of privileged "state" information # the policy will receive, as observation. assert ( np.count_nonzero( [cfg.remove_state, cfg.remove_robot_state, cfg.remove_all_state]) <= 1) if cfg.remove_state: env = PopDict(env, ['object_state']) update_obs_bound('object_state', None) elif cfg.remove_robot_state: env = PopDict(env, ['hand_state']) update_obs_bound('hand_state', None) elif cfg.remove_all_state: env = PopDict(env, ['hand_state', 'object_state']) update_obs_bound('hand_state', None) update_obs_bound('object_state', None) if cfg.add_object_mass: env = AddObjectMass(env, 'object_mass') update_obs_bound('object_mass', OBS_BOUND_MAP.get('mass')) if cfg.add_phys_params: env = AddPhysParams(env, 'phys_params') update_obs_bound('phys_params', OBS_BOUND_MAP.get('phys_params')) if cfg.add_object_embedding: env = AddObjectEmbedding(env, 'object_embedding') update_obs_bound('object_embedding', OBS_BOUND_MAP.get('embedding')) if cfg.add_keypoint: env = AddObjectKeypoint(env, 'object_keypoint') update_obs_bound('object_keypoint', OBS_BOUND_MAP.get('keypoint')) if cfg.add_object_full_cloud: # mutually exclusive w.r.t. `use_cloud` # i.e. the partial point cloud coming from # the camera. # assert (cfg.camera.use_cloud is False) goal_key = None if cfg.add_goal_full_cloud: goal_key = 'goal' env = AddObjectFullCloud(env, 'cloud', goal_key=goal_key) update_obs_bound('cloud', OBS_BOUND_MAP.get('cloud')) if goal_key is not None: update_obs_bound(goal_key, OBS_BOUND_MAP.get('cloud')) if cfg.add_finger_full_cloud: env = AddFingerFullCloud(env, 'finger_cloud') update_obs_bound('finger_cloud', OBS_BOUND_MAP.get('cloud')) if cfg.add_goal_thresh: env = AddGoalThreshFromPushTask(env, key='goal_thresh', dim=3) update_obs_bound('goal_thresh', _identity_bound(3)) if cfg.add_prev_wrench: env = AddPrevArmWrench(env, 'previous_wrench') update_obs_bound('previous_wrench', OBS_BOUND_MAP.get('wrench')) if cfg.add_prev_action: env = AddPrevAction(env, 'previous_action', zero_out=cfg.zero_out_prev_action) update_obs_bound('previous_action', _identity_bound( env.observation_space['previous_action'].shape )) if cfg.add_wrench_penalty: env = AddWrenchPenalty(env, cfg.wrench_penalty_coef, key='env/wrench_cost') if cfg.add_touch_flag: env = AddApproxTouchFlag(env, key='touch', min_force=cfg.min_touch_force, min_speed=cfg.min_touch_speed) if cfg.add_touch_count: assert (cfg.add_touch_flag) env = AddTouchCount(env, key='touch_count') update_obs_bound('touch_count', _identity_bound( env.observation_space['touch_count'].shape )) if cfg.add_success: env = AddSuccessAsObs(env, key='success') update_obs_bound('success', _identity_bound(())) if cfg.use_camera: prev_space_keys = deepcopy(list(env.observation_space.keys())) env = NvdrCameraWrapper( env, cfg.camera ) for k in env.observation_space.keys(): if k in prev_space_keys: continue obs_shape = env.observation_space[k].shape # if k in cfg.normalizer.obs_shape: # obs_shape = cfg.normalizer.obs_shape[k] print(k, obs_shape) if 'cloud' in k: update_obs_bound(k, OBS_BOUND_MAP.get('cloud')) else: update_obs_bound(k, _identity_bound(obs_shape[-1:])) if cfg.multiple_cameras: camera = deepcopy(cfg.camera) camera = replace( camera, use_label=False ) for i, eye in enumerate(cfg.camera_eyes): cloud_key = f'partial_cloud_{i+1}' new_camera = replace( camera, eye=eye ) new_camera = replace( new_camera, key_cloud=cloud_key ) env = NvdrCameraWrapper( env, new_camera ) update_obs_bound(cloud_key, OBS_BOUND_MAP.get('cloud')) if cfg.normalize_img: env = NormalizeImg(env, cfg.img_mean, cfg.img_std, key='depth') # After normalization, it (should) map to (0.0, 1.0) update_obs_bound('depth', (0.0, 1.0)) if cfg.cloud_only: env = PopDict(env, ['depth', 'label']) update_obs_bound('depth', None) update_obs_bound('label', None) if cfg.use_mvp: assert (cfg.use_camera) env = MvpWrapper(env) raise ValueError( 'MVPWrapper does not currently configure a proper obs space.' ) if cfg.add_tracking_reward: env = AddTrackingReward(env, 1e-4) # == curriculum == if cfg.use_tune_init_pos: def get_init_pos_scale(): return env.scene._pos_scale def set_init_pos_scale(s: float): env.scene._pos_scale = s env = MultiplyScalarAdaptiveDomainTuner(cfg.tune_init_pos_scale, env, get_init_pos_scale, set_init_pos_scale, key='env/init_pos_scale') if cfg.use_tune_goal_radius: def get_goal_rad(): return env.task.goal_radius def set_goal_rad(s: float): env.task.goal_radius = s env = MultiplyScalarAdaptiveDomainTuner(cfg.tune_goal_radius, env, get_goal_rad, set_goal_rad, key='env/goal_radius') if cfg.use_tune_goal_speed: def get_goal_speed(): return env.task.max_speed def set_goal_speed(s: float): env.task.max_speed = s env = MultiplyScalarAdaptiveDomainTuner(cfg.tune_goal_speed, env, get_goal_speed, set_goal_speed, key='env/max_speed') if cfg.use_tune_goal_angle: def get_goal_ang(): return env.task.goal_angle def set_goal_ang(s: float): env.task.goal_angle = s env = MultiplyScalarAdaptiveDomainTuner(cfg.tune_goal_angle, env, get_goal_ang, set_goal_ang, key='env/goal_angle') if cfg.use_tune_pot_gamma: def get_pot_gamma(): return env.task.gamma def set_pot_gamma(s: float): env.task.gamma = s env = MultiplyScalarAdaptiveDomainTuner(cfg.tune_pot_gamma, env, get_pot_gamma, set_pot_gamma, key='env/pot_gamma') if cfg.use_tabulate: env = TabulateAction(cfg.tabulate, env) if cfg.use_dcm: env = QuatToDCM(env, { 'goal': 3, 'hand_state': 3, 'object_state': 3 }) raise ValueError( 'DCM (directional cosine matrix) conversions are ' 'currently disabled due to complex integration ' 'with obs_bound.') # Use relative goal between current object pose # and the goal pose, instead of absolute goal. if cfg.use_rel_goal: env = RelGoal(env, 'goal', use_6d=cfg.use_6d_rel_goal) if cfg.use_6d_rel_goal: update_obs_bound('goal', OBS_BOUND_MAP.get('relpose6d')) else: update_obs_bound('goal', OBS_BOUND_MAP.get('relpose')) if cfg.is_phase2: env = Phase2Training(cfg.phase2, env) # == DRAW, LOG, RECORD == if cfg.draw_debug_lines: check_viewer = kwds.pop('check_viewer', True) env = DrawDebugLines(DrawDebugLines.Config( draw_workspace=kwds.pop('draw_workspace', False), draw_wrench_target=kwds.pop('draw_wrench_target', False), draw_cube_action=kwds.pop('draw_hand_action', False) ), env, check_viewer=check_viewer) # NOTE: blocklist=0 indicates the table; # blocklist=2 indicates the robot. Basically, # only draw the inertia-box for the object. env = DrawInertiaBox(env, blocklist=[0, 2], check_viewer=check_viewer) env = DrawObjectBBoxKeypoint(env) env = DrawGoalBBoxKeypoint(env) env = DrawGoalPose(env, check_viewer=check_viewer) env = DrawObjPose(env, check_viewer=check_viewer) # Some alternative visualizations are available below; # [1] draw the goal as a "pose" frame axes # env = DrawTargetPose(env, # check_viewer=check_viewer) # [2] Draw franka EE boundary if cfg.env.franka.track_object: env = DrawPosBound(env, check_viewer=check_viewer) # [3] Draw input point cloud observations as spheres. # Should usually be prevented, so check_viewer=True # env = DrawClouds(env, check_viewer=True, stride=8, # cloud_key='partial_cloud', # or 'cloud' # style='ray') if cfg.draw_patch_attn: class PatchAttentionFromPPV5: """ Retrieve patchified point cloud and attention values from PointPatchV5FeatNet. """ def __init__(self): # self.__net = agent.state_net.feature_encoders['cloud'] self.__net = None def register(self, net): self.__net = net def __call__(self, obs): ravel_index = self.__net._patch_index.reshape( *obs['cloud'].shape[:-2], -1, 1) patch = th.take_along_dim( # B, N, D obs['cloud'], # B, (S, P), 1 ravel_index, dim=-2 ).reshape(*self.__net._patch_index.shape, obs['cloud'].shape[-1]) attn = self.__net._patch_attn # ic(attn) # Only include parts that correspond to # point patches # ic('pre',attn.shape) # attn = attn[..., 1:, :] attn = attn[..., :, 1:] # ic('post',attn.shape) # max among heads # attn = attn.max(dim=-2).values # head zero attn = attn[..., 2, :] return (patch, attn) env = DrawPatchAttention(env, PatchAttentionFromPPV5(), dilate=1.2, style='cloud') if cfg.use_nvdr_record_viewer: env = NvdrRecordViewer(cfg.nvdr_record_viewer, env, hide_arm=False) # == MONITOR PERFORMANCE == if cfg.use_monitor: env = MonitorEnv(cfg.monitor, env) # == Normalize environment == # normalization must come after # the monitoring code, since it # overwrites env statistics. if cfg.use_norm: cfg = recursive_replace_map(cfg, {'normalizer.norm.stats': obs_bound}) env = NormalizeEnv(cfg.normalizer, env, path) if cfg.load_ckpt is not None: ckpt_path = Path(cfg.load_ckpt) if ckpt_path.is_file(): # Try to select stats from matching timestep. step = ckpt_path.stem.split('-')[-1] def ckpt_key(ckpt_file): return (step in str(ckpt_file.stem).rsplit('-')[-1]) stat_dir = ckpt_path.parent / '../stat/' else: # Find the latest checkpoint. ckpt_key = step_from_ckpt stat_dir = ckpt_path / '../stat' if stat_dir.is_dir(): stat_ckpt = last_ckpt(stat_dir, key=ckpt_key) print(F'Also loading env stats from {stat_ckpt}') env.load(stat_ckpt, strict=False) # we'll freeze env stats by default, if loading from ckpt. if freeze_env: env.normalizer.eval() else: stat_ckpt = last_ckpt(cfg.load_ckpt + "_stat", key=ckpt_key) print(F'Also loading env stats from {stat_ckpt}') env.load(stat_ckpt, strict=False) if cfg.use_p2v: env = P2VembObs(env, cfg.p2v) env = PopDict(env, ['cloud']) update_obs_bound('cloud', None) if cfg.use_icp_obs: env = ICPEmbObs(env, cfg.icp_obs) env = PopDict(env, ['cloud']) update_obs_bound('cloud', None) if cfg.use_pn_obs: env = PNEmbObs(env, cfg.pn_obs) env = PopDict(env, ['cloud']) update_obs_bound('cloud', None) return cfg, env def load_agent(cfg, env, path, writer): device = env.device ic(cfg) # FIXME: We currently disable MLPStateEncoder from # receiving previous_action implicitly; it has to be # included in the observations explicitly. cfg.net.state.state.dim_act = 0 state_net = MLPStateEncoder.from_config(cfg.net.state) # Create policy/value networks. # FIXME: introspection into cfg.dim_out dim_state = state_net.state_aggregator.cfg.dim_out if isinstance(env.action_space, spaces.Discrete): actor_net = CategoricalPiNet(cfg.net.policy.actor).to(device) else: actor_net = PiNet(cfg.net.policy.actor).to(device) value_net = VNet(cfg.net.policy.value).to(device) # Add extra networks (Usually for regularization, # auxiliary losses, or learning extra models) extra_nets = None if cfg.add_dyn_aux: trans_net_cfg = MLPFwdBwdDynLossNet.Config( dim_state=dim_state, dim_act=cfg.net.policy.actor.dim_act, dim_hidden=(128,), ) trans_net = MLPFwdBwdDynLossNet(trans_net_cfg).to(device) extra_nets = {'trans_net': trans_net} agent = PPO( cfg.agent, env, state_net, actor_net, value_net, path, writer, extra_nets=extra_nets ).to(device) if cfg.transfer_ckpt is not None: ckpt = last_ckpt(cfg.transfer_ckpt, key=step_from_ckpt) xfer_dict = th.load(ckpt, map_location='cpu') keys = transfer(agent, xfer_dict['self'], freeze=cfg.freeze_transferred, substrs=[ # 'state_net.feature_encoders', # 'state_net.feature_aggregators' 'state_net' ], # prefix_map={ # 'state_net.feature_encoders.state': # 'state_net.feature_encoders.object_state', # 'state_net.feature_aggregators.state': # 'state_net.feature_aggregators.object_state', # }, verbose=True) print(keys) if cfg.load_ckpt is not None: ckpt: str = last_ckpt(cfg.load_ckpt, key=step_from_ckpt) print(F'Load agent from {ckpt}') agent.load(last_ckpt(cfg.load_ckpt, key=step_from_ckpt), strict=True) return agent def eval_agent_inner(cfg: Config, return_dict): # [1] Silence outputs during validation. import sys import os sys.stdout = open(os.devnull, 'w') sys.stderr = open(os.devnull, 'w') # [2] Import & run validation. from valid_ppo_arm import main return_dict.update(main(cfg)) def eval_agent(cfg: Config, env, agent: PPO): # subprocess.check_output('python3 valid_ppo_hand.py ++run=') manager = mp.Manager() return_dict = manager.dict() with tempfile.TemporaryDirectory() as tmpdir: # Save agent ckpt for validation. ckpt_dir = ensure_directory(F'{tmpdir}/ckpt') stat_dir = ensure_directory(F'{tmpdir}/stat') agent_ckpt = str(ckpt_dir / 'last.ckpt') env_ckpt = str(stat_dir / 'env-last.ckpt') env.save(env_ckpt) agent.save(agent_ckpt) # Override cfg. # FIXME: # Hardcoded target_domain coefs # shold potentially be # tune_goal_speed.hard... # etc. cfg = recursive_replace_map(cfg, { 'load_ckpt': str(ckpt_dir), 'force_vel': 0.1, 'force_rad': 0.05, 'force_ang': 0.1, 'env.num_env': cfg.eval_num_env, 'env.use_viewer': False, 'env.single_object_scene.num_object_types': ( cfg.env.single_object_scene.num_object_types), 'monitor.verbose': False, 'draw_debug_lines': True, 'use_nvdr_record_viewer': cfg.eval_record, 'nvdr_record_viewer.record_dir': F'{tmpdir}/record', 'env.task.mode': 'valid', 'env.single_object_scene.mode': 'valid', 'env.single_object_scene.num_valid_poses': 4, 'global_device': cfg.eval_device, 'eval_track_per_obj_suc_rate': True }) ctx = mp.get_context('spawn') # Run. proc = ctx.Process( target=eval_agent_inner, args=(cfg, return_dict), ) proc.start() proc.join() return_dict = dict(return_dict) if 'video' in return_dict: replaced = {} for k, v in return_dict['video'].items(): if isinstance(v, str): video_dir = v assert Path(video_dir).is_dir() filenames = sorted(Path(video_dir).glob('*.png')) rgb_images = [cv2.imread(str(x))[..., ::-1] for x in filenames] vid_array = np.stack(rgb_images, axis=0) v = th.as_tensor(vid_array[None]) v = einops.rearrange(v, 'n t h w c -> n t c h w') replaced[k] = v return_dict['video'] = replaced return return_dict @with_wandb def inner_main(cfg: Config, env, path): """ Basically it's the same as main(), but we commit the config _after_ finalizing. """ commit_hash = assert_committed(force_commit=cfg.force_commit) writer = SummaryWriter(path.tb_train) writer.add_text('meta/commit-hash', str(commit_hash), global_step=0) env.unwrap(target=AddTensorboardWriter).set_writer(writer) agent = load_agent(cfg, env, path, writer) # Enable DataParallel() for subset of modules. if (cfg.parallel is not None) and (th.cuda.device_count() > 1): count: int = th.cuda.device_count() device_ids = list(cfg.parallel) # FIXME: hardcoded DataParallel processing only for # `img` feature if 'img' in agent.state_net.feature_encoders: agent.state_net.feature_encoders['img'] = th.nn.DataParallel( agent.state_net.feature_encoders['img'], device_ids) ic(agent) def __eval(step: int): logs = eval_agent(cfg, env, agent) log_kwds = {'video': {'fps': 20.0}} # == generic log() == for log_type, log in logs.items(): for tag, value in log.items(): write = getattr(writer, F'add_{log_type}') write(tag, value, global_step=step, **log_kwds.get(log_type, {})) try: th.cuda.empty_cache() with th.cuda.amp.autocast(enabled=cfg.use_amp): for step in agent.learn(name=F'{cfg.name}@{path.dir}'): # Periodically run validation. if (cfg.eval_period > 0) and (step % cfg.eval_period) == 0: th.cuda.empty_cache() __eval(step) finally: # Dump final checkpoints. agent.save(path.ckpt / 'last.ckpt') if hasattr(env, 'save'): env.save(path.stat / 'env-last.ckpt') # Finally, upload the trained model to huggingface model hub. if cfg.use_hfhub and (cfg.hf_repo_id is not None): upload_ckpt( cfg.hf_repo_id, (path.ckpt / 'last.ckpt'), cfg.name) upload_ckpt( cfg.hf_repo_id, (path.stat / 'env-last.ckpt'), cfg.name + '_stat') @hydra_cli( config_path='../../src/pkm/data/cfg/', # config_path='/home/user/mambaforge/envs/genom/lib/python3.8/site-packages/pkm/data/cfg/', config_name='train_rl') def main(cfg: Config): ic.configureOutput(includeContext=True) cfg = recursive_replace_map(cfg, {'finalize': True}) # path, writer = setup(cfg) path = setup(cfg) seed = set_seed(cfg.env.seed) cfg, env = load_env(cfg, path) # Save object names... useful for debugging if True: with open(F'{path.stat}/obj_names.json', 'w') as fp: json.dump(env.scene.cur_names, fp) # Update `cfg` elements from `env`. obs_space = map_struct( env.observation_space, lambda src, _: src.shape, base_cls=spaces.Box, dict_cls=(Mapping, spaces.Dict) ) if cfg.state_net_blocklist is not None: for key in cfg.state_net_blocklist: obs_space.pop(key, None) dim_act = ( env.action_space.shape[0] if isinstance( env.action_space, spaces.Box) else env.action_space.n) cfg = replace(cfg, net=replace(cfg.net, obs_space=obs_space, act_space=dim_act, )) return inner_main(cfg, env, path) if __name__ == '__main__': main() 逐行分析我这里的代码,不要分析别的
08-27
import os.path as osp from collections import OrderedDict import math import copy import torch import torch.nn as nn from torch.nn import functional as F from torch.cuda.amp import GradScaler, autocast from dassl.engine import TRAINER_REGISTRY, TrainerX from dassl.metrics import compute_accuracy from dassl.utils import load_pretrained_weights, load_checkpoint from dassl.optim import build_optimizer, build_lr_scheduler from clip import clip from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer try: # Prefer relative import to avoid PYTHONPATH issues from .capid_modules import ( CrossAttentionCoupler, DiffusionPromptGenerator, InteractiveGate, save_debug_image, ) except Exception: # fallback to absolute if needed from trainers.capid_modules import ( CrossAttentionCoupler, DiffusionPromptGenerator, InteractiveGate, save_debug_image, ) _tokenizer = _Tokenizer() class CrossAttentivePromptBridge(nn.Module): """Bridge deep text/vision prompts with bi-directional cross-attention. - Projects text (512) and vision (768) prompts to a common dim (default 512). - Runs two multi-head attentions: text<-vision and vision<-text. - Residual fuse with small alpha, then project back to original dims. - Expects lists of tensors per depth: [ (n_ctx, 512) ], [ (n_ctx, 768) ]. """ def __init__(self, dim_text: int = 512, dim_vision: int = 768, dim_common: int = 512, heads: int = 4, dropout: float = 0.0, alpha: float = 0.1): super().__init__() self.txt_to_common = nn.Linear(dim_text, dim_common, bias=False) self.vis_to_common = nn.Linear(dim_vision, dim_common, bias=False) self.common_to_txt = nn.Linear(dim_common, dim_text, bias=False) self.common_to_vis = nn.Linear(dim_common, dim_vision, bias=False) self.attn_tq = nn.MultiheadAttention(dim_common, heads, dropout=dropout, batch_first=True) self.attn_vq = nn.MultiheadAttention(dim_common, heads, dropout=dropout, batch_first=True) self.alpha = alpha def forward(self, deep_txt_list, deep_vis_list, alpha: float = None): if alpha is None: alpha = self.alpha alpha = float(max(0.0, min(1.0, alpha))) # Stack to (L, n_ctx, C) txt = torch.stack(deep_txt_list, dim=0) # (L, n_ctx, 512) vis = torch.stack(deep_vis_list, dim=0) # (L, n_ctx, 768) L, n_ctx_t, dt = txt.shape L2, n_ctx_v, dv = vis.shape assert L == L2 and n_ctx_t == n_ctx_v, "Text/Vision deep prompts must align in depth and n_ctx" S = L * n_ctx_t txt_seq = txt.reshape(S, dt) vis_seq = vis.reshape(S, dv) t = self.txt_to_common(txt_seq).unsqueeze(0) # (1, S, dc) v = self.vis_to_common(vis_seq).unsqueeze(0) # (1, S, dc) # bi-directional cross-attention t2, _ = self.attn_tq(t, v, v) v2, _ = self.attn_vq(v, t, t) # stabilize and residual blend t2 = F.layer_norm(t2, t2.shape[-1:]) v2 = F.layer_norm(v2, v2.shape[-1:]) t_out = (1.0 - alpha) * t + alpha * t2 v_out = (1.0 - alpha) * v + alpha * v2 # back to original dims and list form t_out = self.common_to_txt(t_out.squeeze(0)).reshape(L, n_ctx_t, dt) v_out = self.common_to_vis(v_out.squeeze(0)).reshape(L, n_ctx_t, dv) out_txt_list = [t_out[i] for i in range(L)] out_vis_list = [v_out[i] for i in range(L)] return out_txt_list, out_vis_list def load_clip_to_cpu(cfg): backbone_name = cfg.MODEL.BACKBONE.NAME url = clip._MODELS[backbone_name] model_path = clip._download(url) try: # loading JIT archive model = torch.jit.load(model_path, map_location="cpu").eval() state_dict = None except RuntimeError: state_dict = torch.load(model_path, map_location="cpu") design_details = {"trainer": 'MaPLe', "vision_depth": 0, "language_depth": 0, "vision_ctx": 0, "language_ctx": 0, "maple_length": cfg.TRAINER.MAPLE.N_CTX} model = clip.build_model(state_dict or model.state_dict(), design_details) return model class TextEncoder(nn.Module): def __init__(self, clip_model): super().__init__() self.transformer = clip_model.transformer self.positional_embedding = clip_model.positional_embedding self.ln_final = clip_model.ln_final self.text_projection = clip_model.text_projection self.dtype = clip_model.dtype def forward(self, prompts, tokenized_prompts, compound_prompts_deeper_text): x = prompts + self.positional_embedding.type(self.dtype) x = x.permute(1, 0, 2) # NLD -> LND # Pass as the list, as nn.sequential cannot process multiple arguments in the forward pass combined = [x, compound_prompts_deeper_text, 0] # third argument is the counter which denotes depth of prompt outputs = self.transformer(combined) x = outputs[0] # extract the x back from here x = x.permute(1, 0, 2) # LND -> NLD x = self.ln_final(x).type(self.dtype) # x.shape = [batch_size, n_ctx, transformer.width] # take features from the eot embedding (eot_token is the highest number in each sequence) x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection return x class MultiModalPromptLearner(nn.Module): def __init__(self, cfg, classnames, clip_model): super().__init__() n_cls = len(classnames) n_ctx = cfg.TRAINER.MAPLE.N_CTX ctx_init = cfg.TRAINER.MAPLE.CTX_INIT dtype = clip_model.dtype ctx_dim = clip_model.ln_final.weight.shape[0] clip_imsize = clip_model.visual.input_resolution cfg_imsize = cfg.INPUT.SIZE[0] # Default is 1, which is compound shallow prompting assert cfg.TRAINER.MAPLE.PROMPT_DEPTH >= 1, "For MaPLe, PROMPT_DEPTH should be >= 1" self.compound_prompts_depth = cfg.TRAINER.MAPLE.PROMPT_DEPTH # max=12, but will create 11 such shared prompts assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})" if ctx_init and (n_ctx) <= 4: # use given words to initialize context vectors ctx_init = ctx_init.replace("_", " ") n_ctx = n_ctx prompt = clip.tokenize(ctx_init) with torch.no_grad(): embedding = clip_model.token_embedding(prompt).type(dtype) ctx_vectors = embedding[0, 1: 1 + n_ctx, :] prompt_prefix = ctx_init else: # random initialization ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype) nn.init.normal_(ctx_vectors, std=0.02) prompt_prefix = " ".join(["X"] * n_ctx) print('MaPLe design: Multi-modal Prompt Learning') print(f'Initial context: "{prompt_prefix}"') print(f"Number of MaPLe context words (tokens): {n_ctx}") # These below, related to the shallow prompts # Linear layer so that the tokens will project to 512 and will be initialized from 768 self.proj = nn.Linear(ctx_dim, 768) self.proj.half() self.ctx = nn.Parameter(ctx_vectors) # These below parameters related to the shared prompts # Define the compound prompts for the deeper layers # Minimum can be 1, which defaults to shallow MaPLe # compound prompts self.compound_prompts_text = nn.ParameterList([nn.Parameter(torch.empty(n_ctx, 512)) for _ in range(self.compound_prompts_depth - 1)]) for single_para in self.compound_prompts_text: nn.init.normal_(single_para, std=0.02) # Also make corresponding projection layers, for each prompt single_layer = nn.Linear(ctx_dim, 768) self.compound_prompt_projections = _get_clones(single_layer, self.compound_prompts_depth - 1) classnames = [name.replace("_", " ") for name in classnames] name_lens = [len(_tokenizer.encode(name)) for name in classnames] prompts = [prompt_prefix + " " + name + "." for name in classnames] tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]) # (n_cls, n_tkn) with torch.no_grad(): embedding = clip_model.token_embedding(tokenized_prompts).type(dtype) # These token vectors will be saved when in save_model(), # but they should be ignored in load_model() as we want to use # those computed using the current class names self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS self.register_buffer("token_suffix", embedding[:, 1 + n_ctx:, :]) # CLS, EOS self.n_cls = n_cls self.n_ctx = n_ctx self.tokenized_prompts = tokenized_prompts # torch.Tensor self.name_lens = name_lens # --- Optional CAPID modules integrated at the prompt learner level --- self._clip_model_ref = clip_model capid = getattr(cfg.TRAINER, "CAPID", None) self.capid_enabled = bool(getattr(capid, "ENABLE", False)) if capid is not None else False if self.capid_enabled: self.ca_enabled = bool(getattr(capid, "CA_ENABLE", False)) self.diff_enabled = bool(getattr(capid, "DIFF_ENABLE", False)) self.gate_enabled = bool(getattr(capid, "GATE_ENABLE", False)) # Conservative safety knobs (with robust defaults) self.ca_alpha = float(getattr(capid, "CA_ALPHA", 0.1)) # residual blend factor for CA self.diff_scale = float(getattr(capid, "DIFF_SCALE", 0.05)) # residual scale for DIFF self.gate_max = float(getattr(capid, "GATE_MAX", 0.5)) # clamp gate strength upper bound # CA mode: 'bridge' (default) couples deep prompts in CustomCLIP; 'shallow' applies here self.ca_mode = str(getattr(capid, "CA_MODE", "bridge")).lower() self.ca_shallow = bool(getattr(capid, "CA_SHALLOW", False)) if self.ca_enabled: self.ca = CrossAttentionCoupler( dim_text=512, dim_vision=768, depth=int(getattr(capid, "CA_DEPTH", 1)), heads=int(getattr(capid, "CA_HEADS", 4)), dropout=float(getattr(capid, "CA_DROPOUT", 0.0)), ) if self.diff_enabled: self.diff_text = DiffusionPromptGenerator(channels=512, cond_channels=512) self.diff_vision = DiffusionPromptGenerator(channels=768, cond_channels=768) self.diff_steps = int(getattr(capid, "DIFF_STEPS", 2)) self.diff_noise = float(getattr(capid, "DIFF_NOISE_STD", 0.1)) self.cfg_scale = float(getattr(capid, "CFG_SCALE", 1.0)) if self.gate_enabled: self.gate = InteractiveGate( alpha=float(getattr(capid, "GATE_ALPHA", 1.0)), beta=float(getattr(capid, "GATE_BETA", 0.0)), ) # Instruction text for gating (optional) self.instruction_text = str(getattr(capid, "INSTRUCTION", "")) # Debugging self.debug_save = bool(getattr(capid, "DEBUG_SAVE", False)) self.debug_freq = int(getattr(capid, "DEBUG_FREQ", 200)) self.debug_dir = str(getattr(capid, "DEBUG_DIR", "output/capid_debug")) self._debug_step = 0 self.capid_applied = False @torch.no_grad() def _encode_instruction(self, text: str): if text is None or len(text.strip()) == 0: return None try: tokens = clip.tokenize([text]) # (1, L) cm = self._clip_model_ref emb = cm.token_embedding(tokens.to(cm.text_projection.device).type(cm.dtype)) x = emb + cm.positional_embedding.type(cm.dtype) x = x.permute(1, 0, 2) x = cm.transformer(x) x = x.permute(1, 0, 2) x = cm.ln_final(x).type(cm.dtype) feat = x[torch.arange(x.shape[0]), tokens.argmax(dim=-1).to(x.device)] @ cm.text_projection return feat except Exception: return None def construct_prompts(self, ctx, prefix, suffix, label=None): # dim0 is either batch_size (during training) or n_cls (during testing) # ctx: context tokens, with shape of (dim0, n_ctx, ctx_dim) # prefix: the sos token, with shape of (n_cls, 1, ctx_dim) # suffix: remaining tokens, with shape of (n_cls, *, ctx_dim) if label is not None: prefix = prefix[label] suffix = suffix[label] prompts = torch.cat( [ prefix, # (dim0, 1, dim) ctx, # (dim0, n_ctx, dim) suffix, # (dim0, *, dim) ], dim=1, ) return prompts def forward(self): ctx = self.ctx if ctx.dim() == 2: ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1) prefix = self.token_prefix suffix = self.token_suffix prompts = self.construct_prompts(ctx, prefix, suffix) # Before returning, need to transform # prompts to 768 for the visual side visual_deep_prompts = [] for index, layer in enumerate(self.compound_prompt_projections): visual_deep_prompts.append(layer(self.compound_prompts_text[index])) # CAPID optional coupling/generation inside prompt learner # Align projection dtype with context dtype to avoid Half/Float mismatch after loading checkpoints if hasattr(self.proj, "weight") and self.proj.weight.dtype != self.ctx.dtype: self.proj.to(self.ctx.dtype) shared_ctx = self.proj(self.ctx) # (n_ctx, 768) if getattr(self, "capid_enabled", False): # Expand to per-class vision tokens vis_tokens = shared_ctx.unsqueeze(0).expand(self.n_cls, -1, -1) gate_strength = 1.0 if getattr(self, "gate_enabled", False): inst_feat = self._encode_instruction(self.instruction_text) # Use a lightweight text summary by averaging prompt tokens try: txt_feat = prompts.mean(dim=1) # (n_cls, 512) except Exception: txt_feat = None g_tensor = self.gate(inst_feat, txt_feat, None) gate_strength = max(0.0, min(self.gate_max, float(g_tensor.item()))) # Safe DIFF: only apply when truly non-zero effect should_diff = ( getattr(self, "diff_enabled", False) and (getattr(self, "cfg_scale", 0.0) > 0.0) and (getattr(self, "diff_noise", 0.0) > 0.0) and (getattr(self, "diff_steps", 0) > 0) ) cond_txt_pl = prompts.mean(dim=1) # (n_cls, 512) cond_vis_pl = vis_tokens.mean(dim=1) # (n_cls, 768) delta_txt = self.diff_text.sample(prompts, cond=cond_txt_pl, steps=self.diff_steps, noise_std=self.diff_noise) delta_txt = F.layer_norm(delta_txt, delta_txt.shape[-1:]) prompts = prompts + self.diff_scale * self.cfg_scale * gate_strength * delta_txt delta_vis = self.diff_vision.sample(vis_tokens, cond=cond_vis_pl, steps=self.diff_steps, noise_std=self.diff_noise) delta_vis = F.layer_norm(delta_vis, delta_vis.shape[-1:]) vis_tokens = vis_tokens + self.diff_scale * self.cfg_scale * gate_strength * delta_vis attn_maps = None # Only apply shallow CA here when explicitly enabled if getattr(self, "ca_enabled", False) and getattr(self, "ca_shallow", False) and getattr(self, "ca_mode", "bridge") != "bridge": # Residual CA with small alpha p_in, v_in = prompts, vis_tokens p_ca, v_ca, attn_maps = self.ca(p_in, v_in) p_ca = F.layer_norm(p_ca, p_ca.shape[-1:]) v_ca = F.layer_norm(v_ca, v_ca.shape[-1:]) alpha = max(0.0, min(1.0, float(self.ca_alpha))) prompts = (1.0 - alpha) * p_in + alpha * p_ca vis_tokens = (1.0 - alpha) * v_in + alpha * v_ca shared_ctx = vis_tokens.mean(dim=0) # Debug saves if getattr(self, "debug_save", False): self._debug_step += 1 if self._debug_step % max(1, self.debug_freq) == 0: try: if attn_maps is not None and len(attn_maps) > 0: a = attn_maps[0][0] # Robust: handle 4D (B,H,Lq,Lk) and 3D (B,Lq,Lk) if a.dim() == 4: a_vis = a.mean(dim=1)[0] elif a.dim() == 3: a_vis = a[0] else: a_vis = a.flatten(1).unsqueeze(0) out_path = osp.join(self.debug_dir, f"pl_attn_layer0_{self._debug_step:06d}.png") save_debug_image(a_vis, out_path) if getattr(self, "diff_enabled", False): try: dt = (delta_txt[0].norm(dim=-1, keepdim=False)) dv = (delta_vis[0].norm(dim=-1, keepdim=False)) save_debug_image(dt.unsqueeze(0), osp.join(self.debug_dir, f"pl_delta_txt_norm_{self._debug_step:06d}.png")) save_debug_image(dv.unsqueeze(0), osp.join(self.debug_dir, f"pl_delta_vis_norm_{self._debug_step:06d}.png")) except Exception: pass except Exception: pass self.capid_applied = True # Now the other way around; return original as for visual 768 is required return prompts, shared_ctx, self.compound_prompts_text, visual_deep_prompts class CustomCLIP(nn.Module): def __init__(self, cfg, classnames, clip_model): super().__init__() self.cfg = cfg self.prompt_learner = MultiModalPromptLearner(cfg, classnames, clip_model) self.tokenized_prompts = self.prompt_learner.tokenized_prompts self.image_encoder = clip_model.visual self.text_encoder = TextEncoder(clip_model) self.logit_scale = clip_model.logit_scale self.dtype = clip_model.dtype # Keep a lightweight reference for encoding free-form instructions self._clip_model_ref = clip_model # CAPID modules (optional) capid = cfg.TRAINER.CAPID self.capid_enabled = bool(getattr(capid, "ENABLE", False)) if self.capid_enabled: self.ca_enabled = bool(getattr(capid, "CA_ENABLE", False)) self.diff_enabled = bool(getattr(capid, "DIFF_ENABLE", False)) self.gate_enabled = bool(getattr(capid, "GATE_ENABLE", False)) self.diff_loss_weight = float(getattr(capid, "DIFF_LOSS_WEIGHT", 0.1)) # Conservative safety knobs (mirror prompt learner) self.ca_alpha = float(getattr(capid, "CA_ALPHA", 0.1)) self.diff_scale = float(getattr(capid, "DIFF_SCALE", 0.05)) self.gate_max = float(getattr(capid, "GATE_MAX", 0.5)) self.ca_mode = str(getattr(capid, "CA_MODE", "bridge")).lower() if self.ca_enabled: self.ca = CrossAttentionCoupler( dim_text=512, dim_vision=768, depth=int(getattr(capid, "CA_DEPTH", 1)), heads=int(getattr(capid, "CA_HEADS", 4)), dropout=float(getattr(capid, "CA_DROPOUT", 0.0)), ) # Bridge module for deep compound prompts (text 512 <-> vision 768) if self.ca_mode == "bridge": self.ca_bridge = CrossAttentivePromptBridge( dim_text=512, dim_vision=768, dim_common=512, heads=int(getattr(capid, "CA_HEADS", 4)), dropout=float(getattr(capid, "CA_DROPOUT", 0.0)), alpha=float(getattr(capid, "CA_ALPHA", 0.1)), ) if self.diff_enabled: self.diff_text = DiffusionPromptGenerator(channels=512, cond_channels=512) self.diff_vision = DiffusionPromptGenerator(channels=768, cond_channels=768) self.diff_steps = int(getattr(capid, "DIFF_STEPS", 2)) self.diff_noise = float(getattr(capid, "DIFF_NOISE_STD", 0.1)) self.cfg_scale = float(getattr(capid, "CFG_SCALE", 1.0)) if self.gate_enabled: self.gate = InteractiveGate( alpha=float(getattr(capid, "GATE_ALPHA", 1.0)), beta=float(getattr(capid, "GATE_BETA", 0.0)), ) # Debug state self.debug_save = bool(getattr(capid, "DEBUG_SAVE", False)) self.debug_freq = int(getattr(capid, "DEBUG_FREQ", 200)) self.debug_dir = str(getattr(capid, "DEBUG_DIR", "output/capid_debug")) self._debug_step = 0 @torch.no_grad() def _encode_instruction(self, text: str): if text is None or len(text.strip()) == 0: return None # Lightweight proxy: mean-pooled token embeddings (no transformer), dtype/device-safe try: tokens = clip.tokenize([text]) # (1, L) emb = self._clip_model_ref.token_embedding(tokens.to(self.logit_scale.device).type(self.dtype)) # (1,L,512) feat = emb.mean(dim=1) # (1,512) return feat except Exception: return None def forward(self, image, label=None): tokenized_prompts = self.tokenized_prompts logit_scale = self.logit_scale.exp() prompts, shared_ctx, deep_compound_prompts_text, deep_compound_prompts_vision = self.prompt_learner() # Bridge deep prompts before encoders if enabled if getattr(self, "capid_enabled", False) and getattr(self, "ca_enabled", False) and getattr(self, "ca_mode", "bridge") == "bridge": try: deep_compound_prompts_text, deep_compound_prompts_vision = self.ca_bridge( deep_compound_prompts_text, deep_compound_prompts_vision, alpha=float(getattr(self, "ca_alpha", 0.1)), ) except Exception: # Fallback: keep original if any shape issue pass # CAPID optional pipeline if getattr(self, "capid_enabled", False) and not getattr(self.prompt_learner, "capid_applied", False): # Prepare per-class vision tokens from shared_ctx for coupling/diffusion # shared_ctx: (n_ctx, 768) -> (n_cls, n_ctx, 768) vis_tokens = shared_ctx.unsqueeze(0).expand(self.prompt_learner.n_cls, -1, -1) gate_strength = 1.0 if getattr(self, "gate_enabled", False): instruction = getattr(self.cfg.TRAINER.CAPID, "INSTRUCTION", "") inst_feat = self._encode_instruction(instruction) # Use text-only gating by default to avoid extra compute; img_feat kept None # Compute a quick baseline text feature from current prompts (detached) try: txt_feat_base = self.text_encoder(prompts, tokenized_prompts, deep_compound_prompts_text).detach() except Exception: txt_feat_base = None g_tensor = self.gate(inst_feat, txt_feat_base, None) gate_strength = max(0.0, min(self.gate_max, float(g_tensor.item()))) # Safe DIFF should_diff = ( getattr(self, "diff_enabled", False) and (getattr(self, "cfg_scale", 0.0) > 0.0) and (getattr(self, "diff_noise", 0.0) > 0.0) and (getattr(self, "diff_steps", 0) > 0) ) cond_txt = prompts.mean(dim=1) # (n_cls, 512) cond_vis = vis_tokens.mean(dim=1) # (n_cls, 768) delta_txt = self.diff_text.sample(prompts, cond=cond_txt, steps=self.diff_steps, noise_std=self.diff_noise) delta_txt = F.layer_norm(delta_txt, delta_txt.shape[-1:]) prompts = prompts + self.diff_scale * self.cfg_scale * gate_strength * delta_txt delta_vis = self.diff_vision.sample(vis_tokens, cond=cond_vis, steps=self.diff_steps, noise_std=self.diff_noise) delta_vis = F.layer_norm(delta_vis, delta_vis.shape[-1:]) vis_tokens = vis_tokens + self.diff_scale * self.cfg_scale * gate_strength * delta_vis attn_maps = None # If using bridge mode, skip shallow CA here if getattr(self, "ca_enabled", False) and getattr(self, "ca_mode", "bridge") != "bridge": p_in, v_in = prompts, vis_tokens p_ca, v_ca, attn_maps = self.ca(p_in, v_in) p_ca = F.layer_norm(p_ca, p_ca.shape[-1:]) v_ca = F.layer_norm(v_ca, v_ca.shape[-1:]) alpha = max(0.0, min(1.0, float(self.ca_alpha))) prompts = (1.0 - alpha) * p_in + alpha * p_ca vis_tokens = (1.0 - alpha) * v_in + alpha * v_ca # Reduce back to shared_ctx shape expected by vision encoder shared_ctx = vis_tokens.mean(dim=0) # (n_ctx, 768) # Debug saves (very lightweight) if getattr(self, "debug_save", False): self._debug_step += 1 if self._debug_step % max(1, self.debug_freq) == 0: try: if attn_maps is not None and len(attn_maps) > 0: a = attn_maps[0][0] if a.dim() == 4: a_vis = a.mean(dim=1)[0] elif a.dim() == 3: a_vis = a[0] else: a_vis = a.flatten(1).unsqueeze(0) out_path = osp.join(self.debug_dir, f"attn_layer0_{self._debug_step:06d}.png") save_debug_image(a_vis, out_path) if getattr(self, "diff_enabled", False): # Save magnitude heatmaps for first class' deltas try: dt = (delta_txt[0].norm(dim=-1, keepdim=False)) # (L_text,) dv = (delta_vis[0].norm(dim=-1, keepdim=False)) # (L_vis,) # Expand to 2D for visualization save_debug_image(dt.unsqueeze(0), osp.join(self.debug_dir, f"delta_txt_norm_{self._debug_step:06d}.png")) save_debug_image(dv.unsqueeze(0), osp.join(self.debug_dir, f"delta_vis_norm_{self._debug_step:06d}.png")) except Exception: pass except Exception: pass text_features = self.text_encoder(prompts, tokenized_prompts, deep_compound_prompts_text) image_features = self.image_encoder(image.type(self.dtype), shared_ctx, deep_compound_prompts_vision) image_features = image_features / image_features.norm(dim=-1, keepdim=True) text_features = text_features / text_features.norm(dim=-1, keepdim=True) logits = logit_scale * image_features @ text_features.t() if self.prompt_learner.training: loss = F.cross_entropy(logits, label) if getattr(self, "capid_enabled", False) and getattr(self, "diff_enabled", False) and (getattr(self, "diff_loss_weight", 0.0) > 0): n_cls = self.prompt_learner.n_cls vis_tokens = shared_ctx.unsqueeze(0).expand(n_cls, -1, -1) # (n_cls, n_ctx, 768) # 条件:文本用 prompts 的 token 平均;视觉用 shared_ctx 的 token 平均 cond_txt = prompts.mean(dim=1) # (n_cls, 512) cond_vis = shared_ctx.mean(dim=0, keepdim=True).expand(n_cls, -1) # (n_cls, 768) try: l_txt = self.diff_text.diffusion_loss(prompts, cond_txt, noise_std=float(getattr(self, "diff_noise", 0.1))) except Exception: l_txt = torch.tensor(0.0, device=loss.device, dtype=loss.dtype) try: l_vis = self.diff_vision.diffusion_loss(vis_tokens, cond_vis, noise_std=float(getattr(self, "diff_noise", 0.1))) except Exception: l_vis = torch.tensor(0.0, device=loss.device, dtype=loss.dtype) loss = loss + self.diff_loss_weight * (l_txt + l_vis) return loss return logits def _get_clones(module, N): return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) @TRAINER_REGISTRY.register() class MaPLe(TrainerX): def check_cfg(self, cfg): assert cfg.TRAINER.MAPLE.PREC in ["fp16", "fp32", "amp"] def build_model(self): cfg = self.cfg classnames = self.dm.dataset.classnames print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") clip_model = load_clip_to_cpu(cfg) if cfg.TRAINER.MAPLE.PREC == "fp32" or cfg.TRAINER.MAPLE.PREC == "amp": # CLIP's default precision is fp16 clip_model.float() print("Building custom CLIP") self.model = CustomCLIP(cfg, classnames, clip_model) print("Turning off gradients in both the image and the text encoder") # Default: only update prompt_learner (MaPLe). If CAPID enabled, also allow # bridge/CA/DIFF/Gate small modules to learn. capid_cfg = getattr(self.cfg.TRAINER, "CAPID", None) capid_on = bool(getattr(capid_cfg, "ENABLE", False)) if capid_cfg is not None else False capid_train_only = bool(getattr(capid_cfg, "TRAIN_ONLY_CAPID", False)) if capid_cfg is not None else False # Freeze CLIP backbone under _clip_model_ref; only train open-track prompt subset + CAPID small modules for name, param in self.model.named_parameters(): # hard block CLIP backbone if "prompt_learner._clip_model_ref" in name: param.requires_grad_(False) continue if capid_on and capid_train_only: # train only CAPID modules allow = ( ("ca_bridge" in name) or (".ca." in name or name.endswith(".ca")) or ("diff_text" in name) or ("diff_vision" in name) or ("gate" in name) ) else: # open-track prompt subset + CAPID modules (+VPT) allow = ( ( name.startswith("prompt_learner.ctx") or name.startswith("prompt_learner.proj") or name.startswith("prompt_learner.compound_prompts_text.0") or name.startswith("prompt_learner.compound_prompt_projections.0") ) or (capid_on and ( ("ca_bridge" in name) or (".ca." in name or name.endswith(".ca")) or ("diff_text" in name) or ("diff_vision" in name) or ("gate" in name) )) or ("VPT" in name) ) param.requires_grad_(bool(allow)) # Double check enabled = set() for name, param in self.model.named_parameters(): if param.requires_grad: enabled.add(name) print(f"Parameters to be updated: {enabled}") if cfg.MODEL.INIT_WEIGHTS: load_pretrained_weights(self.model, cfg.MODEL.INIT_WEIGHTS) self.model.to(self.device) # NOTE: only give prompt_learner to the optimizer self.optim = build_optimizer(self.model, cfg.OPTIM) self.sched = build_lr_scheduler(self.optim, cfg.OPTIM) self.register_model("MultiModalPromptLearner", self.model, self.optim, self.sched) self.scaler = GradScaler() if cfg.TRAINER.MAPLE.PREC == "amp" else None # Note that multi-gpu training could be slow because CLIP's size is # big, which slows down the copy operation in DataParallel device_count = torch.cuda.device_count() if device_count > 1: print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!") self.model = nn.DataParallel(self.model) def forward_backward(self, batch): image, label = self.parse_batch_train(batch) model = self.model optim = self.optim scaler = self.scaler prec = self.cfg.TRAINER.MAPLE.PREC if prec == "amp": with autocast(): loss = model(image, label) optim.zero_grad() scaler.scale(loss).backward() scaler.step(optim) scaler.update() else: loss = model(image, label) optim.zero_grad() loss.backward() optim.step() loss_summary = {"loss": loss.item()} if (self.batch_idx + 1) == self.num_batches: self.update_lr() return loss_summary def parse_batch_train(self, batch): input = batch["img"] label = batch["label"] input = input.to(self.device) label = label.to(self.device) return input, label def load_model(self, directory, epoch=None): if not directory: print("Note that load_model() is skipped as no pretrained model is given") return names = self.get_model_names() # By default, the best model is loaded model_file = "model-best.pth.tar" if epoch is not None: model_file = "model.pth.tar-" + str(epoch) for name in names: model_path = osp.join(directory, name, model_file) if not osp.exists(model_path): raise FileNotFoundError('Model not found at "{}"'.format(model_path)) checkpoint = load_checkpoint(model_path) state_dict = checkpoint["state_dict"] epoch = checkpoint["epoch"] # Ignore fixed token vectors if "prompt_learner.token_prefix" in state_dict: del state_dict["prompt_learner.token_prefix"] if "prompt_learner.token_suffix" in state_dict: del state_dict["prompt_learner.token_suffix"] print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch)) # set strict=False self._models[name].load_state_dict(state_dict, strict=False) 要求: 在类 MultiModalPromptLearner.forward() 中,找到 if should_diff and (not self.training): 下面的两行采样: delta_txt = self.diff_text.sample(prompts, cond=None, steps=self.diff_steps, noise_std=self.diff_noise) delta_vis = self.diff_vision.sample(vis_tokens, cond=None, steps=self.diff_steps, noise_std=self.diff_noise) 将它们改为: cond_txt_pl = prompts.mean(dim=1) # (n_cls, 512) cond_vis_pl = vis_tokens.mean(dim=1) # (n_cls, 768) delta_txt = self.diff_text.sample(prompts, cond=cond_txt_pl, steps=self.diff_steps, noise_std=self.diff_noise) delta_vis = self.diff_vision.sample(vis_tokens, cond=cond_vis_pl, steps=self.diff_steps, noise_std=self.diff_noise) 2) 同一文件,在类 CustomCLIP.forward() 中,另一处 if should_diff and (not self.training): 下面的两行采样: delta_txt = self.diff_text.sample(prompts, cond=None, steps=self.diff_steps, noise_std=self.diff_noise) delta_vis = self.diff_vision.sample(vis_tokens, cond=None, steps=self.diff_steps, noise_std=self.diff_noise) 将它们改为: cond_txt = prompts.mean(dim=1) # (n_cls, 512) cond_vis = vis_tokens.mean(dim=1) # (n_cls, 768) delta_txt = self.diff_text.sample(prompts, cond=cond_txt, steps=self.diff_steps, noise_std=self.diff_noise) delta_vis = self.diff_vision.sample(vis_tokens, cond=cond_vis, steps=self.diff_steps, noise_std=self.diff_noise) 改完后发我修改后的完整代码
最新发布
11-13
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值