require 'torch' | |
require 'nn' | |
require 'optim' | |
opt = { | |
dataset = 'lsun', -- imagenet / lsun / folder | |
batchSize = 64, | |
loadSize = 96, | |
fineSize = 64, | |
nz = 100, -- # of dim for Z | |
ngf = 64, -- # of gen filters in first conv layer | |
ndf = 64, -- # of discrim filters in first conv layer | |
nThreads = 4, -- # of data loading threads to use | |
niter = 25, -- # of iter at starting learning rate | |
lr = 0.0002, -- initial learning rate for adam | |
beta1 = 0.5, -- momentum term of adam | |
ntrain = math.huge, -- # of examples per epoch. math.huge for full dataset | |
display = 1, -- display samples while training. 0 = false | |
display_id = 10, -- display window id. | |
gpu = 1, -- gpu = 0 is CPU mode. gpu=X is GPU mode on GPU X | |
name = 'experiment1', | |
noise = 'normal', -- uniform / normal | |
} | |
-- one-line argument parser. parses enviroment variables to override the defaults | |
for k,v in pairs(opt) do opt[k] = tonumber(os.getenv(k)) or os.getenv(k) or opt[k] end | |
print(opt) | |
if opt.display == 0 then opt.display = false end | |
opt.manualSeed = torch.random(1, 10000) -- fix seed | |
print("Random Seed: " .. opt.manualSeed) | |
torch.manualSeed(opt.manualSeed) | |
torch.setnumthreads(1) | |
torch.setdefaulttensortype('torch.FloatTensor') | |
-- create data loader | |
local DataLoader = paths.dofile('data/data.lua') | |
local data = DataLoader.new(opt.nThreads, opt.dataset, opt) | |
print("Dataset: " .. opt.dataset, " Size: ", data:size()) | |
---------------------------------------------------------------------------- | |
local function weights_init(m) | |
local name = torch.type(m) | |
if name:find('Convolution') then | |
m.weight:normal(0.0, 0.02) | |
m:noBias() | |
elseif name:find('BatchNormalization') then | |
if m.weight then m.weight:normal(1.0, 0.02) end | |
if m.bias then m.bias:fill(0) end | |
end | |
end | |
local nc = 3 | |
local nz = opt.nz | |
local ndf = opt.ndf | |
local ngf = opt.ngf | |
local real_label = 1 | |
local fake_label = 0 | |
local SpatialBatchNormalization = nn.SpatialBatchNormalization | |
local SpatialConvolution = nn.SpatialConvolution | |
local SpatialFullConvolution = nn.SpatialFullConvolution | |
local netG = nn.Sequential() | |
-- input is Z, going into a convolution | |
netG:add(SpatialFullConvolution(nz, ngf * 8, 4, 4)) | |
netG:add(SpatialBatchNormalization(ngf * 8)):add(nn.ReLU(true)) | |
-- state size: (ngf*8) x 4 x 4 | |
netG:add(SpatialFullConvolution(ngf * 8, ngf * 4, 4, 4, 2, 2, 1, 1)) | |
netG:add(SpatialBatchNormalization(ngf * 4)):add(nn.ReLU(true)) | |
-- state size: (ngf*4) x 8 x 8 | |
netG:add(SpatialFullConvolution(ngf * 4, ngf * 2, 4, 4, 2, 2, 1, 1)) | |
netG:add(SpatialBatchNormalization(ngf * 2)):add(nn.ReLU(true)) | |
-- state size: (ngf*2) x 16 x 16 | |
netG:add(SpatialFullConvolution(ngf * 2, ngf, 4, 4, 2, 2, 1, 1)) | |
netG:add(SpatialBatchNormalization(ngf)):add(nn.ReLU(true)) | |
-- state size: (ngf) x 32 x 32 | |
netG:add(SpatialFullConvolution(ngf, nc, 4, 4, 2, 2, 1, 1)) | |
netG:add(nn.Tanh()) | |
-- state size: (nc) x 64 x 64 | |
netG:apply(weights_init) | |
local netD = nn.Sequential() | |
-- input is (nc) x 64 x 64 | |
netD:add(SpatialConvolution(nc, ndf, 4, 4, 2, 2, 1, 1)) | |
netD:add(nn.LeakyReLU(0.2, true)) | |
-- state size: (ndf) x 32 x 32 | |
netD:add(SpatialConvolution(ndf, ndf * 2, 4, 4, 2, 2, 1, 1)) | |
netD:add(SpatialBatchNormalization(ndf * 2)):add(nn.LeakyReLU(0.2, true)) | |
-- state size: (ndf*2) x 16 x 16 | |
netD:add(SpatialConvolution(ndf * 2, ndf * 4, 4, 4, 2, 2, 1, 1)) | |
netD:add(SpatialBatchNormalization(ndf * 4)):add(nn.LeakyReLU(0.2, true)) | |
-- state size: (ndf*4) x 8 x 8 | |
netD:add(SpatialConvolution(ndf * 4, ndf * 8, 4, 4, 2, 2, 1, 1)) | |
netD:add(SpatialBatchNormalization(ndf * 8)):add(nn.LeakyReLU(0.2, true)) | |
-- state size: (ndf*8) x 4 x 4 | |
netD:add(SpatialConvolution(ndf * 8, 1, 4, 4)) | |
netD:add(nn.Sigmoid()) | |
-- state size: 1 x 1 x 1 | |
netD:add(nn.View(1):setNumInputDims(3)) | |
-- state size: 1 | |
netD:apply(weights_init) | |
local criterion = nn.BCECriterion() | |
--------------------------------------------------------------------------- | |
optimStateG = { | |
learningRate = opt.lr, | |
beta1 = opt.beta1, | |
} | |
optimStateD = { | |
learningRate = opt.lr, | |
beta1 = opt.beta1, | |
} | |
---------------------------------------------------------------------------- | |
local input = torch.Tensor(opt.batchSize, 3, opt.fineSize, opt.fineSize) | |
local noise = torch.Tensor(opt.batchSize, nz, 1, 1) | |
local label = torch.Tensor(opt.batchSize) | |
local errD, errG | |
local epoch_tm = torch.Timer() | |
local tm = torch.Timer() | |
local data_tm = torch.Timer() | |
---------------------------------------------------------------------------- | |
if opt.gpu > 0 then | |
require 'cunn' | |
cutorch.setDevice(opt.gpu) | |
input = input:cuda(); noise = noise:cuda(); label = label:cuda() | |
if pcall(require, 'cudnn') then | |
require 'cudnn' | |
cudnn.benchmark = true | |
cudnn.convert(netG, cudnn) | |
cudnn.convert(netD, cudnn) | |
end | |
netD:cuda(); netG:cuda(); criterion:cuda() | |
end | |
local parametersD, gradParametersD = netD:getParameters() | |
local parametersG, gradParametersG = netG:getParameters() | |
if opt.display then disp = require 'display' end | |
noise_vis = noise:clone() | |
if opt.noise == 'uniform' then | |
noise_vis:uniform(-1, 1) | |
elseif opt.noise == 'normal' then | |
noise_vis:normal(0, 1) | |
end | |
-- create closure to evaluate f(X) and df/dX of discriminator | |
local fDx = function(x) | |
gradParametersD:zero() | |
-- train with real | |
data_tm:reset(); data_tm:resume() | |
local real = data:getBatch() | |
data_tm:stop() | |
input:copy(real) | |
label:fill(real_label) | |
local output = netD:forward(input) | |
local errD_real = criterion:forward(output, label) | |
local df_do = criterion:backward(output, label) | |
netD:backward(input, df_do) | |
-- train with fake | |
if opt.noise == 'uniform' then -- regenerate random noise | |
noise:uniform(-1, 1) | |
elseif opt.noise == 'normal' then | |
noise:normal(0, 1) | |
end | |
local fake = netG:forward(noise) | |
input:copy(fake) | |
label:fill(fake_label) | |
local output = netD:forward(input) | |
local errD_fake = criterion:forward(output, label) | |
local df_do = criterion:backward(output, label) | |
netD:backward(input, df_do) | |
errD = errD_real + errD_fake | |
return errD, gradParametersD | |
end | |
-- create closure to evaluate f(X) and df/dX of generator | |
local fGx = function(x) | |
gradParametersG:zero() | |
--[[ the three lines below were already executed in fDx, so save computation | |
noise:uniform(-1, 1) -- regenerate random noise | |
local fake = netG:forward(noise) | |
input:copy(fake) ]]-- | |
label:fill(real_label) -- fake labels are real for generator cost | |
local output = netD.output -- netD:forward(input) was already executed in fDx, so save computation | |
errG = criterion:forward(output, label) | |
local df_do = criterion:backward(output, label) | |
local df_dg = netD:updateGradInput(input, df_do) | |
netG:backward(noise, df_dg) | |
return errG, gradParametersG | |
end | |
-- train | |
for epoch = 1, opt.niter do | |
epoch_tm:reset() | |
local counter = 0 | |
for i = 1, math.min(data:size(), opt.ntrain), opt.batchSize do | |
tm:reset() | |
-- (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) | |
optim.adam(fDx, parametersD, optimStateD) | |
-- (2) Update G network: maximize log(D(G(z))) | |
optim.adam(fGx, parametersG, optimStateG) | |
-- display | |
counter = counter + 1 | |
if counter % 10 == 0 and opt.display then | |
local fake = netG:forward(noise_vis) | |
local real = data:getBatch() | |
disp.image(fake, {win=opt.display_id, title=opt.name}) | |
disp.image(real, {win=opt.display_id * 3, title=opt.name}) | |
end | |
-- logging | |
if ((i-1) / opt.batchSize) % 1 == 0 then | |
print(('Epoch: [%d][%8d / %8d]\t Time: %.3f DataTime: %.3f ' | |
.. ' Err_G: %.4f Err_D: %.4f'):format( | |
epoch, ((i-1) / opt.batchSize), | |
math.floor(math.min(data:size(), opt.ntrain) / opt.batchSize), | |
tm:time().real, data_tm:time().real, | |
errG and errG or -1, errD and errD or -1)) | |
end | |
end | |
paths.mkdir('checkpoints') | |
parametersD, gradParametersD = nil, nil -- nil them to avoid spiking memory | |
parametersG, gradParametersG = nil, nil | |
torch.save('checkpoints/' .. opt.name .. '_' .. epoch .. '_net_G.t7', netG:clearState()) | |
torch.save('checkpoints/' .. opt.name .. '_' .. epoch .. '_net_D.t7', netD:clearState()) | |
parametersD, gradParametersD = netD:getParameters() -- reflatten the params and get them | |
parametersG, gradParametersG = netG:getParameters() | |
print(('End of epoch %d / %d \t Time Taken: %.3f'):format( | |
epoch, opt.niter, epoch_tm:time().real)) | |
end |