```
@torch.no_grad()
2226 def generate(
2227 self,
2228 inputs: Optional[torch.Tensor] = None,
2229 generation_config: Optional[GenerationConfig] = None,
2230 logits_processor: Optional[LogitsProcessorList] = None,
2231 stopping_criteria: Optional[StoppingCriteriaList] = None,
2232 prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
2233 synced_gpus: Optional[bool] = None,
2234 assistant_model: Optional["PreTrainedModel"] = None,
2235 streamer: Optional["BaseStreamer"] = None,
2236 negative_prompt_ids: Optional[torch.Tensor] = None,
2237 negative_prompt_attention_mask: Optional[torch.Tensor] = None,
2238 use_model_defaults: Optional[bool] = None,
2239 custom_generate: Optional[str] = None,
2240 **kwargs,
2241 ) -> Union[GenerateOutput, torch.LongTensor]:
2242 r"""
2243
2244 Generates sequences of token ids for models with a language modeling head.
2245
2246 <Tip warning={true}>
2247
2248 Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
2249 model's default generation configuration. You can override any `generation_config` by passing the corresponding
2250 parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
2251
2252 For an overview of generation strategies and code examples, check out the [following
2253 guide](../generation_strategies).
2254
2255 </Tip>
2256
2257 Parameters:
2258 inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
2259 The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
2260 method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
2261 should be in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of
2262 `input_ids`, `input_values`, `input_features`, or `pixel_values`.
2263 generation_config ([`~generation.GenerationConfig`], *optional*):
2264 The generation configuration to be used as base parametrization for the generation call. `**kwargs`
2265 passed to generate matching the attributes of `generation_config` will override them. If
2266 `generation_config` is not provided, the default will be used, which has the following loading
2267 priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
2268 configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
2269 default values, whose documentation should be checked to parameterize generation.
2270 logits_processor (`LogitsProcessorList`, *optional*):
2271 Custom logits processors that complement the default logits processors built from arguments and
2272 generation config. If a logit processor is passed that is already created with the arguments or a
2273 generation config an error is thrown. This feature is intended for advanced users.
2274 stopping_criteria (`StoppingCriteriaList`, *optional*):
2275 Custom stopping criteria that complements the default stopping criteria built from arguments and a
2276 generation config. If a stopping criteria is passed that is already created with the arguments or a
2277 generation config an error is thrown. If your stopping criteria depends on the `scores` input, make
2278 sure you pass `return_dict_in_generate=True, output_scores=True` to `generate`. This feature is
2279 intended for advanced users.
2280 prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
2281 If provided, this function constraints the beam search to allowed tokens only at each step. If not
2282 provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
2283 `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
2284 on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
2285 for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
2286 Retrieval](https://arxiv.org/abs/2010.00904).
2287 synced_gpus (`bool`, *optional*):
2288 Whether to continue running the while loop until max_length. Unless overridden, this flag will be set
2289 to `True` if using `FullyShardedDataParallel` or DeepSpeed ZeRO Stage 3 with multiple GPUs to avoid
2290 deadlocking if one GPU finishes generating before other GPUs. Otherwise, defaults to `False`.
2291 assistant_model (`PreTrainedModel`, *optional*):
2292 An assistant model that can be used to accelerate generation. The assistant model must have the exact
2293 same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistant model
2294 is much faster than running generation with the model you're calling generate from. As such, the
2295 assistant model should be much smaller.
2296 streamer (`BaseStreamer`, *optional*):
2297 Streamer object that will be used to stream the generated sequences. Generated tokens are passed
2298 through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
2299 negative_prompt_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
2300 The negative prompt needed for some processors such as CFG. The batch size must match the input batch
2301 size. This is an experimental feature, subject to breaking API changes in future versions.
2302 negative_prompt_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
2303 Attention_mask for `negative_prompt_ids`.
2304 use_model_defaults (`bool`, *optional*):
2305 When it is `True`, unset parameters in `generation_config` will be set to the model-specific default
2306 generation configuration (`model.generation_config`), as opposed to the global defaults
2307 (`GenerationConfig()`). If unset, models saved starting from `v4.50` will consider this flag to be
2308 `True`.
2309 custom_generate (`str`, *optional*):
2310 A string containing the name of a huggingface.co repository. If provided, the custom `generate`
2311 function defined in that reposity's `custom_generate/generate.py` file will be executed instead of the
2312 standard `generate` method. Note that the logic is for generation is entirely defined in that
2313 repository, and the return type may be different from the standard `generate` method.
2314 kwargs (`Dict[str, Any]`, *optional*):
2315 Ad hoc parametrization of `generation_config` and/or additional model-specific kwargs that will be
2316 forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
2317 specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
2318
2319 Return:
2320 [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
2321 or when `config.return_dict_in_generate=True`) or a `torch.LongTensor`.
2322
2323 If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
2324 [`~utils.ModelOutput`] types are:
2325
2326 - [`~generation.GenerateDecoderOnlyOutput`],
2327 - [`~generation.GenerateBeamDecoderOnlyOutput`]
2328
2329 If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
2330 [`~utils.ModelOutput`] types are:
2331
2332 - [`~generation.GenerateEncoderDecoderOutput`],
2333 - [`~generation.GenerateBeamEncoderDecoderOutput`]
2334 """
2335 # 0. If requested, load an arbitrary generation recipe from the Hub and run it instead
2336 if custom_generate is not None:
2337 trust_remote_code = kwargs.pop("trust_remote_code", None)
2338 # Get all `generate` arguments in a single variable. Custom functions are responsible for handling them:
2339 # they receive the same inputs as `generate`, only with `model` instead of `self`. They can access to
2340 # methods from `GenerationMixin` through `model`.
2341 global_keys_to_exclude = {"self", "kwargs"}
2342 generate_arguments = {key: value for key, value in locals().items() if key not in global_keys_to_exclude}
2343 generate_arguments.update(kwargs)
2344
2345 custom_generate_function = self.load_custom_generate(
2346 custom_generate, trust_remote_code=trust_remote_code, **kwargs
2347 )
2348 return custom_generate_function(model=self, **generate_arguments)
2349
2350 # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
2351 tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
2352 assistant_tokenizer = kwargs.pop("assistant_tokenizer", None) # only used for assisted generation
2353
2354 generation_config, model_kwargs = self._prepare_generation_config(
2355 generation_config, use_model_defaults, **kwargs
2356 )
2357 self._validate_model_kwargs(model_kwargs.copy())
2358 self._validate_assistant(assistant_model, tokenizer, assistant_tokenizer)
2359
2360 # 2. Set generation parameters if not already defined
2361 if synced_gpus is None:
2362 synced_gpus = (is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)) and dist.get_world_size() > 1
2363
2364 logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
2365 stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
2366
2367 accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
2368 requires_attention_mask = "encoder_outputs" not in model_kwargs
2369 kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
2370
2371 # 3. Define model inputs
2372 inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
2373 inputs, generation_config.bos_token_id, model_kwargs
2374 )
2375 batch_size = inputs_tensor.shape[0]
2376
2377 device = inputs_tensor.device
2378 self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)
2379
2380 # decoder-only models must use left-padding for batched generation.
2381 if not self.config.is_encoder_decoder:
2382 # If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
2383 # Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off.
2384 if (
2385 generation_config._pad_token_tensor is not None
2386 and batch_size > 1
2387 and len(inputs_tensor.shape) == 2
2388 and torch.sum(inputs_tensor[:, -1] == generation_config._pad_token_tensor) > 0
2389 ):
2390 logger.warning(
2391 "A decoder-only architecture is being used, but right-padding was detected! For correct "
2392 "generation results, please set `padding_side='left'` when initializing the tokenizer."
2393 )
2394
2395 # 4. Define other model kwargs
2396 # decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are
2397 # generating the first new token or not, and we only want to use the embeddings for the first new token)
2398 if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":
2399 generation_config.use_cache = True
2400
2401 if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
2402 model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
2403 inputs_tensor, generation_config, model_kwargs
2404 )
2405 elif kwargs_has_attention_mask:
2406 # TODO (joao): generalize this check with other types of inputs
2407 if model_input_name == "input_ids" and len(model_kwargs["attention_mask"].shape) > 2:
2408 raise ValueError("`attention_mask` passed to `generate` must be 2D.")
2409
2410 if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
2411 # if model is encoder decoder encoder_outputs are created and added to `model_kwargs`
2412 model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
2413 inputs_tensor, model_kwargs, model_input_name, generation_config
2414 )
2415
2416 # 5. Prepare `input_ids` which will be used for auto-regressive generation
2417 if self.config.is_encoder_decoder:
2418 input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
2419 batch_size=batch_size,
2420 model_input_name=model_input_name,
2421 model_kwargs=model_kwargs,
2422 decoder_start_token_id=generation_config._decoder_start_token_tensor,
2423 device=inputs_tensor.device,
2424 )
2425 else:
2426 input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
2427
2428 if generation_config.token_healing:
2429 input_ids = self.heal_tokens(input_ids, tokenizer)
2430
2431 if streamer is not None:
2432 streamer.put(input_ids.cpu())
2433
2434 # 6. Prepare `max_length` depending on other stopping criteria.
2435 input_ids_length = input_ids.shape[1]
2436 has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
2437 has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
2438 generation_config = self._prepare_generated_length(
2439 generation_config=generation_config,
2440 has_default_max_length=has_default_max_length,
2441 has_default_min_length=has_default_min_length,
2442 model_input_name=model_input_name,
2443 inputs_tensor=inputs_tensor,
2444 input_ids_length=input_ids_length,
2445 )
2446
2447 # If the model supports `logits_to_keep` in forward(), set it to 1 to avoid computing the whole
2448 # logit matrix. This can save a lot of memory during the first forward pass. Note that assisted decoding
2449 # dynamically overrides this value as it can need more than the last token logits
2450 if self._supports_logits_to_keep() and "logits_to_keep" not in model_kwargs:
2451 model_kwargs["logits_to_keep"] = 1
2452
2453 self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
2454
2455 # 7. Prepare the cache.
2456 # - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`.
2457 # - different models have a different cache name expected by the model (default = "past_key_values")
2458 # - `max_length`, prepared above, is used to determine the maximum cache length
2459 max_cache_length = generation_config.max_length - 1
2460 if (
2461 inputs_tensor.shape[1] != input_ids_length
2462 and model_input_name == "inputs_embeds"
2463 and not self.config.is_encoder_decoder
2464 ):
2465 max_cache_length += inputs_tensor.shape[1]
2466 self._prepare_cache_for_generation(
2467 generation_config, model_kwargs, assistant_model, batch_size, max_cache_length, device
2468 )
2469
2470 # 8. determine generation mode
2471 generation_mode = generation_config.get_generation_mode(assistant_model)
2472
2473 if streamer is not None and (generation_config.num_beams > 1):
2474 raise ValueError(
2475 "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
2476 )
2477
2478 if self.device.type != input_ids.device.type:
2479 warnings.warn(
2480 "You are calling .generate() with the `input_ids` being on a device type different"
2481 f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
2482 f" is on {self.device.type}. You may experience unexpected behaviors or slower generation."
2483 " Please make sure that you have put `input_ids` to the"
2484 f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before"
2485 " running `.generate()`.",
2486 UserWarning,
2487 )
2488
2489 # 9. prepare logits processors and stopping criteria
2490 prepared_logits_processor = self._get_logits_processor(
2491 generation_config=generation_config,
2492 input_ids_seq_length=input_ids_length,
2493 encoder_input_ids=inputs_tensor,
2494 prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
2495 logits_processor=logits_processor,
2496 device=inputs_tensor.device,
2497 model_kwargs=model_kwargs,
2498 negative_prompt_ids=negative_prompt_ids,
2499 negative_prompt_attention_mask=negative_prompt_attention_mask,
2500 )
2501 prepared_stopping_criteria = self._get_stopping_criteria(
2502 generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs
2503 )
2504
2505 # Set model_kwargs `use_cache` so we can use it later in forward runs
2506 model_kwargs["use_cache"] = generation_config.use_cache
2507
2508 # 10. go into different generation modes
2509 if generation_mode == GenerationMode.ASSISTED_GENERATION:
2510 if generation_config.num_return_sequences > 1:
2511 raise ValueError(
2512 "num_return_sequences has to be 1 when doing assisted generate, "
2513 f"but is {generation_config.num_return_sequences}."
2514 )
2515 if batch_size > 1:
2516 raise ValueError("assisted generate is only supported for batch_size = 1")
2517 if not model_kwargs["use_cache"]:
2518 raise ValueError("assisted generate requires `use_cache=True`")
2519 if generation_config.cache_implementation in ["static", "hybrid", "sliding_window"]:
2520 raise ValueError("assisted generate is not supported with Static cache classes`")
2521 if self._is_stateful:
2522 # In assisted generation we need the ability to confirm whether the model would pick certain tokens,
2523 # which is not possible with stateful models (they can't reset to a previous subset of generated text)
2524 raise ValueError(
2525 f"assisted generation is not supported with stateful models, such as {self.__class__.__name__}"
2526 )
2527
2528 # 11. Get the candidate generator, given the parameterization
2529 candidate_generator = self._get_candidate_generator(
2530 generation_config=generation_config,
2531 input_ids=input_ids,
2532 inputs_tensor=inputs_tensor,
2533 assistant_model=assistant_model,
2534 logits_processor=logits_processor,
2535 target_tokenizer=tokenizer,
2536 assistant_tokenizer=assistant_tokenizer,
2537 model_kwargs=model_kwargs,
2538 )
2539
2540 # 12. run assisted generate
2541 result = self._assisted_decoding(
2542 input_ids,
2543 candidate_generator=candidate_generator,
2544 logits_processor=prepared_logits_processor,
2545 stopping_criteria=prepared_stopping_criteria,
2546 generation_config=generation_config,
2547 synced_gpus=synced_gpus,
2548 streamer=streamer,
2549 **model_kwargs,
2550 )
2551 elif generation_mode == GenerationMode.DOLA_GENERATION:
2552 if self._is_stateful:
2553 # DoLa decoding was not designed for stateful models, and would require some changes
2554 raise ValueError(
2555 f"dola decoding is not supported with stateful models, such as {self.__class__.__name__}"
2556 )
2557 result = self._dola_decoding(
2558 input_ids,
2559 dola_layers=generation_config.dola_layers,
2560 logits_processor=prepared_logits_processor,
2561 stopping_criteria=prepared_stopping_criteria,
2562 generation_config=generation_config,
2563 synced_gpus=synced_gpus,
2564 streamer=streamer,
2565 **model_kwargs,
2566 )
2567
2568 elif generation_mode == GenerationMode.CONTRASTIVE_SEARCH:
2569 if not model_kwargs["use_cache"]:
2570 raise ValueError("Contrastive search requires `use_cache=True`")
2571 if self._is_stateful:
2572 # Just like assisted generation, we need to be able to rollback to a previous state (see comment above)
2573 raise ValueError(
2574 f"contrastive search is not supported with stateful models, such as {self.__class__.__name__}"
2575 )
2576
2577 result = self._contrastive_search(
2578 input_ids,
2579 logits_processor=prepared_logits_processor,
2580 stopping_criteria=prepared_stopping_criteria,
2581 generation_config=generation_config,
2582 synced_gpus=synced_gpus,
2583 streamer=streamer,
2584 **model_kwargs,
2585 )
2586
2587 elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
2588 # 11. expand input_ids with `num_return_sequences` additional sequences per batch
2589 input_ids, model_kwargs = self._expand_inputs_for_generation(
2590 input_ids=input_ids,
2591 expand_size=generation_config.num_return_sequences,
2592 is_encoder_decoder=self.config.is_encoder_decoder,
2593 **model_kwargs,
2594 )
2595
2596 # 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
2597 result = self._sample(
2598 input_ids,
2599 logits_processor=prepared_logits_processor,
2600 stopping_criteria=prepared_stopping_criteria,
2601 generation_config=generation_config,
2602 synced_gpus=synced_gpus,
2603 streamer=streamer,
2604 **model_kwargs,
2605 )
2606
2607 elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
2608 # 11. interleave input_ids with `num_beams` additional sequences per batch
2609 input_ids, model_kwargs = self._expand_inputs_for_generation(
2610 input_ids=input_ids,
2611 expand_size=generation_config.num_beams,
2612 is_encoder_decoder=self.config.is_encoder_decoder,
2613 **model_kwargs,
2614 )
2615 # 12. run beam sample
2616 result = self._beam_search(
2617 input_ids,
2618 logits_processor=prepared_logits_processor,
2619 stopping_criteria=prepared_stopping_criteria,
2620 generation_config=generation_config,
2621 synced_gpus=synced_gpus,
2622 **model_kwargs,
2623 )
2624
2625 elif generation_mode == GenerationMode.GROUP_BEAM_SEARCH:
2626 # 11. prepare beam search scorer
2627 beam_scorer = BeamSearchScorer(
2628 batch_size=batch_size,
2629 num_beams=generation_config.num_beams,
2630 device=inputs_tensor.device,
2631 length_penalty=generation_config.length_penalty,
2632 do_early_stopping=generation_config.early_stopping,
2633 num_beam_hyps_to_keep=generation_config.num_return_sequences,
2634 num_beam_groups=generation_config.num_beam_groups,
2635 max_length=generation_config.max_length,
2636 )
2637 # 12. interleave input_ids with `num_beams` additional sequences per batch
2638 input_ids, model_kwargs = self._expand_inputs_for_generation(
2639 input_ids=input_ids,
2640 expand_size=generation_config.num_beams,
2641 is_encoder_decoder=self.config.is_encoder_decoder,
2642 **model_kwargs,
2643 )
2644 # 13. run beam search
2645 result = self._group_beam_search(
2646 input_ids,
2647 beam_scorer,
2648 logits_processor=prepared_logits_processor,
2649 stopping_criteria=prepared_stopping_criteria,
2650 generation_config=generation_config,
2651 synced_gpus=synced_gpus,
2652 **model_kwargs,
2653 )
2654
2655 elif generation_mode == GenerationMode.CONSTRAINED_BEAM_SEARCH:
2656 final_constraints = []
2657 if generation_config.constraints is not None:
2658 final_constraints = generation_config.constraints
2659
2660 if generation_config.force_words_ids is not None:
2661
2662 def typeerror():
2663 raise ValueError(
2664 "`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]` "
2665 f"of positive integers, but is {generation_config.force_words_ids}."
2666 )
2667
2668 if (
2669 not isinstance(generation_config.force_words_ids, list)
2670 or len(generation_config.force_words_ids) == 0
2671 ):
2672 typeerror()
2673
2674 for word_ids in generation_config.force_words_ids:
2675 if isinstance(word_ids[0], list):
2676 if not isinstance(word_ids, list) or len(word_ids) == 0:
2677 typeerror()
2678 if any(not isinstance(token_ids, list) for token_ids in word_ids):
2679 typeerror()
2680 if any(
2681 any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids)
2682 for token_ids in word_ids
2683 ):
2684 typeerror()
2685
2686 constraint = DisjunctiveConstraint(word_ids)
2687 else:
2688 if not isinstance(word_ids, list) or len(word_ids) == 0:
2689 typeerror()
2690 if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids):
2691 typeerror()
2692
2693 constraint = PhrasalConstraint(word_ids)
2694 final_constraints.append(constraint)
2695
2696 # 11. prepare beam search scorer
2697 constrained_beam_scorer = ConstrainedBeamSearchScorer(
2698 constraints=final_constraints,
2699 batch_size=batch_size,
2700 num_beams=generation_config.num_beams,
2701 device=inputs_tensor.device,
2702 length_penalty=generation_config.length_penalty,
2703 do_early_stopping=generation_config.early_stopping,
2704 num_beam_hyps_to_keep=generation_config.num_return_sequences,
2705 max_length=generation_config.max_length,
2706 )
2707 # 12. interleave input_ids with `num_beams` additional sequences per batch
2708 input_ids, model_kwargs = self._expand_inputs_for_generation(
2709 input_ids=input_ids,
2710 expand_size=generation_config.num_beams,
2711 is_encoder_decoder=self.config.is_encoder_decoder,
2712 **model_kwargs,
2713 )
2714 # 13. run beam search
2715 result = self._constrained_beam_search(
2716 input_ids,
2717 constrained_beam_scorer=constrained_beam_scorer,
2718 logits_processor=prepared_logits_processor,
2719 stopping_criteria=prepared_stopping_criteria,
2720 generation_config=generation_config,
2721 synced_gpus=synced_gpus,
2722 **model_kwargs,
2723 )
2724
2725 # Convert to legacy cache format if requested
2726 if (
2727 generation_config.return_legacy_cache is True
2728 and hasattr(result, "past_key_values")
2729 and getattr(result.past_key_values, "to_legacy_cache") is not None
2730 ):
2731 result.past_key_values = result.past_key_values.to_legacy_cache()
2732 return result
```
请对这个函数进行分析,代码逻辑以及功能
最新发布