关于model must be a LightingModule or torch._dynamo.OptimizedModule got TemporalFusionTransformer问题

之前也写过该问日的解决方法,但是学长在调试代码的时候问我的问题让我决定完善这个问题的答复。

检查自己的运行文件中import的写法

我的运行文件名为train.py文件,如下显示为我的import写法:

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger
import torch
import wandb
import yaml
import os
import argparse
import torch.nn as nn
import random
from config import Cfg
from pytorch_forecasting import TemporalFusionTransformer, TimeSeriesDataSet
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting.data.encoders import EncoderNormalizer, TorchNormalizer
from pytorch_forecasting.metrics import QuantileLoss, SMAPE, MAE, RMSE, MAPE, NormalDistributionLoss
from pytorch_lightning.callbacks import ModelCheckpoint

检查base_model.py文件中的import写法 

如下是base_model.py文件中的import写法

form lightning.pytorch import LightningModule,Trainer

在base_model.py文件中我们只需要关注这一句就可以了。

如何修改 

两者区别可以看出:我们自己的运行文件中使用的是import pytorch_lightning as pl导入pytorch lightning库,而base_model.py文件中使用的是form lightning.pytorch import LightningModule,Trainer 导入pytorch lightning库,所以将自己的运行文件中所有pytorch_lightning替换为lightning.pytorch,代码也是可以运行的。

修改base_model.py文件中的import写法 

以我的代码为例,我们也可以在base_model.py文件中将form lightning.pytorch import LightningModule,Trainer修改为form pytorch_lightning import LightningModule,Trainer,但是在代码后续运行时会在

trainer.test(
        model=best_tft,
        dataloaders=test_dataloader,
    )

这一步继续报错显示model must be a LightingModule or torch._dynamo.OptimizedModule got TemporalFusionTransformer,这时候为了得到数据集的测试结果我通常会将这一部分与训练部分分开为两个py文件进行运行,但是这样需要我们在切换两个文件的同时也修改base_model.py文件中的import写法,过于繁琐。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值