import argparse
class BaseOptions():
'''
This class defines the options used for during training and test time
'''
def __init__(self):
self.parse=argparse.ArgumentParser(description='enhancement underwater image')
def init(self):
self.parse.add_argument('--dataset', type=str, default='../underwater_image/512_512/trainset/', help='train_dataset')
self.parse.add_argument('--valset', type=str, default='../underwater_image/512_512/trainset/', help='val_dataset')
self.parse.add_argument('--testset', type=str, default='../underwater_image/512_512/mini_test400/', help='test_dataset')
self.parse.add_argument('--save_model', type=str, default='../Methon_data/MutiTransformer/MTGAN2_unpaired/', help='save_model_dir')
self.parse.add_argument('--pretrain_model', type=str, default='../Methon_data/MutiTransformer/MTGAN2/', help='save_model_dir')
self.parse.add_argument('--save_image', type=str, default='../Methon_data/contrast_image/MutiTransformer/', help='save_image_dir')
self.parse.add_argument('--batch_size', type=int, default=8, help='batch size')
self.parse.add_argument('--device', type=str, default='cuda:0', help='set train device')
self.parse.add_argument('--epoch_end', type=int, default=200, help='train end epoch ')
self.parse.add_argument('--lr', type=float, default=0.0002, help='learning rate ')
self.parse.add_argument('--niter', type=int, default=100, help='batch size')
self.parse.add_argument('--niter_decay', type=int, default=100, help='batch size')
self.parse.add_argument('--beta', type=float, default=0.5, help='batch size')
self.parse.add_argument('--resolution', type=int, default=512, help='use image resolution')
self.parse.add_argument('--image_save_status', type=bool, default=False, help='save image')
self.parse.add_argument('--train_log', type=str, default='../Methon_data/MTGAN_Lap1.txt', help='train loss and psnr log')
self.parse.add_argument('--train_state', type=str, default='MTGAN2_unpaired: MutiTransformer ', help='train set state')
self.parse.add_argument('--lamb', type=int, default=10, help='WGAN-GP')
self.parse.add_argument('--use_GP', type=bool, default=False, help='WGAN-GP')
return self.parse
导入超参数
from base_option import BaseOptions
opt = BaseOptions().init().parse_args()
调用
opt.device