目录
前言:
本文代码节选与修改自:https://github.com/yitong91/StoryGAN
论文地址:https://arxiv.org/abs/1812.02784
简介:Story Visualization任务的提出者,StoryGAN
架构:StoryEncoder + ContextEncoder + 上采样生成 + D_img + D_Story
摘要:
我们提出了一个新的任务,叫做故事可视化。 对于一个多句话的段落,故事可以通过生成一系列图像来可视化,每个图像对应一个句子。 与视频生成相比,故事可视化不太关注生成图像(帧)的连续性,而是更关注动态场景和角色之间的全局一致性——这是任何单一图像或视频生成方法都无法解决的挑战。 因此,我们提出了一种基于顺序条件GAN框架的故事-图像-序列生成模型——StoryGAN。 我们的模型是独特的,因为它包含一个动态跟踪故事流的深度上下文编码器,以及两个在故事和图像层面的鉴别器,以提高图像质量和生成序列的一致性。 为了评估该模型,我们修改了现有的数据集,创建了CLEVR-SV和Pororo-SV数据集。 根据经验,StoryGAN在图像质量、上下文一致性指标和人类评价方面优于最先进的模型。
一、Introduce
1)面临的挑战
a)生成图片的连贯性和连续性
b)不同的图片的具体的人物外貌和背景也必须连贯
2)与video生成的异同
a)影视生成更注重动作特征的抓取,来维持动作的连续性过渡,但故事可视化,更注重于静态帧,不注重动作特征
b)对于视频生成来说,通常有一个基本的固定背景,然后模型捕捉在故事线的不同时段下背景的改变。以这种方式来说,为故事配图,其实就是捕捉长视频的关键帧。
3)模型的构造:略
二、Related Works
1)文本生成图像:基于流、VAE、GAN
2)视频生成:文本到视频/图片到视频,但有不同,一个是我们永远是静态帧而不是连续的动作,二是我们的输入是不同的,视频生成是一句话就生成了n个图(短视频)
3)可视化故事的逆任务:看n个图叙述故事,也需要T2I/强化学习
4)基于检索的故事配图:
三、Methods
3.1 Background
任务描述:
给定一段故事S:{S1,S2,S3,S4 ,S5}
配予一组图片X:{x1,x2,x3,x4,x5}
不同的挑战:
1)与整体故事的一致性
2)风格内容的连贯性
3.2 框架
简介:
文本加工:句子间信息(GRU1提取)+全文信息(StoryEncoder提取)+ 上一张的风格内容连贯(Text2GIST链接)
图像生成:n个生成器+两个判别器(一对一判别、故事判别)
具体:
文本预处理
1)编码:
a)全文编码Story Encoder到低维
b)句子编码:编码到128维
2)信息融合
a)前文每句的信息:GPU(RNN)
b)全文信息+上一张生成的信息:Text2GIST(GRU+DynamicFilter)
GAN:
1)每个句子对应一个Generator(上采样生成)
2)总共配备两个Discriminator(Text-Image)(Story-Image)
3.3 StoryEncoder
1)编码结果h是从正态分布(u,σ²)之中随机采样的
2)u和σ都是通过映射得到:
u=MLP(Story)、σ²=MLP(Story)
3)再乘上一些噪声,最后得到了
4)为了梯度不崩塌,和增强语义空间的连贯性:加入了惩罚项KL散度:计算的是采样分布和标准正态分布之间的距离 ( 在算loss时加)
class CA_NET(nn.Module):
"""StoryEncoder:
需要输入初始故事编码text_embedding(len, d_TEXT),
MLP变换推出μ,σ = (len, c)
采样得到编码h0"""
def __init__(self):
super(CA_NET, self).__init__()
self.t_dim = cfg.TEXT.DIMENSION
self.c_dim = cfg.GAN.CONDITION_DIM
self.fc = nn.Linear(self.t_dim, self.c_dim * 2, bias=True)
self.relu = nn.ReLU()
def encode(self, text_embedding):
"""输入text_embedding,返回σ和μ
输入text_embedding,此层对经过MLP(Story_emb)后,拆成两份为μ与σ"""
# Story_emb=(len,d) --fc+relu--> (len,2c) -> mu=(len,0~c行) ; logvar =(len,c~2c行)
x = self.relu(self.fc(text_embedding)) # Relu(wx+b)
mu = x[:, :self.c_dim]
logvar = x[:, self.c_dim:]
return mu, logvar
def reparametrize(self, mu, logvar):
"""输入σ和μ,返回h0,采样自创立的正态分布:
h~(μ,σ) = (noise * σ) + μ"""
std = logvar.mul(0.5).exp_() # std即均值,σ = 1/2(e^logvar),而logvar = encoder(story_emb) = MLP(story_emb)
eps = torch.cuda.FloatTensor(std.size()).normal_() # eps即noise,采样自标准正态分布(size和std一样)
eps = Variable(eps) # 让nosie的采样变为可训练的
# h = (noise * σ) + μ
return eps.mul(std).add_(mu)
def forward(self, text_embedding):
"""需要输入story_emb,h = (noise * σ) + μ , μ、σ = MLP(S)"""
mu, logvar = self.encode(text_embedding)
c_code = self.reparametrize(mu, logvar)
return c_code, mu, logvar
3.4 Context Encoder
输入:sentence 与 content作为memory初始化
架构:第一层GRU,第二层GRU+DynamicFilter
def get_iteration_input(self, motion_input):# motion_input到底是几维
"""返回GRU_input
= 输入motion_input(B,video_len, d)
+ noise(B, d_noise)"""
num_samples = motion_input.shape[0] # 即num_sample = sentence_len / label_len / des_len?
noise = T.FloatTensor(num_samples, self.noise_dim).normal_(0,1) # 采样正态分布,noise = (len, d)
return torch.cat((noise, motion_input), dim = 1)
def get_gru_initial_state(self, num_samples):
"""初始化GRU:
GRU_input0 = 采样自正态分布
size是(video_len, d_TEXT+d_noise)"""
return Variable(T.FloatTensor(num_samples, self.noise_dim+ self.motion_dim).normal_())
def sample_z_motion(self, motion_input, video_len=None):
"""第一层GRU:
输入的是一组 z(noise)+
motion(B,video_len,d_TEXT),
输出的是一组'i(B*video_len, d_des)'"""
video_len = video_len if video_len is not None else self.video_len
# 获取h0
if video_len > 1:
h_t = [motion_input[:,0,:]]
else:
h_t = [motion_input]
# 一个GRU
for frame_num in range(video_len):
if len(motion_input.shape) == 2:
e_t = self.get_iteration_input(motion_input)
else:
# e_t = sentence+noise ,motion_input=(B,video_len,d_TEXT) t时刻GRU的输入(B, 1, d_nosie+d_TEXT)
e_t = self.get_iteration_input(motion_input[:,frame_num,:])
h_t.append(self.recurrent(e_t, h_t[-1])) # h(t) = (B, 1, d_TEXT)
z_m_t = [h_k.view(-1, 1, self.motion_dim) for h_k in h_t] # ?不本来就是这个了吗
z_motion = torch.cat(z_m_t[1:], dim=1).view(-1, self.motion_dim) # 将H里的h一起拼接成H'(B, video_len, d_des) -> (B*video_len, d_TEXT)
return z_motion
def motion_content_rnn(self, motion_input, content_input):
"""第二层GRU:
输入motion = i(B, video_len, d_TEXT)和
content = h0(B, 1, d_TEXT),
输出H(B*video_len, d_TEXT)"""
video_len = 1 if len(motion_input.shape) == 2 else self.video_len
h_t = [content_input] # h0
if len(motion_input.shape) == 2:
motion_input = motion_input.unsqueeze(1)
for frame_num in range(video_len):
e_t = motion_input[:,frame_num, :] # t时刻TEXT2GIST的输入(B, 1, d_des)
h_t.append(self.mocornn(e_t, h_t[-1])) # 保存输出ht = (B, 1, d_content)
c_m_t = [h_k.view(-1, 1, self.content_dim) for h_k in h_t] # (B, 1, d_content)
mocornn_co = torch.cat(c_m_t[1:], dim=1).view(-1, self.content_dim) # (B, video_len, d_content) -view> (B*video_len, d_content)
return mocornn_co
class DynamicFilterLayer(nn.Module): #MergeLayer
"""输入i与h,经过Fliter(i)*h后得到了最终一个四维的输出"""
def __init__(self, filter_size, stride=(1,1), pad=(0,0),flip_filters=False, grouping=False):
super(DynamicFilterLayer, self).__init__()
self.filter_size = filter_size #filter是三维向量(self.filter_size, self.filter_size, 1) = (15,15,1)
self.stride = (1,1) #tuple 2
self.pad = (0,0) #tuple 2
self.flip_filters = False #翻转卷积核
self.grouping = False #分组卷积
def get_output_shape_for(self, input_shapes):
"""o(t)=Filter(i)*h, Text2Gist的输出h(t)与输入i(t),卷积核矩阵向量应该是一维的:C*1*1*len
而输入h(t)应该是二维的B*len,我们看这边的input_shape居然有四个维度"""
shape = (input_shapes[0][0], 1, input_shapes[0][2], input_shapes[0][3])
def forward(self, _input, **kwargs):
#def get_output_for(self, _input, **kwargs):
# 翻转卷积核与边界填充与分组卷积(已删)
conv_mode = 'conv' if self.flip_filters else 'cross'
border_mode = self.pad
if border_mode == 'same':
border_mode = tuple(s // 2 for s in self.filter_size)
# input = [m_image, c_filter] = [(B*video_len, 1, r_image_size, r_image_size),(B*video_len, 1, filter_size, filter_size)]
image = _input[0] # image = (B*video_len, 1, r_image_size, r_image_size) 这是i,第一层GRU的输出
filters = _input[1] # filters = (B*video_len, 1, filter_size, filter_size) 这是h,第二层GRU的输出
filter_size = self.filter_size
# 创建filter 矩阵 (np.prod(filter_size) , 1, filter_size, filter_size)
filter_localexpand_np = np.reshape( #numpy.reshape(Tensor,size)
np.eye(np.prod(filter_size)), #Tensor=diag(行,列),只有一个参数那么默认是方阵
(np.prod(filter_size), filter_size[2], filter_size[0], filter_size[1])) # filter_size = (filter_size, filter_size, 1)
# 转化成Tensor
filter_localexpand = torch.from_numpy(filter_localexpand_np.astype('float32')).cuda()
# 卷积函数:输入是(B*video_len, 1, r_image_size, r_image_size), 卷积核是(B*video_len, 1, filter_size, filter_size) ,输出应该是(B*video_len, 1,r-f+1, r-f+1)
input_localexpand = F.conv2d(image, filter_localexpand, padding = self.pad) # Filter(i)
output = torch.sum(input_localexpand*filters, dim=1, keepdim=True) # Filter(i)*h
# output=四维
return output
3.5 Generator
class StoryGAN(nn.Module):
def __init__(self, video_len):
super(StoryGAN, self).__init__()
self.gf_dim = cfg.GAN.GF_DIM * 8
self.motion_dim = cfg.TEXT.DIMENSION # motion=label 在这指几个抽取的运动帧
self.content_dim = cfg.TEXT.DIMENSION # content = Σdes, 都是encoded text dim
self.noise_dim = cfg.GAN.Z_DIM # noise的dim
self.recurrent = nn.GRUCell(self.noise_dim + self.motion_dim, self.motion_dim) # sample_z_motion
self.mocornn = nn.GRUCell(self.motion_dim, self.content_dim) # motion_content_rnn
self.video_len = video_len
self.n_channels = 3
self.filter_size = 15
self.r_image_size = 15
self.define_module()
def define_module(self):
"""fc=Linear+BN+ReLU
upsample = 下采样三倍
image = conv + Tanh
filter_net = Linear+BN
image_net = Linear+BN
dfn_layer = 可训练的积核层
downsamples = Conv+BN+LReLU + Conv+BN+LReLU"""
from layers import DynamicFilterLayer
ninput = self.motion_dim + self.content_dim
ngf = self.gf_dim
# TEXT.DIMENSION -> GAN.CONDITION_DIM
self.ca_net = CA_NET()
# -> ngf x 4 x 4
self.fc = nn.Sequential(
nn.Linear(ninput, int(ngf * 4 * 4 / 2), bias=False),
nn.BatchNorm1d(ngf * 4 * 2),
nn.ReLU(True))
# ngf x 4 x 4 -> ngf/2 x 8 x 8
self.upsample1 = upBlock(ngf, ngf // 2)
# -> ngf/4 x 16 x 16
self.upsample2 = upBlock(ngf // 2, ngf // 4)
# -> ngf/8 x 32 x 32
self.upsample3 = upBlock(ngf // 4, ngf // 8)
# -> ngf/16 x 64 x 64
self.upsample4 = upBlock(ngf // 8, ngf // 16)
# -> 3 x 64 x 64
self.img = nn.Sequential(
conv3x3(ngf // 16, 3),
nn.Tanh())
self.filter_net = nn.Sequential(
nn.Linear(self.content_dim, self.filter_size**2, bias = False),
nn.BatchNorm1d(self.filter_size**2),
#nn.Softmax()
)
self.image_net = nn.Sequential(
nn.Linear(self.motion_dim, self.r_image_size**2, bias = False),
nn.BatchNorm1d(self.r_image_size**2)
)
self.dfn_layer = DynamicFilterLayer((self.filter_size, self.filter_size, 1),
pad = (self.filter_size//2, self.filter_size//2), grouping = False)
self.downsamples = nn.Sequential(
nn.Conv2d(1, ngf, 3, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ngf, ngf//2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf//2),
nn.LeakyReLU(0.2, inplace=True),
)
def get_iteration_input(self, motion_input):# motion_input到底是几维
"""返回GRU_input
= 输入motion_input(B,video_len, d)
+ noise(B, d_noise)"""
num_samples = motion_input.shape[0] # 即num_sample = sentence_len / label_len / des_len?
noise = T.FloatTensor(num_samples, self.noise_dim).normal_(0,1) # 采样正态分布,noise = (len, d)
return torch.cat((noise, motion_input), dim = 1)
def get_gru_initial_state(self, num_samples):
"""初始化GRU:
GRU_input0 = 采样自正态分布
size是(video_len, d_TEXT+d_noise)"""
return Variable(T.FloatTensor(num_samples, self.noise_dim+ self.motion_dim).normal_())
def sample_z_motion(self, motion_input, video_len=None):
"""第一层GRU:
输入的是一组 z(noise)+
motion(B,video_len,d_TEXT),
输出的是一组'i(B*video_len, d_des)'"""
video_len = video_len if video_len is not None else self.video_len
# 获取h0
if video_len > 1:
h_t = [motion_input[:,0,:]]
else:
h_t = [motion_input]
# 一个GRU
for frame_num in range(video_len):
if len(motion_input.shape) == 2:
e_t = self.get_iteration_input(motion_input)
else:
# e_t = sentence+noise ,motion_input=(B,video_len,d_TEXT) t时刻GRU的输入(B, 1, d_nosie+d_TEXT)
e_t = self.get_iteration_input(motion_input[:,frame_num,:])
h_t.append(self.recurrent(e_t, h_t[-1])) # h(t) = (B, 1, d_TEXT)
z_m_t = [h_k.view(-1, 1, self.motion_dim) for h_k in h_t] # ?不本来就是这个了吗
z_motion = torch.cat(z_m_t[1:], dim=1).view(-1, self.motion_dim) # 将H里的h一起拼接成H'(B, video_len, d_des) -> (B*video_len, d_TEXT)
return z_motion
def motion_content_rnn(self, motion_input, content_input):
"""第二层GRU:
输入motion = i(B, video_len, d_TEXT)和
content = h0(B, 1, d_TEXT),
输出H(B*video_len, d_TEXT)"""
video_len = 1 if len(motion_input.shape) == 2 else self.video_len
h_t = [content_input] # h0
if len(motion_input.shape) == 2:
motion_input = motion_input.unsqueeze(1)
for frame_num in range(video_len):
e_t = motion_input[:,frame_num, :] # t时刻TEXT2GIST的输入(B, 1, d_des)
h_t.append(self.mocornn(e_t, h_t[-1])) # 保存输出ht = (B, 1, d_content)
c_m_t = [h_k.view(-1, 1, self.content_dim) for h_k in h_t] # (B, 1, d_content)
mocornn_co = torch.cat(c_m_t[1:], dim=1).view(-1, self.content_dim) # (B, video_len, d_content) -view> (B*video_len, d_content)
return mocornn_co
def sample_videos(self, motion_input, content_input):
"""生成一组图片: GRU1+GRU2+filiter+(combine)+ Generator
输入story_des和story_des
story_motion_input = story_content_input"""
# StoryEncoder:
content_mean = content_input.mean(1) # h0 = (B, len*video_len, d_TEXT) -> (B, 1, d_TEXT) 取平均类比ht的格式(B, 1, d_TEXT)
r_code, r_mu, r_logvar = content_mean, content_mean, content_mean # (B, 1, d_TEXT) , r_μ r_σ to KL's μ、σ
# GRU2
crnn_code = self.motion_content_rnn(motion_input, content_mean)
# GRU1
zm_code = self.sample_z_motion(motion_input) # zm_code = GRU(一组句子) = (B*video_len, d_TEXT)
# cat (GRU1 + c)
content_input = content_mean.repeat(1, self.video_len) # (B, video_len * d_TEXT)
content_input = content_input.view(
(content_input.shape[0] * self.video_len, # (B*video_len), (video_len * d ÷ video_len )
int(content_input.shape[1] / self.video_len))) # (B*video_len, d_TEXT)
c_code, c_mu, c_logvar = content_input, content_input, content_input # self.ca_net(content_input)
zmc_code = torch.cat((zm_code, c_code), dim = 1) # output_GRU1 + content = (B*video_len, 2d_TEXT)
zmc_code = self.fc(zmc_code)
zmc_code = zmc_code.view(-1, int(self.gf_dim/2), 4, 4) # ( ?, 128/2, 4, 4)
# DynamicFilter(GRU2 + m)
temp = motion_input.view(-1, motion_input.shape[2]) # (B*video_len, d_TEXT)
m_code, m_mu, m_logvar = temp, temp, temp # (B*video_len, d_TEXT) , m_μ m_σ to input loss's des
m_image = self.image_net(m_code) # (B*video_len, d_TEXT)->(B*video_len, r_image_size**2)
m_image = m_image.view(-1, 1, self.r_image_size, self.r_image_size) # m_image = (B*video_len, 1, r_image_size, r_image_size)
c_filter = self.filter_net(crnn_code) # c_filter = (B*video_len, d_TEXT)-fc->(B*video_len, filter_size)
c_filter = c_filter.view(-1, 1, self.filter_size, self.filter_size) # filter = (B*video_len, 1, filter_size, filter_size)
mc_image = self.dfn_layer([m_image, c_filter])
mc_image = self.downsamples(mc_image) # 四维
# cat (DynamicFilter + cat[GRU1+c])
zmc_all = torch.cat((zmc_code, mc_image), dim = 1) # 最终image_tensor=(?, 1, ?, ?)
# Generate
h_code = self.upsample1(zmc_all)
h_code = self.upsample2(h_code)
h_code = self.upsample3(h_code)
h_code = self.upsample4(h_code)
# state size 3 x 64 x 64
h = self.img(h_code)
fake_video = h.view(int(h.size(0) / self.video_len), self.video_len, self.n_channels, h.size(3), h.size(3)) #(B, video_len, C, H, W)
fake_video = fake_video.permute(0, 2, 1, 3, 4) # (B?, C, video_len, H, W)
return None, fake_video, r_mu, r_logvar, m_mu, m_logvar
def conv3x3(in_planes, out_planes, stride=1):
"需要输入in_chanel即输入的第三维数,out_chanel即想要的filter数, kernel=3x3"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
# Upsale the spatial size by a factor of 2
def upBlock(in_planes, out_planes):
"""上采样两倍,需要输入in_chanel即输入的第三维数,out_chanel即想要的filter数"""
block = nn.Sequential(
nn.Upsample(scale_factor=2, mode='nearest'), # Upsample
conv3x3(in_planes, out_planes), # conv(kernel_size=3*3)
nn.BatchNorm2d(out_planes), # BN
nn.ReLU(True)) # ReLU
return block
class ResBlock(nn.Module):
"""(conv+BN+Relu+conv+BN)+x,需要输入的是一个通道数,因为设置conv的输入通道=输出通道/卷积核数了"""
def __init__(self, channel_num):
super(ResBlock, self).__init__()
self.block = nn.Sequential(
conv3x3(channel_num, channel_num),
nn.BatchNorm2d(channel_num),
nn.ReLU(True),
conv3x3(channel_num, channel_num),
nn.BatchNorm2d(channel_num))
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
residual = x
out = self.block(x)
out += residual
out = self.relu(out)
return out
3.6 Discrimintor
class D_IMG(nn.Module):
"""Image_D[单图] 的 单个Image Encoder:下采样到(dim*8)*4*4;用作D的输入生成图片编码,最后与文本编码后一起点积后输入打分"""
def __init__(self, use_categories = True):
super(D_IMG, self).__init__()
self.df_dim = cfg.GAN.DF_DIM
self.ef_dim = cfg.TEXT.DIMENSION
self.label_num = cfg.LABEL_NUM
self.define_module(use_categories)
def define_module(self, use_categories):
ndf, nef = self.df_dim, self.ef_dim
self.encode_img = nn.Sequential(
nn.Conv2d(3, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 32 x 32
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size (ndf*2) x 16 x 16
nn.Conv2d(ndf*2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size (ndf*4) x 8 x 8
nn.Conv2d(ndf*4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
# state size (ndf * 8) x 4 x 4)
nn.LeakyReLU(0.2, inplace=True)
)
if use_categories:
self.cate_classify = nn.Conv2d(ndf * 8, self.label_num, 4, 4, 1, bias = False)
else:
self.cate_classify = None
self.get_cond_logits = D_GET_LOGITS(ndf, nef, 1)
self.get_uncond_logits = None
def forward(self, image):
img_embedding = self.encode_img(image)
return img_embedding
class D_STY(nn.Module):
"""Story_D[一组图*Stroy] 的图片编码器:Story_Image Encoder,output(N, C*L, W, H)"""
def __init__(self):
super(D_STY, self).__init__()
self.df_dim = cfg.GAN.DF_DIM
self.ef_dim = cfg.GAN.CONDITION_DIM
self.text_dim = cfg.TEXT.DIMENSION
self.label_num = cfg.LABEL_NUM
self.define_module()
def define_module(self):
ndf, nef = self.df_dim, self.ef_dim
self.encode_img = nn.Sequential(
nn.Conv2d(3, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 32 x 32
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size (ndf*2) x 16 x 16
nn.Conv2d(ndf*2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size (ndf*4) x 8 x 8
nn.Conv2d(ndf*4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
# state size (ndf * 8) x 4 x 4)
nn.LeakyReLU(0.2, inplace=True)
)
self.get_cond_logits = D_GET_LOGITS(ndf, nef, cfg.VIDEO_LEN)
self.get_uncond_logits = None
self.cate_classify = None
def forward(self, story):
N, C, video_len, W, H = story.shape
story = story.permute(0,2,1,3,4) #(N, video_len, C, W, H)
story = story.contiguous().view(-1, C,W,H) # contigous为了二次reshape, ( N*video_len, C, W, H)
if hasattr(torch.cuda, 'empty_cache'):
torch.cuda.empty_cache()
story_embedding = torch.squeeze(self.encode_img(story)) # 图片丢进去编码,并保证最后只有四维
_, C1, W1, H1 = story_embedding.shape # 读取卷积后的CWH
#story_embedding = story_embedding.view(N,video_len, C1, W1, H1)
#story_embedding = story_embedding.mean(1).squeeze()
story_embedding = story_embedding.permute(2,3,0,1) # (W, H, N*L,C)
story_embedding = story_embedding.contiguous().view( W1, H1, N, video_len * C1) # (W, H, N, C*L)
story_embedding = story_embedding.contiguous().permute(2,3,0,1) # (B, C*L, W, H)
#print(f"D_S_output (B, C*L, W, H)? = {story_embedding.shape}")
return story_embedding
class D_GET_LOGITS(nn.Module):
"""logit函数:1:(Conv+BN+LReLU) 2:(Conv+Sigmoid) c:(conv+BN+LReLU)"""
def __init__(self, ndf, nef, video_len = 1, bcondition=True):
super(D_GET_LOGITS, self).__init__()
self.df_dim = ndf
self.ef_dim = nef
self.bcondition = bcondition
self.video_len = video_len
if bcondition:
self.conv1 = nn.Sequential(
conv3x3(ndf * 8 * video_len, ndf * 8),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True)
)
self.conv2 = nn.Sequential(
nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4),
nn.Sigmoid()
)
self.convc = nn.Sequential(
conv3x3(self.ef_dim, ndf * 8),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True)
)
else:
self.conv2 = nn.Sequential(
nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4),
nn.Sigmoid())
# if video_len > 1:
# self.storynet = nn.GRUCell(self.ef_dim, self.ef_dim)
def forward(self, h_code, c_code=None):
# conditioning output ((B, C*L, W, H), (B*video_len, d_TEXT))
#print(f"h_code are {type(h_code)}")
#print(f"c_code are {type(c_code)}")
if self.bcondition and c_code is not None:
c_code = c_code.view(-1, self.ef_dim, 1, 1)
c_code = c_code.repeat(1, 1, 4, 4)
c_code = self.convc(c_code)
h_code = self.conv1(h_code)
h_c_code = h_code * c_code
else:
h_c_code = h_code[0]
#print(f"h_code are {h_code.shape} and {type(h_code)}")
#print(f"h_c_code are {h_c_code.shape} and {type(h_c_code)}")
#print(f"h_c_code[0] are {h_c_code[0].shape} and {type(h_c_code[0])}")
sigmoid = nn.Sigmoid()
input = h_c_code
kern = Variable(torch.randn((1, self.df_dim * 8, 4, 4))).cuda()
output = sigmoid(F.conv2d(input, kern, stride=(4,4)))
#output = self.conv2(h_c_code)
return output.view(-1)
3.7 损失函数
def KL_loss(mu, logvar):
# -0.5 * sum(1 + log(σ²) - μ² - σ²)
KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
KLD = torch.mean(KLD_element).mul_(-0.5)
return KLD
def compute_discriminator_loss(netD, real_imgs, fake_imgs,
real_labels, fake_labels, real_catelabels,
conditions, gpus, mode='image'):
"""
Args:
netD:D_Img or D_STORY
real_imgs: Img_feat——load from Dataset['images']——(B, C, img_size, img_size)/(B, C, video_len, img_size, ming_size)
fake_imgs: generated images——get from sample_Video/Images——(B, C, img_size, img_size)/(B, C, video_len, img_size, ming_size)
real_labels: compute BCELOSS——torch.FloatTensor(self.batch_size).fill_(1)——(B,)
fake_labels: compute BCELOSS——torch.FloatTensor(self.batch_size).fill_(0)——(B,)
conditions: Text_feat——get from sample_Video/Images (content+motion+label)——(B, d_ef)
real_catelabels: MultiLableLoss's label——load from Dataset['labels']——(B, num_labels)
LOSS:
<BCELOSS>:
(1)D_logit = D.get_cond_and_logit(img_feat, Text)
Args:
img_feat:(B, d_df*8, 4, 4)
Text_feat:(B, d_TEXT)
Experience:
img_feat = conv(img_feat)
text_feat = conv(text_feat.view(B, d_TEXT, 1, 1).repeat(1, 1, 4, 4))
D_logit = img_feat * text_feat = (B, d_df*8, 4, 4) * (B, d_df*8, 4, 4)
Returns:
D_logit:(B, d_df*8, 4, 4)
(2)loss1 = BCELOSS(D_logit, real/fake_labels)
Args:
D_logit:(B, d_df*8, 4, 4)
real_labels:([1] x batchsize)
fake_labels:([0] x batchsize)
Experience:
二分类交叉熵损失
Returns:
float
<MultiLabelSoftMarginLoss>:
(1)cate_logits = D.cate_classify(img_feat).squeeze()
Args:
img_feat:(B, d_df*8, 4, 4)
Experience:
nn.Conv2d(ndf * 8, self.label_num, 4, 4, 1, bias = False).squeeze()
Returns:
cate_logits:(B, self.label_num)
(2)loss2 = MultiLabelLoss(cate_logits, real_catelabels)
Args:
cate_logits: (B, self.label_num)
real_catelabels:(B, self.label_num)
Experience:
多标签-二分类交叉熵损失,比如 fake[(0.5,0.6,0.7) x Batchsize] 比较 real[(1,1,0) x Batchsize],其中要求num_labels相等
Returns:
float
"""
"(1) Config"
#print("real:", real_imgs.shape)
#print("fake:", fake_imgs.shape)
#print("text_feat:", conditions.shape)
criterion = nn.BCELoss()
cate_criterion =nn.MultiLabelSoftMarginLoss()
batch_size = real_imgs.size(0)
cond = conditions.detach()
fake = fake_imgs.detach()
"(2) BCELOSS(text_feat_and_img_feat)"
# get img_feat
real_features = nn.parallel.data_parallel(netD, (real_imgs), gpus)
fake_features = nn.parallel.data_parallel(netD, (fake), gpus)
#print("netD(real):", real_features.shape)
#print("netD(fake):", fake_features.shape)
"""
if mode == 'story':
real_features_st = real_features
fake_features = fake_features.mean(1).squeeze()
real_features = real_features.mean(1).squeeze()"""
# real pairs
#print("real_features",real_features.shape) # netD(real)=(B, d_df*8, 4, 4)
#print("cond", cond.shape) # condition = (B, d_TEXT+d_condition+nums_labels)
# get text_feat and combine text_img_logit
inputs = (real_features, cond)
real_logits = nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus)
errD_real = criterion(real_logits, real_labels)
# wrong pairs
#print("wrong_features", real_features[:(batch_size-1)].shape)
#print("wrong_cond", cond[1:].shape)
# get text_feat and combine text_img_logit
inputs = (real_features[:(batch_size-1)], cond[1:])
wrong_logits = \
nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus)
errD_wrong = criterion(wrong_logits, fake_labels[1:])
# fake pairs
#print("fake_features", fake_features.shape)
#print("cond", cond.shape)
# get text_feat and combine text_img_logit
inputs = (fake_features, cond)
fake_logits = nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus)
errD_fake = criterion(fake_logits, fake_labels)
"(3)BCELOSS(only img_feat)"
if netD.get_uncond_logits is not None:
real_logits = \
nn.parallel.data_parallel(netD.get_uncond_logits,
(real_features), gpus)
fake_logits = \
nn.parallel.data_parallel(netD.get_uncond_logits,
(fake_features), gpus)
uncond_errD_real = criterion(real_logits, real_labels)
uncond_errD_fake = criterion(fake_logits, fake_labels)
#
errD = ((errD_real + uncond_errD_real) / 2. +
(errD_fake + errD_wrong + uncond_errD_fake) / 3.)
errD_real = (errD_real + uncond_errD_real) / 2.
errD_fake = (errD_fake + uncond_errD_fake) / 2.
else:
errD = errD_real + (errD_fake + errD_wrong) * 0.5
loss_report = {
mode + ' Fake/Real Discriminator Loss (Real pairs) --> ': errD_real.data.item(),
mode + ' Fake/Real Discriminator Loss (Wrong pairs) --> ': errD_wrong.data.item(),
mode + ' Fake/Real Discriminator Loss (Fake pairs) --> ': errD_fake.data.item(),
}
"(4)MultiLabel's BCELOSS"
if netD.cate_classify is not None:
# character classification loss
#print('Real features shape', real_features.shape) # (B, d_df*8, 4, 4)
cate_logits = nn.parallel.data_parallel(netD.cate_classify, real_features, gpus) # (B, nums_labels, 1, 1)
cate_logits = cate_logits.squeeze() # (B, nums_labels)
#print('Categorical logits shape', cate_logits.shape) # (B, nums_labels)
#print("real_character_labels shape", real_catelabels.shape) # (B, nums_labels)
errD = errD + 1.0 * cate_criterion(cate_logits, real_catelabels)
acc = get_multi_acc(cate_logits.cpu().data.numpy(), real_catelabels.cpu().data.numpy())
loss_report[mode + ' Character Classifier Accuracy (Discriminator) --> '] = acc
return errD, loss_report
def compute_generator_loss(netD, fake_imgs, real_labels, fake_catelabels, conditions, gpus, mode='image'):
"""没有真实图片哦。因为这是更新G,
Args:
netD:D_Image or D_Story
fake_imgs: (B, C, img_size, img_size)/(B, C, video_len, img_size, ming_size)
real_labels: torch.FloatTensor(self.batch_size).fill_(1)——(B,)
fake_catelabels: MultiLableLoss's label——load from Dataset['labels']——(B, num_labels)
conditions: Text_feat——get from sample_Video/Images (content+motion+label)——(B, d_ef)
"""
"(1)Prepare"
#print("fake:", fake_imgs.shape)
#print("text_feat:", conditions.shape)
criterion = nn.BCELoss()
cate_criterion =nn.MultiLabelSoftMarginLoss()
cond = conditions.detach()
fake_features = nn.parallel.data_parallel(netD, (fake_imgs), gpus)
#print("img_feat:", fake_features.shape)
"""
if mode == 'story':
fake_features_st = fake_features
fake_features = torch.mean(fake_features, dim=1).squeeze()"""
"(2)BCELOSS"
# fake pairs
inputs = (fake_features, cond)
fake_logits = nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus)
errD_fake = criterion(fake_logits, real_labels)
if netD.get_uncond_logits is not None:
fake_logits = \
nn.parallel.data_parallel(netD.get_uncond_logits,
(fake_features), gpus)
uncond_errD_fake = criterion(fake_logits, real_labels)
errD_fake += uncond_errD_fake
loss_report = {
mode + ' Fake/Real Generator Loss (Fake pairs) --> ': errD_fake.data.item(),
}
"(3)MultiLabelLOSS"
if netD.cate_classify is not None:
# print('Fake features shape', fake_features.shape)
cate_logits = nn.parallel.data_parallel(netD.cate_classify, fake_features, gpus)
cate_logits = cate_logits.mean(dim=-1).mean(dim=-1)
cate_logits = cate_logits.squeeze()
# print(cate_logits.shape, fake_catelabels.shape)
errD_fake = errD_fake + 1.0 * cate_criterion(cate_logits, fake_catelabels)
acc = get_multi_acc(cate_logits.cpu().data.numpy(), fake_catelabels.cpu().data.numpy())
loss_report[mode + ' Character Classifier Accuracy (Generator) --> '] = acc
return errD_fake, loss_report
四、实验
4.1 Dataset
4.2 实验
五、结论
我们将故事可视化任务作为一个顺序条件生成问题来研究。提出的StoryGAN模型将当前输入的句子与上下文信息结合起来处理任务。这是通过上下文编码器中的Text2Gist组件实现的。从烧蚀测试,两级鉴别器和输入上的循环结构有助于确保生成的图像和故事的一致性,而上下文编码器有效地提供了图像生成器的局部和全局条件