前言
我们在nnUNet框架下设计自己的网络结构时,通常需要继承nnUNet的若干个类,在此基础上进行修改。同时,有时候我们为了loss function以及网络结构的设计,可能会让网络有多个不同的输入,代表同一主旨的某个不同的特征。言归正传,本文讨论我们怎么继承nnUNet,使得可以在一次训练中导入两个输入和两个输出。
对nnUNet的nnUNetTrainer的优化
虽然我们是继承nnUNet的nnUNetTrainer,但是还是有需要优化nnUNetTrainer的地方
比如107-109行代码
self.my_init_kwargs = {}
for k in inspect.signature(self.__init__).parameters.keys():
self.my_init_kwargs[k] = locals()[k]
需要改为
self.my_init_kwargs = {}
for k in inspect.signature(nnUNetTrainer.__init__).parameters.keys():
self.my_init_kwargs[k] = locals()[k]
这是因为 inspect.signature() 是基于当前类的代码结构进行检查的,而不是基于运行时的实际参数传递。
inspect.signature() 查看的是类的静态定义,它会返回当前类 nnUNetTrainerToothWNet 的 init 方法签名,而不是父类 nnUNetTrainer 的签名,所以需要将nnUNetTrainer中这几行中的self.__init__改为nnUNetTrainer.init。
未完待续