Flux.1系列模型解析--Flux.1

部署运行你感兴趣的模型镜像

简介

Flux.1模型有三个版本,分别是pro、dev和schnell,三个模型性能依次递减,但生图效率依次提高。dev和schnell基于pro模型蒸馏而来,pro模型只能通过api访问,而dev、shcnell模型可获取具体权重,bfl并没有对Flux.1系列模型架构进行过多展示,只表明基于多模态和并行扩散 Transformer 模块的混合架构,参数扩展到了12B;通过基于流匹配范式训练,且引入旋转位置编码和并行注意力层来提高模型性能并提升硬件效率。
在这里插入图片描述

图1 Flux.1模型架构图

虽然bfl没有进一步公布详细的技术文档,但其在github上开源了推理代码,可以基于推理代码梳理出整个模型架构,图1就是reddit论坛上社区开发者发布的Flux.1模型架构图。Flux.1模型基于DiT架构,与LLMs相同使用RoPE来表征图片位置信息,先使用双流块、再使用单流块实现图像隐空间和文本编码空间的对齐,最终舍弃文本tokens,对图像tokens进行解码得到图片。图1要从下向上看,后续将针对其中的主要模块或概念结合推理代码进行说明。

文本编码器

如图1所示,文本提示词会经过T5 Encoder和CLIP两个文本编码器提取文本特征,官方推理代码具体实现如下所示,实现极其简洁,基于transformers库提供的接口直接将两个文本编码器封装在一个类中,使用时根据version参数自动识别、初始化对应实例。

from torch import Tensor, nn
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer

class HFEmbedder(nn.Module):  # 可分别初始化clip和t5的文本编码器类
    def __init__(self, version: str, max_length: int, **hf_kwargs):
        super().__init__()
        self.is_clip = version.startswith("openai")  # 判断是clip还是t5
        self.max_length = max_length
        self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"

        if self.is_clip:
            self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)  # 初始化clip的tokenizer
            self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)  # 初始化clip的文本编码器
        else:
            self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)  # 初始化t5的tokenizer
            self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)  # 初始化t5的文本编码器

        self.hf_module = self.hf_module.eval().requires_grad_(False)  # 设置为eval模式,并禁用梯度计算

    def forward(self, text: list[str]) -> Tensor:
        batch_encoding = self.tokenizer(
            text,  # 输入文本列表
            truncation=True,  # 允许截断
            max_length=self.max_length,  # 最大长度
            return_length=False,  # 不返回长度
            return_overflowing_tokens=False,  # 不返回溢出token
            padding="max_length",  # 填充到最大长度
            return_tensors="pt",  # 返回pytorch张量
        )

        outputs = self.hf_module(
            input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
            attention_mask=None,  # 不使用注意力掩码
            output_hidden_states=False,  # 不输出隐藏状态
        )
        return outputs[self.output_key].bfloat16()  # 返回输出

旋转位置编码

Flux.1模型中使用的是二维旋转位置编码,其会同时对文本和图像进行处理。Flux.1模型中完整的旋转基本由三部分组成,分别是。先分别基于文本和图片位置索引张量txt_idsimg_ids构建二维位置编码张量的EmbedND模块、对查询张量、键张量应用旋转位置编码的apply_rope函数和对带有旋转位置编码信息的张量进行注意力计算的attention函数。具体实现如下:

# src/flux/modules/layers.py
class EmbedND(nn.Module):
    """N维位置编码模块
    
    参数:
        dim (int): 位置编码维度, 通常为64或128
        theta (int): RoPE旋转角度参数, 通常为10000
        axes_dim (list[int]): 每个轴的编码维度, 如[32,32]表示2D位置编码,每个维度32
    """
    def __init__(self, dim: int = 64, theta: int = 10000, axes_dim: list[int] = [32, 32]):
        super().__init__()
        self.dim = dim  # 位置编码总维度,等于axes_dim之和
        self.theta = theta  # RoPE旋转参数
        self.axes_dim = axes_dim  # 每个轴的编码维度,如[32,32]表示2D位置编码,每个维度32
        
    def forward(self, ids: Tensor) -> Tensor:
        """前向传播
        
        参数:
            ids: shape为[batch_size, seq_len, n_axes]的位置索引张量,此处的seq_len是所有轴的编码维度之和
            
        返回:
            shape为[batch_size, 1, dim, 2, 2]的位置编码张量
        """
        n_axes = ids.shape[-1]  # 获取轴数,如2表示2D位置
        # 对每个轴应用RoPE编码并拼接
        emb = torch.cat(
            [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
            dim=-3,
        )  # 每个轴的编码在-3维度上拼接,因为最后两个维度是旋转矩阵
        # 增加head维度
        return emb.unsqueeze(1)  # [B, 1, D, 2, 2],D为总编码维度,2x2为RoPE的旋转矩阵

# src/flux/math.py
import torch
from einops import rearrange
from torch import Tensor

def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
    """
    注意力机制
    q: query张量 [batch, heads, seq_len, head_dim]
    k: key张量 [batch, heads, seq_len, head_dim]
    v: value张量 [batch, heads, seq_len, head_dim]
    pe: 位置编码张量 [batch, 1, dim, 2, 2]
    """
    q, k = apply_rope(q, k, pe)  # 将预计算的rope旋转矩阵应用于q,k

    x = torch.nn.functional.scaled_dot_product_attention(q, k, v)  # 计算注意力
    x = rearrange(x, "B H L D -> B L (H D)")  # 将多头注意力组合回整体

    return x  # [batch, seq_len, heads*head_dim]

def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
    assert dim % 2 == 0
    # 计算每个位置的频率缩放因子;先生成序列 [0, 2, 4, ..., dim-2],然后除以 dim,得到得到 [0, 2/dim, 4/dim, ..., (dim-2)/dim]
    scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim
    omega = 1.0 / (theta**scale)  # 计算最终的角频率  ω_i = 1/θ^(2i/dim)
    out = torch.einsum("...n,d->...nd", pos, omega)  # Einstein 求和约定计算位置和频率的外积,shape: [batch, seq_len, dim//2]
    out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)  # 构建旋转矩阵,shape: [batch, seq_len, dim//2, 4]
    out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)  # 重排列成矩阵形式,将最后一个维度4拆分成2*2,shape: [batch, seq_len, dim//2, 2, 2]
    return out.float()

def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
    # 输入的q、k张量的最后一维拆分为两个维度,相当于构建复述形式;[batch, heads, seq_len, head_dim] --> [batch, heads, seq_len, head_dim//2, 1, 2],新增加的维度1是为了广播计算添加
    xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
    xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
    # 进行旋转变换;freqs_cis[..., 0]、freqs_cis[..., 1]是一个行数为2的列向量,xq_[..., 0]、xq_[..., 1]、xk_[..., 0]、xk_[..., 1]是一个列数为2的行向量
    xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
    xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
    return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)  # 将结果重排列回原来的形状

旋转位置编码是在绝对位置编码的基础上通过旋转操作表征相对特征,在初始绝对位置编码时,为将文本和图像特征拼接,Flux.1模型中是初始化三维位置编码,图片中的每个patch通过时空坐标 ( t , h , w ) (t,h,w) (t,h,w)进行索引,此类构建方式与常规的视频生成模型相同,第一个维度 t t t就表示时间维度,而图片任务为可视为单帧视频,故 t t t设置为0;而对于文本位置编码,因为文本没有空间信息,故其是一个长度维度与图片位置编码可能不同但其他维度完全相同,且值为全0的张量。

Flux.1 backbone

Flux.1模型的backbone实现如下所示,可与图1对比。img对应图1中的latent、txt对应经过T5 Encoder提取的文本嵌入、y对应经过CLIP提取的文本嵌入。基于timestepsguidance初始化的编码特征会以及clip文本嵌入三者相加为vec,会在整个迭代预测过程中作为调制向量,用于计算对应的调制项。在初始化旋转位置编码pe后,先以imgtxtvecpe为输入经过多个双流模块的计算;然后将imgtxt拼接为一个张量,再经过多个单流模块的计算;最后只截取图片序列,将其输出层归一化、线性层等模块输出最终的latent。

class Flux(nn.Module):
    """
    Transformer model for flow matching on sequences.
    """

    def __init__(self, params: FluxParams):
        super().__init__()

        self.params = params
        self.in_channels = params.in_channels
        self.out_channels = params.out_channels
        if params.hidden_size % params.num_heads != 0:
            raise ValueError(
                f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
            )  # 隐藏层维度必须能被头数整除
        pe_dim = params.hidden_size // params.num_heads  # 位置编码的维度数与单个自注意力头的维度数相同
        if sum(params.axes_dim) != pe_dim:  # 各个轴的维度之和应该等于位置编码的维度数
            raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
        self.hidden_size = params.hidden_size
        self.num_heads = params.num_heads
        self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)  # 多维旋转位置编码
        self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
        self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
        self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
        self.guidance_in = (
            MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
        )
        self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)

        self.double_blocks = nn.ModuleList(
            [
                DoubleStreamBlock(
                    self.hidden_size,
                    self.num_heads,
                    mlp_ratio=params.mlp_ratio,
                    qkv_bias=params.qkv_bias,
                )
                for _ in range(params.depth)
            ]
        )  # 双流注意力模块堆

        self.single_blocks = nn.ModuleList(
            [
                SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
                for _ in range(params.depth_single_blocks)
            ]
        )  # 单流注意力模块堆

        self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)

    def forward(
        self,
        img: Tensor,  # 重排后的图像张量
        img_ids: Tensor,
        txt: Tensor,  # t5文本嵌入
        txt_ids: Tensor,
        timesteps: Tensor,
        y: Tensor,  # vec  # clip文本嵌入
        guidance: Tensor | None = None,
    ) -> Tensor:
        if img.ndim != 3 or txt.ndim != 3:
            raise ValueError("Input img and txt tensors must have 3 dimensions.")

        # running on sequences img
        img = self.img_in(img)
        vec = self.time_in(timestep_embedding(timesteps, 256))  # 时间编码
        if self.params.guidance_embed:
            if guidance is None:
                raise ValueError("Didn't get guidance strength for guidance distilled model.")
            vec = vec + self.guidance_in(timestep_embedding(guidance, 256))  # 叠加引导编码
        vec = vec + self.vector_in(y)  # 至此,时间编码、引导编码、clip文本嵌入都融合到vec中
        txt = self.txt_in(txt)

        ids = torch.cat((txt_ids, img_ids), dim=1)  # 文本位置ids和图像位置ids拼接
        pe = self.pe_embedder(ids)  # 旋转位置编码

        for block in self.double_blocks:
            img, txt = block(img=img, txt=txt, vec=vec, pe=pe)  # 双流模块

        img = torch.cat((txt, img), 1)  # 文本隐向量和图像隐向量拼接称单一向量
        for block in self.single_blocks:
            img = block(img, vec=vec, pe=pe)  # 单流模块
        img = img[:, txt.shape[1] :, ...]  # 只使用后半段的图片ids序列

        img = self.final_layer(img, vec)  # (B, img_seq_len, out_channels)
        return img

双流模块

该模块称为双流的原因就是其内部为图像、文本特征采用单独的模块进行计算。先使用输入的vec为图片和文本分别预测两个调制模块,然后均先分别使用第一个调值模块的scaleshift分量分别处理图片张量、文本张量,再分别应用对应的注意力模块处理q、k、v张量并将图像、文本的q、k、v张量对应拼接,得到最终参与注意力计算的q、k、v张量。注意力计算后,再从结果中拆分出文本、图像注意力输出,再分别使用第二个调制模块配合对应的层归一化、mlp层和原始输入的imgtxt得到最终的输出imgtxt向量。

class DoubleStreamBlock(nn.Module):
    def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
        super().__init__()

        mlp_hidden_dim = int(hidden_size * mlp_ratio)
        self.num_heads = num_heads
        self.hidden_size = hidden_size
        self.img_mod = Modulation(hidden_size, double=True)
        self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)

        self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.img_mlp = nn.Sequential(
            nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
            nn.GELU(approximate="tanh"),
            nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
        )

        self.txt_mod = Modulation(hidden_size, double=True)
        self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)

        self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.txt_mlp = nn.Sequential(
            nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
            nn.GELU(approximate="tanh"),
            nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
        )

    def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
        """
        参数:
            img: 图像张量 [batch, seq_len, hidden_size], 图片latent
            txt: 文本张量 [batch, seq_len, hidden_size],prompt经过T5 encoder后的嵌入
            vec: 调制向量 [batch, hidden_size], 是prompt经过clip、时间经过位置编码,guidance经过位置编码后的拼接张量
            pe: 位置编码张量 [batch, 1, dim, 2, 2]
        
        返回:
            img: 处理后的图像张量 [batch, seq_len, hidden_size]
            txt: 处理后的文本张量 [batch, seq_len, hidden_size]
        """
        # 分别预测图片和文本的调制项
        img_mod1, img_mod2 = self.img_mod(vec)
        txt_mod1, txt_mod2 = self.txt_mod(vec)

        # prepare image for attention
        img_modulated = self.img_norm1(img)  # 归一化
        img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift  # 调制
        img_qkv = self.img_attn.qkv(img_modulated)  # 单独使用图片自注意力模块中的qkv子模块从拼接在一起的qkv
        img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)  # 拆分
        img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)  # 单独使用图片自注意力模块中的归一化模块

        # prepare txt for attention
        txt_modulated = self.txt_norm1(txt)
        txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
        txt_qkv = self.txt_attn.qkv(txt_modulated)
        txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
        txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)

        # run actual attention
        q = torch.cat((txt_q, img_q), dim=2)  # 文本和图片的q拼接
        k = torch.cat((txt_k, img_k), dim=2)  # 文本和图片的k拼接
        v = torch.cat((txt_v, img_v), dim=2)  # 文本和图片的v拼接

        attn = attention(q, k, v, pe=pe)  # 注意力计算  
        txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]  # 拆分出文本和图片的注意力结果

        # calculate the img blocks
        img = img + img_mod1.gate * self.img_attn.proj(img_attn)  # 先单独使用图片自注意力模块中的投影层转换,再乘上图片调制项的门控系数,最后加上图片调制项的偏移量
        img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)  # 与图片调制项中的第二组调制参数组合

        # calculate the txt blocks
        txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
        txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
        return img, txt

单流模块

单流模块的输入是文本和图像拼接之后的向量,故只使用一个模块组进行计算。整个计算流程与双流模块基本一致,但有一点不同是单流模块中不像常规的transformer block,先进行注意力计算,然后执行mlp层,而是在构建q、k、v张量时就并行预测了mlp的输出,然后对注意力输出进行正则化处理时直接和mlp层内容拼接,也是并行处理。

class SingleStreamBlock(nn.Module):
    """
    A DiT block with parallel linear layers as described in
    https://arxiv.org/abs/2302.05442 and adapted modulation interface.
    """

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        mlp_ratio: float = 4.0,
        qk_scale: float | None = None,
    ):
        super().__init__()
        self.hidden_dim = hidden_size
        self.num_heads = num_heads
        head_dim = hidden_size // num_heads
        self.scale = qk_scale or head_dim**-0.5

        self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
        # qkv and mlp_in
        self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)  # 并行线性层,qkv转换时同时将mlp输入也预测处理
        # proj and mlp_out 
        self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)  # 并行线性层,最后转换注意力计算时,直接和mlp的输出拼接一并处理

        self.norm = QKNorm(head_dim)

        self.hidden_size = hidden_size
        self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)

        self.mlp_act = nn.GELU(approximate="tanh")
        self.modulation = Modulation(hidden_size, double=False)

    def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
        """
        参数:
            x: 输入张量 [batch, seq_len, hidden_size],经过双流注意力模块后输出的image latent和prompt latent拼接后的张量
            vec: 调制向量 [batch, hidden_size], 是prompt经过clip、时间经过位置编码,guidance经过位置编码后的拼接张量
            pe: 位置编码张量 [batch, 1, dim, 2, 2]
        
        返回:
            x: 处理后的张量 [batch, seq_len, hidden_size]
        """
        mod, _ = self.modulation(vec)  # 调制,只预测一组调制参数
        x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift  # 调制
        qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)  # 拆分

        q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
        q, k = self.norm(q, k, v)

        # compute attention
        attn = attention(q, k, v, pe=pe)
        # compute activation in mlp stream, cat again and run second linear layer
        output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
        return x + mod.gate * output

VAE

Flux.1模型中使用到的VAE模型保持了常规的VAE架构,主要由编码器、解码器和重参数化采样层组成;不同点是其中的编码器和解码器均是Unet结构,并在Unet结构的中间层均加入了一个自注意力模块。其他没有太多可说的,具体可参考推理代码中的定义。

生图+采样

与常规的扩散模型相似,流匹配范式采样的原始数据也是纯噪声数据x,再基于x、提示词prompts构建输入,主要是调整图像shape和构建对应的多维绝对位置ids序列、文本嵌入和对应的文本绝对位置ids序列。在使用预定义的调度器构建采样时间步长后,就能开始去噪采样,flow matching范式是直接以线性插值的方式迭代更新。flux backbone预测的是离散序列形式的图像隐向量,先将其解包回图像隐空间,再使用vae的编码器将其解码回像素空间得到最终的生成图片。此过程涉及的细节角度,具体可参考以下代码中的注释,更多细节可进一步参考原始推理代码。想进一步了解Flow matching的读者也可参考笔者之前的文章从扩散模型开始的生成模型范式演变–FM(1)从扩散模型开始的生成模型范式演变–FM(2)

# 初始化噪声
def get_noise(
    num_samples: int,
    height: int,
    width: int,
    device: torch.device,
    dtype: torch.dtype,
    seed: int,
):  #  采样图片隐空间尺寸的噪声
    return torch.randn(  # 从标准正态分布中采样随机噪声
        num_samples,
        16,
        # allow for packing
        2 * math.ceil(height / 16),
        2 * math.ceil(width / 16),
        dtype=dtype,
        generator=torch.Generator(device="cpu").manual_seed(seed),
    ).to(device)

# 常规的数据准备函数
def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
    bs, c, h, w = img.shape  # 此处的img的shape与经过vae编码后的隐向量shape相同
    if bs == 1 and not isinstance(prompt, str):
        bs = len(prompt)
    # 图像重排和批次扩展;将图像隐向量在平面维度上分割为2*2的patch,再经过展平实现长度为H/2*W/2的patches序列,即完成了图像隐向量离散序列化,每个patch的维度是C*2*2
    img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)  # [B, C, H, W] -> [B, H/2*W/2, C*2*2]
    if img.shape[0] == 1 and bs > 1:
        img = repeat(img, "1 ... -> bs ...", bs=bs)

    # 生成图像三维位置ids
    img_ids = torch.zeros(h // 2, w // 2, 3)  # 因为将图像隐向量分割为2*2的patch,以空间角度位置编码的角度来看最后一个维度应该为2,此处为3的原因是后续会和文本位置ids拼接,在最前面添加一个区域模态的维度
    img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]  # 行索引,[0, 1, 2, ..., H/2-1]
    img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]  # 列索引,[0, 1, 2, ..., W/2-1]
    # 将三维位置ids拉平,再补齐batch
    img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)  # [H/2, W/2, 3] -> [B, H/2*W/2, 3]

    if isinstance(prompt, str):
        prompt = [prompt]
    txt = t5(prompt)  # t5 encoder编码后的文本嵌入
    if txt.shape[0] == 1 and bs > 1:
        txt = repeat(txt, "1 ... -> bs ...", bs=bs)
    txt_ids = torch.zeros(bs, txt.shape[1], 3)  # 为了能与图像的位置编码拼接,最后一个维度也是3

    vec = clip(prompt)  # clip encoder编码后的文本嵌入
    if vec.shape[0] == 1 and bs > 1:
        vec = repeat(vec, "1 ... -> bs ...", bs=bs)

    return {
        "img": img,  # 重排后的图像张量
        "img_ids": img_ids.to(img.device),  # 图像多维位置ids
        "txt": txt.to(img.device),  # t5文本嵌入
        "txt_ids": txt_ids.to(img.device),  # 文本位置ids
        "vec": vec.to(img.device),  # clip文本向量
    }


# 构建经过两个点(x1,y1)和(x2,y2)的线性函数
def get_lin_function(
    x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
) -> Callable[[float], float]:
    m = (y2 - y1) / (x2 - x1)
    b = y1 - m * x1
    return lambda x: m * x + b

def get_schedule(
    num_steps: int,
    image_seq_len: int,
    base_shift: float = 0.5,
    max_shift: float = 1.15,
    shift: bool = True,
) -> list[float]:
    # extra step for zero
    timesteps = torch.linspace(1, 0, num_steps + 1)  # 从1到0的num_steps+1个等差数列

    # shifting the schedule to favor high timesteps for higher signal images
    if shift:
        # estimate mu based on linear estimation between two points
        mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
        timesteps = time_shift(mu, 1.0, timesteps)

    return timesteps.tolist()


# 去噪
def denoise(
    model: Flux,
    # model input
    img: Tensor,
    img_ids: Tensor,
    txt: Tensor,
    txt_ids: Tensor,
    vec: Tensor,
    # sampling parameters
    timesteps: list[float],
    guidance: float = 4.0,
    # extra img tokens (channel-wise)
    img_cond: Tensor | None = None,
    # extra img tokens (sequence-wise)
    img_cond_seq: Tensor | None = None,
    img_cond_seq_ids: Tensor | None = None,
):
    # this is ignored for schnell
    guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)  # 创建一个长度为batch size的一维张量,所有值都是guidance
    for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
        t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)  # 创建一个长度为batch size的一维张量,所有值都是t_curr
        img_input = img
        img_input_ids = img_ids
        if img_cond is not None:
            img_input = torch.cat((img, img_cond), dim=-1)
        if img_cond_seq is not None:
            assert (
                img_cond_seq_ids is not None
            ), "You need to provide either both or neither of the sequence conditioning"
            img_input = torch.cat((img_input, img_cond_seq), dim=1)
            img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1)
        pred = model(
            img=img_input,
            img_ids=img_input_ids,
            txt=txt,
            txt_ids=txt_ids,
            y=vec,
            timesteps=t_vec,
            guidance=guidance_vec,
        )  # 使用flux backbone预测flow matching范式中当前时间步的移动速度
        if img_input_ids is not None:
            pred = pred[:, : img.shape[1]]  # 只使用前半段的图片ids序列

        img = img + (t_prev - t_curr) * pred  # flow matching范式更新时就直接以进行线性插值,即新的图像值就是当前预测值和上一步图像值的插值
    return img

您可能感兴趣的与本文相关的镜像

FLUX.1-dev

FLUX.1-dev

图片生成
FLUX

FLUX.1-dev 是一个由 Black Forest Labs 创立的开源 AI 图像生成模型版本,它以其高质量和类似照片的真实感而闻名,并且比其他模型更有效率

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值