运行starganV2是遇到三个报错以及解决方法

错误一

CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets.ckpt'), data_parallel=True, **self.nets)
报错信息为:FileNotFoundError: [WinError 3] 系统找不到指定的路径。: '{:'

报错代码

self.ckptios = [
                CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets.ckpt'), data_parallel=True, **self.nets),
                CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_nets_ema.ckpt'), data_parallel=True, **self.nets_ema),
                CheckpointIO(ospj(args.checkpoint_dir, '{:06d}_optims.ckpt'), **self.optims)]

查看了源代码,ospj正常输入应该为args.checkpoint_dir/{:06d}_nets.ckpt,但是实际输出为{:06d}_nets.ckpt。导致os.makedirs(os.path.dirname(fname_template), exist_ok=True)运行出错,不能正常创建指定文件夹。

class CheckpointIO(object):
    def __init__(self, fname_template, data_parallel=False, **kwargs):
        os.makedirs(os.path.dirname(fname_template), exist_ok=True)
        self.fname_template = fname_template
        self.module_dict = kwargs
        self.data_parallel = data_parallel

改正一

self.ckptios = [
                CheckpointIO(ospj(args.checkpoint_dir, 'nets_{:06d}.ckpt'), data_parallel=True, **self.nets),
                CheckpointIO(ospj(args.checkpoint_dir, 'nets_ema_{:06d}.ckpt'), data_parallel=True, **self.nets_ema),
                CheckpointIO(ospj(args.checkpoint_dir, 'optims_{:06d}.ckpt'), **self.optims)]

错误解决

报错二

报错信息为x, y = next(self.iter),AttributeError: 'InputFetcher' object has no attribute 'iter'

这是没有在__init__中加载self.iter

改正二

在data_loader.py中的InputFetcher类修改为

class InputFetcher:
    def __init__(self, loader, loader_ref=None, latent_dim=16, mode=''):
        self.loader = loader
        self.loader_ref = loader_ref
        self.latent_dim = latent_dim
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.mode = mode
        self.iter = iter(self.loader)  # 初始化 self.iter

    def _fetch_inputs(self):
        try:
            x, y = next(self.iter)
        except (AttributeError, StopIteration):
            self.iter = iter(self.loader)
            x, y = next(self.iter)
        return x, y

    def _fetch_refs(self):
        try:
            x, x2, y = next(self.iter_ref)
        except (AttributeError, StopIteration):
            self.iter_ref = iter(self.loader_ref)
            x, x2, y = next(self.iter_ref)
        return x, x2, y

    def __next__(self):
        x, y = self._fetch_inputs()
        if self.mode == 'train':
            x_ref, x_ref2, y_ref = self._fetch_refs()
            z_trg = torch.randn(x.size(0), self.latent_dim)
            z_trg2 = torch.randn(x.size(0), self.latent_dim)
            inputs = Munch(x_src=x, y_src=y, y_ref=y_ref,
                           x_ref=x_ref, x_ref2=x_ref2,
                           z_trg=z_trg, z_trg2=z_trg2)
        elif self.mode == 'val':
            x_ref, y_ref = self._fetch_inputs()
            inputs = Munch(x_src=x, y_src=y,
                           x_ref=x_ref, y_ref=y_ref)
        elif self.mode == 'test':
            inputs = Munch(x=x, y=y)
        else:
            raise NotImplementedError

        return Munch({k: v.to(self.device)
                      for k, v in inputs.items()})

    # 使其兼容 for-in 语法
    def __iter__(self):
        return self

问题解决

报错三

AttributeError: Can’t pickle local object “get_train_loader< locals><\lambda>

参考如下博文
https://blog.youkuaiyun.com/genous110/article/details/115474244

改正三

在data_loader.py中的get_train_loader()函数附近添加如下代码

# 定义一个常规函数来替代 lambda
def random_crop_transform(x, crop_fn, probability):
    if random.random() < probability:
        return crop_fn(x)
    else:
        return x

# 直接使用函数进行变换
class RandomCropTransform:
    def __init__(self, crop_fn, probability):
        self.crop_fn = crop_fn
        self.probability = probability

    def __call__(self, x):
        return random_crop_transform(x, self.crop_fn, self.probability)

然后将get_train_loader()函数中的rand_crop变量修改为

rand_crop = RandomCropTransform(crop, prob)

问题解决

完结撒花

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值