问题描述
使用darts
做时间序列预测时,使用optuna
做超参数优化,调用study.optimize(objective, timeout=7200)
,报错:ValueError: Expected a parent
原因
新版本的pytorch_lightning
改名为lightning
,在函数objective
中采用了pytorch_lightning
的EarlyStopping
的回调函数,同时采用了from optuna.integration import PyTorchLightningPruningCallback
的PyTorchLightningPruningCallback
回调函数,在optuna
的3.5版本中采用了lightning.pytorch
来代替pytorch_lightning
,因此同时存在pytorch_lightning
和lightning.pytorch
两个新老版本引发了冲突。
解决方式
- 定义OptunaPruning类替代PyTorchLightningPruningCallback,使两个版本兼容(不推荐)。
class OptunaPruning(PyTorchLightningPruningCall