Coggle数据科学 | 小白学大模型:LogitsProcessor 文本可控生成

本文来源公众号“Coggle数据科学”,仅用于学术分享,侵权删,干货满满。

原文链接:小白学大模型:LogitsProcessor 文本可控生成

大模型的输出往往难以精准把控。它们在生成文本时,仿佛拥有自己的“想法”,并不总是按照人类预期的轨迹前行。

  • 大模型的生成逻辑是基于概率统计的,并完全按照提示词行动;

  • 大模型会生成冗长且偏离主题的内容;

  • 大模型会生成与指定格式存在差异的内容;

LogitsProcessorZoo 简介

https://github.com/NVIDIA/logits-processor-zoo

LogitsProcessorZoo 是一个强大的工具库,它为大语言模型的输出控制提供了多种实用的 logits 处理器(logits processor)。这些处理器能够在模型生成文本的过程中,对 logits(即模型输出的原始概率分布)进行调整,从而引导模型生成更符合用户需求的文本。

pip install logits-processor-zoo

目前,LogitsProcessorZoo 支持以下主流框架:

  • transformers:Hugging Face 的 transformers 库,广泛应用于自然语言处理任务。

  • vLLM:一个高效的推理框架,专注于大规模语言模型的快速部署。

  • TensorRT-LLM:NVIDIA 的 TensorRT-LLM,利用 TensorRT 的优化能力,为语言模型提供高性能推理支持。

使用案例

import vllm
from logits_processor_zoo.vllm import GenLengthLogitsProcessor, CiteFromPromptLogitsProcessor, ForceLastPhraseLogitsProcessor

model = vllm.LLM(
            model_name,
            trust_remote_code=True,
            dtype="half",
            enforce_eager=True
        )
tokenizer = model.get_tokenizer()
        
logits_processors = [
    CiteFromPromptLogitsProcessor(tokenizer, boost_factor=2.0),
    GenLengthLogitsProcessor(tokenizer, boost_factor=-0.2, p=1),
    ForceLastPhraseLogitsProcessor("\n\nReferences:\n", tokenizer)
]


gen_output = model.generate(
            prompts,
            vllm.SamplingParams(
                n=1,
                temperature=0,
                seed=0,
                skip_special_tokens=True,
                max_tokens=64,
                logits_processors=logits_processors
            ),
            use_tqdm=False
        )

实现原理

LogitsProcessorZoo 提供了多种 logits 处理器,每种处理器都有其独特的功能。

GenLengthLogitsProcessor

GenLengthLogitsProcessor 是一个 logits 处理器,用于调整生成序列的长度。它的核心思想是通过动态调整 EOS(End-of-Sequence,序列结束)标记的概率,来控制生成文本的长度。

  • boost_factor:控制 EOS 概率的调整方向和强度。如果 boost_factor 为正,随着生成序列的长度增加,EOS 的概率会逐渐增加,从而鼓励生成更短的文本;如果 boost_factor 为负,EOS 的概率会逐渐减少,从而鼓励生成更长的文本。

  • p:控制长度调整的强度。p 是一个幂次参数,决定了生成序列长度对 EOS 概率调整的影响程度。p 值越大,长度对 EOS 概率的调整越明显。

  • token_count:记录当前生成的标记数量,用于计算 EOS 概率的调整值。

class GenLengthLogitsProcessor:
    def __init__(self, tokenizer: PreTrainedTokenizer, boost_factor: float,
                 p: int = 2, complete_sentences: bool = False):
        self.eos_token = tokenizer.eos_token_id
        self.boost_factor = boost_factor
        self.p = p
        self.token_count = 0
        self.full_stop_token = text_to_token(tokenizer, "It is a sentence.", last=True)
        self.new_line_token = text_to_token(tokenizer, "It is a new line\n", last=True)
        self.complete_sentences = complete_sentences

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.Tensor:
        boost_val = self.boost_factor * (self.token_count ** self.p) / (10 ** self.p)

        if self.complete_sentences:
            enabled = (input_ids[:, -1] == self.full_stop_token) | (input_ids[:, -1] == self.new_line_token)
            scores[:, self.eos_token] += enabled * boost_val
        else:
            scores[:, self.eos_token] += boost_val
        self.token_count += 1
        return scores

CiteFromPromptLogitsProcessor

CiteFromPromptLogitsProcessor 是一个 logits 处理器,用于调整模型生成文本时对提示(prompt)内容的引用概率。它的核心思想是通过动态调整提示中出现的标记(tokens)的概率,来引导模型生成与提示相似或相反的内容。

在初始化时,CiteFromPromptLogitsProcessor 会将每个提示(prompt)中的标记提取出来,并存储为一个集合。如果设置了 boost_eos 参数为 True,还会将 EOS(End-of-Sequence)标记加入到集合中。

在生成过程中,CiteFromPromptLogitsProcessor 会根据 boost_factor 参数调整提示中出现的标记的概率:

  • 正值:增加提示中出现的标记的概率,鼓励模型生成与提示相似的内容。

  • 负值:减少提示中出现的标记的概率,鼓励模型生成与提示不同的内容。

class CiteFromPromptLogitsProcessor:
    def __init__(self, tokenizer: PreTrainedTokenizer, prompts: List[str], boost_factor: float = 1.0,
                 boost_eos: bool = True):
        self.boost_factor = boost_factor

        self.boost_ids = []
        for prompt in prompts:
            prompt_tokens = set(tokenizer.encode(prompt))

            if boost_eos:
                prompt_tokens.add(tokenizer.eos_token_id)

            self.boost_ids.append(list(prompt_tokens))

    def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor:
        for i in range(scores.shape[0]):
            scores[i, self.boost_ids[i]] += self.boost_factor
        return scores

ForceLastPhraseLogitsProcessor

ForceLastPhraseLogitsProcessor 是一个 logits 处理器,用于强制语言模型在生成文本结束之前,插入一个指定的短语(phrase)。这种机制特别适用于需要在生成文本末尾添加特定内容的场景,例如提供引用、感谢用户、总结等。

在初始化时,ForceLastPhraseLogitsProcessor 会将指定的短语(phrase)通过分词器(tokenizer)转换为标记(tokens)。这些标记将被依次插入到生成文本中。

在生成过程中,ForceLastPhraseLogitsProcessor 会动态调整 logits,确保模型在生成文本结束之前,逐步插入指定短语的标记:

  • 迭代器(iterators:每个批次(batch)中的每个生成任务都有一个独立的迭代器,用于跟踪当前需要插入的标记位置。

  • 强制插入:当模型即将结束生成(即下一个标记是 EOS)时,处理器会强制模型生成指定短语的第一个标记。随后,逐步引导模型生成短语中的后续标记,直到短语完全插入。

class ForceLastPhraseLogitsProcessor:
    def __init__(self, phrase: str, tokenizer: PreTrainedTokenizer, batch_size: int):
        self.eos_token_id = tokenizer.eos_token_id
        self.phrase_tokens = tokenizer.encode(phrase, add_special_tokens=False)
        self.iterators = torch.zeros(batch_size, dtype=torch.int32)

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.Tensor:
        for i in range(scores.shape[0]):
            it = self.iterators[i].item()
            if scores[i, :].argmax() == self.eos_token_id and it == 0:
                scores[i, self.phrase_tokens[it]] = scores[i].max() + 1
                self.iterators[i] += 1
            elif len(self.phrase_tokens) > it > 0:
                scores[i, self.phrase_tokens[it]] = scores[i].max() + 1
                self.iterators[i] += 1

        return scores

MultipleChoiceLogitsProcessor

MultipleChoiceLogitsProcessor 是一个 logits 处理器,专门用于处理多项选择题,引导语言模型从给定的选项中选择一个答案。它的核心思想是通过调整 logits,使模型更倾向于生成正确的选项标记(如 "1"、"2"、"3" 等)。

在初始化时,MultipleChoiceLogitsProcessor 会将每个选项(choices)通过分词器(tokenizer)转换为标记(tokens)。这些标记将被用于后续的 logits 调整。

在生成过程中,MultipleChoiceLogitsProcessor 会动态调整 logits,确保模型更倾向于生成选项标记:

  • boost_first_words:如果设置了 boost_first_words 参数,处理器会识别选项的起始位置,并对选项的第一个标记的概率进行调整,以增强模型对选项的识别能力。

  • 强制选择:在每次生成时,处理器会将选项标记的概率提升到一个非常高的值(very_large_number),从而确保模型生成这些选项标记。

class MultipleChoiceLogitsProcessor:
    def __init__(self, tokenizer: PreTrainedTokenizer, choices: List[str] = None,
                 delimiter: str = ".", boost_first_words: float = 0.0):
        if choices is None:
            choices = ["1", "2", "3", "4"]

        self.new_line_tokens = get_new_line_tokens(tokenizer)
        self.delimiter_token = text_to_token(tokenizer, delimiter, last=False)
        self.choice_tokens = [text_to_token(tokenizer, choice, last=False) for choice in choices]
        self.boost_first_words = boost_first_words
        self.very_large_number = 999

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.Tensor:
        for row_ind in range(input_ids.shape[0]):
            if self.boost_first_words:
                choice = 0

                first_tokens = []
                for i in range(len(input_ids[row_ind]) - 3):
                    # A choice is like "\nA) hair dryer", where first token is "hair"
                    choice_starts = (
                            (input_ids[row_ind, i].item() in self.new_line_tokens) and
                            (input_ids[row_ind, i + 1] == self.choice_tokens[choice]) and
                            (input_ids[row_ind, i + 2] == self.delimiter_token)
                    )

                    if choice_starts:
                        first_tokens.append(input_ids[row_ind, i + 3])
                        choice += 1

                        if choice >= len(self.choice_tokens):
                            break

                boost = self.boost_first_words * scores[row_ind, first_tokens]
                scores[row_ind, self.choice_tokens[:len(first_tokens)]] += boost

        scores[:, self.choice_tokens] += self.very_large_number

        return scores

THE END !

文章结束,感谢阅读。您的点赞,收藏,评论是我继续更新的动力。大家有推荐的公众号可以评论区留言,共同学习,一起进步。

### HAL_TIM_PeriodElapsedCallback 函数功能与用法 #### 1. 功能描述 `HAL_TIM_PeriodElapsedCallback` 是 STM32 HAL 库中的回调函数,用于处理定时器周期结束事件。当定时器的计数值达到设定的最大值并触发更新事件时,该回调函数会被调用[^1]。 此函数的主要作用是在中断服务程序中被自动调用,允许用户在不修改底层驱动的情况下实现自定义逻辑。它通常用来响应特定的时间间隔到达后的动作,例如刷新数据、切换状态或其他实时任务调度[^2]。 --- #### 2. 定义形式 以下是 `HAL_TIM_PeriodElapsedCallback` 的典型定义: ```c void HAL_TIM_PeriodElapsedCallback(TIM_HandleTypeDef *htim) { // 用户可以在此处编写自己的代码来处理定时器周期溢出事件 } ``` - **参数说明** - `TIM_HandleTypeDef *htim`: 这是一个指向定时器句柄结构体的指针,包含了配置和运行状态的信息。通过这个句柄,可以在回调函数内部访问当前定时器的相关属性或重新设置其行为。 --- #### 3. 使用方法 为了使能这一回调机制,需完成以下几个步骤: 1. 初始化定时器:利用 `HAL_TIM_Base_Init` 或其他初始化接口完成硬件资源分配以及基础参数配置(如预分频系数、计数器周期等)。 2. 启动带中断模式的定时器:调用 `HAL_TIM_Base_Start_IT(htim)` 来开启定时器及其关联的中断请求。这一步会启用相应的中断线,并注册默认的中断服务例程(ISR)[^1]。 3. 实现回调函数:根据实际需求重写 `HAL_TIM_PeriodElapsedCallback` 方法的内容。每当发生一次完整的计数循环后,即进入下一轮计数前,都会跳转到此处执行指定的操作[^3]。 4. 清除标志位/中断挂起比特 (可选): 如果需要手动管理某些特殊类型的干扰信号,则可能还需要借助宏指令如 __HAL_TIM_CLEAR_IT() 对应位置零操作。 --- #### 示例代码片段 下面展示了一个简单的应用案例——每秒钟点亮 LED 一次: ```c #include "stm32f4xx_hal.h" // 假设已正确设置了 GPIO 和 TIM 句柄 htim2 uint8_t led_state = 0; void HAL_TIM_PeriodElapsedCallback(TIM_HandleTypeDef *htim){ if(htim->Instance == TIM2){ // 判断是否来自 TIM2 中断 if(led_state == 0){ HAL_GPIO_WritePin(GPIOA, GPIO_PIN_5, GPIO_PIN_SET); // 打开LED led_state = 1; } else { HAL_GPIO_WritePin(GPIOA, GPIO_PIN_5, GPIO_PIN_RESET); // 关闭LED led_state = 0; } } } int main(void){ /* MCU Initialization */ // 配置GPIO PA5作为输出端口 // 设置 TIM2 参数 TIM_HandleTypeDef timHandle; timHandle.Instance = TIM2; timHandle.Init.Prescaler = 8399; // 设定预分频值使得频率接近1KHz timHandle.Init.CounterMode = TIM_COUNTERMODE_UP; timHandle.Init.Period = 9999; // 计数至最大值约等于一秒 timHandle.Init.ClockDivision = TIM_CLOCKDIVISION_DIV1; if(HAL_TIM_Base_Init(&timHandle) != HAL_OK){ Error_Handler(); } // 开启 IT 模式的定时器 HAL_TIM_Base_Start_IT(&timHandle); while(1); } ``` 上述例子展示了如何结合外部设备控制形成规律性的脉冲序列。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值