PyTorch-GAN 项目使用教程
1. 项目的目录结构及介绍
PyTorch-GAN 项目的目录结构如下:
PyTorch-GAN/
├── implementations/
│ ├── auxiliary_classifier_gan/
│ ├── adversarial_autoencoder/
│ ├── began/
│ ├── bicyclegan/
│ ├── boundary_seeking_gan/
│ ├── cluster_gan/
│ ├── conditional_gan/
│ ├── context_conditional_gan/
│ ├── context_encoder/
│ ├── coupled_gan/
│ ├── cyclegan/
│ ├── deep_convolutional_gan/
│ ├── discogan/
│ ├── dragan/
│ ├── dualgan/
│ ├── energy_based_gan/
│ ├── enhanced_super_resolution_gan/
│ ├── gan/
│ ├── infogan/
│ ├── least_squares_gan/
│ ├── munit/
│ ├── pix2pix/
│ ├── pixelda/
│ ├── relativistic_gan/
│ ├── semi_supervised_gan/
│ ├── softmax_gan/
│ ├── stargan/
│ ├── super_resolution_gan/
│ ├── unit/
│ ├── wasserstein_gan/
│ ├── wasserstein_gan_gp/
│ └── wasserstein_gan_div/
├── data/
├── README.md
└── requirements.txt
目录结构介绍
implementations/
:包含各种 GAN 模型的实现代码。data/
:用于存放训练数据。README.md
:项目说明文档。requirements.txt
:项目依赖的 Python 包列表。
2. 项目的启动文件介绍
每个 GAN 模型的启动文件位于 implementations/
目录下的相应子目录中。例如,CycleGAN 的启动文件是 cyclegan.py
,位于 implementations/cyclegan/
目录下。
启动文件示例
以 CycleGAN 为例:
cd implementations/cyclegan/
python3 cyclegan.py
3. 项目的配置文件介绍
PyTorch-GAN 项目没有统一的配置文件,每个 GAN 模型的配置参数通常在启动文件中直接定义。例如,CycleGAN 的配置参数在 cyclegan.py
文件中定义。
配置参数示例
以 CycleGAN 为例,部分配置参数如下:
# cyclegan.py
parser = argparse.ArgumentParser()
parser.add_argument('--epoch', type=int, default=0, help='starting epoch')
parser.add_argument('--n_epochs', type=int, default=200, help='number of epochs of training')
parser.add_argument('--batch_size', type=int, default=1, help='size of the batches')
parser.add_argument('--lr', type=float, default=0.0002, help='adam: learning rate')
parser.add_argument('--b1', type=float, default=0.5, help='adam: decay of first order momentum of gradient')
parser.add_argument('--b2', type=float, default=0.999, help='adam: decay of first order momentum of gradient')
parser.add_argument('--decay_epoch', type=int, default=100, help='epoch from which to start lr decay')
parser.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation')
parser.add_argument('--img_height', type=int, default=256, help='size of image height')
parser.add_argument('--img_width', type=int, default=256, help='size of image width')
parser.add_argument('--channels', type=int, default=3, help='number of image channels')
parser.add_argument('--sample_interval', type=int, default=500, help='interval between saving generator outputs')
parser.add_argument('--checkpoint_interval', type=int, default=-1, help='interval between saving model checkpoints')
parser.add_argument('--n_residual_blocks', type=int, default=9, help='number of residual blocks in generator')
parser.add_argument('--lambda_cyc', type=float, default=1
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考