past_key_values在P-TuningV2中的巧用

背景

目前HuggingFace发布了关于微调LLMs的方法包——Parameter-Efficient Fine-Tuning(PEFT),其中包含下面6种方法:

  1. LoRA: LORA: LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS
  2. Prefix Tuning: Prefix-Tuning: Optimizing Continuous Prompts for Generation, P-Tuning v2: Prompt Tuning Can Be Comparable to Fine-tuning Universally Across Scales and Tasks
  3. P-Tuning: GPT Understands, Too
  4. Prompt Tuning: The Power of Scale for Parameter-Efficient Prompt Tuning
  5. AdaLoRA: Adaptive Budget Allocation for Parameter-Efficient Fine-Tuning
  6. LLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention

此外也列出了该包对不同的任务中,不同方法和模型的支持情况(我只列出了关于NLP的,还有部分图像的):

image-20230607154231676 image-20230607154941828 image-20230607155331943 image-20230607155407529

但是还没有P-Tuning v2: Prompt Tuning Can Be Comparable to Fine-tuning Universally Across Scales and Tasks的方法,因此我就看源码是怎么处理的。

在研究和阅读其他人blog期间,发现有些人对P-Tuning描述不准确。下面是一些存在的不准确描述情况❎:

  • 把Prompt Tuning称作是P-Tuning ❎;
  • 使用P-Tuning V2的时候,直接称作是P-Tuning❎;
  • 使用了P-Tuning,但是却说是使用了P-TuningV2❎
  • ……

因此需要注意甄别(主要是P-Tuning和Prompt-Tuning的方法提出时间就差了一个月,并且在方法上有一定的相似性,都是在Embedding中使用了continuous prompt)

Prompt-Tuning一文中对两者的不同做了说明:

P-tuning” where learnable continuous prompts are interleaved throughout the embedded input, using patterns based on human design. Prompt tuning approach removes this complication by simply prepending the prompt to the input. To achieve strong SuperGLUE results, P-tuning has to be used in conjunction with model tuning, that is, models jointly update both the prompt and the main model parameters【也就是全量微调,但是文章的论点是GPT在NLU任务上的表现,并且还突出few-shot场景下效果】, whereas our approach keeps the original language model frozen” “As another difference, P-tuning requires the addition of “anchor” tokens in the input (e.g. a question mark following the hypothesis in the RTE task) to achieve strong performance, while prompt tuning leaves inputs untouched.”

P-Tuning V2源码定位

这里以run_script/run_rte_roberta.sh为例,下面是代码:

export TASK_NAME=superglue
export DATASET_NAME=rte
export CUDA_VISIBLE_DEVICES=0

bs=32
lr=5e-3
dropout=0.1
psl=128
epoch=100

python3 run.py \
  --model_name_or_path roberta-large \
  --task_name $TASK_NAME \
  --dataset_name $DATASET_NAME \
  --do_train \
  --do_eval \
  --max_seq_length 128 \
  --per_device_train_batch_size $bs \
  --learning_rate $lr \
  --num_train_epochs $epoch \
  --pre_seq_len $psl \
  --output_dir checkpoints/$DATASET_NAME-roberta/ \
  --overwrite_output_dir \
  --hidden_dropout_prob $dropout \
  --seed 11 \
  --save_strategy no \
  --evaluation_strategy epoch \
  --prefix

通过查看arguments.py文件可以查看到prefix参数是如下定义的:

 prefix: bool = field(
        default=False,
        metadata={
   
   
            "help": "Will use P-tuning v2 during training"
        }
    )

【为什么非要用prefix呢?我一开始以为这个参数是使用prefix Tuning呢QAQ】

因为shell脚本中的task=superglue,可以看到第96行中引用了get_trainer

if data_args.task_name.lower() == "superglue":
  assert data_args.dataset_name.lower() in SUPERGLUE_DATASETS
  from tasks.superglue.get_trainer import get_trainer

最终定位到了BertPrefixForQuestionAnswering类,可以看到P-Tuning V2的关键入口代码:

class BertPrefixForQuestionAnswering(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.pre_seq_len = config.pre_seq_len
        self.n_layer = config.num_hidden_layers
        self.n_head = config.num_attention_heads
        self.n_embd = config.hidden_size // config.num_attention_heads

        self.bert = BertModel(config, add_pooling_layer=False)
        self.qa_outputs = torch.nn.Linear(config.hidden_size, config.num_labels)
        self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
        self.prefix_enc
输入名称: input_ids 输入名称: attention_mask 输入名称: position_ids 输入名称: past_key_values.0.key 输入名称: past_key_values.0.value 输入名称: past_key_values.1.key 输入名称: past_key_values.1.value 输入名称: past_key_values.2.key 输入名称: past_key_values.2.value 输入名称: past_key_values.3.key 输入名称: past_key_values.3.value 输入名称: past_key_values.4.key 输入名称: past_key_values.4.value 输入名称: past_key_values.5.key 输入名称: past_key_values.5.value 输入名称: past_key_values.6.key 输入名称: past_key_values.6.value 输入名称: past_key_values.7.key 输入名称: past_key_values.7.value 输入名称: past_key_values.8.key 输入名称: past_key_values.8.value 输入名称: past_key_values.9.key 输入名称: past_key_values.9.value 输入名称: past_key_values.10.key 输入名称: past_key_values.10.value 输入名称: past_key_values.11.key 输入名称: past_key_values.11.value 输入名称: past_key_values.12.key 输入名称: past_key_values.12.value 输入名称: past_key_values.13.key 输入名称: past_key_values.13.value 输入名称: past_key_values.14.key 输入名称: past_key_values.14.value 输入名称: past_key_values.15.key 输入名称: past_key_values.15.value 输入名称: past_key_values.16.key 输入名称: past_key_values.16.value 输入名称: past_key_values.17.key 输入名称: past_key_values.17.value 输入名称: past_key_values.18.key 输入名称: past_key_values.18.value 输入名称: past_key_values.19.key 输入名称: past_key_values.19.value 输入名称: past_key_values.20.key 输入名称: past_key_values.20.value 输入名称: past_key_values.21.key 输入名称: past_key_values.21.value 输入名称: past_key_values.22.key 输入名称: past_key_values.22.value 输入名称: past_key_values.23.key 输入名称: past_key_values.23.value Traceback (most recent call last): File "/home/yejianxiong.yjx/test_qwen/test_cpu_v3.py", line 68, in <module> print(generate_response("你好", conversation_history)) File "/home/yejianxiong.yjx/test_qwen/test_cpu_v3.py", line 45, in generate_response outputs = session.run( File "/home/yejianxiong.yjx/test_qwen/test_env/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 269, in run self._validate_input(list(input_feed.keys())) File "/home/yejianxiong.yjx/test_qwen/test_env/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 251, in _validate_input raise ValueError( ValueError: Required inputs (['past_key_values.0.key', 'past_key_values.0.value', 'past_key_values.1.key', 'past_key_values.1.value', 'past_key_values.2.key', 'past_key_values.2.value', 'past_key_values.3.key', 'past_key_values.3.value', 'past_key_values.4.key', 'past_key_values.4.value', 'past_key_values.5.key', 'past_key_values.5.value', 'past_key_values.6.key', 'past_key_values.6.value', 'past_key_values.7.key', 'past_key_values.7.value', 'past_key_values.8.key', 'past_key_values.8.value', 'past_key_values.9.key', 'past_key_values.9.value', 'past_key_values.10.key', 'past_key_values.10.value', 'past_key_values.11.key', 'past_key_values.11.value', 'past_key_values.12.key', 'past_key_values.12.value', 'past_key_values.13.key', 'past_key_values.13.value', 'past_key_values.14.key', 'past_key_values.14.value', 'past_key_values.15.key', 'past_key_values.15.value', 'past_key_values.16.key', 'past_key_values.16.value', 'past_key_values.17.key', 'past_key_values.17.value', 'past_key_values.18.key', 'past_key_values.18.value', 'past_key_values.19.key', 'past_key_values.19.value', 'past_key_values.20.key', 'past_key_values.20.value', 'past_key_values.21.key', 'past_key_values.21.value', 'past_key_values.22.key', 'past_key_values.22.value', 'past_key_values.23.key', 'past_key_values.23.value']) are missing from input feed (['input_ids', 'attention_mask', 'position_ids']).
最新发布
01-06
Traceback (most recent call last): File "/home/yejianxiong.yjx/test_qwen/test_cpu_v3.py", line 64, in <module> print(generate_response("你好", conversation_history)) File "/home/yejianxiong.yjx/test_qwen/test_cpu_v3.py", line 42, in generate_response outputs = session.run( File "/home/yejianxiong.yjx/test_qwen/test_env/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 269, in run self._validate_input(list(input_feed.keys())) File "/home/yejianxiong.yjx/test_qwen/test_env/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 251, in _validate_input raise ValueError( ValueError: Required inputs (['position_ids', 'past_key_values.0.key', 'past_key_values.0.value', 'past_key_values.1.key', 'past_key_values.1.value', 'past_key_values.2.key', 'past_key_values.2.value', 'past_key_values.3.key', 'past_key_values.3.value', 'past_key_values.4.key', 'past_key_values.4.value', 'past_key_values.5.key', 'past_key_values.5.value', 'past_key_values.6.key', 'past_key_values.6.value', 'past_key_values.7.key', 'past_key_values.7.value', 'past_key_values.8.key', 'past_key_values.8.value', 'past_key_values.9.key', 'past_key_values.9.value', 'past_key_values.10.key', 'past_key_values.10.value', 'past_key_values.11.key', 'past_key_values.11.value', 'past_key_values.12.key', 'past_key_values.12.value', 'past_key_values.13.key', 'past_key_values.13.value', 'past_key_values.14.key', 'past_key_values.14.value', 'past_key_values.15.key', 'past_key_values.15.value', 'past_key_values.16.key', 'past_key_values.16.value', 'past_key_values.17.key', 'past_key_values.17.value', 'past_key_values.18.key', 'past_key_values.18.value', 'past_key_values.19.key', 'past_key_values.19.value', 'past_key_values.20.key', 'past_key_values.20.value', 'past_key_values.21.key', 'past_key_values.21.value', 'past_key_values.22.key', 'past_key_values.22.value', 'past_key_values.23.key', 'past_key_values.23.value']) are missing from input feed (['input_ids', 'attention_mask']).
01-06
评论 3
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值