def __apply()方法详解

文章详细解析了Yolov8中_classBaseModel中的_apply方法,该方法用于将非参数和已注册缓冲区的模型张量应用to(),cpu(),cuda(),half()等操作,同时涉及模块转移和autograd处理。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1.class BaseModel(nn.Module)类中的_apply方法:

def _apply(self, fn):
        # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
        self = super()._apply(fn)
        m = self.model[-1]  # Detect()
        if isinstance(m, (Detect, Segment)):
            m.stride = fn(m.stride)
            m.grid = list(map(fn, m.grid))
            if isinstance(m.anchor_grid, list):
                m.anchor_grid = list(map(fn, m.anchor_grid))
        return self

 此处是yolo v8中的代码:个人认为它的含义是:将 to()、cpu()、cuda()、half() 应用到不是参数或已注册缓冲区的模型张量上。

_apply(fn)

_apply(fn)方法会递归的应用于模块的每一个子模块及其自身。_apply() 是专门针对 parameter 和 buffer 而实现的一个“仅供内部使用”的接口,apply 函数是“公有”接口 (Python 对类的“公有”和“私有”区别并不是很严格,一般通过单前导下划线来区分)。apply 实际上可以通过修改 fn 来实现 _apply 能实现的功能

用法一:将模块转移到 CPU/ GPU上时,会调用_apply()方法,比如在执行net.cuda()时,会调用.

 	    def cuda(self: T, device: Optional[Union[int, device]] = None) -> T:
 	        r"""Moves all model parameters and buffers to the GPU.
  	
  	        This also makes associated parameters and buffers different objects. So
  	        it should be called before constructing optimizer if the module will
  	        live on GPU while being optimized.
  	
 	        .. note::
  	            This method modifies the module in-place.
 	
  	        Args:
  	            device (int, optional): if specified, all parameters will be
  	                copied to that device
 	
  	        Returns:
  	            Module: self
  	        """
  	        return self._apply(lambda t: t.cuda(device))

在_apply(fn)内部会执行3步:先对self.children() 进行递归的调用;使用fn对 self._parameters 中的参数及其 gradient 进行处理;使用fn对 self._buffers 中的 buffer 进行处理。

def _apply(self, fn):
    for module in self.children():
        module._apply(fn)
  	
    def compute_should_use_set_data(tensor, tensor_applied):
        if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
  	     # If the new tensor has compatible tensor type as the existing tensor,
  	     # the current behavior is to change the tensor in-place using `.data =`,
  	     # and the future behavior is to overwrite the existing tensor. However,
  	     # changing the current behavior is a BC-breaking change, and we want it
 	     # to happen in future releases. So for now we introduce the
  	     # `torch.__future__.get_overwrite_module_params_on_conversion()`
  	     # global flag to let the user control whether they want the future
  	     # behavior of overwriting the existing tensor or not.
  	        return not torch.__future__.get_overwrite_module_params_on_conversion()
  	    else:
  	        return False
  	
  	 for key, param in self._parameters.items():
  	     if param is not None:
  	       # Tensors stored in modules are graph leaves, and we don't want to
  	       # track autograd history of `param_applied`, so we have to use
  	       # `with torch.no_grad():`
  	          with torch.no_grad():
  	              param_applied = fn(param)
  	          should_use_set_data = compute_should_use_set_data(param, param_applied)
  	          if should_use_set_data:
  	              param.data = param_applied
  	          else:
  	              assert isinstance(param, Parameter)
  	              assert param.is_leaf
  	              self._parameters[key] = Parameter(param_applied, param.requires_grad)
 	
  	          if param.grad is not None:
  	              with torch.no_grad():
  	              grad_applied = fn(param.grad)
  	              should_use_set_data = compute_should_use_set_data(param.grad, grad_applied)
  	              if should_use_set_data:
  	                  param.grad.data = grad_applied
  	              else:
  	                  assert param.grad.is_leaf
  	                 self._parameters[key].gradgrad_applied.requires_grad_
(param.grad.requires_grad)
  	    
    for key, buf in self._buffers.items():
  	    if buf is not None:
  	        self._buffers[key] = fn(buf)
    return self

这个地方我查了好多资料,都说的是apply()方法,自己还是有一些疑惑,知道的大佬欢迎解答。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

。七十二。

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值