1. 框架

2. main
import argparse
import torch
from torch.utils.data import DataLoader
from torch import autograd, optim
from torchvision.transforms import transforms
from torchvision import datasets
import numpy as np
from ACGAN import *
import os
import json
import sys
from tqdm import tqdm
from torch.autograd import Variable
from ShowFeatureMap import *
def getArgs():
parse = argparse.ArgumentParser()
parse.add_argument("--action", type=str, default="train&val", help="train/val")
parse.add_argument("--rootpath", type=str, default='/media/yuanxingWorkSpace/ImageAugmentation/ACGAN/ISBClassifyLabelData', help="path")
parse.add_argument("--epoch", type=int, default=100)
parse.add_argument('--arch', '-a', metavar='ARCH', default='ACGAN', help='ACGAN')
parse.add_argument("--batch_size", type=int, default=1)
parse.add_argument("--shuffle", default= True)
parse.add_argument('--dataset', default='ISBClassifyLabelData', help='ISBClassifyLabelData')
parse.add_argument('--lr', type=float, default=0.0001, help='Learning Rate. Default=0.001')
parse.add_argument("--weight_decay", "--wd", default=0, type=float, help="Weight decay, Default: 1e-4")
parse.add_argument("--momentum", default=0.9, type=float, help="Momentum, Default: 0.9")
parse.add_argument("--step", type=int, default=20, help="Sets the learning rate to the initial LR decayed by momentum every n epochs, Default: n=10")
parse.add_argument("--latent_dim", type=int, default=100)
parse.add_argument("--n_classes", type=int, default=3)
args = parse.parse_args()
return args
def getDataset(args):
data_transform = {
"train": transforms.Compose([
transforms.Grayscale(num_output_channels=1),
transforms.CenterCrop(480),
transforms.ToTensor(),]),
"val": transforms.Compose([
transforms.Grayscale(num_output_channels=1),
transforms.CenterCrop(480),
transforms.ToTensor(),])}
assert os.path.exists(args.rootpath), "{} path does not exist.".format(args.rootpath)
train_dataset = datasets.ImageFolder(root=os.path.join(args.rootpath, "train"),transform=data_transform["train"])
ISB_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in ISB_list.items())
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)
nw = min([os.cpu_count(), args.batch_size if args.batch_size > 1 else 0, 4])
print('Using {} dataloader workers every process'.format(nw))
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=args.batch_size, shuffle=True,
num_workers=nw)
validate_dataset = datasets.ImageFolder(root=os.path.join(args.rootpath, "val"),
transform=data_transform["val"])
val_loader = torch.utils.data.DataLoader(validate_dataset,
batch_size=args.batch_size, shuffle=False,
num_workers=nw)
return train_loader, val_loader
def getModel(args):
if args.arch == 'ACGAN':
Gmodel = Generator()
Dmodel = Discriminator()
return Gmodel, Dmodel
def train(Dmodel, Gmodel, adversarial_loss, auxiliary_loss, Doptimizer, Goptimizer, train_loader, val_loader, args):
num_epochs = args.epoch
for epoch in range(1, num_epochs+1):
Dmodel = Dmodel.train()
Gmodel = Gmodel.train()
train_bar = tqdm(train_loader, file=sys.stdout)
for i, (imgs, labels) in enumerate(train_bar):
batch_size = imgs.shape[0]
valid =torch.FloatTensor(batch_size, 1).fill_(1.0).to(device)
fake = torch.FloatTensor(batch_size, 1).fill_(0.0).to(device)
real_imgs = torch.FloatTensor(imgs).to(device)
labels = torch.LongTensor(labels).to(device)
Goptimizer.zero_grad()
z = torch.FloatTensor(np.random.normal(0, 1, (batch_size, args.latent_dim))).to(device)
gen_labels = torch.LongTensor(np.random.randint(0, args.n_classes, batch_size)).to(device)
gen_imgs = Gmodel(z, gen_labels)
showGFeatureMap(real_imgs, gen_imgs)
validity, pred_label = Dmodel(gen_imgs)
g_loss = 0.5 * (adversarial_loss(validity, valid) + auxiliary_loss(pred_label, gen_labels))
g_loss.backward()
Goptimizer.step()
Doptimizer.zero_grad()
real_pred, real_aux = Dmodel(real_imgs)
d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, labels)) / 2
fake_pred, fake_aux = Dmodel(gen_imgs.detach())
d_fake_loss = (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, gen_labels)) / 2
d_loss = (d_real_loss + d_fake_loss) / 2
pred = np.concatenate([real_aux.data.cpu().numpy(), fake_aux.data.cpu().numpy()], axis=0)
gt = np.concatenate([labels.data.cpu().numpy(), gen_labels.data.cpu().numpy()], axis=0)
d_acc = np.mean(np.argmax(pred, axis=1) == gt)
d_loss.backward()
Doptimizer.step()
print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %d%%] [G loss: %f]"% (epoch, num_epochs, i, len(train_loader), d_loss.item(), 100 * d_acc, g_loss.item()))
Gmodel.eval()
with torch.no_grad():
for i, (valRealImg, valLabels) in enumerate(val_loader):
valNoise = torch.FloatTensor(np.random.normal(0, 1, (1, 100))).to(device)
valLabels = torch.LongTensor(valLabels).to(device)
valGImg = Gmodel(valNoise, valLabels)
showValGFeatureMap(valRealImg, valGImg)
torch.save(Gmodel.state_dict(), r"/media/yuanxingWorkSpace/ImageAugmentation/ACGAN/ISB-ACGAN/Gmodel/" + str(args.arch) + '_' + str(args.batch_size) + '_' + str(args.epoch) + '.pth')
if __name__ == '__main__':
os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
args = getArgs()
print('**************************')
print('models:%s,\nepoch:%s,\nbatch size:%s\ndataset:%s' % \
(args.arch, args.epoch, args.batch_size, args.dataset))
print('**************************')
Gmodel, Dmodel = getModel(args)
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
Dmodel = nn.DataParallel(Dmodel, device_ids=[0, 1])
Gmodel = nn.DataParallel(Gmodel, device_ids=[0, 1])
Gmodel.to(device)
Dmodel.to(device)
train_loader, val_loader = getDataset(args)
Goptimizer = optim.Adam(Gmodel.parameters(), lr=args.lr, betas=(0.5, 0.999))
Doptimizer = optim.Adam(Dmodel.parameters(), lr=args.lr, betas=(0.5, 0.999))
adversarial_loss = torch.nn.BCELoss()
auxiliary_loss = torch.nn.CrossEntropyLoss()
if 'train' in args.action:
train(Dmodel, Gmodel, adversarial_loss, auxiliary_loss, Doptimizer, Goptimizer, train_loader, val_loader, args)
2. ACGAN
import torch.nn as nn
import torch.nn.functional as F
import torch
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.label_emb = nn.Embedding(3, 100)
self.init_size = 480 // 4
self.l1 = nn.Sequential(nn.Linear(100, 128 * self.init_size ** 2))
self.conv_blocks = nn.Sequential(
nn.BatchNorm2d(128),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 128, 3, stride=1, padding=1),
nn.BatchNorm2d(128, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 64, 3, stride=1, padding=1),
nn.BatchNorm2d(64, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 1, 3, stride=1, padding=1),
nn.Tanh(),
)
def forward(self, noise, labels):
gen_input = torch.mul(self.label_emb(labels), noise)
out = self.l1(gen_input)
out = out.view(out.shape[0], 128, self.init_size, self.init_size)
img = self.conv_blocks(out)
return img
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
def discriminator_block(in_filters, out_filters, bn=True):
"""Returns layers of each discriminator block"""
block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
if bn:
block.append(nn.BatchNorm2d(out_filters, 0.8))
return block
self.conv_blocks = nn.Sequential(
*discriminator_block(1, 16, bn=False),
*discriminator_block(16, 32),
*discriminator_block(32, 64),
*discriminator_block(64, 128),
)
ds_size = 480 // 2 ** 4
self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())
self.aux_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 3), nn.Softmax())
def forward(self, img):
out = self.conv_blocks(img)
out = out.view(out.shape[0], -1)
validity = self.adv_layer(out)
label = self.aux_layer(out)
return validity, label
3. ShowFeatureMap
import numpy as np
import imageio
def showGFeatureMap(real, featureMap):
G_real = real.squeeze(0)
featureMap = featureMap.squeeze(0)
G_real = G_real.detach().cpu().numpy()
featureMap = featureMap.detach().cpu().numpy()
concatImg = np.concatenate((G_real, featureMap), axis=2)
featureMapNum = featureMap.shape[0]
for index in range(1, featureMapNum+1):
imageio.imwrite("/media/yuanxingWorkSpace/ImageAugmentation/ACGAN/ISB-ACGAN/GFeatureMap/" + str(index) + ".png",
(concatImg[index-1]*255).astype("uint8"))
def showValGFeatureMap(real, featureMap):
G_real = real.squeeze(0)
featureMap = featureMap.squeeze(0)
G_real = G_real.detach().cpu().numpy()
featureMap = featureMap.detach().cpu().numpy()
concatImg = np.concatenate((G_real, featureMap), axis=2)
featureMapNum = featureMap.shape[0]
for index in range(1, featureMapNum+1):
imageio.imwrite("/media/yuanxingWorkSpace/ImageAugmentation/ACGAN/ISB-ACGAN/valGFeatureMap/" + str(index) + ".png",
(concatImg[index-1]*255).astype("uint8"))