上下文窗口长度达到了100万token,LWM支持处理多模态信息,能在100万token中准确找到目标文本,还能一口气看完1小时的视频,RingAttention还与FlashAttention结合

LWM是一个强大的多模态模型,能够处理100万token的上下文,准确理解视频内容。通过RingAttention和FlashAttention技术,以及Pallas框架优化,LWM在长序列理解和视频处理任务中超越了其他闭源和开源模型。研究人员使用Books3数据集和逐步扩大的窗口训练,LWM在语言、图像和视频生成方面表现出色,且代码和模型已开源。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

上下文窗口长度达到了100万token,LWM支持处理多模态信息,能在100万token中准确找到目标文本,还能一口气看完1小时的视频,RingAttention还与FlashAttention结合使用,并通过Pallas框架进行优化,从而提高性能。

在这里插入图片描述

上下文窗口长度达到了100万token,持平了谷歌同时推出的王炸Gemini 1.5,伯克利出品。

强大的模型,命名也是简单粗暴——没有任何额外点缀,直接就叫LargeWorldModel(LWM)。

LWM支持处理多模态信息,能在100万token中准确找到目标文本,还能一口气看完1小时的视频。

网友看了不禁表示,这种大海捞针般的测试,LWM能完成的如此出色,而且还开源,实在是令人印象深刻。

那么,LWM的表现到底有多强呢?

百万上下文窗口,可看1小时视频
在测试过程中,研究人员用多段一个多小时的视频检验了LWM的长序列理解能力,这些视频由YouTube上不同的视频片段拼接而成。

他们将这些视频输入LWM,然后针对其中的细节进行提问,涉及的片段位于整个视频的不同位置,同时研究者还将LWM与GPT-4V等模型做了对比。

结果GPT-4V是一问一个不吱声,闭源强者Gemini Pro和开源强者Video-LLaVA都给出了错误的答案,只有LWM回答对了。

在另一段视频的测试中,其他模型都说找不到有关信息,只有LWM找到了答案,而且完全正确。

不仅是理解细节,LWM也能把握视频的整体内容,做出归纳总结。

在理解的基础之上,LWM也可以结合自有知识进行推理,比如分析视频中不符合常理的地方。

Benchmark测试结果显示,LWM在MSVD-QA等三个数据集上的评分仅次于Video-LLaVA。

LWM不仅能理解长短视频,在超长文本任务上的表现同样优异。

在1百万token窗口的“插针”检索测试中,LWM取得了单针检索全绿的成绩。

多针检索时,表现也同样优异:

语言任务数据集的测试结果表明,LWM在32k到1M的窗口长度上表现不输甚至超过只有4k窗口的Llama2-7B。

除了多模态信息理解,LWM还支持图像和视频的生成,至于效果,还是直接上图感受一下吧。

那么,研究人员又是怎样训练出这样一款世界模型的呢?

循序渐进,分而治之
LMW的训练过程,大致可分为两个阶段。

第一阶段的目标是建立一个能够处理长文本序列的语言模型,以理解复杂的文档和长文本内容。

为实现这一目的,研究人员采取了渐进式的训练方式,使用总计33B Token、由图书内容组成的Books3数据集,从32k开始训练,逐步将窗口扩增至1M。

而为了增强LWM的长文本处理能力,开发者应用了RingAttention机制。

RingAttention是该团队去年提出的一种窗口扩增方式,入选了ICLR 2024。

它运用了“分而治之”的思想,将长文本分成多个块,用多个计算设备做序列并行处理,然后再进行叠加,理论上允许模型扩展到无限长的上下文。

在LWM中,RingAttention还与FlashAttention结合使用,并通过Pallas框架进行优化,从而提高性能。

在文本能力的基础上,研究人员又用模型生成了部分QA数据,针对LWM的对话能力进行了优化。

第二阶段则是将视觉信息(如图像和视频)整合到模型中,以提高对多模态数据的理解能力。

在此阶段,研究人员对LWM-Text模型进行了架构修改,以支持视觉输入。

他们使用VQGAN将图像和视频帧转换为token,并与文本结合进行训练。

这一阶段同样采用循序渐进的训练方法, LWM首先在文本-图像数据集上进行训练,然后扩展到文本-视频数据集,且视频帧数逐步增多。

在训练过程中,模型还会随机交换文本和视觉数据的顺序,以学习文本-图像生成、图像理解、文本-视频生成和视频理解等多种任务。

性能方面,研究人员在TPUv4-1024(大致相对于450块A100)上训练,批大小为8M、全精度(float32)的条件下,花费的时间如下表所示,其中1M窗口版本用了58个小时。

目前,LWM的代码、模型都已开源,其中多模态模型为Jax版本,纯文本模型有Jax和PyTorch两个版本,感兴趣的话可以到GitHub页面中了解详情。

论文地址:
https://arxiv.org/abs/2402.08268
GitHub:
https://github.com/LargeWorldModel/LWM


Large World Model (LWM)

[Project]
[Paper]
[Models]

Large World Model (LWM) is a general-purpose large-context multimodal autoregressive model. It is trained on a large dataset of diverse long videos and books using RingAttention, and can perform language, image, and video understanding and generation.

Approach

Current language models fall short in understanding aspects of the world not easily described in words, and struggle with complex, long-form tasks. Video sequences offer valuable temporal information absent in language and static images, making them attractive for joint modeling with language. Such models could develop a understanding of both human textual knowledge and the physical world, enabling broader AI capabilities for assisting humans. However, learning from millions of tokens of video and language sequences poses challenges due to memory constraints, computational complexity, and limited datasets. To address these challenges, we curate a large dataset of diverse videos and books, utilize the RingAttention technique to scalably train on long sequences, and gradually increase context size from 4K to 1M tokens. This paper makes the following contributions: (a) Largest context size neural network: We train one of the largest context size transformers on long video and language sequences, setting new benchmarks in difficult retrieval tasks and long video understanding. (b) Solutions for overcoming vision-language training challenges, including using masked sequence packing for mixing different sequence lengths, loss weighting to balance language and vision, and model-generated QA dataset for long sequence chat. © A highly-optimized implementation with RingAttention, masked sequence packing, and other key features for training on millions-length multimodal sequences. (d) Fully open-sourced a family of 7B parameter models capable of processing long text documents (LWM-Text, LWM-Text-Chat) and videos (LWM, LWM-Chat) of over 1M tokens.
This work paves the way for training on massive datasets of long video and language to develop understanding of both human knowledge and the multimodal world, and broader capabilities.

LWM Capabilities

Setup

This codebase is supported on Ubuntu and has not been tested on Windows or macOS. We recommend using TPUs for training and inference, although it is also possible to use GPUs. On TPU, the code is highly optimized with Jax’s Pallas and can achieve high MFUs with RingAttention at very large context sizes. On GPU, the code is based on XLA and is not as optimized as it is for TPU.

Install the requirements with:

conda create -n lwm python=3.10
pip install -U "jax[cuda12_pip]==0.4.23" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install -r requirements.txt

or set up TPU VM with:

sh tpu_vm_setup.sh

Available models

There are language-only and video-language versions, offering context sizes from 32K, to 128K, 256K and 1M tokens. The vision-language models are available only in Jax, and the language-only models are available in both PyTorch and Jax. Below are the names of the available models and their corresponding context sizes and capabilities:

Model NameContext SizeLanguage or Vision-LanguageChat or BaseURL
LWM-Text-Chat-128K128KLanguageChat[Pytorch][Jax]
LWM-Text-Chat-256K256KLanguageChat[Pytorch][Jax]
LWM-Text-Chat-512K512KLanguageChat[Pytorch][Jax]
LWM-Text-Chat-1M1MLanguageChat[Pytorch][Jax]
LWM-Text-128K128KLanguageBase[Pytorch][Jax]
LWM-Text-256K256KLanguageBase[Pytorch][Jax]
LWM-Text-512K512KLanguageBase[Pytorch][Jax]
LWM-Text-1M1MLanguageBase[Pytorch][Jax]
LWM-Chat-32K32KVision-LanguageChat[Jax]
LWM-Chat-128K128KVision-LanguageChat[Jax]
LWM-Chat-1M1MVision-LanguageChat[Jax]

Code structure

Use scan_query_chunk_size and scan_key_chunk_size to control the block size in blockwise compute of the self-attention. Use scan_mlp_chunk_size to control the block size in blockwise compute of the feedforward network. Use scan_attention=True and scan_mlp=True to enable/disable blockwise compute in the self-attention and feed-forward network. Use remat_attention and remat_mlp to control the rematerialization policy with nothing_saveable recommended.

You can use mesh_dim=dp, fsdp, tp, sp to control the degree of parallelism and RingAttention. It is a string of 4 integers separated by commas, representing the number of data parallelism, fully sharded data parallelism, tensor parallelism, and sequence parallelism.
For example, mesh_dim='1,64,4,1' means 1 data parallelism, 64 fully sharded data parallelism, 4 tensor parallelism, and 1 sequence parallelism. mesh_dim='1,1,4,64' means 1 data parallelism, 1 fully sharded data parallelism, 4 tensor parallelism, and 64 sequence parallelism for RingAttention.

Running Jax Models

In this section, we provide instructions on how to run each of the provided scripts. For each script, you may need to fill in your own paths and values in the variables described in the beginning of each script.

To run each of the following scripts, use bash <script_name>.sh:

  • Language model training: bash scripts/run_train_text.sh
  • Vision-Language model training: bash scripts/run_train_vision_text.sh
  • Single Needle Evals (Language Model): bash scripts/run_eval_needle.sh
  • Multi Needle Evals (Language Model): bash scripts/run_eval_needle_multi.sh
  • Sampling images (Vision-Language Model): bash scripts/run_sample_image.sh
  • Sampling videos (Vision-LanguageModel): bash scripts/run_sample_video.sh
  • Image / Video understanding (Vision-Language Model): bash scripts/run_vision_chat.sh

By default the mesh_dim argument puts all devices on tp (tensor parallelism). For longer sequences, you may want to include sp, which is the last dimension in the mesh_dim.

When running needle evals, you may need to adjust the theta and max_sequence_length arguments in the scripts depending on the model. Below shows the correct values for each model.

LWM-Text-128K / LWM-Text-Chat-128KLWM-Text-256K / LWM-Text-Chat-256KLWM-Text-512K / LWM-Text-Chat-512KLWM-Text-1M / LWM-Text-Chat-1M
theta10000000100000002500000050000000
max_sequence_length1310722621445242881048576

An example of filling out a script (run_sample_video.sh) is as follows

#! /bin/bash

export SCRIPT_DIR="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
export PROJECT_DIR="$( cd -- "$( dirname -- "$SCRIPT_DIR" )" &> /dev/null && pwd )"
cd $PROJECT_DIR
export PYTHONPATH="$PYTHONPATH:$PROJECT_DIR"

export llama_tokenizer_path="/path/to/ckpt/folder/tokenizer.model"
export vqgan_checkpoint="/path/to/ckpt/folder/vqgan"
export lwm_checkpoint="/path/to/ckpt/folder/params"

python3 -u -m lwm.vision_generation \
    --prompt='Fireworks over the city' \
    --output_file='fireworks.mp4' \
    --temperature_image=1.0 \
    --temperature_video=1.0 \
    --top_k_image=8192 \
    --top_k_video=1000 \
    --cfg_scale_image=5.0 \
    --cfg_scale_video=1.0 \
    --vqgan_checkpoint="$vqgan_checkpoint" \
    --n_frames=8 \
    --mesh_dim='!1,1,-1,1' \
    --dtype='fp32' \
    --load_llama_config='7b' \
    --update_llama_config="dict(sample_mode='vision',theta=50000000,max_sequence_length=32768,scan_attention=False,scan_query_chunk_size=128,scan_key_chunk_size=128,scan_mlp=False,scan_mlp_chunk_size=8192,scan_layers=True)" \
    --load_checkpoint="params::$lwm_checkpoint" \
    --tokenizer.vocab_file="$llama_tokenizer_path"
read

Needle Haystack Data

Run python scripts/create_needle_data.py

Running PyTorch Models

Only text and text chat models are currently supported for PyTorch inference. PyTorch models can be loaded as Hugging Face LlamaForCausalLM models. Run python scripts/sample_pyt.py to sample. You may need to separately install torch.

Documentation

For more details on the codebase, please refer to the data.md and sharding.md.
The data.md provides details on the data processing and the sharding.md provides details on the sharding and parallelism.

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

代码讲故事

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值