# from_pretrained 是一个类方法,用于从预训练模型中加载模型实例
@classmethod
def from_pretrained(
cls,
# pretrained_model_name_or_path: 预训练模型的名称或路径,可以是本地路径或在线路径
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
# model_args: 模型的初始化参数,将传递给模型的构造函数
*model_args,
# config: 模型配置对象或其路径,如果未提供,将尝试从 pretrained_model_name_or_path 加载
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
# cache_dir: 用于缓存下载的模型文件的目录
cache_dir: Optional[Union[str, os.PathLike]] = None,
# ignore_mismatched_sizes: 是否忽略权重大小不匹配的情况
ignore_mismatched_sizes: bool = False,
# force_download: 是否强制重新下载模型权重
force_download: bool = False,
# local_files_only: 是否只使用本地文件,不尝试从远程下载
local_files_only: bool = False,
# token: Hugging Face Hub 的访问令牌,用于下载模型权重
token: Optional[Union[str, bool]] = None,
# revision: 要使用的模型修订版本
revision: str = "main",
# use_safetensors: 是否使用 safetensors 格式加载模型权重
use_safetensors: bool = None,
**kwargs,
):
# 从 kwargs 中提取一些常用参数
state_dict = kwargs.pop("state_dict", None)
from_tf = kwargs.pop("from_tf", False)
from_flax = kwargs.pop("from_flax", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
output_loading_info = kwargs.pop("output_loading_info", False)
use_auth_token = kwargs.pop("use_auth_token", None)
trust_remote_code = kwargs.pop("trust_remote_code", None) # 使用远端加载模型文件
_ = kwargs.pop("mirror", None)
from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False)
_fast_init = kwargs.pop("_fast_init", True)
torch_dtype = kwargs.pop("torch_dtype", None) # 模型加载的数据类型,type.bfloat16 等
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", None) # 低占用
device_map = kwargs.pop("device_map", None) # auto 的话 accelerate 会自动分配设备
max_memory = kwargs.pop("max_memory", None)
offload_folder = kwargs.pop("offload_folder", None)
offload_state_dict = kwargs.pop("offload_state_dict", False)
offload_buffers = kwargs.pop("offload_buffers", False)
load_in_8bit = kwargs.pop("load_in_8bit", False)
load_in_4bit = kwargs.pop("load_in_4bit", False) # 4bit加载(未来删除),现在使用 quantization_config 参数
quantization_config = kwargs.pop("quantization_config", None) # 量化的参数
subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)
variant = kwargs.pop("variant", None)
adapter_kwargs = kwargs.pop("adapter_kwargs", {})
adapter_name = kwargs.pop("adapter_name", "default") # peft adapter_name, 默认即可
use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False) # 未来弃用,使用参数 attn_implementation 来控制,使用方法见下文
# 如果启用了 FSDP,则强制启用 low_cpu_mem_usage
if is_fsdp_enabled():
low_cpu_mem_usage = True
# 对于 use_auth_token 参数的处理,已弃用,建议使用 token 参数代替
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
FutureWarning,
)
if token is not None:
raise ValueError(
"`token` and `use_auth_token` are both specified. Please set only the argument `token`."
)
token = use_auth_token
# 如果提供了 token 和 adapter_kwargs,则将 token 添加到 adapter_kwargs 中
if token is not None and adapter_kwargs is not None and "token" not in adapter_kwargs:
adapter_kwargs["token"] = token
# 处理 use_safetensors 参数
if use_safetensors is None and not is_safetensors_available():
use_safetensors = False
if trust_remote_code is True:
logger.warning(
"The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is"
" ignored."
)
# 尝试从预训练模型路径获取 commit_hash
if commit_hash is None:
if not isinstance(config, PretrainedConfig):
resolved_config_file = cached_file(
pretrained_model_name_or_path,
CONFIG_NAME,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
_raise_exceptions_for_gated_repo=False,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
else:
commit_hash = getattr(config, "_commit_hash", None)
# 如果使用了 PEFT 并且可用,则尝试加载适配器配置文件
if is_peft_available():
_adapter_model_path = adapter_kwargs.pop("_adapter_model_path", None)
if _adapter_model_path is None:
_adapter_model_path = find_adapter_config_file(
pretrained_model_name_or_path,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
_commit_hash=commit_hash,
**adapter_kwargs,
)
if _adapter_model_path is not None and os.path.isfile(_adapter_model_path):
with open(_adapter_model_path, "r", encoding="utf-8") as f:
_adapter_model_path = pretrained_model_name_or_path
pretrained_model_name_or_path = json.load(f)["base_model_name_or_path"]
else:
_adapter_model_path = None
# 处理 device_map 参数,将其转换为适当的格式
if isinstance(device_map, torch.device):
device_map = {"": device_map}
elif isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
try:
device_map = {"": torch.device(device_map)}
except RuntimeError:
raise ValueError(
"When passing device_map as a string, the value needs to be a device name (e.g. cpu, cuda:0) or "
f"'auto', 'balanced', 'balanced_low_0', 'sequential' but found {device_map}."
)
elif isinstance(device_map, int):
if device_map < 0:
raise ValueError(
"You can't pass device_map as a negative int. If you want to put the model on the cpu, pass device_map = 'cpu' "
)
else:
device_map = {"": device_map}
# 如果提供了 device_map,则强制启用 low_cpu_mem_usage
if device_map is not None:
if low_cpu_mem_usage is None:
low_cpu_mem_usage = True
elif not low_cpu_mem_usage:
raise ValueError("Passing along a `device_map` requires `low_cpu_mem_usage=True`")
# 如果启用了 low_cpu_mem_usage,但未安装 Accelerate,则引发异常
if low_cpu_mem_usage:
if is_deepspeed_zero3_enabled():
raise ValueError(
"DeepSpeed Zero-3 is not compatible with `low_cpu_mem_usage=True` or with passing a `device_map`."
)
elif not is_accelerate_available():
raise ImportError(
"Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install accelerate`"
)
# 处理 quantization_config 参数
if load_in_4bit or load_in_8bit:
if quantization_config is not None:
raise ValueError(
"You can't pass `load_in_4bit`or `load_in_8bit` as a kwarg when passing "
"`quantization_config` argument at the same time."
)
config_dict = {k: v for k, v in kwargs.items() if k in inspect.signature(BitsAndBytesConfig).parameters}
config_dict = {**config_dict, "load_in_4bit": load_in_4bit, "load_in_8bit": load_in_8bit}
quantization_config, kwargs = BitsAndBytesConfig.from_dict(
config_dict=config_dict, return_unused_kwargs=True, **kwargs
)
logger.warning(
"The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. "
"Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead."
)
from_pt = not (from_tf | from_flax)
# 设置用户代理字符串,用于下载模型权重
user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
if from_pipeline is not None:
user_agent["using_pipeline"] = from_pipeline
# 如果处于离线模式,则强制启用 local_files_only
if is_offline_mode() and not local_files_only:
logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True
# 加载模型配置
if not isinstance(config, PretrainedConfig):
config_path = config if config is not None else pretrained_model_name_or_path
config, model_kwargs = cls.config_class.from_pretrained(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
_from_auto=from_auto_class,
_from_pipeline=from_pipeline,
**kwargs,
)
else:
# 如果直接提供了配置对象,则复制一份,以免修改原始配置
config = copy.deepcopy(config)
# 处理 attn_implementation 参数
kwarg_attn_imp = kwargs.pop("attn_implementation", None)
if kwarg_attn_imp is not None and config._attn_implementation != kwarg_attn_imp:
config._attn_implementation = kwarg_attn_imp
model_kwargs = kwargs
# 处理量化相关配置
pre_quantized = getattr(config, "quantization_config", None) is not None
if pre_quantized or quantization_config is not None:
if pre_quantized:
config.quantization_config = AutoHfQuantizer.merge_quantization_configs(
config.quantization_config, quantization_config
)
else:
config.quantization_config = quantization_config
hf_quantizer = AutoHfQuantizer.from_config(config.quantization_config, pre_quantized=pre_quantized)
else:
hf_quantizer = None
# 如果启用了量化,则需要验证环境和调整一些参数
if hf_quantizer is not None:
hf_quantizer.validate_environment(
torch_dtype=torch_dtype, from_tf=from_tf, from_flax=from_flax, device_map=device_map
)
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
device_map = hf_quantizer.update_device_map(device_map)
if low_cpu_mem_usage is None:
low_cpu_mem_usage = True
logger.warning("`low_cpu_mem_usage` was None, now set to True since model is quantized.")
is_quantized = hf_quantizer is not None
# 检查是否为分片加载检查点
is_sharded = False
sharded_metadata = None
# 加载模型权重
loading_info = None
keep_in_fp32_modules = None
use_keep_in_fp32_modules = False
(......) 详细解释一下
最新发布