StoryVisualization系列(1)StoryGAN:一个用于故事可视化生成的GAN

目录

前言:

摘要:

一、Introduce

二、Related Works

三、Methods

3.1 Background

3.2 框架

3.3 StoryEncoder

 3.4 Context Encoder

3.5 Generator

3.6 Discrimintor

3.7 损失函数

四、实验

4.1 Dataset

4.2 实验

 五、结论


前言:

        本文代码节选与修改自:https://github.com/yitong91/StoryGAN

        论文地址:https://arxiv.org/abs/1812.02784

        简介:Story Visualization任务的提出者,StoryGAN

        架构:StoryEncoder + ContextEncoder + 上采样生成 + D_img + D_Story

摘要:

我们提出了一个新的任务,叫做故事可视化。 对于一个多句话的段落,故事可以通过生成一系列图像来可视化,每个图像对应一个句子。 与视频生成相比,故事可视化不太关注生成图像(帧)的连续性,而是更关注动态场景和角色之间的全局一致性——这是任何单一图像或视频生成方法都无法解决的挑战。 因此,我们提出了一种基于顺序条件GAN框架的故事-图像-序列生成模型——StoryGAN。 我们的模型是独特的,因为它包含一个动态跟踪故事流的深度上下文编码器,以及两个在故事和图像层面的鉴别器,以提高图像质量和生成序列的一致性。 为了评估该模型,我们修改了现有的数据集,创建了CLEVR-SV和Pororo-SV数据集。 根据经验,StoryGAN在图像质量、上下文一致性指标和人类评价方面优于最先进的模型。 

一、Introduce

1)面临的挑战
    a)生成图片的连贯性和连续性
    b)不同的图片的具体的人物外貌和背景也必须连贯

2)与video生成的异同
    a)影视生成更注重动作特征的抓取,来维持动作的连续性过渡,但故事可视化,更注重于静态帧,不注重动作特征
    b)对于视频生成来说,通常有一个基本的固定背景,然后模型捕捉在故事线的不同时段下背景的改变。以这种方式来说,为故事配图,其实就是捕捉长视频的关键帧。

3)模型的构造:

二、Related Works

1)文本生成图像:基于流、VAE、GAN

2)视频生成:文本到视频/图片到视频,但有不同,一个是我们永远是静态帧而不是连续的动作,二是我们的输入是不同的,视频生成是一句话就生成了n个图(短视频)

3)可视化故事的逆任务:看n个图叙述故事,也需要T2I/强化学习

4)基于检索的故事配图:

三、Methods

3.1 Background

任务描述:

给定一段故事S{S1S2S3S4 S5}

配予一组图片X{x1x2x3x4x5}

不同的挑战

          1)与整体故事的一致性

          2)风格内容的连贯性

3.2 框架

简介:

        文本加工:句子间信息GRU1提取+全文信息(StoryEncoder提取+  上一张的风格内容连贯Text2GIST链接

        图像生成:n个生成器+两个判别器一对一判别、故事判别)

具体:

        文本预处理
                1)编码:
                            a)全文编码Story Encoder到低维
                            b)句子编码:编码到128维
                2)信息融合
                            a)前文每句的信息:GPU(RNN)
                            b)全文信息+上一张生成的信息:Text2GIST(GRU+DynamicFilter)
        GAN:
                 1)每个句子对应一个Generator(上采样生成)
                 2)总共配备两个Discriminator(Text-Image)(Story-Image)

3.3 StoryEncoder

 1)编码结果h是从正态分布(uσ²)之中随机采样

 2uσ都是通过映射得到

u=MLPStoryσ²=MLPStory

3)再乘上一些噪声,最后得到

 4)为了梯度不崩塌,和增强语义空间的连贯性:加入了惩罚项KL散度:计算的是采样分布和标准正态分布之间的距离 在算loss时加)

class CA_NET(nn.Module):
   """StoryEncoder:
   需要输入初始故事编码text_embedding(len, d_TEXT),
   MLP变换推出μ,σ = (len, c)
   采样得到编码h0"""
def __init__(self):
    super(CA_NET, self).__init__()
    self.t_dim = cfg.TEXT.DIMENSION
    self.c_dim = cfg.GAN.CONDITION_DIM
    self.fc = nn.Linear(self.t_dim, self.c_dim * 2, bias=True)
    self.relu = nn.ReLU()

    def encode(self, text_embedding):
        """输入text_embedding,返回σ和μ
        输入text_embedding,此层对经过MLP(Story_emb)后,拆成两份为μ与σ"""
        # Story_emb=(len,d) --fc+relu--> (len,2c) -> mu=(len,0~c行) ; logvar =(len,c~2c行)
        x = self.relu(self.fc(text_embedding)) # Relu(wx+b)
        mu = x[:, :self.c_dim]
        logvar = x[:, self.c_dim:]
        return mu, logvar

    def reparametrize(self, mu, logvar):
        """输入σ和μ,返回h0,采样自创立的正态分布:
        h~(μ,σ) = (noise * σ) + μ"""
        std = logvar.mul(0.5).exp_() # std即均值,σ = 1/2(e^logvar),而logvar = encoder(story_emb) = MLP(story_emb)

        eps = torch.cuda.FloatTensor(std.size()).normal_() # eps即noise,采样自标准正态分布(size和std一样)
        eps = Variable(eps) # 让nosie的采样变为可训练的
        # h = (noise * σ) + μ
        return eps.mul(std).add_(mu)

    def forward(self, text_embedding):
        """需要输入story_emb,h = (noise * σ) + μ ,  μ、σ = MLP(S)"""
        mu, logvar = self.encode(text_embedding)
        c_code = self.reparametrize(mu, logvar)
        return c_code, mu, logvar

 3.4 Context Encoder

 输入:sentence 与 content作为memory初始化

 架构:第一层GRU,第二层GRU+DynamicFilter


    def get_iteration_input(self, motion_input):# motion_input到底是几维
        """返回GRU_input
        = 输入motion_input(B,video_len, d)
        + noise(B, d_noise)"""
        num_samples = motion_input.shape[0]   # 即num_sample = sentence_len / label_len / des_len?

        noise = T.FloatTensor(num_samples, self.noise_dim).normal_(0,1) # 采样正态分布,noise = (len, d)
        return torch.cat((noise, motion_input), dim = 1)

    def get_gru_initial_state(self, num_samples):
        """初始化GRU:
        GRU_input0 = 采样自正态分布
        size是(video_len, d_TEXT+d_noise)"""
        return Variable(T.FloatTensor(num_samples, self.noise_dim+ self.motion_dim).normal_())

    def sample_z_motion(self, motion_input, video_len=None):
        """第一层GRU:
        输入的是一组 z(noise)+
        motion(B,video_len,d_TEXT),
        输出的是一组'i(B*video_len, d_des)'"""
        video_len = video_len if video_len is not None else self.video_len
        # 获取h0
        if video_len > 1:
            h_t = [motion_input[:,0,:]]
        else:
            h_t = [motion_input]
        # 一个GRU
        for frame_num in range(video_len):
            if len(motion_input.shape) == 2:
                e_t = self.get_iteration_input(motion_input)
            else:
                # e_t = sentence+noise ,motion_input=(B,video_len,d_TEXT) t时刻GRU的输入(B, 1, d_nosie+d_TEXT)
                e_t = self.get_iteration_input(motion_input[:,frame_num,:])
            h_t.append(self.recurrent(e_t, h_t[-1]))  # h(t) = (B, 1, d_TEXT)

        z_m_t = [h_k.view(-1, 1, self.motion_dim) for h_k in h_t]  # ?不本来就是这个了吗
        z_motion = torch.cat(z_m_t[1:], dim=1).view(-1, self.motion_dim)  # 将H里的h一起拼接成H'(B, video_len, d_des) -> (B*video_len, d_TEXT)
        return z_motion

    def motion_content_rnn(self, motion_input, content_input):
        """第二层GRU:
        输入motion = i(B, video_len, d_TEXT)和
        content = h0(B, 1, d_TEXT),
        输出H(B*video_len, d_TEXT)"""
        video_len = 1 if len(motion_input.shape) == 2 else self.video_len
        h_t = [content_input]  # h0
        if len(motion_input.shape) == 2:
            motion_input = motion_input.unsqueeze(1)
        for frame_num in range(video_len):
            e_t = motion_input[:,frame_num, :]  # t时刻TEXT2GIST的输入(B, 1, d_des)
            h_t.append(self.mocornn(e_t, h_t[-1]))  # 保存输出ht = (B, 1, d_content)
        
        c_m_t = [h_k.view(-1, 1, self.content_dim) for h_k in h_t]  # (B, 1, d_content)
        mocornn_co = torch.cat(c_m_t[1:], dim=1).view(-1, self.content_dim) #  (B, video_len, d_content) -view> (B*video_len, d_content)
        return mocornn_co

class DynamicFilterLayer(nn.Module): #MergeLayer
    """输入i与h,经过Fliter(i)*h后得到了最终一个四维的输出"""
    def __init__(self, filter_size, stride=(1,1), pad=(0,0),flip_filters=False, grouping=False):
        super(DynamicFilterLayer, self).__init__()

        self.filter_size = filter_size    #filter是三维向量(self.filter_size, self.filter_size, 1) = (15,15,1)

        self.stride = (1,1)               #tuple 2
        self.pad = (0,0)                  #tuple 2
        self.flip_filters = False         #翻转卷积核
        self.grouping = False             #分组卷积
 
    def get_output_shape_for(self, input_shapes):
        """o(t)=Filter(i)*h, Text2Gist的输出h(t)与输入i(t),卷积核矩阵向量应该是一维的:C*1*1*len
        而输入h(t)应该是二维的B*len,我们看这边的input_shape居然有四个维度"""
        shape = (input_shapes[0][0], 1, input_shapes[0][2], input_shapes[0][3])


    def forward(self, _input, **kwargs):
    #def get_output_for(self, _input, **kwargs):
        # 翻转卷积核与边界填充与分组卷积(已删)
        conv_mode = 'conv' if self.flip_filters else 'cross'
        border_mode = self.pad
        if border_mode == 'same':
            border_mode = tuple(s // 2 for s in self.filter_size)

        # input = [m_image, c_filter] = [(B*video_len, 1, r_image_size, r_image_size),(B*video_len, 1, filter_size, filter_size)]
        image = _input[0]  # image = (B*video_len, 1, r_image_size, r_image_size) 这是i,第一层GRU的输出
        filters = _input[1]  # filters = (B*video_len, 1, filter_size, filter_size) 这是h,第二层GRU的输出
        filter_size = self.filter_size

        # 创建filter 矩阵 (np.prod(filter_size) , 1, filter_size, filter_size)
        filter_localexpand_np = np.reshape(  #numpy.reshape(Tensor,size)
            np.eye(np.prod(filter_size)),    #Tensor=diag(行,列),只有一个参数那么默认是方阵
            (np.prod(filter_size), filter_size[2], filter_size[0], filter_size[1]))  # filter_size = (filter_size, filter_size, 1)
        # 转化成Tensor
        filter_localexpand = torch.from_numpy(filter_localexpand_np.astype('float32')).cuda()

        # 卷积函数:输入是(B*video_len, 1, r_image_size, r_image_size), 卷积核是(B*video_len, 1, filter_size, filter_size) ,输出应该是(B*video_len, 1,r-f+1, r-f+1)
        input_localexpand = F.conv2d(image, filter_localexpand, padding = self.pad)  # Filter(i)
        output = torch.sum(input_localexpand*filters, dim=1, keepdim=True) #  Filter(i)*h
        #  output=四维
        return output    

3.5 Generator

class StoryGAN(nn.Module):
    def __init__(self, video_len):
        super(StoryGAN, self).__init__()
        self.gf_dim = cfg.GAN.GF_DIM * 8

        self.motion_dim = cfg.TEXT.DIMENSION # motion=label 在这指几个抽取的运动帧
        self.content_dim = cfg.TEXT.DIMENSION # content = Σdes, 都是encoded text dim
        self.noise_dim = cfg.GAN.Z_DIM  # noise的dim

        self.recurrent = nn.GRUCell(self.noise_dim + self.motion_dim, self.motion_dim)  # sample_z_motion
        self.mocornn = nn.GRUCell(self.motion_dim, self.content_dim)  # motion_content_rnn

        self.video_len = video_len
        self.n_channels = 3
        self.filter_size = 15
        self.r_image_size = 15
        self.define_module()
        

    def define_module(self):
        """fc=Linear+BN+ReLU
        upsample = 下采样三倍
        image = conv + Tanh
        filter_net = Linear+BN
        image_net = Linear+BN
        dfn_layer = 可训练的积核层
        downsamples = Conv+BN+LReLU + Conv+BN+LReLU"""
        from layers import DynamicFilterLayer
        ninput = self.motion_dim + self.content_dim
        ngf = self.gf_dim
        # TEXT.DIMENSION -> GAN.CONDITION_DIM
        self.ca_net = CA_NET()

        # -> ngf x 4 x 4
        self.fc = nn.Sequential(
            nn.Linear(ninput, int(ngf * 4 * 4 / 2), bias=False),
            nn.BatchNorm1d(ngf * 4 * 2),
            nn.ReLU(True))

        # ngf x 4 x 4 -> ngf/2 x 8 x 8
        self.upsample1 = upBlock(ngf, ngf // 2)
        # -> ngf/4 x 16 x 16
        self.upsample2 = upBlock(ngf // 2, ngf // 4)
        # -> ngf/8 x 32 x 32
        self.upsample3 = upBlock(ngf // 4, ngf // 8)
        # -> ngf/16 x 64 x 64
        self.upsample4 = upBlock(ngf // 8, ngf // 16)
        # -> 3 x 64 x 64
        self.img = nn.Sequential(
            conv3x3(ngf // 16, 3),
            nn.Tanh())

        self.filter_net = nn.Sequential(
            nn.Linear(self.content_dim,  self.filter_size**2, bias = False),
            nn.BatchNorm1d(self.filter_size**2),
            #nn.Softmax()
            )

        self.image_net = nn.Sequential(
            nn.Linear(self.motion_dim, self.r_image_size**2, bias = False),
            nn.BatchNorm1d(self.r_image_size**2)
            )

        self.dfn_layer = DynamicFilterLayer((self.filter_size, self.filter_size, 1), 
            pad = (self.filter_size//2, self.filter_size//2), grouping = False)

        self.downsamples = nn.Sequential(
            nn.Conv2d(1, ngf, 3, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ngf, ngf//2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf//2),
            nn.LeakyReLU(0.2, inplace=True),
            )


    def get_iteration_input(self, motion_input):# motion_input到底是几维
        """返回GRU_input
        = 输入motion_input(B,video_len, d)
        + noise(B, d_noise)"""
        num_samples = motion_input.shape[0]   # 即num_sample = sentence_len / label_len / des_len?

        noise = T.FloatTensor(num_samples, self.noise_dim).normal_(0,1) # 采样正态分布,noise = (len, d)
        return torch.cat((noise, motion_input), dim = 1)

    def get_gru_initial_state(self, num_samples):
        """初始化GRU:
        GRU_input0 = 采样自正态分布
        size是(video_len, d_TEXT+d_noise)"""
        return Variable(T.FloatTensor(num_samples, self.noise_dim+ self.motion_dim).normal_())

    def sample_z_motion(self, motion_input, video_len=None):
        """第一层GRU:
        输入的是一组 z(noise)+
        motion(B,video_len,d_TEXT),
        输出的是一组'i(B*video_len, d_des)'"""
        video_len = video_len if video_len is not None else self.video_len
        # 获取h0
        if video_len > 1:
            h_t = [motion_input[:,0,:]]
        else:
            h_t = [motion_input]
        # 一个GRU
        for frame_num in range(video_len):
            if len(motion_input.shape) == 2:
                e_t = self.get_iteration_input(motion_input)
            else:
                # e_t = sentence+noise ,motion_input=(B,video_len,d_TEXT) t时刻GRU的输入(B, 1, d_nosie+d_TEXT)
                e_t = self.get_iteration_input(motion_input[:,frame_num,:])
            h_t.append(self.recurrent(e_t, h_t[-1]))  # h(t) = (B, 1, d_TEXT)

        z_m_t = [h_k.view(-1, 1, self.motion_dim) for h_k in h_t]  # ?不本来就是这个了吗
        z_motion = torch.cat(z_m_t[1:], dim=1).view(-1, self.motion_dim)  # 将H里的h一起拼接成H'(B, video_len, d_des) -> (B*video_len, d_TEXT)
        return z_motion

    def motion_content_rnn(self, motion_input, content_input):
        """第二层GRU:
        输入motion = i(B, video_len, d_TEXT)和
        content = h0(B, 1, d_TEXT),
        输出H(B*video_len, d_TEXT)"""
        video_len = 1 if len(motion_input.shape) == 2 else self.video_len
        h_t = [content_input]  # h0
        if len(motion_input.shape) == 2:
            motion_input = motion_input.unsqueeze(1)
        for frame_num in range(video_len):
            e_t = motion_input[:,frame_num, :]  # t时刻TEXT2GIST的输入(B, 1, d_des)
            h_t.append(self.mocornn(e_t, h_t[-1]))  # 保存输出ht = (B, 1, d_content)
        
        c_m_t = [h_k.view(-1, 1, self.content_dim) for h_k in h_t]  # (B, 1, d_content)
        mocornn_co = torch.cat(c_m_t[1:], dim=1).view(-1, self.content_dim) #  (B, video_len, d_content) -view> (B*video_len, d_content)
        return mocornn_co

    def sample_videos(self, motion_input, content_input):
        """生成一组图片: GRU1+GRU2+filiter+(combine)+ Generator
        输入story_des和story_des
        story_motion_input = story_content_input"""

        # StoryEncoder:
        content_mean = content_input.mean(1)  # h0 = (B, len*video_len, d_TEXT) -> (B, 1, d_TEXT) 取平均类比ht的格式(B, 1, d_TEXT)
        r_code, r_mu, r_logvar = content_mean, content_mean, content_mean  # (B, 1, d_TEXT) , r_μ r_σ to KL's μ、σ

        # GRU2
        crnn_code = self.motion_content_rnn(motion_input, content_mean)

        # GRU1
        zm_code = self.sample_z_motion(motion_input) # zm_code = GRU(一组句子) = (B*video_len, d_TEXT)

        # cat (GRU1 + c)
        content_input = content_mean.repeat(1, self.video_len)  # (B, video_len * d_TEXT)
        content_input = content_input.view(
            (content_input.shape[0] * self.video_len,  # (B*video_len), (video_len * d ÷ video_len )
             int(content_input.shape[1] / self.video_len)))  # (B*video_len, d_TEXT)
        c_code, c_mu, c_logvar = content_input, content_input, content_input  # self.ca_net(content_input)
        zmc_code = torch.cat((zm_code, c_code), dim = 1)  # output_GRU1 + content = (B*video_len, 2d_TEXT)
        zmc_code = self.fc(zmc_code)
        zmc_code = zmc_code.view(-1, int(self.gf_dim/2), 4, 4)  # ( ?, 128/2, 4, 4)

        # DynamicFilter(GRU2 + m)
        temp = motion_input.view(-1, motion_input.shape[2])  # (B*video_len, d_TEXT)
        m_code, m_mu, m_logvar = temp, temp, temp  # (B*video_len, d_TEXT) , m_μ m_σ to input loss's des
        m_image = self.image_net(m_code)  # (B*video_len, d_TEXT)->(B*video_len, r_image_size**2)
        m_image = m_image.view(-1, 1, self.r_image_size, self.r_image_size)  # m_image = (B*video_len, 1, r_image_size, r_image_size)
        c_filter = self.filter_net(crnn_code)  # c_filter = (B*video_len, d_TEXT)-fc->(B*video_len, filter_size)
        c_filter = c_filter.view(-1, 1, self.filter_size, self.filter_size)  # filter =  (B*video_len, 1, filter_size, filter_size)
        mc_image = self.dfn_layer([m_image, c_filter])
        mc_image = self.downsamples(mc_image)  # 四维

        # cat (DynamicFilter + cat[GRU1+c])
        zmc_all = torch.cat((zmc_code, mc_image), dim = 1) # 最终image_tensor=(?, 1, ?, ?)

        # Generate
        h_code = self.upsample1(zmc_all)
        h_code = self.upsample2(h_code)
        h_code = self.upsample3(h_code)
        h_code = self.upsample4(h_code)
        # state size 3 x 64 x 64
        h = self.img(h_code)
        fake_video = h.view(int(h.size(0) / self.video_len), self.video_len, self.n_channels, h.size(3), h.size(3)) #(B, video_len, C, H, W)
        fake_video = fake_video.permute(0, 2, 1, 3, 4)  # (B?, C, video_len, H, W)
        return None, fake_video, r_mu, r_logvar, m_mu, m_logvar

def conv3x3(in_planes, out_planes, stride=1):
    "需要输入in_chanel即输入的第三维数,out_chanel即想要的filter数, kernel=3x3"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


# Upsale the spatial size by a factor of 2
def upBlock(in_planes, out_planes):
    """上采样两倍,需要输入in_chanel即输入的第三维数,out_chanel即想要的filter数"""
    block = nn.Sequential(
        nn.Upsample(scale_factor=2, mode='nearest'),  # Upsample
        conv3x3(in_planes, out_planes),               # conv(kernel_size=3*3)
        nn.BatchNorm2d(out_planes),                   # BN
        nn.ReLU(True))                                # ReLU
    return block


class ResBlock(nn.Module):
    """(conv+BN+Relu+conv+BN)+x,需要输入的是一个通道数,因为设置conv的输入通道=输出通道/卷积核数了"""
    def __init__(self, channel_num):
        super(ResBlock, self).__init__()
        self.block = nn.Sequential(
            conv3x3(channel_num, channel_num),
            nn.BatchNorm2d(channel_num),
            nn.ReLU(True),
            conv3x3(channel_num, channel_num),
            nn.BatchNorm2d(channel_num))
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        residual = x
        out = self.block(x)
        out += residual
        out = self.relu(out)
        return out

3.6 Discrimintor

 

class D_IMG(nn.Module):
    """Image_D[单图] 的 单个Image Encoder:下采样到(dim*8)*4*4;用作D的输入生成图片编码,最后与文本编码后一起点积后输入打分"""
    def __init__(self, use_categories = True):
        super(D_IMG, self).__init__()
        self.df_dim = cfg.GAN.DF_DIM
        self.ef_dim = cfg.TEXT.DIMENSION
        self.label_num = cfg.LABEL_NUM
        self.define_module(use_categories)

    def define_module(self, use_categories):
        ndf, nef = self.df_dim, self.ef_dim
        self.encode_img = nn.Sequential(
            nn.Conv2d(3, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size (ndf*2) x 16 x 16
            nn.Conv2d(ndf*2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size (ndf*4) x 8 x 8
            nn.Conv2d(ndf*4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            # state size (ndf * 8) x 4 x 4)
            nn.LeakyReLU(0.2, inplace=True)
        )

        if use_categories:
            self.cate_classify = nn.Conv2d(ndf * 8, self.label_num, 4, 4, 1, bias = False)
        else:
            self.cate_classify = None
        self.get_cond_logits = D_GET_LOGITS(ndf, nef, 1)
        self.get_uncond_logits = None

    def forward(self, image):
        img_embedding = self.encode_img(image)

        return img_embedding


class D_STY(nn.Module):
    """Story_D[一组图*Stroy] 的图片编码器:Story_Image Encoder,output(N, C*L, W, H)"""
    def __init__(self):
        super(D_STY, self).__init__()
        self.df_dim = cfg.GAN.DF_DIM
        self.ef_dim = cfg.GAN.CONDITION_DIM
        self.text_dim = cfg.TEXT.DIMENSION
        self.label_num = cfg.LABEL_NUM
        self.define_module()

    def define_module(self):
        ndf, nef = self.df_dim, self.ef_dim
        self.encode_img = nn.Sequential(
            nn.Conv2d(3, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size (ndf*2) x 16 x 16
            nn.Conv2d(ndf*2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size (ndf*4) x 8 x 8
            nn.Conv2d(ndf*4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            # state size (ndf * 8) x 4 x 4)
            nn.LeakyReLU(0.2, inplace=True)
        )

        self.get_cond_logits = D_GET_LOGITS(ndf, nef, cfg.VIDEO_LEN)
        self.get_uncond_logits = None
        self.cate_classify = None

    def forward(self, story):
        N, C, video_len, W, H = story.shape
        story = story.permute(0,2,1,3,4)  #(N, video_len, C, W, H)
        story = story.contiguous().view(-1, C,W,H) # contigous为了二次reshape, ( N*video_len, C, W, H)
        if hasattr(torch.cuda, 'empty_cache'):
            torch.cuda.empty_cache()
        story_embedding = torch.squeeze(self.encode_img(story))  # 图片丢进去编码,并保证最后只有四维
        _, C1, W1, H1 = story_embedding.shape  # 读取卷积后的CWH
        #story_embedding = story_embedding.view(N,video_len, C1, W1, H1)
        #story_embedding = story_embedding.mean(1).squeeze()
        story_embedding = story_embedding.permute(2,3,0,1) # (W, H, N*L,C)
        story_embedding = story_embedding.contiguous().view( W1, H1, N, video_len * C1)  # (W, H, N, C*L)
        story_embedding = story_embedding.contiguous().permute(2,3,0,1) # (B, C*L, W, H)
        #print(f"D_S_output (B, C*L, W, H)? = {story_embedding.shape}")
        return story_embedding


class D_GET_LOGITS(nn.Module):
    """logit函数:1:(Conv+BN+LReLU) 2:(Conv+Sigmoid) c:(conv+BN+LReLU)"""
    def __init__(self, ndf, nef, video_len = 1, bcondition=True):
        super(D_GET_LOGITS, self).__init__()
        self.df_dim = ndf
        self.ef_dim = nef
        self.bcondition = bcondition
        self.video_len = video_len
        if bcondition:
            self.conv1 = nn.Sequential(
                conv3x3(ndf * 8 * video_len, ndf * 8),
                nn.BatchNorm2d(ndf * 8),
                nn.LeakyReLU(0.2, inplace=True)
                )
            self.conv2 = nn.Sequential(
                nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4),
                nn.Sigmoid()
                )
            self.convc = nn.Sequential(
                conv3x3(self.ef_dim, ndf * 8),
                nn.BatchNorm2d(ndf * 8),
                nn.LeakyReLU(0.2, inplace=True)
                )
        else:
            self.conv2 = nn.Sequential(
                nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4),
                nn.Sigmoid())
        # if video_len > 1:
        #     self.storynet = nn.GRUCell(self.ef_dim, self.ef_dim)

    def forward(self, h_code, c_code=None):
        # conditioning output    ((B, C*L, W, H), (B*video_len, d_TEXT))
        #print(f"h_code are {type(h_code)}")
        #print(f"c_code are {type(c_code)}")
        
        if self.bcondition and c_code is not None:
            c_code = c_code.view(-1, self.ef_dim, 1, 1)
            c_code = c_code.repeat(1, 1, 4, 4)
            c_code = self.convc(c_code)
            h_code = self.conv1(h_code)
            h_c_code = h_code * c_code
        else:
            h_c_code = h_code[0]
        #print(f"h_code are {h_code.shape} and {type(h_code)}")
        #print(f"h_c_code are {h_c_code.shape} and {type(h_c_code)}")
        #print(f"h_c_code[0] are {h_c_code[0].shape} and {type(h_c_code[0])}")
        
        sigmoid = nn.Sigmoid()
        input = h_c_code
        kern = Variable(torch.randn((1, self.df_dim * 8, 4, 4))).cuda()
        output = sigmoid(F.conv2d(input, kern, stride=(4,4)))

        #output = self.conv2(h_c_code)
        return output.view(-1)

3.7 损失函数

 


def KL_loss(mu, logvar):
    # -0.5 * sum(1 + log(σ²) - μ² - σ²)
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.mean(KLD_element).mul_(-0.5)
    return KLD


def compute_discriminator_loss(netD, real_imgs, fake_imgs,
                               real_labels, fake_labels, real_catelabels,
                               conditions, gpus, mode='image'):
    """
    Args:
        netD:D_Img or D_STORY
        real_imgs: Img_feat——load from Dataset['images']——(B, C, img_size, img_size)/(B, C, video_len, img_size, ming_size)
        fake_imgs: generated images——get from sample_Video/Images——(B, C, img_size, img_size)/(B, C, video_len, img_size, ming_size)
        real_labels: compute BCELOSS——torch.FloatTensor(self.batch_size).fill_(1)——(B,)
        fake_labels: compute BCELOSS——torch.FloatTensor(self.batch_size).fill_(0)——(B,)
        conditions: Text_feat——get from sample_Video/Images (content+motion+label)——(B, d_ef)
        real_catelabels: MultiLableLoss's label——load from Dataset['labels']——(B, num_labels)
    LOSS:
        <BCELOSS>:
            (1)D_logit = D.get_cond_and_logit(img_feat, Text)
                    Args:
                        img_feat:(B, d_df*8, 4, 4)
                        Text_feat:(B, d_TEXT)
                    Experience:
                        img_feat  = conv(img_feat)
                        text_feat = conv(text_feat.view(B, d_TEXT, 1, 1).repeat(1, 1, 4, 4))
                        D_logit = img_feat * text_feat = (B, d_df*8, 4, 4) * (B, d_df*8, 4, 4)
                    Returns:
                        D_logit:(B, d_df*8, 4, 4)

            (2)loss1 = BCELOSS(D_logit, real/fake_labels)
                    Args:
                        D_logit:(B, d_df*8, 4, 4)
                        real_labels:([1] x batchsize)
                        fake_labels:([0] x batchsize)
                    Experience:
                        二分类交叉熵损失
                    Returns:
                        float

        <MultiLabelSoftMarginLoss>:
            (1)cate_logits = D.cate_classify(img_feat).squeeze()
                    Args:
                        img_feat:(B, d_df*8, 4, 4)
                    Experience:
                        nn.Conv2d(ndf * 8, self.label_num, 4, 4, 1, bias = False).squeeze()
                    Returns:
                        cate_logits:(B, self.label_num)
            (2)loss2 = MultiLabelLoss(cate_logits, real_catelabels)
                    Args:
                        cate_logits:    (B, self.label_num)
                        real_catelabels:(B, self.label_num)
                    Experience:
                        多标签-二分类交叉熵损失,比如 fake[(0.5,0.6,0.7) x Batchsize] 比较 real[(1,1,0) x Batchsize],其中要求num_labels相等
                    Returns:
                        float

    """

    "(1) Config"
    #print("real:", real_imgs.shape)
    #print("fake:", fake_imgs.shape)
    #print("text_feat:", conditions.shape)
    criterion = nn.BCELoss()
    cate_criterion =nn.MultiLabelSoftMarginLoss()
    batch_size = real_imgs.size(0)
    cond = conditions.detach()
    fake = fake_imgs.detach()

    "(2) BCELOSS(text_feat_and_img_feat)"
    # get img_feat
    real_features = nn.parallel.data_parallel(netD, (real_imgs), gpus)
    fake_features = nn.parallel.data_parallel(netD, (fake), gpus)

    #print("netD(real):", real_features.shape)
    #print("netD(fake):", fake_features.shape)
    """
    if mode == 'story':
        real_features_st = real_features
        fake_features = fake_features.mean(1).squeeze()
        real_features = real_features.mean(1).squeeze()"""

    # real pairs
    #print("real_features",real_features.shape)  # netD(real)=(B, d_df*8, 4, 4)
    #print("cond", cond.shape)  # condition = (B, d_TEXT+d_condition+nums_labels)
    # get text_feat and combine text_img_logit
    inputs = (real_features, cond)
    real_logits = nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus)
    errD_real = criterion(real_logits, real_labels)
    # wrong pairs
    #print("wrong_features", real_features[:(batch_size-1)].shape)
    #print("wrong_cond", cond[1:].shape)
    # get text_feat and combine text_img_logit
    inputs = (real_features[:(batch_size-1)], cond[1:])
    wrong_logits = \
        nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus)
    errD_wrong = criterion(wrong_logits, fake_labels[1:])
    # fake pairs
    #print("fake_features", fake_features.shape)
    #print("cond", cond.shape)
    # get text_feat and combine text_img_logit
    inputs = (fake_features, cond)
    fake_logits = nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus)
    errD_fake = criterion(fake_logits, fake_labels)

    "(3)BCELOSS(only img_feat)"
    if netD.get_uncond_logits is not None:
        real_logits = \
            nn.parallel.data_parallel(netD.get_uncond_logits,
                                      (real_features), gpus)
        fake_logits = \
            nn.parallel.data_parallel(netD.get_uncond_logits,
                                      (fake_features), gpus)
        uncond_errD_real = criterion(real_logits, real_labels)
        uncond_errD_fake = criterion(fake_logits, fake_labels)
        #
        errD = ((errD_real + uncond_errD_real) / 2. +
                (errD_fake + errD_wrong + uncond_errD_fake) / 3.)
        errD_real = (errD_real + uncond_errD_real) / 2.
        errD_fake = (errD_fake + uncond_errD_fake) / 2.
    else:
        errD = errD_real + (errD_fake + errD_wrong) * 0.5

    loss_report = {
        mode + ' Fake/Real Discriminator Loss (Real pairs) --> ': errD_real.data.item(),
        mode + ' Fake/Real Discriminator Loss (Wrong pairs) --> ': errD_wrong.data.item(),
        mode + ' Fake/Real Discriminator Loss (Fake pairs) --> ': errD_fake.data.item(),
    }
    "(4)MultiLabel's BCELOSS"
    if netD.cate_classify is not None:
        # character classification loss
        #print('Real features shape', real_features.shape)  # (B, d_df*8, 4, 4)
        cate_logits = nn.parallel.data_parallel(netD.cate_classify, real_features, gpus)  # (B, nums_labels, 1, 1)
        cate_logits = cate_logits.squeeze()   # (B, nums_labels)
        #print('Categorical logits shape', cate_logits.shape)   # (B, nums_labels)
        #print("real_character_labels shape", real_catelabels.shape)   # (B, nums_labels)
        errD = errD + 1.0 * cate_criterion(cate_logits, real_catelabels)
        acc = get_multi_acc(cate_logits.cpu().data.numpy(), real_catelabels.cpu().data.numpy())
        loss_report[mode + ' Character Classifier Accuracy (Discriminator) --> '] = acc

    return errD, loss_report


def compute_generator_loss(netD, fake_imgs, real_labels, fake_catelabels, conditions, gpus, mode='image'):
    """没有真实图片哦。因为这是更新G,
    Args:
        netD:D_Image or D_Story
        fake_imgs: (B, C, img_size, img_size)/(B, C, video_len, img_size, ming_size)
        real_labels: torch.FloatTensor(self.batch_size).fill_(1)——(B,)
        fake_catelabels: MultiLableLoss's label——load from Dataset['labels']——(B, num_labels)
        conditions:  Text_feat——get from sample_Video/Images (content+motion+label)——(B, d_ef)
    """
    "(1)Prepare"
    #print("fake:", fake_imgs.shape)
    #print("text_feat:", conditions.shape)
    criterion = nn.BCELoss()
    cate_criterion =nn.MultiLabelSoftMarginLoss()
    cond = conditions.detach()
    fake_features = nn.parallel.data_parallel(netD, (fake_imgs), gpus)
    #print("img_feat:", fake_features.shape)
    """
    if mode == 'story':
        fake_features_st = fake_features
        fake_features = torch.mean(fake_features, dim=1).squeeze()"""
    "(2)BCELOSS"
    # fake pairs
    inputs = (fake_features, cond)
    fake_logits = nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus)
    errD_fake = criterion(fake_logits, real_labels)

    if netD.get_uncond_logits is not None:
        fake_logits = \
            nn.parallel.data_parallel(netD.get_uncond_logits,
                                      (fake_features), gpus)
        uncond_errD_fake = criterion(fake_logits, real_labels)
        errD_fake += uncond_errD_fake

    loss_report = {
        mode + ' Fake/Real Generator Loss (Fake pairs) --> ': errD_fake.data.item(),
    }
    "(3)MultiLabelLOSS"
    if netD.cate_classify is not None:
        # print('Fake features shape', fake_features.shape)
        cate_logits = nn.parallel.data_parallel(netD.cate_classify, fake_features, gpus)
        cate_logits = cate_logits.mean(dim=-1).mean(dim=-1)
        cate_logits = cate_logits.squeeze()
        # print(cate_logits.shape, fake_catelabels.shape)
        errD_fake = errD_fake + 1.0 * cate_criterion(cate_logits, fake_catelabels)
        acc = get_multi_acc(cate_logits.cpu().data.numpy(), fake_catelabels.cpu().data.numpy())
        loss_report[mode + ' Character Classifier Accuracy (Generator) --> '] = acc


    return errD_fake, loss_report

四、实验

4.1 Dataset

4.2 实验

 五、结论

我们将故事可视化任务作为一个顺序条件生成问题来研究。提出的StoryGAN模型将当前输入的句子与上下文信息结合起来处理任务。这是通过上下文编码器中的Text2Gist组件实现的。从烧蚀测试,两级鉴别器和输入上的循环结构有助于确保生成的图像和故事的一致性,而上下文编码器有效地提供了图像生成器的局部和全局条件

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值