OmniGen: Unified Image Generation(代码的学习)


github项目地址

OmniGen项目介绍

一个通用的图像生成模型
大型语言模型(llm)的出现实现了统一的语言生成任务,并彻底改变了人机交互。然而,在图像生成领域,一个能够在单个框架内处理各种任务的统一模型在很大程度上仍未被探索。在这项工作中,我们引入了综合性的,一个新的扩散模型的统一图像生成。与流行的扩散模型(例如,稳定扩散)不同,通用技术不再需要额外的模块,如控制网或ip适配器来处理不同的控制条件。
OmniGen具有以下特点:
1)统一:它不仅具有文本到图像的生成功能,而且还天生支持各种下游任务,如图像编辑、主题驱动的生成和视觉条件生成。此外,通用综合技术还可以通过将经典的计算机视觉任务转换为图像生成任务来处理这些任务,如边缘检测和人体姿态识别。
2)简单性:通用集成系统的架构高度简化,消除了对额外的文本编码器的需要。综合性的是高度简化的,不需要额外的文本编码器。此外,与现有的扩散模型相比,它更为用户友好,使得复杂的任务可以通过指令完成,而不需要额外的预处理步骤(例如,人体姿态估计),从而大大简化了图像生成的工作流程。
3)知识转移:受益于统一格式的学习,综合性能有效地在不同的任务之间转移知识,管理看不见的任务和领域,并展示出新的能力。我们还探讨了该模型的推理能力和思维链机制的潜在应用。
这项工作代表了第一次尝试一个通用的图像生成模型,仍然有几个未解决的问题。
可以实现的任务
文生图
在这里插入图片描述
混合模态的提示
比如可以实现文本编辑和风格迁移
在这里插入图片描述
可以实现图像超分,图像增亮
在这里插入图片描述
可以实现上下文理解能力,给他参考的处理方式,他能学习到任务的需求
在这里插入图片描述
所以总的来讲,该模型的理解能力很强大, 其效果也是非常不错的。

模型的整体结构

在这里插入图片描述

这里对输入数据的处理方式需要学习:

对输入文本的处理

这里用OmniGenProcessor对输入文本进行处理

class OmniGenProcessor:
    def __init__(self, 
                text_tokenizer, 
                max_image_size: int=1024):
        self.text_tokenizer = text_tokenizer
        self.max_image_size = max_image_size

        self.image_transform = transforms.Compose([
            transforms.Lambda(lambda pil_image: crop_arr(pil_image, max_image_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
        ])

        self.collator = OmniGenCollator()
        self.separate_collator = OmniGenSeparateCollator()

    @classmethod
    def from_pretrained(cls, model_name):
        if not os.path.exists(model_name):
            cache_folder = os.getenv('HF_HUB_CACHE')
            model_name = snapshot_download(repo_id=model_name,
                                           cache_dir=cache_folder,
                                           allow_patterns="*.json")
        text_tokenizer = AutoTokenizer.from_pretrained(model_name)

        return cls(text_tokenizer)


    def process_image(self, image):
        image = Image.open(image).convert('RGB')
        return self.image_transform(image)
    
    def process_multi_modal_prompt(self, text, input_images):
        text = self.add_prefix_instruction(text)
        if input_images is None or len(input_images) == 0:
            model_inputs = self.text_tokenizer(text)
            return {"input_ids": model_inputs.input_ids, "pixel_values": None, "image_sizes": None}

        pattern = r"<\|image_\d+\|>"
        prompt_chunks = [self.text_tokenizer(chunk).input_ids for chunk in re.split(pattern, text)] 

        for i in range(1, len(prompt_chunks)):
            if prompt_chunks[i][0] == 1:
                prompt_chunks[i] = prompt_chunks[i][1:]

        image_tags = re.findall(pattern, text) 
        image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags]

        unique_image_ids = sorted(list(set(image_ids)))
        assert unique_image_ids == list(range(1, len(unique_image_ids)+1)), f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}"
        # total images must be the same as the number of image tags
        assert len(unique_image_ids) == len(input_images), f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(input_images)} images"
        
        input_images = [input_images[x-1] for x in image_ids]

        all_input_ids = []
        img_inx = []
        idx = 0
        for i in range(len(prompt_chunks)):
            all_input_ids.extend(prompt_chunks[i])
            if i != len(prompt_chunks) -1:
                start_inx = len(all_input_ids)
                size = input_images[i].size(-2) *  input_images[i].size(-1) // 16 // 16
                img_inx.append([start_inx, start_inx+size])
                all_input_ids.extend([0]*size)

        return {"input_ids": all_input_ids, "pixel_values": input_images, "image_sizes": img_inx}


    def add_prefix_instruction(self, prompt):
        user_prompt = '<|user|>\n'
        generation_prompt = 'Generate an image according to the following instructions\n'
        assistant_prompt = '<|assistant|>\n<|diffusion|>'
        prompt_suffix = "<|end|>\n"
        prompt = f"{user_prompt}{generation_prompt}{prompt}{prompt_suffix}{assistant_prompt}"
        return prompt


    def __call__(self, 
                instructions: List[str], 
                input_images: List[List[str]] = None,
                height: int = 1024,
                width: int = 1024,
                negative_prompt: str = "low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers.",
                use_img_cfg: bool = True,
                separate_cfg_input: bool = False,
                ) -> Dict:

        if input_images is None:
            use_img_cfg = False
        if isinstance(instructions, str):
            instructions = [instructions]
            input_images = [input_images]
        
        input_data = []
        for i in range(len(instructions)):
            cur_instruction = instructions[i]
            cur_input_images = None if input_images is None else input_images[i]
            if cur_input_images is not None and len(cur_input_images) > 0:
                cur_input_images = [self.process_image(x) for x in cur_input_images]
            else:
                cur_input_images = None
                assert "<img><|image_1|></img>" not in cur_instruction
            
            mllm_input = self.process_multi_modal_prompt(cur_instruction, cur_input_images)

        
            neg_mllm_input, img_cfg_mllm_input = None, None
            neg_mllm_input = self.process_multi_modal_prompt(negative_prompt, None)
            if use_img_cfg:
                if cur_input_images is not None and len(cur_input_images) >= 1:
                    img_cfg_prompt = [f"<img><|image_{i+1}|></img>" for i in range(len(cur_input_images))]
                    img_cfg_mllm_input = self.process_multi_modal_prompt(" ".join(img_cfg_prompt), cur_input_images)
                else:
                    img_cfg_mllm_input = neg_mllm_input

            input_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input, [height, width]))

        if separate_cfg_input:
            return self.separate_collator(input_data)
        return self.collator(input_data)




class OmniGenCollator:
    def __init__(self, pad_token_id=2, hidden_size=3072):
        self.pad_token_id = pad_token_id
        self.hidden_size = hidden_size
    
    def create_position(self, attention_mask, num_tokens_for_output_images):
        position_ids = []
        text_length = attention_mask.size(-1)
        img_length = max(num_tokens_for_output_images)  
        for mask in attention_mask:
            temp_l = torch.sum(mask)
            temp_position = [0]*(text_length-temp_l) + [i for i in range(temp_l+img_length+1)] # we add a time embedding into the sequence, so add one more token
            position_ids.append(temp_position)
        return torch.LongTensor(position_ids)
    
    def create_mask(self, attention_mask, num_tokens_for_output_images):
        extended_mask = []
        padding_images = []
        text_length = attention_mask.size(-1)
        img_length = max(num_tokens_for_output_images)
        seq_len = text_length + img_length + 1 # we add a time embedding into the sequence, so add one more token
        inx = 0
        for mask in attention_mask:
            temp_l = torch.sum(mask)
            pad_l = text_length - temp_l

            temp_mask = torch.tril(torch.ones(size=(temp_l+1, temp_l+1)))

            image_mask = torch.zeros(size=(temp_l+1, img_length))
            temp_mask = torch.cat([temp_mask, image_mask], dim=-1)

            image_mask = torch.ones(size=(img_length, temp_l+img_length+1))
            temp_mask = torch.cat([temp_mask, image_mask], dim=0)

            if pad_l > 0:
                pad_mask = torch.zeros(size=(temp_l+1+img_length, pad_l))
                temp_mask = torch.cat([pad_mask, temp_mask], dim=-1)

                pad_mask = torch.ones(size=(pad_l, seq_len))
                temp_mask = torch.cat([pad_mask, temp_mask], dim=0)

            true_img_length = num_tokens_for_output_images[inx]
            pad_img_length = img_length - true_img_length
            if pad_img_length > 0:
                temp_mask[:, -pad_img_length:] = 0
                temp_padding_imgs = torch.zeros(size=(1, pad_img_length, self.hidden_size))
            else:
                temp_padding_imgs = None
            
            extended_mask.append(temp_mask.unsqueeze(0))
            padding_images.append(temp_padding_imgs)
            inx += 1
        return torch.cat(extended_mask, dim=0), padding_images
    
    def adjust_attention_for_input_images(self, attention_mask, image_sizes):
        for b_inx in image_sizes.keys():
            for start_inx, end_inx in image_sizes[b_inx]:
                attention_mask[b_inx][start_inx:end_inx, start_inx:end_inx] = 1

        return attention_mask
    
    def pad_input_ids(self, input_ids, image_sizes):
        max_l = max([len(x) for x in input_ids])
        padded_ids = []
        attention_mask = []
        new_image_sizes = []

        for i in range(len(input_ids)):
            temp_ids = input_ids[i]
            temp_l = len(temp_ids)
            pad_l = max_l - temp_l
            if pad_l == 0:
                attention_mask.append([1]*max_l)
                padded_ids.append(temp_ids)
            else:
                attention_mask.append([0]*pad_l+[1]*temp_l)
                padded_ids.append([self.pad_token_id]*pad_l+temp_ids)
            
            if i in image_sizes:
                new_inx = []
                for old_inx in image_sizes[i]:
                    new_inx.append([x+pad_l for x in old_inx])
                image_sizes[i] = new_inx

        return torch.LongTensor(padded_ids), torch.LongTensor(attention_mask), image_sizes


    def process_mllm_input(self, mllm_inputs, target_img_size):
        num_tokens_for_output_images = []
        for img_size in target_img_size:
            num_tokens_for_output_images.append(img_size[0]*img_size[1]//16//16)

        pixel_values, image_sizes = [], {}
        b_inx = 0
        for x in mllm_inputs:
            if x['pixel_values'] is not None:
                pixel_values.extend(x['pixel_values'])
                for size in x['image_sizes']:
                    if b_inx not in image_sizes:
                        image_sizes[b_inx] = [size]
                    else:
                        image_sizes[b_inx].append(size)
            b_inx += 1     
        pixel_values = [x.unsqueeze(0) for x in pixel_values]

        
        input_ids = [x['input_ids'] for x in mllm_inputs]
        padded_input_ids, attention_mask, image_sizes = self.pad_input_ids(input_ids, image_sizes)
        position_ids = self.create_position(attention_mask, num_tokens_for_output_images)
        attention_mask, padding_images = self.create_mask(attention_mask, num_tokens_for_output_images)
        attention_mask = self.adjust_attention_for_input_images(attention_mask, image_sizes)

        return padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes
    
    
    def __call__(self, features):
        mllm_inputs = [f[0] for f in features]
        cfg_mllm_inputs = [f[1] for f in features]
        img_cfg_mllm_input = [f[2] for f in features]
        target_img_size = [f[3] for f in features]

        
        if img_cfg_mllm_input[0] is not None:
            mllm_inputs = mllm_inputs + cfg_mllm_inputs + img_cfg_mllm_input
            target_img_size = target_img_size + target_img_size + target_img_size
        else:
            mllm_inputs = mllm_inputs + cfg_mllm_inputs
            target_img_size = target_img_size + target_img_size


        all_padded_input_ids, all_position_ids, all_attention_mask, all_padding_images, all_pixel_values, all_image_sizes = self.process_mllm_input(mllm_inputs, target_img_size)

        data = {"input_ids": all_padded_input_ids,
        "attention_mask": all_attention_mask,
        "position_ids": all_position_ids,
        "input_pixel_values": all_pixel_values,
        "input_image_sizes": all_image_sizes,
        "padding_images": all_padding_images,
        }
        return data


class OmniGenSeparateCollator(OmniGenCollator):
    def __call__(self, features):
        mllm_inputs = [f[0] for f in features]
        cfg_mllm_inputs = [f[1] for f in features]
        img_cfg_mllm_input = [f[2] for f in features]
        target_img_size = [f[3] for f in features]

        
        all_padded_input_ids, all_attention_mask, all_position_ids, all_pixel_values, all_image_sizes, all_padding_images = [], [], [], [], [], []


        padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes = self.process_mllm_input(mllm_inputs, target_img_size)
        all_padded_input_ids.append(padded_input_ids)
        all_attention_mask.append(attention_mask)
        all_position_ids.append(position_ids)
        all_pixel_values.append(pixel_values)
        all_image_sizes.append(image_sizes)
        all_padding_images.append(padding_images)

        if cfg_mllm_inputs[0] is not None:
            padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes = self.process_mllm_input(cfg_mllm_inputs, target_img_size)
            all_padded_input_ids.append(padded_input_ids)
            all_attention_mask.append(attention_mask)
            all_position_ids.append(position_ids)
            all_pixel_values.append(pixel_values)
            all_image_sizes.append(image_sizes)
            all_padding_images.append(padding_images)
        if img_cfg_mllm_input[0] is not None:
            padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes = self.process_mllm_input(img_cfg_mllm_input, target_img_size)
            all_padded_input_ids.append(padded_input_ids)
            all_attention_mask.append(attention_mask)
            all_position_ids.append(position_ids)
            all_pixel_values.append(pixel_values)
            all_image_sizes.append(image_sizes)
            all_padding_images.append(padding_images)

        data = {"input_ids": all_padded_input_ids,
        "attention_mask": all_attention_mask,
        "position_ids": all_position_ids,
        "input_pixel_values": all_pixel_values,
        "input_image_sizes": all_image_sizes,
        "padding_images": all_padding_images,
        }
        return data

例如刚输入的文本是:

prompt= "Make <img><|image_1|></img> has the same style of <img><|image_2|></img>.Maintain the consistency of the content in <img><|image_1|></img> and ensure that there is only the style of  <img><|image_2|></img>, without its content.",
            

首先对于输入的文本进行填充,增加前后缀,增加后的结果如下

<|user|>
Generate an image according to the following instructions
Make  <img><|image_1|></img> has the same style of <img><|image_2|></img>.Maintain the consistency of the content in <img><|image_1|></img> and ensure that there is only the style of  <img><|image_2|></img>, without its content.<|end|>
<|assistant|>
<|diffusion|>

然后将text按照出现图像的位置进行划分

#按照这个正则表达式,对整个text进行划分
pattern = r"<\|image_\d+\|>"
prompt_chunks = [self.text_tokenizer(chunk).input_ids for chunk in re.split(pattern, text)] 

将整个text拆分为5段
在这里插入图片描述
然后将出去第一段的起始字符去掉

for i in range(1, len(prompt_chunks)):
     if prompt_chunks[i][0] == 1:
         prompt_chunks[i] = prompt_chunks[i][1:]

在这里插入图片描述
然后对text中图像的标签和id进行识别
可以看到我们有4个图像的标签,然后实际用的图像只有2个
在这里插入图片描述
然后将图像和文本的信息融合在一起

 all_input_ids = []
 img_inx = []
 idx = 0
 for i in range(len(prompt_chunks)):
     all_input_ids.extend(prompt_chunks[i])
     if i != len(prompt_chunks) -1:
         start_inx = len(all_input_ids)
         size = input_images[i].size(-2) *  input_images[i].size(-1) // 16 // 16
         img_inx.append([start_inx, start_inx+size])
         all_input_ids.extend([0]*size)

对输入进LLM的数据的处理

对于输图像信息,会转化为1024的长度

用这个卷积吗模块Conv2d(4, 3072, kernel_size=(2, 2), stride=(2, 2))对1,4,64,64的输入数据转化为1,1024,3972

 def patch_multiple_resolutions(self, latents, padding_latent=None, is_input_images:bool=False):
        if isinstance(latents, list):
            return_list = False
            if padding_latent is None:
                padding_latent = [None] * len(latents)
                return_list = True
            patched_latents, num_tokens, shapes = [], [], []
            for latent, padding in zip(latents, padding_latent):
                height, width = latent.shape[-2:]
                if is_input_images:
                #利用这个卷积将输入的4,64,64的图像变为,1,1024,3072
                    latent = self.input_x_embedder(latent)
                else:
                    latent = self.x_embedder(latent)
                pos_embed = self.cropped_pos_embed(height, width)    
                latent = latent + pos_embed
                if padding is not None:
                    latent = torch.cat([latent, padding], dim=-2)
                patched_latents.append(latent)

                num_tokens.append(pos_embed.size(1))
                shapes.append([height, width])
            if not return_list:
                latents = torch.cat(patched_latents, dim=0)
            else:
                latents = patched_latents
        else:
            height, width = latents.shape[-2:]
            if is_input_images:
                latents = self.input_x_embedder(latents)
            else:
                latents = self.x_embedder(latents)
            pos_embed = self.cropped_pos_embed(height, width)  
            latents = latents + pos_embed
            num_tokens = latents.size(1)
            shapes = [height, width]
        return latents, num_tokens, shapes

将输入的图像信息嵌入文本嵌入

整个Omnigen

class OmniGen(nn.Module, PeftAdapterMixin):
    """
    Diffusion model with a Transformer backbone.
    """
    def __init__(
        self,
        transformer_config: Phi3Config,
        patch_size=2,
        in_channels=4,
        pe_interpolation: float = 1.0,
        pos_embed_max_size: int = 192,
    ):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = in_channels
        self.patch_size = patch_size
        self.pos_embed_max_size = pos_embed_max_size

        hidden_size = transformer_config.hidden_size

        self.x_embedder = PatchEmbedMR(patch_size, in_channels, hidden_size, bias=True)
        self.input_x_embedder = PatchEmbedMR(patch_size, in_channels, hidden_size, bias=True)

        self.time_token = TimestepEmbedder(hidden_size)
        self.t_embedder = TimestepEmbedder(hidden_size)
        
        self.pe_interpolation = pe_interpolation
        pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, interpolation_scale=self.pe_interpolation, base_size=64)
        self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=True)

        self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)

        self.initialize_weights()

        self.llm = Phi3Transformer(config=transformer_config)
        self.llm.config.use_cache = False
    
    @classmethod
    def from_pretrained(cls, model_name):
        if not os.path.exists(model_name):
            cache_folder = os.getenv('HF_HUB_CACHE')
            model_name = snapshot_download(repo_id=model_name,
                                           cache_dir=cache_folder,
                                           ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
        config = Phi3Config.from_pretrained(model_name)
        model = cls(config)
        if os.path.exists(os.path.join(model_name, 'model.safetensors')):
            print("Loading safetensors")
            ckpt = load_file(os.path.join(model_name, 'model.safetensors'))
        else:
            ckpt = torch.load(os.path.join(model_name, 'model.pt'), map_location='cpu')
        model.load_state_dict(ckpt)
        return model

    def initialize_weights(self):
        assert not hasattr(self, "llama")

        # Initialize transformer layers:
        def _basic_init(module):
            if isinstance(module, nn.Linear):
                torch.nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
        self.apply(_basic_init)
        
        # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
        w = self.x_embedder.proj.weight.data
        nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
        nn.init.constant_(self.x_embedder.proj.bias, 0)

        w = self.input_x_embedder.proj.weight.data
        nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
        nn.init.constant_(self.x_embedder.proj.bias, 0)


        # Initialize timestep embedding MLP:
        nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
        nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
        nn.init.normal_(self.time_token.mlp[0].weight, std=0.02)
        nn.init.normal_(self.time_token.mlp[2].weight, std=0.02)

        # Zero-out output layers:
        nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
        nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
        nn.init.constant_(self.final_layer.linear.weight, 0)
        nn.init.constant_(self.final_layer.linear.bias, 0)

    def unpatchify(self, x, h, w):
        """
        x: (N, T, patch_size**2 * C)
        imgs: (N, H, W, C)
        """
        c = self.out_channels

        x = x.reshape(shape=(x.shape[0], h//self.patch_size, w//self.patch_size, self.patch_size, self.patch_size, c))
        x = torch.einsum('nhwpqc->nchpwq', x)
        imgs = x.reshape(shape=(x.shape[0], c, h, w))
        return imgs


    def cropped_pos_embed(self, height, width):
        """Crops positional embeddings for SD3 compatibility."""
        if self.pos_embed_max_size is None:
            raise ValueError("`pos_embed_max_size` must be set for cropping.")

        height = height // self.patch_size
        width = width // self.patch_size
        if height > self.pos_embed_max_size:
            raise ValueError(
                f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
            )
        if width > self.pos_embed_max_size:
            raise ValueError(
                f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
            )

        top = (self.pos_embed_max_size - height) // 2
        left = (self.pos_embed_max_size - width) // 2
        spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
        spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :]
        # print(top, top + height, left, left + width, spatial_pos_embed.size())
        spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
        return spatial_pos_embed


    def patch_multiple_resolutions(self, latents, padding_latent=None, is_input_images:bool=False):
        if isinstance(latents, list):
            return_list = False
            if padding_latent is None:
                padding_latent = [None] * len(latents)
                return_list = True
            patched_latents, num_tokens, shapes = [], [], []
            for latent, padding in zip(latents, padding_latent):
                height, width = latent.shape[-2:]
                if is_input_images:
                    latent = self.input_x_embedder(latent)
                else:
                    latent = self.x_embedder(latent)
                pos_embed = self.cropped_pos_embed(height, width)    
                latent = latent + pos_embed
                if padding is not None:
                    latent = torch.cat([latent, padding], dim=-2)
                patched_latents.append(latent)

                num_tokens.append(pos_embed.size(1))
                shapes.append([height, width])
            if not return_list:
                latents = torch.cat(patched_latents, dim=0)
            else:
                latents = patched_latents
        else:
            height, width = latents.shape[-2:]
            if is_input_images:
                latents = self.input_x_embedder(latents)
            else:
                latents = self.x_embedder(latents)
            pos_embed = self.cropped_pos_embed(height, width)  
            latents = latents + pos_embed
            num_tokens = latents.size(1)
            shapes = [height, width]
        return latents, num_tokens, shapes

    
    def forward(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, padding_latent=None, past_key_values=None, return_past_key_values=True):
        """
        
        """
        input_is_list = isinstance(x, list)
        x, num_tokens, shapes = self.patch_multiple_resolutions(x, padding_latent)
        time_token = self.time_token(timestep, dtype=x[0].dtype).unsqueeze(1)   
        
        if input_img_latents is not None:
            input_latents, _, _ = self.patch_multiple_resolutions(input_img_latents, is_input_images=True)
        if input_ids is not None:
            # wen ben tiao jian qian ru
            condition_embeds = self.llm.embed_tokens(input_ids).clone()
            input_img_inx = 0
            for b_inx in input_image_sizes.keys():
                for start_inx, end_inx in input_image_sizes[b_inx]:
                    condition_embeds[b_inx, start_inx: end_inx] = input_latents[input_img_inx]
                    input_img_inx += 1
            if input_img_latents is not None:
                assert input_img_inx == len(input_latents) 
            # wo men  zhi qian zai mei ge tu xian de wei  zhi yu liu le 1024 de jian xi ,zhe li hui ba tu xiang de xingxi fang  ru qi zhong zuo wei tiaojian
            input_emb = torch.cat([condition_embeds, time_token, x], dim=1)
        else:
            input_emb = torch.cat([time_token, x], dim=1)
        # ran hou jiang de dao de suoyou qian  ru wenben ,shijian .noist   *.3072 fang ru LLM
        output = self.llm(inputs_embeds=input_emb, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values)
        output, past_key_values = output.last_hidden_state, output.past_key_values
        if input_is_list:
            image_embedding = output[:, -max(num_tokens):]
            time_emb = self.t_embedder(timestep, dtype=x.dtype)
            x = self.final_layer(image_embedding, time_emb)
            latents = []
            for i in range(x.size(0)):
                latent = x[i:i+1, :num_tokens[i]]
                latent = self.unpatchify(latent, shapes[i][0], shapes[i][1])
                latents.append(latent)
        else:
            image_embedding = output[:, -num_tokens:]
            time_emb = self.t_embedder(timestep, dtype=x.dtype)
            x = self.final_layer(image_embedding, time_emb)
            latents = self.unpatchify(x, shapes[0], shapes[1])

        if return_past_key_values:
            return latents, past_key_values
        return latents

    @torch.no_grad()
    def forward_with_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache):
        """
        Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
        """        
        self.llm.config.use_cache = use_kv_cache
        model_out, past_key_values = self.forward(x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, past_key_values=past_key_values, return_past_key_values=True)
        if use_img_cfg:
            cond, uncond, img_cond = torch.split(model_out, len(model_out) // 3, dim=0)
            cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
            model_out = [cond, cond, cond]
        else:
            cond, uncond = torch.split(model_out, len(model_out) // 2, dim=0)
            cond = uncond + cfg_scale * (cond - uncond)
            model_out = [cond, cond]
        
        return torch.cat(model_out, dim=0), past_key_values


    @torch.no_grad()
    def forward_with_separate_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, return_past_key_values=True):
        """
        Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
        """        
        self.llm.config.use_cache = use_kv_cache
        if past_key_values is None:
            past_key_values = [None] * len(attention_mask)

        x = torch.split(x, len(x) // len(attention_mask), dim=0)
        timestep = timestep.to(x[0].dtype)
        timestep = torch.split(timestep, len(timestep) // len(input_ids), dim=0)

        model_out, pask_key_values = [], []
        for i in range(len(input_ids)):
            temp_out, temp_pask_key_values = self.forward(x[i], timestep[i], input_ids[i], input_img_latents[i], input_image_sizes[i], attention_mask[i], position_ids[i], past_key_values[i])
            model_out.append(temp_out)
            pask_key_values.append(temp_pask_key_values)

        if len(model_out) == 3:
            cond, uncond, img_cond = model_out
            cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
            model_out = [cond, cond, cond]
        elif len(model_out) == 2:
            cond, uncond = model_out
            cond = uncond + cfg_scale * (cond - uncond)
            model_out = [cond, cond]
        else:
            return model_out[0]
        
        return torch.cat(model_out, dim=0), pask_key_values


将输入的数据都映射为3072的维度
有引导图像的话就会有3组噪声,分别是文本引导,无条件引导,图像引导
然后三个类别的信息按照,文本条件,时间嵌入,噪声的顺序concatenate在一起放入LLM作为条件输入来预测噪声。
这里的文本条件由于之前我们会将其中的图像位置用1024个进行占位,后续会把这个占位的地方替换为对应的图像嵌入。

forward函数

    def forward(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, padding_latent=None, past_key_values=None, return_past_key_values=True):
        """
        
        """
        input_is_list = isinstance(x, list)
        x, num_tokens, shapes = self.patch_multiple_resolutions(x, padding_latent)
        time_token = self.time_token(timestep, dtype=x[0].dtype).unsqueeze(1)   
        
        if input_img_latents is not None:
            input_latents, _, _ = self.patch_multiple_resolutions(input_img_latents, is_input_images=True)
        if input_ids is not None:
            # wen ben tiao jian qian ru
            condition_embeds = self.llm.embed_tokens(input_ids).clone()
            input_img_inx = 0
            for b_inx in input_image_sizes.keys():
            #	这里会将之前记录的图像的位置
                for start_inx, end_inx in input_image_sizes[b_inx]:
                    condition_embeds[b_inx, start_inx: end_inx] = input_latents[input_img_inx]
                    input_img_inx += 1
            if input_img_latents is not None:
                assert input_img_inx == len(input_latents) 
            # wo men  zhi qian zai mei ge tu xian de wei  zhi yu liu le 1024 de jian xi ,zhe li hui ba tu xiang de xingxi fang  ru qi zhong zuo wei tiaojian
            input_emb = torch.cat([condition_embeds, time_token, x], dim=1)
        else:
            input_emb = torch.cat([time_token, x], dim=1)
        # ran hou jiang de dao de suoyou qian  ru wenben ,shijian .noist   *.3072 fang ru LLM
        output = self.llm(inputs_embeds=input_emb, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values)
        output, past_key_values = output.last_hidden_state, output.past_key_values
        if input_is_list:
            image_embedding = output[:, -max(num_tokens):]
            time_emb = self.t_embedder(timestep, dtype=x.dtype)
            x = self.final_layer(image_embedding, time_emb)
            latents = []
            for i in range(x.size(0)):
                latent = x[i:i+1, :num_tokens[i]]
                latent = self.unpatchify(latent, shapes[i][0], shapes[i][1])
                latents.append(latent)
        else:
            image_embedding = output[:, -num_tokens:]
            time_emb = self.t_embedder(timestep, dtype=x.dtype)
            x = self.final_layer(image_embedding, time_emb)
            latents = self.unpatchify(x, shapes[0], shapes[1])

        if return_past_key_values:
            return latents, past_key_values
        return latents

最后将输出的x选择其最后的noise,1,1024,3072,然后将其采用线性映射映射为1,1024,16。最后在改变形状变为1,4,64,64

#最后将输出的x选择其最后的noise,1024
image_embedding = output[:, -num_tokens:]
time_emb = self.t_embedder(timestep, dtype=x.dtype)
#这一步手x转化为1,1024,16
x = self.final_layer(image_embedding, time_emb)
#在最后将1,1024,16转化为1,4,64,64
latents = self.unpatchify(x, shapes[0], shapes[1])
class FinalLayer(nn.Module):
    """
    The final layer of DiT.
    """
    def __init__(self, hidden_size, patch_size, out_channels):
        super().__init__()
        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 2 * hidden_size, bias=True)
        )

    def forward(self, x, c):
    	#把输入的时间嵌入转化为两份 1,3072
        shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
        #将x利用t进行线性变换 x:1,1024,3072
        x = modulate(self.norm_final(x), shift, scale)
        #将1,1024,3072  转化为1,1024,16
        x = self.linear(x)
        return x

modulate模块:进行线性变换

def modulate(x, shift, scale):
    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
FinalLayer(
  (norm_final): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
  (linear): Linear(in_features=3072, out_features=16, bias=True)
  (adaLN_modulation): Sequential(
    (0): SiLU()
    (1): Linear(in_features=3072, out_features=6144, bias=True)
  )
)

在最后将1,1024,16转化为1,4,64,64

    def unpatchify(self, x, h, w):
        """
        x: (N, T, patch_size**2 * C)
        imgs: (N, H, W, C)
        """
        c = self.out_channels

        x = x.reshape(shape=(x.shape[0], h//self.patch_size, w//self.patch_size, self.patch_size, self.patch_size, c))
        x = torch.einsum('nhwpqc->nchpwq', x)
        imgs = x.reshape(shape=(x.shape[0], c, h, w))
        return imgs


总结

OmniGen整个模型采用了类似DiT的架构
对于输入的文本信息,图像信息:
文本信息只进行分词,变为一个一个tokenid,图像信息用vae编码为4,64,64后映射为1024个token
然后按照文本tokenid,时间步嵌入,图像token,以及noise_token的顺序concatenate在一起得到我们的总的输入条件
然后调用一个phi3模型来进行噪声的预测(里面是Dit架构,只有selfattn)

因为有三种条件输入,用户的提示词,消极提示词,图像提示词。所以会构成3个条件
所以,对于以上过程需要重复三次。最后将三者得到的噪声预测按照比例进行加权,得到最终的噪声。
然后迭代这个过程50步得到最后的图像

cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值