环境:
LLaMaFactory:0.9.1
transformers:4.46.1
大模型:GLM-4-9B-chat
错误:
通过LLaMAFactory进行Evaluate& Predict时,报错。
[rank0]: Traceback (most recent call last):
[rank0]: File "/data/dms/LLaMA-Factory-0.9.1/src/llamafactory/launcher.py", line 23, in <module>
[rank0]: launch()
[rank0]: File "/data/dms/LLaMA-Factory-0.9.1/src/llamafactory/launcher.py", line 19, in launch
[rank0]: run_exp()
[rank0]: File "/data/dms/LLaMA-Factory-0.9.1/src/llamafactory/train/tuner.py", line 50, in run_exp
[rank0]: run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
[rank0]: File "/data/dms/LLaMA-Factory-0.9.1/src/llamafactory/train/sft/workflow.py", line 127, in run_sft
[rank0]: predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs)
[rank0]: File "/data/dms/minconda3/envs/factory091/lib/python3.10/site-packages/transformers/trainer_seq2seq.py", line 259, in predict
[rank0]: return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
[rank0]: File "/data/dms/minconda3/envs/factory091/lib/python3.10/site-packages/transformers/trainer.py", line 4042, in predict
[rank0]: output = eval_loop(
[rank0]: File "/data/dms/minconda3/envs/factory091/lib/python3.10/site-packages/transformers/trainer.py", line 4158, in evaluation_loop
[rank0]: losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
[rank0]: File "/data/dms/LLaMA-Factory-0.9.1/src/llamafactory/train/sft/trainer.py", line 121, in prediction_step
[rank0]: loss, generated_tokens, _ = super().prediction_step( # ignore the returned labels (may be truncated)
[rank0]: File "/data/dms/minconda3/envs/factory091/lib/python3.10/site-packages/transformers/trainer_seq2seq.py", line 331, in prediction_step
[rank0]: generated_tokens = self.model.generate(**generation_inputs, **gen_kwargs)
[rank0]: File "/data/dms/minconda3/envs/factory091/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]: return func(*args, **kwargs)
[rank0]: File "/data/dms/minconda3/envs/factory091/lib/python3.10/site-packages/transformers/generation/utils.py", line 2215, in generate
[rank0]: result = self._sample(
[rank0]: File "/data/dms/minconda3/envs/factory091/lib/python3.10/site-packages/transformers/generation/utils.py", line 3209, in _sample
[rank0]: model_kwargs = self._update_model_kwargs_for_generation(
[rank0]: File "/root/.cache/huggingface/modules/transformers_modules/glm-4-9b-chat/modeling_chatglm.py", line 939, in _update_model_kwargs_for_generation
[rank0]: model_kwargs["past_key_values"] = self._extract_past_from_model_output(
[rank0]: TypeError: GenerationMixin._extract_past_from_model_output() got an unexpected keyword argument 'standardize_cache_format'
原因:
大模型的transformers版本太低导致,通过config.json文件发现,transformers是4.40.2版本,升级到4.44.0就可以。从
魔搭社区下载最新的配置文件就可,包括如下7个文件。
你会发现modeling_chatglm.py的内容有变化,新版本没有standardize_cache_format参数。
4.40.2版本
4.44.0版本