LISA: Reasoning Segmentation via Large Language Model

发表时间:CVPR 2024

论文链接:https://openaccess.thecvf.com/content/CVPR2024/papers/Lai_LISA_Reasoning_Segmentation_via_Large_Language_Model_CVPR_2024_paper.pdf

作者单位:CUHK

Motivation尽管感知系统近年来取得了显著的进展,但在执行视觉识别任务之前,它们仍然依赖于明确的人类指令或预定义的类别来识别目标对象。此类系统无法积极推理和理解隐含的用户意图我们能否使多模态llm能够输出分割掩码?

解决方法:在这项工作中,我们提出了一个新的分割任务——推理分割。该任务旨在在给定复杂和隐式查询文本的情况下输出分割掩码。 举个例子:

为了完成这项任务,模型必须具备两个关键能力:1)与图像联合推理复杂和隐式文本查询; 2)生成分割掩码

实现方式

  • 我们提出了 LISA:大型语言指令分割助手,它继承了多模态大型语言模型 (LLM) 的语言生成能力,同时还具有生成分割掩码的能力。

  • 我们用<SEG>令牌扩展原始词汇表,并提出嵌入掩码范式来解锁分割能力。

  • 我们建立了一个包含超过一千个图像指令掩码数据样本的基准,将复杂的推理和世界知识纳入评估目的。

模型结构:除非另有说明,否则我们使用LLAVA-7B-v1-1或LLAVA-13B-v1-1作为基础多模态LLM F,并采用ViT-H SAM骨干作为视觉骨干Fenc。γ的projection layer是一个通道为[256,4096,4096]的MLP

训练参数:为了保持预先训练的多模态LLM (即我们实验中的LLAVA)的学习知识,我们利用LoRA对LLM进行有效的微调并完全冻结视觉骨干Fenc解码器 Fdec 被完全微调。此外,LLM token embeddings (embed tokens)、LLM头((lm head)和投影层γ也是可训练的

实验

数据集:training data comprises mainly three parts:Semantic Segmentation Dataset,Vanilla Referring Segmentation Dataset,Visual Question Answering Dataset.

结论

  1. a new segmentation task—reasoning segmentation。

  2. introduced an evaluation benchmark ReasonSeg, which comprises over one thousand data samples。

  3. 提出模型——LISA。它将分割能力注入到当前的多模态llm中,并在推理分割任务中表现出奇地有效。

修改下面程序的参数以此进行GPU显存的优化:# -------------------------------------------------------- # LISA: Reasoning Segmentation via Large Language Model # Licensed under Apache-2.0 license [see LICENSE for details] # Authors: Xin Lai, Zhuotao Tian, Yukang Chen, Yanwei Li, Yuhui Yuan, Shu Liu, Jiaya Jia # -------------------------------------------------------- # GSVA: Generalized Segmentation via Multimodal Large Language Models # Modified by Zhuofan Xia # -------------------------------------------------------- import argparse import os os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" import shutil from functools import partial import torch import transformers import deepspeed import deepspeed.comm as dist from torch.utils.data import DataLoader, DistributedSampler from tqdm import tqdm import model.llava.conversation as conversation_lib from model import LisaGSVAForCausalLM, add_task_tokens, init_vision_seg_for_model from data import MixedTrainingDataset, ValDataset, collate_fn from solver import train_one_epoch, validate, eval_gres from utils import get_logger def parse_args(): # parser.add_argument("--dataset_dir", required=True, type=str, help="Where do we store the huge datasets?") # parser.add_argument("--dataset", default="sem_seg||refer_seg||vqa||reason_seg", type=str) # parser.add_argument("--sem_seg_data",default="ade20k||cocostuff||pascal_part||paco_lvis||mapillary",type=str) # parser.add_argument("--refer_seg_data", default="refclef||refcoco||refcoco+||refcocog", type=str) # parser.add_argument("--image_size", default=1024, type=int, help="Image size of segmentation model.") # parser.add_argument("--lora_alpha", default=16, type=int) parser = argparse.ArgumentParser(description="GSVA Training and Evaluation") parser.add_argument("--segmentation_model_path", default="pretrain_weights_sam/sam_vit_h_4b8939.pth", type=str) parser.add_argument("--vision-tower", default="openai/clip-vit-large-patch14", type=str) parser.add_argument("--local_rank", default=0, type=int, help="For local rank in distributed training") # parser.add_argument("--mllm_model_path", default="liuhaotian/llava-llama-2-13b-chat-lightning-preview") parser.add_argument("--mllm_model_path", default="pretrain_weights_llava") parser.add_argument("--dataset_dir", default="dataset", type=str, help="Where do we store the huge datasets?") parser.add_argument("--precision", default="bf16", type=str, choices=["fp32", "bf16", "fp16"], help="precision for training and inference") parser.add_argument("--image_size", default=128, type=int, help="Image size of segmentation model.") parser.add_argument("--model_max_length", default=1024, type=int) parser.add_argument("--lora_r", default=1, type=int) parser.add_argument("--dataset", default="refer_seg", type=str) parser.add_argument("--sample_rates", default="9,3,3,1", type=str) parser.add_argument("--sem_seg_data", default="ade20k", type=str, ) parser.add_argument("--refer_seg_data", default="refcocog", type=str) parser.add_argument("--vqa_data", default="llava_instruct_150k", type=str) parser.add_argument("--reason_seg_data", default="ReasonSeg|train", type=str) parser.add_argument("--val_dataset", default="ReasonSeg|val", type=str) parser.add_argument("--log_base_dir", default="./outputs", type=str) parser.add_argument("--exp_name", default="default", type=str) parser.add_argument("--epochs", default=10, type=int) parser.add_argument("--steps_per_epoch", default=500, type=int) # parser.add_argument("--batch_size", default=16, type=int, help="batch size per device per step") parser.add_argument("--batch_size", default=1, type=int, help="batch size per device per step") parser.add_argument("--grad_accumulation_steps", default=1, type=int) parser.add_argument("--val_batch_size", default=1, type=int) parser.add_argument("--workers", default=8, type=int) parser.add_argument("--lr", default=0.0003, type=float) parser.add_argument("--ce_loss_weight", default=1.0, type=float) parser.add_argument("--dice_loss_weight", default=0.5, type=float) parser.add_argument("--bce_loss_weight", default=2.0, type=float) parser.add_argument("--lora_alpha", default=1, type=int) parser.add_argument("--lora_dropout", default=0.05, type=float) parser.add_argument("--lora_target_modules", default="q_proj,v_proj", type=str) parser.add_argument("--explanatory", default=0.1, type=float) parser.add_argument("--beta1", default=0.9, type=float) parser.add_argument("--beta2", default=0.95, type=float) parser.add_argument("--num_classes_per_sample", default=2, type=int) parser.add_argument("--exclude_val", action="store_true", default=False) parser.add_argument("--no_eval", action="store_true", default=False) parser.add_argument("--eval_only", action="store_true", default=False) parser.add_argument("--out_dim", default=256, type=int) parser.add_argument("--resume", default="", type=str) parser.add_argument("--print_freq", default=1, type=int) parser.add_argument("--start_epoch", default=0, type=int) parser.add_argument("--train_mask_decoder", action="store_true", default=True) parser.add_argument("--use_mm_start_end", action="store_true", default=True) parser.add_argument("--auto_resume", action="store_true", default=False, help='Whether resume the latest checkpoint when training is interrupted.') parser.add_argument("--no_sampling", action="store_true", default=False, help="Only one dataset finetuning, train on full length dataset.") parser.add_argument('--val_refzom', action='store_true', default=False, help='Default gres/zom evaluation, if True, RefZOM, else gRefCOCO.') parser.add_argument("--conv_type",default="llava_v1", type=str, choices=["llava_v1", "llava_llama_2"],) parser.add_argument("--merge_lora_path", type=str, default=None, help="Path to destination HF checkpoint.") parser.add_argument("--weight", type=str, default=None, help="Path to a bin ckpt.") parser = deepspeed.add_config_arguments(parser) return parser.parse_args() def main(): # Get arguments from commandline args = parse_args() # Set up Deepspeed distributed environment torch.cuda.set_device(args.local_rank) dist.init_distributed() args.world_size = world_size = dist.get_world_size() args.rank = rank = dist.get_rank() local_rank: int args.local_rank = local_rank = dist.get_local_rank() # Set up logging dir args.log_dir = os.path.join(args.log_base_dir, args.exp_name) if rank == 0: os.makedirs(args.log_dir, exist_ok=True) logger = get_logger(args.log_dir, rank, name=args.exp_name) # Create model tokenizer = transformers.AutoTokenizer.from_pretrained( args.mllm_model_path, cache_dir=None, model_max_length=args.model_max_length, padding_side="right", use_fast=False ) tokenizer, args = add_task_tokens(tokenizer, args) # Determine working model precision args.torch_dtype = torch.float32 if args.precision == "bf16": args.torch_dtype = torch.bfloat16 elif args.precision == "fp16": args.torch_dtype = torch.half # Prepare model creation arguments model_args = { "train_mask_decoder": args.train_mask_decoder, "out_dim": args.out_dim, "ce_loss_weight": args.ce_loss_weight, "dice_loss_weight": args.dice_loss_weight, "bce_loss_weight": args.bce_loss_weight, "seg_token_idx": args.seg_token_idx, "segmentation_model_path": args.segmentation_model_path, "vision_tower": args.vision_tower, "use_mm_start_end": args.use_mm_start_end, "tokenizer": tokenizer, "rej_token_idx": args.rej_token_idx } model = LisaGSVAForCausalLM.from_pretrained( args.mllm_model_path, torch_dtype=args.torch_dtype, **model_args ) # Set up two vision models for whole model, and lora model = init_vision_seg_for_model(model, tokenizer, args).half() # Evaluation or finetuning, btw, merge-lora always fails if args.weight is not None: # `args.weight`` is a large `*.bin` file. state_dict = torch.load(args.weight, map_location="cpu", weights_only=True) model.load_state_dict(state_dict, strict=False) logger.info("Load trained weights successfully!") # Specify the conversation type conversation_lib.default_conversation = conversation_lib.conv_templates[args.conv_type] # Build training set if args.eval_only: train_dataset = None else: train_dataset = MixedTrainingDataset( args.dataset_dir, tokenizer, args.vision_tower, samples_per_epoch=args.batch_size * args.grad_accumulation_steps * args.steps_per_epoch * world_size, precision=args.precision, image_size=args.image_size, num_classes_per_sample=args.num_classes_per_sample, exclude_val=args.exclude_val, dataset=args.dataset, sample_rate=[float(x) for x in args.sample_rates.split(",")], sem_seg_data=args.sem_seg_data, refer_seg_data=args.refer_seg_data, vqa_data=args.vqa_data, reason_seg_data=args.reason_seg_data, explanatory=args.explanatory, no_sampling=args.no_sampling ) if args.no_eval: val_dataset = None logger.info(f"Training with {len(train_dataset)} examples.") else: val_dataset = ValDataset( args.dataset_dir, tokenizer, args.vision_tower, args.val_dataset, args.image_size ) grefcoco_val_ds = ValDataset( args.dataset_dir, tokenizer, args.vision_tower, 'refzom|final|test' if args.val_refzom else 'grefcoco|unc|val', args.image_size ) if args.eval_only: logger.info(f"Testing with {len(val_dataset)} examples.") else: logger.info(f"Training with {len(train_dataset)} examples and validating with {len(val_dataset)} examples, also validating on gRefCOCO with {len(grefcoco_val_ds)} examples.") # The accelerated training configurations only work for ZeRO-2. if args.eval_only: ds_config = { "train_micro_batch_size_per_gpu": 1, "fp16": { "enabled": args.precision == "fp16", }, "bf16": { "enabled": args.precision == "bf16", } } else: ds_config = { "train_micro_batch_size_per_gpu": args.batch_size, "gradient_accumulation_steps": args.grad_accumulation_steps, "optimizer": { "type": "Adam", "params": { "lr": args.lr, "weight_decay": 0.0, "betas": (args.beta1, args.beta2), }, }, "scheduler": { "type": "WarmupDecayLR", "params": { "total_num_steps": args.epochs * args.steps_per_epoch, "warmup_min_lr": 0, "warmup_max_lr": args.lr, "warmup_num_steps": 100, "warmup_type": "linear", }, }, "fp16": { "enabled": args.precision == "fp16", }, "bf16": { "enabled": args.precision == "bf16", }, "gradient_clipping": 1.0, "zero_optimization": { "stage": 2, "contiguous_gradients": True, "overlap_comm": True, "reduce_scatter": True, "reduce_bucket_size": 1e9, "allgather_bucket_size": 1e9 } } # Build a model engine wrapped with Deepspeed if args.eval_only: model_engine, optimizer, train_loader, scheduler = deepspeed.initialize( model=model, config=ds_config ) else: logger.info('Before initializing deepspeed zero optimizer...') model_engine, optimizer, train_loader, scheduler = deepspeed.initialize( model=model, model_parameters=model.parameters(), training_data=train_dataset, collate_fn=partial( collate_fn, tokenizer=tokenizer, conv_type=args.conv_type, use_mm_start_end=args.use_mm_start_end, local_rank=local_rank, ), config=ds_config ) train_loader.num_local_io_workers = args.workers logger.info('After initializing deepspeed zero optimizer!') # resume deepspeed checkpoint, `auto-resume` snippets are borrowed from Swin Transfomer codebase: # https://github.com/microsoft/Swin-Transformer/blob/f82860bfb5225915aca09c3227159ee9e1df874d/utils.py#L163 if args.auto_resume: checkpoints = os.listdir(args.log_dir) checkpoints = [ckpt for ckpt in checkpoints if ckpt.startswith('ckpt_model')] if len(checkpoints) > 0: args.resume = max([os.path.join(args.log_dir, d) for d in checkpoints], key=os.path.getmtime) logger.info(f"Auto resume found latest: {args.resume}") else: logger.info("No auto resume.") if args.resume: # resume from training, scattered checkpoints (list of ***.pt) load_path, client_state = model_engine.load_checkpoint(args.resume) with open(os.path.join(args.resume, "latest"), "r") as f: ckpt_dir = f.readlines()[0].strip() args.start_epoch = ( int(ckpt_dir.replace("global_step", "")) // args.steps_per_epoch ) logger.info( "resume training from {}, start from epoch {}".format( args.resume, args.start_epoch ) ) # Build validation dataset if val_dataset is not None: assert args.val_batch_size == 1 val_sampler = DistributedSampler(val_dataset, shuffle=False, drop_last=False) val_loader = DataLoader( val_dataset, batch_size=args.val_batch_size, shuffle=False, num_workers=args.workers, pin_memory=False, sampler=val_sampler, collate_fn=partial( collate_fn, tokenizer=tokenizer, conv_type=args.conv_type, use_mm_start_end=args.use_mm_start_end, local_rank=local_rank ) ) if val_dataset.ds not in ['grefcoco', 'refzom']: grefcoco_sampler = DistributedSampler(grefcoco_val_ds, shuffle=False, drop_last=False) grefcoco_loader = DataLoader( grefcoco_val_ds, batch_size=args.val_batch_size, shuffle=False, num_workers=args.workers, pin_memory=False, sampler=grefcoco_sampler, collate_fn=partial( collate_fn, tokenizer=tokenizer, conv_type=args.conv_type, use_mm_start_end=args.use_mm_start_end, local_rank=local_rank ) ) else: grefcoco_loader = None # If we only want to evaluate models, then we evaluate them and quit the program. if args.eval_only: if val_dataset.ds in ['grefcoco', 'refzom']: eval_gres(val_loader, model_engine, 0, args, logger) else: validate(val_loader, model_engine, 0, args, logger) return # Otherwise, we train the model using the initialized Deepspeed-Zero model engine. logger.info("Training begin!") train_iter = iter(train_loader) for epoch in tqdm(range(args.start_epoch, args.epochs), desc="train:"): # train for one epoch, keep a `train_iter`` for iter-based training train_iter = train_one_epoch(train_loader, model_engine, epoch, train_iter, args, logger) # barrier for saving checkpoints dist.barrier() save_dir = os.path.join(args.log_dir, f"ckpt_model_{epoch + 1:02d}") if rank == 0 and os.path.exists(save_dir): shutil.rmtree(save_dir) model_engine.save_checkpoint(save_dir) dist.barrier() # Skip if we don't need evalutation if args.no_eval: continue else: reason_giou, reason_ciou = validate(val_loader, model_engine, epoch, args, logger) grefcoco_giou, grefcoco_ciou, n_acc, t_acc = eval_gres(grefcoco_loader, model_engine, epoch, args, logger) if rank == 0: with open(os.path.join(args.log_dir, "quick_look_result.log"), "a") as t: t.write( f"[{epoch + 1}] reasonseg_val: gIoU:{reason_giou:.4f}, cIoU:{reason_ciou:.4f}, " f"grefcoco_val: gIoU:{grefcoco_giou:.4f}, cIoU:{grefcoco_ciou:.4f}, NAcc:{n_acc:.4f}, TAcc:{t_acc:.4f}.\n" ) if __name__ == "__main__": main()
09-01
人工智能领域,大型语言模型(LLM)在处理各类任务时展现出强大能力,但处理需同时进行推理和与外部工具交互的复杂多步骤任务仍存在挑战。ReAct(推理 + 行动)框架应运而生,旨在通过动态结合推理和行动,使语言模型更智能、通用和可解释 [^1]。 ReAct是一种基于提示的新型范式,用于协同语言模型中的推理和行动以解决一般任务 [^2]。它通过插入的方式生成推理跟踪和特定任务的动作,实现推理和动作之间更大的协同。推理跟踪有助于模型诱导、跟踪和更新动作规划并处理异常,而动作则与外部来源(如知识库或环境)交互并收集其他信息 [^3]。 在实验过程中,将ReAct分别与仅执行(only - Act)、思维链(CoT - Chain of Thought)进行对比,并将ReAct与CoT框架相结合,在四个不同的数据集上进行比较评估,结果显示ReAct取得了显著的改进效果,能够解决各种语言推理和决策任务 [^4]。 以下是一个简单的伪代码示例,展示ReAct风格的多意图任务处理逻辑: ```python def react_agent(task): reasoning = [] actions = [] current_state = task while not is_task_complete(current_state): # 生成推理 reasoning_step = generate_reasoning(current_state) reasoning.append(reasoning_step) # 根据推理生成动作 action = generate_action(reasoning_step) actions.append(action) # 执行动作并更新状态 current_state = execute_action(action, current_state) return reasoning, actions # 示例函数,需根据实际情况实现 def generate_reasoning(state): # 实现推理生成逻辑 return "Some reasoning based on state" def generate_action(reasoning): # 实现动作生成逻辑 return "Some action based on reasoning" def execute_action(action, state): # 实现动作执行逻辑 return updated_state def is_task_complete(state): # 实现任务完成判断逻辑 return False ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

KKdlg

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值