R e s o u r c e s \rm Resources Resources
📋 P a p e r \rm Paper Paper >> T r a i n i n g G e n e r a t i v e A d v e r s a r i a l N e t w o r k s w i t h L i m i t e d D a t a \rm Training~Generative~Adversarial~Networks~with~Limited~Data Training Generative Adversarial Networks with Limited Data
💻 C o d e s \rm Codes Codes >>tensorflow
/pytorch
📰 B l o g s \rm Blogs Blogs >> C S D N : \rm 优快云: CSDN: T r a i n i n g G e n e r a t i v e A d v e r s a r i a l N e t w o r k s w i t h L i m i t e d D a t a \rm Training~Generative~Adversarial~Networks~with~Limited~Data Training Generative Adversarial Networks with Limited Data
A n a l y s i s \rm Analysis Analysis
1. 1. 1. 计算 L a d v L_{adv} Ladv 时,前馈真实样本或伪样本给鉴别器 D D D;
## 对生成伪样本
fake_logits = self.run_D(fake_img, fake_c)
## 对真实训练样本
real_logits = self.run_D(real_img, real_c)
2.
2.
2. run_D
的具体细节为:
def run_D(self, img, c):
## 使用定义的含 p 的数据增强流水线作 aug
img = self.augment_pipe(img)
logits = self.D(img, c)
return logits
3.
3.
3. 记录
L
a
d
v
L_{adv}
Ladv 和
E
[
s
i
g
n
(
D
t
r
a
i
n
)
]
{\mathbb E}[\rm sign (D_{train})]
E[sign(Dtrain)] (这里使用的是 WGAN
,所以
L
a
d
v
L_{adv}
Ladv 计算方式比较简单,
min
/
max
\min/\max
min/max ?_logits
即可)
training_stats.report('Loss/scores/real', real_logits)
training_stats.report('Loss/signs/real' , real_logits.sign()) ## 👈
4. 4. 4. 具体的,每一个被 r e p o r t e d \rm reported reported 的状态(统计数据 s t a t i s t i c \rm statistic statistic)被记录了 3 个统计量:
## `elems` 是一个形参,这里考虑 tensor `logits`
moments = torch.stack([
torch.ones_like(elems).sum(), ## 记录数量 (count)
elems.sum(), ## 求和
## 计算 E[sign(D_train)] 只需要前两个统计量,即:moments[1]/moments[0]
elems.square().sum(),
])
5.
5.
5. 累计前面
4.
4.
4. 记录的数据 moments of real_logits
,累积(会使用到 _moments.add_(moments)
)
N
=
4
N=4
N=4 次迭代(
i
t
e
r
/
m
i
n
i
b
a
t
c
h
\rm iter/minibatch
iter/minibatch)
6.
6.
6. 通过获取
E
[
s
i
g
n
(
D
t
r
a
i
n
)
]
\mathbb E[\rm sign(D_{train})]
E[sign(Dtrain)] 来动态更新 p
# Execute ADA heuristic.
if (ada_stats is not None) and \ ## 是否使用 ADA 这一项技术
(batch_idx % ada_interval == 0): ## N 值
ada_stats.update()
adjust = np.sign(ada_stats['Loss/signs/real'] - ada_target) \ ## ada_target 是 r_t 的阈值,文中设置是 0.6
## ada_stats['Loss/signs/real'] = moments_of_real_logits[1]/moments_of_real_logits[0]
* (batch_size * ada_interval) / (ada_kimg * 1000) ## 增益,the gain := (BxN)/SCALE, `B` is batch size, `N` is # of batches; all-in-all, it is FIXED.
## 更新 p 值
augment_pipe.p.copy_((augment_pipe.p + adjust)\ ## D 偏强,则 adjust 为正,Aug 强度适当增大;D 偏弱,则 adjust 为负,Aug 强度适当减弱
.max(misc.constant(0, device=device))) ## clip/truncate,限制概率在有效范围
U s a g e o f A u g P i p e \rm Usage~of~AugPipe Usage of AugPipe
引用脚本文件 ./training/augment.py
,直接初始化 nn.Module
模块——
import augment
aug_pipe = augment.AugmentPipe()
## input --type=torch.tensor --size=(N,C,H,W)
aug_input = aug_pipe(input)
B
T
W
\rm BTW
BTW,这个项目对于 pytorch
多线程、多卡并行训练编程有非常好的借鉴性,木奉👍!