| require 'torch' | |
| require 'nn' | |
| require 'optim' | |
| opt = { | |
| dataset = 'lsun', -- imagenet / lsun / folder | |
| batchSize = 64, | |
| loadSize = 96, | |
| fineSize = 64, | |
| nz = 100, -- # of dim for Z | |
| ngf = 64, -- # of gen filters in first conv layer | |
| ndf = 64, -- # of discrim filters in first conv layer | |
| nThreads = 4, -- # of data loading threads to use | |
| niter = 25, -- # of iter at starting learning rate | |
| lr = 0.0002, -- initial learning rate for adam | |
| beta1 = 0.5, -- momentum term of adam | |
| ntrain = math.huge, -- # of examples per epoch. math.huge for full dataset | |
| display = 1, -- display samples while training. 0 = false | |
| display_id = 10, -- display window id. | |
| gpu = 1, -- gpu = 0 is CPU mode. gpu=X is GPU mode on GPU X | |
| name = 'experiment1', | |
| noise = 'normal', -- uniform / normal | |
| } | |
| -- one-line argument parser. parses enviroment variables to override the defaults | |
| for k,v in pairs(opt) do opt[k] = tonumber(os.getenv(k)) or os.getenv(k) or opt[k] end | |
| print(opt) | |
| if opt.display == 0 then opt.display = false end | |
| opt.manualSeed = torch.random(1, 10000) -- fix seed | |
| print("Random Seed: " .. opt.manualSeed) | |
| torch.manualSeed(opt.manualSeed) | |
| torch.setnumthreads(1) | |
| torch.setdefaulttensortype('torch.FloatTensor') | |
| -- create data loader | |
| local DataLoader = paths.dofile('data/data.lua') | |
| local data = DataLoader.new(opt.nThreads, opt.dataset, opt) | |
| print("Dataset: " .. opt.dataset, " Size: ", data:size()) | |
| ---------------------------------------------------------------------------- | |
| local function weights_init(m) | |
| local name = torch.type(m) | |
| if name:find('Convolution') then | |
| m.weight:normal(0.0, 0.02) | |
| m:noBias() | |
| elseif name:find('BatchNormalization') then | |
| if m.weight then m.weight:normal(1.0, 0.02) end | |
| if m.bias then m.bias:fill(0) end | |
| end | |
| end | |
| local nc = 3 | |
| local nz = opt.nz | |
| local ndf = opt.ndf | |
| local ngf = opt.ngf | |
| local real_label = 1 | |
| local fake_label = 0 | |
| local SpatialBatchNormalization = nn.SpatialBatchNormalization | |
| local SpatialConvolution = nn.SpatialConvolution | |
| local SpatialFullConvolution = nn.SpatialFullConvolution | |
| local netG = nn.Sequential() | |
| -- input is Z, going into a convolution | |
| netG:add(SpatialFullConvolution(nz, ngf * 8, 4, 4)) | |
| netG:add(SpatialBatchNormalization(ngf * 8)):add(nn.ReLU(true)) | |
| -- state size: (ngf*8) x 4 x 4 | |
| netG:add(SpatialFullConvolution(ngf * 8, ngf * 4, 4, 4, 2, 2, 1, 1)) | |
| netG:add(SpatialBatchNormalization(ngf * 4)):add(nn.ReLU(true)) | |
| -- state size: (ngf*4) x 8 x 8 | |
| netG:add(SpatialFullConvolution(ngf * 4, ngf * 2, 4, 4, 2, 2, 1, 1)) | |
| netG:add(SpatialBatchNormalization(ngf * 2)):add(nn.ReLU(true)) | |
| -- state size: (ngf*2) x 16 x 16 | |
| netG:add(SpatialFullConvolution(ngf * 2, ngf, 4, 4, 2, 2, 1, 1)) | |
| netG:add(SpatialBatchNormalization(ngf)):add(nn.ReLU(true)) | |
| -- state size: (ngf) x 32 x 32 | |
| netG:add(SpatialFullConvolution(ngf, nc, 4, 4, 2, 2, 1, 1)) | |
| netG:add(nn.Tanh()) | |
| -- state size: (nc) x 64 x 64 | |
| netG:apply(weights_init) | |
| local netD = nn.Sequential() | |
| -- input is (nc) x 64 x 64 | |
| netD:add(SpatialConvolution(nc, ndf, 4, 4, 2, 2, 1, 1)) | |
| netD:add(nn.LeakyReLU(0.2, true)) | |
| -- state size: (ndf) x 32 x 32 | |
| netD:add(SpatialConvolution(ndf, ndf * 2, 4, 4, 2, 2, 1, 1)) | |
| netD:add(SpatialBatchNormalization(ndf * 2)):add(nn.LeakyReLU(0.2, true)) | |
| -- state size: (ndf*2) x 16 x 16 | |
| netD:add(SpatialConvolution(ndf * 2, ndf * 4, 4, 4, 2, 2, 1, 1)) | |
| netD:add(SpatialBatchNormalization(ndf * 4)):add(nn.LeakyReLU(0.2, true)) | |
| -- state size: (ndf*4) x 8 x 8 | |
| netD:add(SpatialConvolution(ndf * 4, ndf * 8, 4, 4, 2, 2, 1, 1)) | |
| netD:add(SpatialBatchNormalization(ndf * 8)):add(nn.LeakyReLU(0.2, true)) | |
| -- state size: (ndf*8) x 4 x 4 | |
| netD:add(SpatialConvolution(ndf * 8, 1, 4, 4)) | |
| netD:add(nn.Sigmoid()) | |
| -- state size: 1 x 1 x 1 | |
| netD:add(nn.View(1):setNumInputDims(3)) | |
| -- state size: 1 | |
| netD:apply(weights_init) | |
| local criterion = nn.BCECriterion() | |
| --------------------------------------------------------------------------- | |
| optimStateG = { | |
| learningRate = opt.lr, | |
| beta1 = opt.beta1, | |
| } | |
| optimStateD = { | |
| learningRate = opt.lr, | |
| beta1 = opt.beta1, | |
| } | |
| ---------------------------------------------------------------------------- | |
| local input = torch.Tensor(opt.batchSize, 3, opt.fineSize, opt.fineSize) | |
| local noise = torch.Tensor(opt.batchSize, nz, 1, 1) | |
| local label = torch.Tensor(opt.batchSize) | |
| local errD, errG | |
| local epoch_tm = torch.Timer() | |
| local tm = torch.Timer() | |
| local data_tm = torch.Timer() | |
| ---------------------------------------------------------------------------- | |
| if opt.gpu > 0 then | |
| require 'cunn' | |
| cutorch.setDevice(opt.gpu) | |
| input = input:cuda(); noise = noise:cuda(); label = label:cuda() | |
| if pcall(require, 'cudnn') then | |
| require 'cudnn' | |
| cudnn.benchmark = true | |
| cudnn.convert(netG, cudnn) | |
| cudnn.convert(netD, cudnn) | |
| end | |
| netD:cuda(); netG:cuda(); criterion:cuda() | |
| end | |
| local parametersD, gradParametersD = netD:getParameters() | |
| local parametersG, gradParametersG = netG:getParameters() | |
| if opt.display then disp = require 'display' end | |
| noise_vis = noise:clone() | |
| if opt.noise == 'uniform' then | |
| noise_vis:uniform(-1, 1) | |
| elseif opt.noise == 'normal' then | |
| noise_vis:normal(0, 1) | |
| end | |
| -- create closure to evaluate f(X) and df/dX of discriminator | |
| local fDx = function(x) | |
| gradParametersD:zero() | |
| -- train with real | |
| data_tm:reset(); data_tm:resume() | |
| local real = data:getBatch() | |
| data_tm:stop() | |
| input:copy(real) | |
| label:fill(real_label) | |
| local output = netD:forward(input) | |
| local errD_real = criterion:forward(output, label) | |
| local df_do = criterion:backward(output, label) | |
| netD:backward(input, df_do) | |
| -- train with fake | |
| if opt.noise == 'uniform' then -- regenerate random noise | |
| noise:uniform(-1, 1) | |
| elseif opt.noise == 'normal' then | |
| noise:normal(0, 1) | |
| end | |
| local fake = netG:forward(noise) | |
| input:copy(fake) | |
| label:fill(fake_label) | |
| local output = netD:forward(input) | |
| local errD_fake = criterion:forward(output, label) | |
| local df_do = criterion:backward(output, label) | |
| netD:backward(input, df_do) | |
| errD = errD_real + errD_fake | |
| return errD, gradParametersD | |
| end | |
| -- create closure to evaluate f(X) and df/dX of generator | |
| local fGx = function(x) | |
| gradParametersG:zero() | |
| --[[ the three lines below were already executed in fDx, so save computation | |
| noise:uniform(-1, 1) -- regenerate random noise | |
| local fake = netG:forward(noise) | |
| input:copy(fake) ]]-- | |
| label:fill(real_label) -- fake labels are real for generator cost | |
| local output = netD.output -- netD:forward(input) was already executed in fDx, so save computation | |
| errG = criterion:forward(output, label) | |
| local df_do = criterion:backward(output, label) | |
| local df_dg = netD:updateGradInput(input, df_do) | |
| netG:backward(noise, df_dg) | |
| return errG, gradParametersG | |
| end | |
| -- train | |
| for epoch = 1, opt.niter do | |
| epoch_tm:reset() | |
| local counter = 0 | |
| for i = 1, math.min(data:size(), opt.ntrain), opt.batchSize do | |
| tm:reset() | |
| -- (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) | |
| optim.adam(fDx, parametersD, optimStateD) | |
| -- (2) Update G network: maximize log(D(G(z))) | |
| optim.adam(fGx, parametersG, optimStateG) | |
| -- display | |
| counter = counter + 1 | |
| if counter % 10 == 0 and opt.display then | |
| local fake = netG:forward(noise_vis) | |
| local real = data:getBatch() | |
| disp.image(fake, {win=opt.display_id, title=opt.name}) | |
| disp.image(real, {win=opt.display_id * 3, title=opt.name}) | |
| end | |
| -- logging | |
| if ((i-1) / opt.batchSize) % 1 == 0 then | |
| print(('Epoch: [%d][%8d / %8d]\t Time: %.3f DataTime: %.3f ' | |
| .. ' Err_G: %.4f Err_D: %.4f'):format( | |
| epoch, ((i-1) / opt.batchSize), | |
| math.floor(math.min(data:size(), opt.ntrain) / opt.batchSize), | |
| tm:time().real, data_tm:time().real, | |
| errG and errG or -1, errD and errD or -1)) | |
| end | |
| end | |
| paths.mkdir('checkpoints') | |
| parametersD, gradParametersD = nil, nil -- nil them to avoid spiking memory | |
| parametersG, gradParametersG = nil, nil | |
| torch.save('checkpoints/' .. opt.name .. '_' .. epoch .. '_net_G.t7', netG:clearState()) | |
| torch.save('checkpoints/' .. opt.name .. '_' .. epoch .. '_net_D.t7', netD:clearState()) | |
| parametersD, gradParametersD = netD:getParameters() -- reflatten the params and get them | |
| parametersG, gradParametersG = netG:getParameters() | |
| print(('End of epoch %d / %d \t Time Taken: %.3f'):format( | |
| epoch, opt.niter, epoch_tm:time().real)) | |
| end |
dcgan.torch/main.lua
本文介绍了一个使用PyTorch实现的生成对抗网络(GAN)训练流程,详细展示了如何配置网络参数、构建生成器与判别器,并通过迭代优化过程实现高质量图像生成。
部署运行你感兴趣的模型镜像
您可能感兴趣的与本文相关的镜像
PyTorch 2.5
PyTorch
Cuda
PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理






