ADAPT视觉语言模型分析

ADAPT: Action-aware Driving Caption Transformer
含CSP(Control Signal Prediction)和DCG(Driving Caption Generation)两个部分

1、结构概述

预训练的 Video Swin Transformer (Swin_base_patch244_window877_kinetics400_22k.pth),用于提取图片特征,其参数参与调试;

DCG基本结构由 BertImgModel 和  BertPredictionHeadTransform 构成,实现一个图像识别解释的Bert结构:BertForImageCaptioning;

CSP基本结构由 BertEncoder,线性img_embedding,线性decoder组成,实现图片至控制信号(速度和方向)的预测。

2、DCG(Driving Caption Generation)详解

DCG的流程是:

(1)输入转成32帧的视频、对应的汽车动作和理解,如:"[Action narration:] the car pulls over to the right side of the road, [Reasoning:] because the car is parking"。

(2)导入用于提取图片特征的video transformer模型,如:Swin_base_patch244_window877_kinetics400_22k.pth。

(3)初始化用于文字识别理解的BERT和用于分词的tokenizer,使用随机初始化的bert_base_uncased结构;这里的BERT是一种BertForImageCaptioning模型。

(4)根据video transformer模型、BERT、tokenizer构建视频理解的模型,并训练,video transformer的参数也参与训练,并不冻结。

核心代码1:video transformer模型的导入;BERT、tokenizer模型的初始化

    # Get Video Swin backbone 
    swin_model = get_swin_model(args)

    # Get BERT and tokenizer for DCG (Driving Caption Generation) 
    bert_model, config, tokenizer = get_bert_model(args)



def get_swin_model(args):
    if int(args.img_res) == 384:
        assert args.vidswin_size == "large"
        config_path = 'src/modeling/video_swin/swin_%s_384_patch244_window81212_kinetics%s_22k.py'%(args.vidswin_size, args.kinetics)
        model_path = 'models/video_swin_transformer/swin_%s_384_patch244_window81212_kinetics%s_22k.pth'%(args.vidswin_size, args.kinetics)
    else:
        # in the case that args.img_res == '224'
        config_path = 'src/modeling/video_swin/swin_%s_patch244_window877_kinetics%s_22k.py'%(args.vidswin_size, args.kinetics)
        model_path = 'models/video_swin_transformer/swin_%s_patch244_window877_kinetics%s_22k.pth'%(args.vidswin_size, args.kinetics)
    if args.pretrained_2d:
        config_path = 'src/modeling/video_swin/swin_base_patch244_window877_kinetics400_22k.py'
        model_path = 'models/swin_transformer/swin_base_patch4_window7_224_22k.pth'

    logger.info(f'video swin (config path): {config_path}')
    if args.pretrained_checkpoint == '':
        logger.info(f'video swin (model path): {model_path}')
    cfg = Config.fromfile(config_path)
    pretrained_path = model_path if args.pretrained_2d else None
    backbone = SwinTransformer3D(
                    pretrained=pretrained_path,
                    pretrained2d=args.pretrained_2d,
                    patch_size=cfg.model['backbone']['patch_size'],
                    in_chans=3,
                    embed_dim=cfg.model['backbone']['embed_dim'],
                    depths=cfg.model['backbone']['depths'],
                    num_heads=cfg.model['backbone']['num_heads'],
                    window_size=cfg.model['backbone']['window_size'],
                    mlp_ratio=4.,
                    qkv_bias=True,
                    qk_scale=None,
                    drop_rate=0.,
                    attn_drop_rate=0.,
                    drop_path_rate=0.2,
                    norm_layer=torch.nn.LayerNorm,
                    patch_norm=cfg.model['backbone']['patch_norm'],
                    frozen_stages=-1,
                    use_checkpoint=False)

    video_swin = myVideoSwin(args=args, cfg=cfg, backbone=backbone)

    if not args.pretrained_2d:
        checkpoint_3d = torch.load(model_path, map_location='cpu')
        video_swin.load_state_dict(checkpoint_3d['state_dict'], strict=False)
    else:
        video_swin.backbone.init_weights()
    return video_swin


#BertForImageCaptioning
class BertForImageCaptioning(BertPreTrainedModel):
    r"""
    Bert for Image Captioning.
    """
    def __init__(self, config):
        super(BertForImageCaptioning, self).__init__(config)
        self.config = config
        self.bert = BertImgModel(config)
        self.cls = BertCaptioningHeads(config)
        self.loss = BertCaptioningLoss(config)
        # cclin
        self.cls_img_feat = BertIFPredictionHead(config)
        self.loss_img_feat = BertImgFeatureLoss(config)

        self.apply(self.init_weights)
        self.tie_weights()

        self.model_type = getattr(config, 'model_type', 'bert')
        if self.model_type == 'TIMM_vit':
            self.bert = BertImgModel(config)

#BertImgModel
    def __init__(self, config):
        super(BertImgModel, self).__init__(config)

        self.embeddings = BertEmbeddings(config)
        self.encoder = BertEncoder(config)
        self.pooler = BertPooler(config)

        self.img_dim = config.img_feature_dim #2054 #565
        logger.info('BertImgModel Image Dimension: {}'.format(self.img_dim))
        self.img_feature_type = config.img_feature_type
        try:
            self.use_img_layernorm = config.use_img_layernorm
        except:
            self.use_img_layernorm = None

        if config.img_feature_type == 'dis_code':
            self.code_embeddings = nn.Embedding(config.code_voc, config.code_dim, padding_idx=0)
            self.img_embedding = nn.Linear(config.code_dim, self.config.hidden_size, bias=True)
        elif config.img_feature_type == 'dis_code_t': # transpose
            self.code_embeddings = nn.Embedding(config.code_voc, config.code_dim, padding_idx=0)
            self.img_embedding = nn.Linear(config.code_size, self.config.hidden_size, bias=True)
        elif config.img_feature_type == 'dis_code_scale': # scaled
            self.input_embeddings = nn.Linear(config.code_dim, config.code_size, bias=True)
            self.code_embeddings = nn.Embedding(config.code_voc, config.code_dim, padding_idx=0)
            self.img_embedding = nn.Linear(config.code_dim, self.config.hidden_size, bias=True)
        else:
            self.img_embedding = nn.Linear(self.img_dim, self.config.hidden_size, bias=True)
            self.dropout = nn.Dropout(config.hidden_dropout_prob)
            if self.use_img_layernorm:
                self.LayerNorm = LayerNormClass(config.hidden_size, eps=config.img_layer_norm_eps)

        self.apply(self.init_weights)
        self.model_type = getattr(config, 'model_type', 'bert')
        if self.model_type == 'TIMM_vit':
            self.encoder = TIMMVitEncoder(config)

        # re-initialize img_embedding weight
        # self.img_embedding.weight.data.normal_(mean=0.0, std=config.img_initializer_range)


#BertEncoder
class BertEncoder(nn.Module):
    def __init__(self, config):
        super(BertEncoder, self).__init__()
        self.output_attentions = config.output_attentions
        self.output_hidden_states = config.output_hidden_states
        self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])

核心代码2:vision-language: VideoTransformer及forward函数

class VideoTransformer(torch.nn.Module):
    """ This is the one head module that performs Dirving Caption Generation. """
    def __init__(self, args, config, swin, transformer_encoder):
        super(VideoTransformer, self).__init__()
        """ Initializes the model.
        Parameters:
            args: basic args of ADAPT, mostly defined in `src/configs/VidSwinBert/BDDX_multi_default.json` and input args
            config: config of transformer_encoder, mostly defined in `models/captioning/bert-base-uncased/config.json`
            swin: torch module of the backbone to be used. See `src/modeling/load_swin.py`
            transformer_encoder: torch module of the transformer architecture. See `src/modeling/load_bert.py`
        """
        self.config = config
        self.use_checkpoint = args.use_checkpoint and not args.freeze_backbone
        if self.use_checkpoint:
            self.swin = checkpoint_wrapper(swin, offload_to_cpu=True)
        else:
            self.swin = swin
        self.trans_encoder = transformer_encoder
        self.img_feature_dim = int(args.img_feature_dim)
        self.use_grid_feat = args.grid_feat
        self.latent_feat_size = self.swin.backbone.norm.normalized_shape[0]
        self.fc = torch.nn.Linear(self.latent_feat_size, self.img_feature_dim)
        self.compute_mask_on_the_fly = False # deprecated
        self.mask_prob = args.mask_prob
        self.mask_token_id = -1
        self.max_img_seq_length = args.max_img_seq_length

        self.max_num_frames = getattr(args, 'max_num_frames', 2)
        self.expand_car_info = torch.nn.Linear(self.max_num_frames, self.img_feature_dim)

        # add sensor information
        self.use_car_sensor = getattr(args, 'use_car_sensor', False)

        # learn soft attention mask
        self.learn_mask_enabled = getattr(args, 'learn_mask_enabled', False)
        self.sparse_mask_soft2hard = getattr(args, 'sparse_mask_soft2hard', False)
        
        if self.learn_mask_enabled==True:
            self.learn_vid_att = torch.nn.Embedding(args.max_img_seq_length*args.max_img_seq_length,1)
            self.sigmoid = torch.nn.Sigmoid()

    def forward(self, *args, **kwargs):
        """ The forward process of ADAPT, 
        Parameters:
            input_ids: word tokens of input sentences tokenized by tokenizer
            attention_mask: multimodal attention mask in Vision-Language transformer
            token_type_ids: typen tokens of input sentences, 
                            0 means it is a narration sentence and 1 means a reasoning sentence, same size with input_ids
            img_feats: preprocessed frames of the video
            masked_pos: [MASK] position when performing MLM, used to locate the masked words
            masked_ids: groung truth of [MASK] when performing MLM
        """
        # grad cam can only input a tuple (args, kwargs)
        if isinstance(args, tuple) and len(args) != 0:
            kwargs = args[0]
            args= ()

        images = kwargs['img_feats']
        B, S, C, H, W = images.shape  # batch, segment, chanel, hight, width
        # (B x S x C x H x W) --> (B x C x S x H x W)
        images = images.permute(0, 2, 1, 3, 4)
        vid_feats = self.swin(images)

        # tokenize video features to video tokens
        if self.use_grid_feat==True:
            vid_feats = vid_feats.permute(0, 2, 3, 4, 1)
        vid_feats = vid_feats.view(B, -1, self.latent_feat_size)

        # use an mlp to transform video token dimension
        vid_feats = self.fc(vid_feats)

        # use video features to predict car tensor
        if self.use_car_sensor:
            car_infos = kwargs['car_info']
            car_infos = self.expand_car_info(car_infos)
            vid_feats = torch.cat((vid_feats, car_infos), dim=1)

        # prepare VL transformer inputs
        kwargs['img_feats'] = vid_feats

        # disable bert attention outputs to avoid some bugs
        if self.trans_encoder.bert.encoder.output_attentions:
            self.trans_encoder.bert.encoder.set_output_attentions(False)

        # learn soft attention mask
        if self.learn_mask_enabled:
            kwargs['attention_mask'] = kwargs['attention_mask'].float()
            vid_att_len = self.max_img_seq_length
            learn_att = self.learn_vid_att.weight.reshape(vid_att_len,vid_att_len)
            learn_att = self.sigmoid(learn_att)
            diag_mask = torch.diag(torch.ones(vid_att_len)).cuda()
            video_attention = (1. - diag_mask)*learn_att
            learn_att = diag_mask + video_attention
            if self.sparse_mask_soft2hard:
                learn_att = (learn_att>=0.5)*1.0
                learn_att = learn_att.cuda()
                learn_att.requires_grad = False
            kwargs['attention_mask'][:, -vid_att_len::, -vid_att_len::] = learn_att

        # Driving Caption Generation head
        outputs = self.trans_encoder(*args, **kwargs)

        # sparse attention mask loss
        if self.learn_mask_enabled:
            loss_sparsity = self.get_loss_sparsity(video_attention)  
            outputs = outputs + (loss_sparsity, )

        return outputs


#trans_encoder的forward函数
    def forward(self, input_ids, token_type_ids=None, attention_mask=None,
            position_ids=None, head_mask=None, img_feats=None,
            encoder_history_states=None):

        if attention_mask is None:
            if img_feats is not None:
                attention_mask = torch.ones((input_ids.shape[0], input_ids.shape[1] + img_feats.shape[1]), device=input_ids.device)
            else:
                attention_mask = torch.ones_like(input_ids)
            #if img_feats is not None: attention_mask = torch.ones_like((input_ids.shape[0], input_ids.shape[1]+img_feats.shape[1]))
            #else: attention_mask = torch.ones_like(input_ids)

        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        # We create a 3D attention mask from a 2D tensor mask.
        # Sizes are [batch_size, 1, 1, to_seq_length]
        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
        # this attention mask is more simple than the triangular masking of causal attention
        # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
        if attention_mask.dim() == 2:
            extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
        elif attention_mask.dim() == 3:
            extended_attention_mask = attention_mask.unsqueeze(1)
        else:
            raise NotImplementedError

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        if head_mask is not None:
            if head_mask.dim() == 1:
                head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
                head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
            elif head_mask.dim() == 2:
                head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)  # We can specify head_mask for each layer
            head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
        else:
            head_mask = [None] * self.config.num_hidden_layers

        embedding_output = self.embeddings(input_ids, position_ids=position_ids,
                token_type_ids=token_type_ids)
        # add img_embedding_output and sum with embedding_output
        #logger.info('embedding_output: %s' % str(embedding_output.shape))
        if encoder_history_states is not None:
            if encoder_history_states[0].shape[1] != 0:
                assert img_feats is None or img_feats.shape[1]==0, "Cannot take image features while using encoder history states"

        if img_feats is not None:
            if self.img_feature_type == 'dis_code':
                code_emb = self.code_embeddings(img_feats)
                img_embedding_output = self.img_embedding(code_emb)
            elif self.img_feature_type == 'dis_code_t': # transpose
                code_emb = self.code_embeddings(img_feats)
                code_emb = code_emb.permute(0, 2, 1)
                img_embedding_output = self.img_embedding(code_emb)
            elif self.img_feature_type == 'dis_code_scale': # left scaled
                code_emb = self.code_embeddings(img_feats)
                #scale_output =
                # add scale ouput
                img_embedding_output = self.img_embedding(code_emb)
            elif self.img_feature_type == 'e2e' and self.model_type == 'TIMM_vit':
                img_embedding_output = img_feats
            else:
                if torch._C._get_tracing_state():
                    # Ugly workaround to make this work for ONNX.
                    #  It is also valid for PyTorch bu I keep this path separate to remove once fixed in ONNX
                    img_embedding_output = self.img_embedding(img_feats.squeeze(0)).unsqueeze(0)
                else:
                    img_embedding_output = self.img_embedding(img_feats)
                #logger.info('img_embedding_output: %s' % str(img_embedding_output.shape))
                if self.use_img_layernorm:
                    img_embedding_output = self.LayerNorm(img_embedding_output)

                # add dropout on image embedding
                img_embedding_output = self.dropout(img_embedding_output)

                # sum two embeddings
                #padding_matrix = torch.zeros((embedding_output.shape[0], embedding_output.shape[1]-img_embedding_output.shape[1], embedding_output.shape[2])).cuda()
                #img_embedding_output = torch.cat((padding_matrix, img_embedding_output), 1)
                #embedding_output = embedding_output + img_embedding_output

            # concatenate two embeddings
            embedding_output = torch.cat((embedding_output, img_embedding_output), 1)

        encoder_outputs = self.encoder(embedding_output,
                extended_attention_mask, head_mask=head_mask,
                encoder_history_states=encoder_history_states)
        sequence_output = encoder_outputs[0]
        pooled_output = self.pooler(sequence_output)

        outputs = (sequence_output, pooled_output,) + encoder_outputs[1:]  # add hidden_states and attentions if they are here
        return outputs  # sequence_o

3、CSP(Control Signal Prediction)详解

CSP的图片特征提取的部分与DCG一致,不同的是控制信号建模预测的部分,给出核心代码如下:

CSP核心代码:

class Sensor_Pred_Head(torch.nn.Module):
    """ This is the Control Signal Prediction head that performs sensor regression """
    def __init__(self, args):
        """ Initializes the prediction head.
        A simple transformer that performs sensor regression. 
        We simply use a transformer to regress the whole signals of a video, which is superficial and could be optimized to a large extent.
        """
        super(Sensor_Pred_Head, self).__init__()

        self.img_feature_dim = int(args.img_feature_dim)
        self.use_grid_feat = args.grid_feat

        # Motion Transformer implemented by bert
        self.config = BertConfig.from_pretrained(args.config_name if args.config_name else \
            args.model_name_or_path, num_labels=2, finetuning_task='image_captioning')
        self.encoder = BertEncoder(self.config)

        # type number of control signals to be used
        # TODO: Set this variable as an argument, corresponging to the control signal in dataloader
        
        self.sensor_dim = len(args.signal_types)
        self.sensor_embedding = torch.nn.Linear(self.sensor_dim, self.config.hidden_size)
        self.sensor_dropout = nn.Dropout(self.config.hidden_dropout_prob)

        # a mlp to transform the dimension of video feature 
        self.img_dim = self.img_feature_dim
        self.img_embedding = nn.Linear(self.img_dim, self.config.hidden_size, bias=True)
        self.img_dropout = nn.Dropout(self.config.hidden_dropout_prob)

        # a sample regression decoder
        self.decoder = nn.Linear(self.config.hidden_size, self.sensor_dim)


    def forward(self, *args, **kwargs):
        """The forward process.
        Parameters:
            img_feats: video features extracted by video swin
            car_info: ground truth of control signals
        """
        vid_feats = kwargs['img_feats']
        car_info  = kwargs['car_info']

        car_info = car_info.permute(0, 2, 1)

        B, S, C = car_info.shape
        assert C == self.sensor_dim, f"{C}, {self.sensor_dim}"
        frame_num = S

        img_embedding_output = self.img_embedding(vid_feats)
        img_embedding_output = self.img_dropout(img_embedding_output)

        extended_attention_mask = self.get_attn_mask(img_embedding_output)

        encoder_outputs = self.encoder(img_embedding_output,
                                        extended_attention_mask)
        sequence_output = encoder_outputs[0][:, :frame_num, :]

        pred_tensor = self.decoder(sequence_output)

        loss = self.get_l2_loss(pred_tensor, car_info)

        return loss, pred_tensor

4、分析

(1)控制信号预测的Sensor_Pred_Head虽然建立了sensor_embedding的部分,但是在forward函数中并未使用;

(2)BDD数据集除汽车动作文字理解外,还有车道、物体的标注信息,可用于物体的识别、分类;本文没有给出详细的物体识别、可行驶轨迹预测等信息;

(3)BERT结构不仅用于文字理解,也用于时序信息建模,模型实现理论基础是:emdedding、mask、多层注意力机制。

(4)改进方案:添加sensor_embedding查看效果;添加更多物体识别、轨迹预测的功能;当前语义理解、语库都比较小,相当于建立了一个专用的两句话描述汽车当前行动和原因的语言模型;再找点文章或者专利对比一下,想几个前沿可行方案吧

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值