python __call__的作用,是可以将对象作为方法使用的关键 分析nn.Module源码

代码举例

import torch.nn as nn

class LSTMClassifier(nn.Module):
    """
    This is the simple RNN model we will be using to perform Sentiment Analysis.
    """

    def __init__(self, embedding_dim, hidden_dim, vocab_size):
        """
        Initialize the model by settingg up the various layers.
        """
        super(LSTMClassifier, self).__init__()

        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim)
        self.dense = nn.Linear(in_features=hidden_dim, out_features=1)
        self.sig = nn.Sigmoid()
        
        self.word_dict = None

    def forward(self, x):
        """
        Perform a forward pass of our model on some input.
        """
        x = x.t()
        lengths = x[0,:]
        reviews = x[1:,:]
        embeds = self.embedding(reviews)
        lstm_out, _ = self.lstm(embeds)
        out = self.dense(lstm_out)
        out = out[lengths - 1, range(len(lengths))]
        return self.sig(out.squeeze())

 

 

def train(model, train_loader, epochs, optimizer, loss_fn, device):
    for epoch in range(1, epochs + 1):
        model.train()
        total_loss = 0
        for batch in train_loader:         
            batch_X, batch_y = batch
            print("batch_X.shape=",batch_X.shape)
            print("batch_y.shape=",batch_y.shape)
            batch_X = batch_X.to(device)
            batch_y = batch_y.to(device)
           
            # TODO: Complete this train method to train the model provided.
            optimizer.zero_grad()

            # get predictions from model
            y_pred = model(batch_X)

            # perform backprop
            loss = criterion(y_pred, batch_y)
            loss.backward()
            optimizer.step()
            total_loss += loss.data.item()
        print("Epoch: {}, BCELoss: {}".format(epoch, total_loss / len(train_loader)))

 

该类的使用LSTMClassifier对象model 实际调用的是forward (),原因是nn.Module()中的__call__方法中调用了forward方法

#第三方调用model代码场景如下

import torch.optim as optim

from train.model import LSTMClassifier

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LSTMClassifier(32, 100, 5000).to(device)
optimizer = optim.Adam(model.parameters(),lr=0.001)
loss_fn = torch.nn.BCELoss()

train(model, train_sample_dl, 5, optimizer, loss_fn, device)

 

 

 

 

#上面类的父类nn.Module中的__call__方法中调用了forward方法,如下

def __call__(self, *input, **kwargs):
    for hook in self._forward_pre_hooks.values():
        result = hook(self, input)
        if result is not None:
            if not isinstance(result, tuple):
                result = (result,)
            input = result
    if torch._C._get_tracing_state():
        result = self._slow_forward(*input, **kwargs)
    else:
        result = self.forward(*input, **kwargs)
    for hook in self._forward_hooks.values():
        hook_result = hook(self, input, result)
        if hook_result is not None:
            result = hook_result
    if len(self._backward_hooks) > 0:
        var = result
        while not isinstance(var, torch.Tensor):
            if isinstance(var, dict):
                var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
            else:
                var = var[0]
        grad_fn = var.grad_fn
        if grad_fn is not None:
            for hook in self._backward_hooks.values():
                wrapper = functools.partial(hook, self)
                functools.update_wrapper(wrapper, hook)
                grad_fn.register_hook(wrapper)
    return result

 

 

理解该现象的关键

__call__

_forward_pre_hooks

参考

https://www.cnblogs.com/SBJBA/p/11355412.html

nn.Module方法逐个分析

https://blog.youkuaiyun.com/lishuiwang/article/details/104505675/

https://www.jb51.net/article/184033.htm

nn.Module源码

https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py

FAILED [ 50%] ultralytics\nn\tasks.py:1139 ([doctest] ultralytics.nn.tasks.temporary_modules) 1144 where you've moved a module from one location to another, but you still want to support the old import 1145 paths for backwards compatibility. 1146 1147 Args: 1148 modules (dict, optional): A dictionary mapping old module paths to new module paths. 1149 attributes (dict, optional): A dictionary mapping old module attributes to new module attributes. 1150 1151 Examples: 1152 >>> from ultralytics.nn.tasks import DetectionModel 1153 >>> with temporary_modules({"old.module": "new.module"}, {"old.module.attribute": "new.module.attribute"}): UNEXPECTED EXCEPTION: ModuleNotFoundError("No module named 'old'") Traceback (most recent call last): File "C:\Users\Taodawang\.conda\envs\yolov11\Lib\doctest.py", line 1355, in __run exec(compile(example.source, filename, "single", File "<doctest ultralytics.nn.tasks.temporary_modules[1]>", line 1, in <module> File "C:\Users\Taodawang\.conda\envs\yolov11\Lib\contextlib.py", line 137, in __enter__ return next(self.gen) ^^^^^^^^^^^^^^ File "C:\D_pan\pystudy\yolov11\ultralytics-v8.3.133\ultralytics\nn\tasks.py", line 1174, in temporary_modules setattr(import_module(old_module), old_attr, getattr(import_module(new_module), new_attr)) ^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\Taodawang\.conda\envs\yolov11\Lib\importlib\__init__.py", line 126, in import_module return _bootstrap._gcd_import(name[level:], package, level) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "<frozen importlib._bootstrap>", line 1204, in _gcd_import File "<frozen importlib._bootstrap>", line 1176, in _find_and_load File "<frozen importlib._bootstrap>", line 1126, in _find_and_load_unlocked File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed File "<frozen importlib._bootstrap>", line 1204, in _gcd_import File "<frozen importlib._bootstrap>", line 1176, in _find_and_load File "<frozen importlib._bootstrap>", line 1140, in _find_and_load_unlocked ModuleNotFoundError: No module named 'old'
最新发布
06-06
### 问题分析 当用户在使用 `ultralytics.nn.tasks.temporary_modules` 时出现 `ModuleNotFoundError: No module named 'old'` 的错误,这通常表明 Python 解释器无法找到名为 `old` 的模块。以下是可能的原因和解决方案: 1. **模块名称错误**:可能是代码中引用了不存在的模块 `old`,需要检查代码逻辑是否正确。 2. **依赖未安装**:如果 `old` 是某个第三方库的一部分,则需要确保该库已正确安装。 3. **路径问题**:Python 解释器可能未将包含 `old` 模块的目录添加到其搜索路径中。 --- ### 解决方案 #### 1. 检查代码中的模块引用 确认代码中是否存在对 `old` 模块的直接或间接引用。如果 `old` 并非预期的模块名称,则需要修正代码以引用正确的模块[^1]。 #### 2. 安装缺失的依赖 如果 `old` 是某个特定库的一部分,可以尝试通过以下命令安装该库: ```bash pip install <库名> ``` 例如,如果 `old` 属于某个已知库(如 `old-module`),则运行: ```bash pip install old-module ``` #### 3. 检查 Python 路径 确保包含 `old` 模块的目录已添加到 Python 的搜索路径中。可以通过以下代码检查当前路径: ```python import sys print(sys.path) ``` 如果目标模块所在的目录未列出,可以通过以下方式临时添加: ```python import sys sys.path.append('/path/to/module') ``` #### 4. 验证 ultralytics 的安装 确保 `ultralytics` 库已正确安装,并且版本与文档一致。可以运行以下命令重新安装: ```bash pip install ultralytics --upgrade ``` #### 5. 检查 ultralytics 的源码 如果以上方法无效,可以检查 `ultralytics.nn.tasks.temporary_modules` 的源码,确认是否存在对 `old` 模块的引用。如果确实存在但未定义,则可能是代码中的 bug,需联系开发者或提交 issue[^2]。 --- ### 示例代码 以下是一个简单的调试脚本,用于定位问题: ```python try: import ultralytics.nn.tasks.temporary_modules except ModuleNotFoundError as e: print(f"Error: {e}") import sys print("Current Python Path:") print(sys.path) ``` --- ###
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值