MiniCPM-O-2.6代码阅读详解

1.入口文件 modeling_minicpmo.py

class MiniCPMO(MiniCPMOPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        # 初始化语言模型(LLM),使用Qwen2ForCausalLM作为基础模型
        self.llm = Qwen2ForCausalLM(config)
        # 为语言模型添加一个方法,用于生成模型标准输入
        self.llm.prepare_inputs_for_generation = types.MethodType(prepare_inputs_for_generation, self.llm)  # patch llm
        # 获取语言模型的隐藏层大小作为嵌入维度
        self.embed_dim = self.llm.config.hidden_size
        # 初始化视觉模块
        if self.config.init_vision:
            # 初始化视觉模块并获取视觉嵌入维度
            self.vpm = self.init_vision_module()
            self.vision_dim = self.vpm.embed_dim
            # 初始化重采样器,用于将视觉特征映射到语言模型的嵌入空间
            self.resampler = self.init_resampler(self.embed_dim, self.vision_dim)

        # 初始化音频模块
        if self.config.init_audio:
            # 初始化音频模块
            self.apm = self.init_audio_module()
            # 计算音频输出维度
            audio_output_dim = int(self.apm.config.encoder_ffn_dim // 4)
            # 初始化音频平均池化层
            self.audio_avg_pooler = nn.AvgPool1d(self.config.audio_pool_step, stride=self.config.audio_pool_step)
            # 初始化音频投影层,用于将音频特征映射到语言模型的嵌入空间
            self.audio_projection_layer = MultiModalProjector(in_dim=audio_output_dim, out_dim=self.embed_dim)
            # 设置音频编码器的层数
            self.audio_encoder_layer = -1

        # 初始化文本到语音(TTS)模块
        if self.config.init_tts:
            # 检查是否安装了TTS依赖库
            assert _tts_deps, "please make sure vector_quantize_pytorch and vocos are installed."
            # 初始化TTS模块
            self.tts = self.init_tts_module()
        # 初始化处理器,用于处理输入数据
        self.processor = AutoProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True)
        # 定义终止符,用于标记生成文本的结束
        self.terminators = ["<|im_end|>", "<|endoftext|>"]
        # 定义默认的TTS聊天模板
        self.default_tts_chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n<|spk_bos|><|spk|><|spk_eos|><|tts_bos|>' }}{% endif %}"

        # 强制不停止生成的标志
        self.force_no_stop = False

        # 初始化会话状态,用于流式API
        self.reset_session()

其中视觉模块用的是:model = SiglipVisionTransformer(self.config.vision_config)

其中音频模块用的是:model = MiniCPMWhisperEncoder(self.config.audio_config)

其中self.resampler是一个 2D perceiver-resampler网络,主要目的是能够处理高维输入(如图像、音频等),通过将输入映射到低维的潜在空间来减少计算复杂度,并用于从输入中提取关键信息且重新采样到固定数量的查询(queries)。

其中TTS模块(是一个功能强大的文本到语音模型,支持 LLM 隐藏状态条件和流式生成)用的是:model = ConditionalChatTTS(self.config.tts_config)

2.forward函数

def forward(self, data, **kwargs):
        vllm_embedding, vision_hidden_states = self.get_vllm_embedding(data)

        if self.config.init_audio:
            vllm_embedding = self.get_omni_embedding(
                data, input_embeddings=vllm_embedding, chunk_length=self.config.audio_chunk_length
            )

        position_ids = data["position_ids"]
        if position_ids.dtype != torch.int64:
            position_ids = position_ids.long()

        # compatible with llama factory
        for key in ["input_ids", "inputs_embeds", "position_ids"]:
            if key in kwargs:
                del kwargs[key]

        return self.llm(input_ids=None, position_ids=position_ids, inputs_embeds=vllm_embedding, **kwargs)

其中重要的就是self.get_vllm_embeddingself.get_omni_embedding

其中self.get_vllm_embedding函数输入的参数data,需要先经过MiniCPMVImageProcessor,主要功能是对输入图像进行预处理,将其转换为模型可接受的格式。它支持多种输入格式,并提供了图像切片、归一化、标准化和通道维度调整等功能其实这部分和V2.6都是类似的。

其中的preprocess函数:

def preprocess(
    self,
    images: Union[Image.Image, List[Image.Image], List[List[Image.Image]]],
    do_pad: Optional[bool] = True,
    max_slice_nums: int = None,
    return_tensors: Optional[Union[str, TensorType]] = None,
    **kwargs,
) -> MiniCPMOBatchFeature:
    """
    预处理图像数据,将其转换为模型可接受的格式。
    
    参数:
        images: 输入图像,可以是单张图像、图像列表或嵌套的图像列表。
        do_pad: 是否对图像进行填充(默认为 True)。
        max_slice_nums: 每张图像的最大切片数量(可选)。
        return_tensors: 返回的张量类型(如 "pt" 表示 PyTorch 张量,可选)。
        **kwargs: 其他关键字参数。
    
    返回:
        MiniCPMOBatchFeature: 包含预处理后的图像数据(pixel_values、image_sizes、tgt_sizes)的对象。
    """
    
    # 将输入图像统一转换为嵌套列表格式
    if isinstance(images, Image.Image):  # 单张图像
        images_list = [[images]]
    elif isinstance(images[0], Image.Image):  # 单层图像列表
        images_list = [images]
    else:  # 嵌套图像列表
        images_list = images

    # 初始化存储预处理结果的列表
    new_images_list = []  # 存储预处理后的图像数据
    image_sizes_list = []  # 存储原始图像的尺寸
    tgt_sizes_list = []  # 存储目标尺寸(经过 patch 处理后的尺寸)

    # 遍历每张图像
    for _images in images_list:
        # 如果图像为空,跳过处理
        if _images is None or len(_images) == 0:
            new_images_list.append([])
            image_sizes_list.append([])
            tgt_sizes_list.append([])
            continue

        # 检查图像格式是否有效
        if not valid_images(_images):
            raise ValueError(
                "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
                "torch.Tensor, tf.Tensor or jax.ndarray."
            )

        # 将图像转换为 PIL 格式并统一为 RGB 模式
        _images = [self.to_pil_image(image).convert("RGB") for image in _images]
        
        # 推断图像的通道维度格式(如 ChannelDimension.FIRST 或 ChannelDimension.LAST)
        input_data_format = infer_channel_dimension_format(np.array(_images[0]))

        # 初始化当前图像的预处理结果
        new_images = []  # 存储当前图像的预处理结果
        image_sizes = [image.size for image in _images]  # 存储当前图像的原始尺寸
        tgt_sizes = []  # 存储当前图像的目标尺寸

        # 对每张图像进行切片和预处理
        for image in _images:
            # 将图像切片为多个 patch
            image_patches = self.get_sliced_images(image, max_slice_nums)
            
            # 将图像 patch 转换为 numpy 数组,并归一化到 [0, 1] 范围
            image_patches = [to_numpy_array(image).astype(np.float32) / 255 for image in image_patches]
            
            # 对每个 patch 进行标准化处理
            image_patches = [
                self.normalize(image=image, mean=self.mean, std=self.std, input_data_format=input_data_format)
                for image in image_patches
            ]
            
            # 将通道维度调整为模型需要的格式(如 ChannelDimension.FIRST)
            image_patches = [
                to_channel_dimension_format(image, ChannelDimension.FIRST, input_channel_dim=input_data_format)
                for image in image_patches
            ]
            
            # 对每个 patch 进行 reshape 操作,并计算目标尺寸
            for slice_image in image_patches:
                new_images.append(self.reshape_by_patch(slice_image))
                tgt_sizes.append(
                    np.array((slice_image.shape[1] // self.patch_size, slice_image.shape[2] // self.patch_size))
                )

        # 将目标尺寸堆叠为一个数组
        if tgt_sizes:
            tgt_sizes = np.vstack(tgt_sizes)

        # 将当前图像的预处理结果添加到总列表中
        new_images_list.append(new_images)
        image_sizes_list.append(image_sizes)
        tgt_sizes_list.append(tgt_sizes)

    # 返回预处理后的数据,封装为 MiniCPMOBatchFeature 对象
    return MiniCPMOBatchFeature(
        data={"pixel_values": new_images_list, "image_sizes": image_sizes_list, "tgt_sizes": tgt_sizes_list},
        tensor_type=return_tensors,
    )

视图数据预处理完毕,就需要用视觉模型进行编码,不过依旧需要做很多预处理操作,最终每个图像(切分后)都会用视觉模块(vpm)计算视觉token:

if B > vision_batch_size:
    hs = []
    for i in range(0, B, vision_batch_size):
        start_idx = i
        end_idx = i + vision_batch_size
        tmp_hs = self.vpm(
            all_pixel_values[start_idx:end_idx],
            patch_attention_mask=patch_attn_mask[start_idx:end_idx],
            tgt_sizes=tgt_sizes[start_idx:end_idx],
        ).last_hidden_state
        hs.append(tmp_hs)
    vision_embedding = torch.cat(hs, dim=0)
else:
    vision_embedding = self.vpm(
        all_pixel_values, patch_attention_mask=patch_attn_mask, tgt_sizes=tgt_sizes
    ).last_hidden_state

由于图像编码后token过长,还需要进一步提炼

vision_embedding = self.resampler(vision_embedding, tgt_sizes)

接下来就是Audio编码,输入为梅尔频谱图(batch_size, 80, frames),具体怎么来的可以看processing_minicpmo.py内的MiniCPMOProcessor:

...
        if images is not None:
            image_inputs = self.image_processor(
                images, do_pad=do_pad, max_slice_nums=max_slice_nums, return_tensors=return_tensors
            )
        else:
            image_inputs = None

        if audios is not None:
            audio_features, audio_feature_lens, audio_phs = self.audio_feature_extract(
                audios, audio_parts, chunk_input, sampling_rate
            )
        else:
            audio_features, audio_feature_lens, audio_phs = [], [], []
def get_omni_embedding(self, data, input_embeddings, chunk_length=-1, stream_input=False):
    """
    获取包含音频特征的多模态嵌入。
    
    参数:
        data: 包含音频特征和其他相关数据的字典。
        input_embeddings: 输入的语言模型嵌入。
        chunk_length: 音频分块长度(用于控制是否使用分块注意力机制)。
        stream_input: 是否使用流式音频嵌入(默认为 False)。
    
    返回:
        包含音频特征的最终嵌入。
    """
    
    # 根据是否使用流式输入,获取音频嵌入
    if stream_input:
        audio_embeddings = self.get_audio_embedding_streaming(data)  # 流式音频嵌入
    else:
        audio_embeddings = self.get_audio_embedding(data, chunk_length)  # 普通音频嵌入

    # 获取输入嵌入的批量大小
    bs = len(input_embeddings)

    # 如果数据中包含音频特征
    if len(data.get("audio_features", [])) > 0:
        # 检查音频嵌入和输入嵌入的数量是否一致
        assert len(audio_embeddings) == len(input_embeddings)
        
        # 如果音频嵌入不为空
        if len(audio_embeddings) > 0:
            audio_bounds = data["audio_bounds"]  # 获取音频边界信息

            # 如果配置为分块输入
            if self.config.chunk_input:
                for i in range(bs):
                    # 将音频嵌入拼接为一个张量,并调整设备和数据类型
                    audio_embs = torch.cat(audio_embeddings[i], dim=0).to(
                        device=input_embeddings.device, dtype=input_embeddings.dtype
                    )
                    audio_start_pos = 0
                    
                    # 将音频嵌入分配到输入嵌入的对应位置
                    for bound in audio_bounds[i]:
                        audio_len = bound[1] - bound[0]  # 计算当前音频段的长度
                        input_embeddings[0, bound[0] : bound[1]] = audio_embs[
                            audio_start_pos : audio_start_pos + audio_len, :
                        ]
                        audio_start_pos += audio_len  # 更新音频嵌入的起始位置
            else:
                # 如果不分块输入
                for i in range(bs):
                    audio_embs = audio_embeddings[i]  # 获取当前批次的音频嵌入
                    bounds = audio_bounds[i]  # 获取当前批次的音频边界
                    
                    # 将音频嵌入分配到输入嵌入的对应位置
                    for embs, bound in zip(audio_embs, bounds):
                        audio_indices = torch.arange(bound[0], bound[1], dtype=torch.long).to(
                            input_embeddings.device
                        )
                        
                        # 检查音频嵌入和索引的形状是否匹配
                        if embs.shape[0] != len(audio_indices):
                            raise ValueError(
                                f"Shape mismatch: Trying to assign embeddings of shape {embs.shape} "
                                f"to input indices of length {len(audio_indices)}"
                            )
                        
                        # 将音频嵌入分配到输入嵌入的对应位置
                        input_embeddings[i, audio_indices] = embs.to(input_embeddings.dtype)
    
    # 如果处于训练模式且没有音频特征
    elif self.training:
        for i in range(bs):
            # 使用虚拟音频嵌入(占位符)
            input_embeddings = input_embeddings + audio_embeddings[0].mean() * 0

    # 返回包含音频特征的最终嵌入
    return input_embeddings

        其中get_audio_embedding_streaming支持流式处理(实时、增量),使用 past_key_values 缓存,仅支持 batch_size=1get_audio_embedding非流式处理(批量、离线)

        对于past_key_values 缓存:是 Transformer 模型中的一种缓存机制,主要用于加速自回归生成任务(如文本生成、语音生成等)。它的作用是存储之前计算过的 Key 和 Value 向量,避免在生成新 token 时重复计算历史 token 的 Key 和 Value,从而显著提高推理效率。

        在 get_audio_embedding_streaming 函数中,past_key_values 用于缓存之前音频帧的 Key 和 Value 向量,以便在流式处理中加速推理。具体来说:每次处理新的音频帧时,只需计算当前帧的 Query 向量,并从 past_key_values 中读取历史帧的 Key 和 Value 向量。

        接着一个比较重要的模块就是TTSgenerate_audio,主要是将生成的文本转换为语音,其中应该是先转成梅尔频谱图,在将梅尔频谱图解码为音频波形,具体可以查看以下两个函数:

if use_tts_template and generate_audio:
    mel_spec = self._generate_mel_spec(inputs, outputs, answer)
    wav_numpy, sr = self.decode_mel_to_audio(mel_spec, output_audio_path)
def _generate_mel_spec(self, inputs, outputs, text, output_chunk_size=25, tts_max_new_tokens=2048):
    """
    生成梅尔频谱图(Mel-spectrogram),用于文本到语音(TTS)任务。
    参数:
        inputs: 输入数据,包含文本和其他相关信息。
        outputs: 模型输出,包含语言模型的隐藏状态等。
        text: 输入的文本。
        output_chunk_size: 每次生成的音频 token 数量(默认为 25)。
        tts_max_new_tokens: 最大生成的音频 token 数量(默认为 2048)。
    返回:
        mel_spec: 生成的梅尔频谱图。
    """
    # 获取说话者嵌入(speaker embeddings)
    spk_embeds = self._get_last_spk_embeds(inputs, outputs)
    # 提取 TTS 相关的文本部分
    text = text.split("<|tts_bos|>")[-1]  # 从 <|tts_bos|> 开始提取
    gen_text = text.split("<|tts_eos|>")[0]  # 提取到 <|tts_eos|> 之前的部分
    # 准备 TTS 文本并编码为 token
    tts_text, tts_token_lens = self.prepare_tts_text(gen_text)
    tts_inputs = self.tts_processor.text_tokenizer.encode(tts_text, add_special_tokens=False)
    tts_input_ids = torch.Tensor(tts_inputs).unsqueeze(0).to("cuda", dtype=torch.long)
    # 构建流式文本掩码
    streaming_tts_text_mask = self._build_streaming_mask(tts_token_lens).to(device=self.tts.device)

    # 配置生成过程中的 logits 处理器和 warper
    logits_warpers, logits_processors = gen_logits(
        num_code=626, top_P=self.tts.top_p, top_K=self.tts.top_k, repetition_penalty=self.tts.repetition_penalty
    )

    # 计算条件长度(包括说话者嵌入和流式文本预留长度)
    condition_length = (
        1 + self.tts.use_speaker_embedding * self.tts.num_spk_embs + self.tts.streaming_text_reserved_len + 1
    )

    # 初始化嵌入和 past_key_values
    dtype = self.tts.emb_text.weight.dtype
    emb = torch.zeros(1, condition_length, self.tts.num_vq, dtype=dtype, device=self.tts.device)
    past_key_values = [
        (
            torch.zeros(
                1,
                self.tts.config.num_attention_heads,
                condition_length - 1,
                self.tts.config.hidden_size // self.tts.config.num_attention_heads,
                dtype=emb.dtype,
                device=self.tts.device,
            ),
            torch.zeros(
                1,
                self.tts.config.num_attention_heads,
                condition_length - 1,
                self.tts.config.hidden_size // self.tts.config.num_attention_heads,
                dtype=emb.dtype,
                device=self.tts.device,
            ),
        )
        for _ in range(self.tts.config.num_hidden_layers)
    ]
    # 初始化音频输入 token
    audio_input_ids = torch.zeros(1, condition_length, self.tts.num_vq, dtype=torch.long, device=self.tts.device)
    # 生成音频 token
    eos_lab = False  # 是否生成结束标记
    for chunk_idx in range(math.ceil(emb.shape[1] / self.tts.streaming_text_chunk_size)):
        # 计算当前块的起始和结束位置
        if chunk_idx == 0:
            begin = chunk_idx * self.tts.streaming_text_chunk_size + 0
            end = (
                (chunk_idx + 1) * self.tts.streaming_text_chunk_size
                + 1
                + self.tts.use_speaker_embedding * self.tts.num_spk_embs
            )
        else:
            begin = (
                chunk_idx * self.tts.streaming_text_chunk_size
                + 1
                + self.tts.use_speaker_embedding * self.tts.num_spk_embs
            )
            end = min(
                (chunk_idx + 1) * self.tts.streaming_text_chunk_size
                + 1
                + self.tts.use_speaker_embedding * self.tts.num_spk_embs,
                condition_length - 1,
            )

        # 处理当前块的文本 token
        if end - begin > 0:
            text_input_ids = tts_input_ids[:, begin:end]
            position_ids = torch.arange(begin, end, dtype=torch.long, device=self.tts.device).unsqueeze(0)

            # 预填充文本 token 并更新 past_key_values
            if begin == 0:
                past_key_values = self.tts.prefill_text(
                    input_ids=text_input_ids,
                    position_ids=position_ids,
                    past_key_values=past_key_values,
                    lm_spk_emb_last_hidden_states=spk_embeds,
                )
            else:
                past_key_values = self.tts.prefill_text(
                    input_ids=text_input_ids, position_ids=position_ids, past_key_values=past_key_values
                )

        # 生成音频 token
        outputs = self.tts.generate(
            input_ids=audio_input_ids,
            past_key_values=past_key_values,
            streaming_tts_text_mask=streaming_tts_text_mask,
            max_new_token=output_chunk_size,
            force_no_stop=self.force_no_stop,
            temperature=torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.tts.device),
            eos_token=torch.tensor([625], dtype=torch.long, device=self.tts.device),
            logits_warpers=logits_warpers,
            logits_processors=logits_processors,
        )
        audio_input_ids = outputs.audio_input_ids
        past_key_values = outputs.past_key_values

        # 检查是否生成结束标记
        if outputs.finished:
            logger.debug("Generation finished.")
            eos_lab = True
            break

    # 如果未生成结束标记,则继续生成
    if not eos_lab:
        logger.debug("eos_lab False, Generation continue.")
        while True:
            outputs = self.tts.generate(
                input_ids=audio_input_ids,
                past_key_values=past_key_values,
                streaming_tts_text_mask=streaming_tts_text_mask,
                max_new_token=output_chunk_size,
                force_no_stop=self.force_no_stop,
                temperature=torch.tensor([0.1, 0.3, 0.1, 0.3], dtype=torch.float, device=self.tts.device),
                eos_token=torch.tensor([625], dtype=torch.long, device=self.tts.device),
                logits_warpers=logits_warpers,
                logits_processors=logits_processors,
            )

            audio_input_ids = outputs.audio_input_ids
            past_key_values = outputs.past_key_values

            # 检查是否生成结束标记或达到最大 token 数量
            if outputs.finished:
                logger.debug("Generation finished.")
                break
            if outputs.new_ids.shape[1] > tts_max_new_tokens:
                logger.debug(f"Generation length > {tts_max_new_tokens}, stopped.")
                break

    # 将生成的音频 token 解码为梅尔频谱图
    mel_spec = self.tts.decode_to_mel_specs(outputs.new_ids)
    return mel_spec
    
    
def decode_mel_to_audio(self, mel_spec, output_path=""):
    """
    将梅尔频谱图(Mel-spectrogram)解码为音频波形,并保存到指定路径(如果提供了路径)。
    参数:
        mel_spec: 输入的梅尔频谱图,形状为 (1, num_mels, time_steps)。
        output_path: 音频文件的保存路径(可选)。如果未提供路径,则仅返回音频数据。
    返回:
        wav_numpy: 解码后的音频波形数据(NumPy 数组)。
        sr: 音频的采样率(默认为 24000 Hz)。
    """
    # 使用 torch.inference_mode 上下文管理器,禁用梯度计算以加速推理
    with torch.inference_mode():
        # 使用 vocos 模型将梅尔频谱图解码为音频波形
        wav_numpy = self.vocos.decode(mel_spec.float()).cpu().squeeze()
        # 设置采样率
        sr = 24000
    # 如果提供了输出路径,则将音频保存为文件
    if output_path:
        # 使用 soundfile 库将音频数据写入文件
        sf.write(output_path, wav_numpy.numpy(), samplerate=sr)
        # 记录日志,提示音频已保存
        logger.info(f"Audio saved to {output_path}")
    # 返回解码后的音频波形数据和采样率
    return wav_numpy, sr

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值