MindSpore源码精读之ckpt及safetensors权重保存流程

前言

在深度学习的世界里,模型的训练犹如一场漫长而艰辛的旅程。当我们历经无数次迭代,调整超参数,优化损失函数,终于让模型达到满意的性能时,一个不容忽视的环节摆在面前 —— 网络权重的保存。这可不是简单的存储操作,它承载着模型训练的成果与智慧结晶。从模型复用的角度来看,保存的权重能让我们在不同场景下快速调用已训练好的模型,无需重复训练,大大节省时间和计算资源。在迁移学习中,保存的权重更是成为知识迁移的桥梁,助力新任务快速收敛。接下来,让我们深入 MindSpore 框架,探寻 ckpt 和 safetensors 这两种权重保存流程的奥秘。

前置知识

在深入探究 MindSpore 中 ckpt 和 safetensors 权重保存流程之前,掌握一些关键的前置知识是十分必要的。这些知识不仅能帮助我们理解源码的工作原理,还能让我们明白不同技术在深度学习模型保存过程中所起到的重要作用。

深度学习中的网络权重

权重的基本概念

在深度学习模型里,网络权重是模型的核心要素,它本质上是一组可学习的参数。想象一个简单的线性回归模型 (y = wx + b),其中 (w) 就是权重,它决定了输入 (x) 对输出 (y) 的影响程度。在更复杂的神经网络中,如多层感知机(MLP)、卷积神经网络(CNN)和循环神经网络(RNN)等,权重分布在各个神经元之间的连接上。

以卷积神经网络为例,卷积层中的卷积核就是一组权重。这些卷积核在图像数据上滑动,通过卷积操作提取图像的特征。不同的权重组合能够提取出不同类型的特征,如边缘、纹理等。在训练过程中,模型会根据输入数据和目标输出不断调整这些权重,使得模型的预测结果尽可能接近真实值。

权重保存的意义

模型训练是一个计算资源和时间消耗都很大的过程。一旦训练完成,保存网络权重就显得尤为重要。保存的权重可以用于后续的推理任务,即对新的数据进行预测。同时,权重保存也为模型的复用和迁移学习提供了可能。例如,在一个图像分类任务中训练好的模型权重,可以作为基础,在另一个相关的图像分类任务中进行微调,从而大大缩短训练时间并提高模型性能。

Protocol Buffers(Protobuf)

Protobuf 的定义与原理

Protocol Buffers(简称 Protobuf)是由 Google 开发的一种与语言无关、平台无关、可扩展的序列化结构数据的方法。它允许你定义数据的结构,然后使用特殊生成的源代码轻松地在各种数据流中使用各种语言对结构化数据进行序列化和反序列化。

Protobuf 的工作流程主要分为以下几个步骤:首先,使用 .proto 文件定义数据结构,这类似于定义一个类或结构体。例如:

syntax = "proto3";

message Person {
  string name = 1;
  int32 age = 2;
  string email = 3;
}

在这个例子中,我们定义了一个名为 Person 的消息类型,包含三个字段:name(字符串类型)、age(32 位整数类型)和 email(字符串类型)。每个字段后面的数字(如 123)是该字段的唯一标识,用于在序列化和反序列化过程中识别字段。

接着,使用 protoc 编译器根据 .proto 文件生成特定语言的代码。例如,对于 Python 语言,可以运行以下命令:

protoc --python_out=. your_file.proto

这将生成一个 Python 模块,其中包含用于创建、序列化和反序列化 Person 消息的类和方法。

最后,在代码中使用生成的类来创建消息对象,将其序列化为二进制数据,或者将二进制数据反序列化为消息对象。以下是一个 Python 示例:

import your_file_pb2

# 创建一个 Person 对象
person = your_file_pb2.Person()
person.name = "Alice"
person.age = 25
person.email = "alice@example.com"

# 序列化对象
serialized_data = person.SerializeToString()

# 反序列化数据
new_person = your_file_pb2.Person()
new_person.ParseFromString(serialized_data)

print(new_person.name)  # 输出: Alice

Safetensors

Safetensors 的基本介绍

在深度学习领域,张量作为一种多维数组,是模型权重以及输入输出数据的主要表现形式。而 Safetensors 作为一种用于安全存储和交换张量数据的格式,正逐渐崭露头角。它的诞生旨在应对传统张量存储格式(像 .npy.pt 等)在安全性与性能方面存在的问题。

Safetensors 具备多个显著特点,使其在众多存储格式中脱颖而出:

  • 安全性高:Safetensors 构建了严格的数据验证与检查机制,可有效抵御因恶意构造数据而引发的安全漏洞。它会仔细核查张量的形状、数据类型等信息是否合法合规,从根源上避免因数据不一致而产生的各类错误,确保数据的可靠性和安全性。
  • 零拷贝特性带来极快加载速度:Safetensors 采用了经过精心优化的二进制格式,其核心亮点在于运用了零拷贝技术。传统的数据加载过程往往需要在内存中进行多次数据复制操作,这不仅耗费大量时间,还会占用额外的内存资源。而零拷贝技术允许数据在存储介质和内存之间直接进行传输,无需中间的复制步骤。在处理大规模模型权重时,这种特性能够让加载过程的速度得到大幅度提升,显著提高模型的部署效率。
  • 跨框架兼容性强:Safetensors 打破了不同深度学习框架之间的壁垒,能够在诸如 PyTorch、TensorFlow、MindSpore 等多个主流框架之间实现数据的顺畅交换。这为用户在不同的开发环境中使用相同的张量数据提供了极大的便利,大大提高了数据的复用性和灵活性。

MindSpore实现

理解了上述前置知识过后,我们就可以开始详细的研究MindSpore是如何利用protobuf和safetensors的能力来进行权重保存的了。先看一下函数定义:

def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True,
                    async_save=False, append_dict=None, enc_key=None, enc_mode="AES-GCM", choice_func=None,
                    crc_check=False, format="ckpt", **kwargs):
"""
参数:
    - **save_obj** (Union[Cell, list, dict]) - 待保存的对象。数据类型可为 :class:`mindspore.nn.Cell` 、list或dict。

        - 若为list,可以是 `Cell.trainable_params()` 的返回值,或元素为dict的列表(如[{"name": param_name, "data": param_data},…],`param_name` 的类型必须是str,`param_data` 的类型必须是Parameter或者Tensor)。
        - 若为dict,可以是 :func:`mindspore.load_checkpoint` 的返回值。
    
    - **ckpt_file_name** (str) - checkpoint文件名称。如果文件已存在,将会覆盖原有文件。
    - **integrated_save** (bool) - 在并行场景下是否合并保存拆分的Tensor。默认值: ``True`` 。
    - **async_save** (Union[bool, str]) - 是否使用异步方式保存checkpoint文件,True时默认使用异步线程;如果是str类型,选择异步保存方式,可以是 "process" 或 "thread"。默认值: ``False`` 。
    - **append_dict** (dict) - 需要保存的其他信息。dict的键必须为str类型,dict的值类型必须是int、float、bool、string、Parameter或Tensor类型。默认值: ``None`` 。
    - **enc_key** (Union[None, bytes]) - 用于加密的字节类型密钥。如果值为 ``None`` ,那么不需要加密。默认值: ``None`` 。
    - **enc_mode** (str) - 该参数在 `enc_key` 不为 ``None`` 时有效,指定加密模式,目前仅支持 ``"AES-GCM"`` , ``"AES-CBC"`` 和 ``"SM4-CBC"`` 。默认值: ``"AES-GCM"`` 。
    - **choice_func** (function) - 用于自定义控制保存参数的函数。函数的输入值为字符串类型的Parameter名称,并且返回值是一个布尔值。默认值: ``None`` 。

        - 如果返回 ``True`` ,则匹配自定义条件的Parameter将被保存。 
        - 如果返回 ``False`` ,则未匹配自定义条件的Parameter不会被保存。

    - **crc_check** (bool) - 是否在保存checkpoint时进行crc32校验,并把计算结果写到文件中。默认值: ``False`` 。
    - **format** (str) - 输出文件的格式,可以是 "ckpt" 或 "safetensors"。默认值:``"ckpt"``。
    - **kwargs** (dict) - 配置选项字典。
"""

查看官方文档可以详细的理解每一个参数的意义,很好理解,可以看到,最基础也是最重要的两个参数就是save_obj和ckpt_file_name,配置这俩参数就可以成功完成权重保存,此外,通过配置参数format就可以简单的控制存储落盘的权重文件格式,选择为ckpt或者safetensors格式保存,其中ckpt格式就是使用protobuf能力保存定义的格式结尾。接下来我们一段一段的对代码进行解析。

准备阶段代码解析

start_save_time = time.time()
ckpt_file_name = _check_save_obj_and_ckpt_file_name(save_obj, ckpt_file_name, format)
integrated_save = Validator.check_bool(integrated_save)
async_save = _check_async_save(async_save)
append_dict = _check_append_dict(append_dict)
enc_key = Validator.check_isinstance('enc_key', enc_key, (type(None), bytes))
enc_mode = Validator.check_isinstance('enc_mode', enc_mode, str)
crc_check = Validator.check_isinstance('crc_check', crc_check, bool)
map_param_inc = kwargs.get('incremental', False)
logger.info("Execute the process of saving checkpoint files.")
global_step_num = kwargs.get('global_step_num', None)
_check_save_checkpoint_upsupported_param(format, enc_key, enc_mode, async_save, map_param_inc, global_step_num)

这一段代码比较简单,不做过多的解析,本阶段承担了一些简单的参数检查,日志记录等工作,获取了一些通过kwargs入参的关键字参数,本段代码有效的提高了保存操作的稳定性和可靠性。

处理追加字典与保存对象转换

if append_dict and "__exception_save__" in append_dict:
    s1 = mindspore.hal.Stream()
    with mindspore.hal.StreamCtx(s1):
        save_obj = _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func)
        for k_name, value in append_dict.items():
            if isinstance(value, (Tensor, Parameter)):
                append_dict[k_name] = Tensor(Tensor_.move_to(value, "CPU", False))
    s1.synchronize()
else:
    save_obj = _convert_save_obj_to_param_list(save_obj, integrated_save, append_dict, choice_func)

if append_dict:
    if "__exception_save__" in append_dict:
        del append_dict["__exception_save__"]
    append_info_list = []
    for k_name, value in append_dict.items():
        if isinstance(value, Generator):
            value = value.get_state()
        elif not isinstance(value, str):
            value = Tensor(value)
        append_info_list.append({"name": k_name, "data": value})
    save_obj.extend(append_info_list)

本阶段单独对append_dict这个参数进行了处理,在append_dict 存在且其中包含 "exception_save" 键时,通过创建流并设置上下文的方式,转换保存对象,不存在时则直接进行转换,用于处理追加字典并扩展保存对象,确保在保存模型权重时能够包含额外的信息,同时保证数据的一致性和设备兼容性,该部分功能在本文不进行详细讲解,感兴趣的同学们可以自行研读源码。

数据整理与存储准备

data_list = OrderedDict()
data_list_np = OrderedDict()
with _ckpt_mutex:
    for param in save_obj:
        if param["name"] == "random_op":
            if os.getenv("AITURBO") == "1":
                data_list_np["random_op"] = []
                data_list_np["random_op"].append(param["data"])
                if crc_check:
                    bytes_value = bytes(data_list_np[key][0])
                    data_list_np[key].append(binascii.crc32(bytes_value))
            else:
                data_list["random_op"] = param["data"]
            continue
        key = param["name"]
        data_list[key] = []
        data_list_np[key] = []
        if isinstance(param["data"], MapParameter):
            data_list[param["name"]].append("mapparameter")
            data_list[param["name"]].append(param["data"])
            continue
        if isinstance(param["data"], list):
            if param["data"][0] == "persistent_data":
                _save_param_list_data(data_list, key, param)
            elif param["data"][0] == "offload_parameter":
                data_list[key].append("offload_parameter")
                _save_param_list_data(data_list, key, param)

        if isinstance(param["data"], str):
            if os.getenv("AITURBO") == "1":
                data_list_np[key].append(np.array(param["data"]))
                if crc_check:
                    bytes_value = data_list_np[key][0].tobytes()
                    data_list_np[key].append(binascii.crc32(bytes_value))
            else:
                data_list[key].append([0])
                data_list[key].append('str')
                data = np.array(param["data"])
                data_list[key].append(data)
        else:
            if isinstance(param["data"], Parameter):
                param["data"].init_data()
            if os.getenv("AITURBO") == "1":
                data_list_np[key].append(param["data"].asnumpy())
                if crc_check:
                    bytes_value = data_list_np[key][0].tobytes()
                    data_list_np[key].append(binascii.crc32(bytes_value))
            else:
                dims = []
                for dim in param['data'].shape:
                    dims.append(dim)
                data_list[key].append(dims)
                tensor_type = str(param["data"].dtype)
                data_list[key].append(tensor_type)
                data = param["data"] if async_save != "process" else param["data"].asnumpy()
                data_list[key].append(data)

在 MindSpore 进行模型权重保存的流程中,将保存对象中的参数数据进行整理和分类存储是关键步骤。上面这段代码看起来比较复杂,其实只是处理的分支比较多罢了,实际上函数本身就是负责对 save_obj 中的参数进行处理,将其存储到 data_list 或 data_list_np 中,同时根据不同的条件进行了特殊的处理。

1. 初始化有序字典

data_list = OrderedDict()
data_list_np = OrderedDict()

这里创建了两个有序字典 data_list 和 data_list_np,用于存储整理后的参数数据。有序字典可以保证元素的插入顺序,方便后续处理。

2. 加锁处理参数

在 MindSpore 权重保存流程中,为了保证在多线程或多进程环境下对参数处理的线程安全,避免数据竞争问题,代码中使用了锁机制。其中,_ckpt_mutex 锁的初始化是整个线程安全保障的重要一环。

_ckpt_mutex = RLock()
with _ckpt_mutex:

RLock 即可重入锁(Reentrant Lock),它是 Python 标准库 threading 模块中的一个类。与普通的锁(Lock)不同,可重入锁允许同一个线程多次获取该锁而不会产生死锁。这意味着在一个线程已经持有该锁的情况下,它可以再次请求该锁,而不会被阻塞。当线程释放锁时,需要相应次数的释放操作才能完全释放该锁。

在权重保存过程中,对 save_obj 中的参数进行处理是一个关键步骤,而这个过程可能会涉及到多个线程或进程同时访问和修改共享数据。通过初始化 _ckpt_mutex 为可重入锁,我们可以在处理参数时使用 with _ckpt_mutex: 语句来创建一个临界区,确保在同一时间只有一个线程可以进入该临界区进行操作。

当一个线程进入 with 语句块时,它会尝试获取 _ckpt_mutex 锁。如果锁当前未被其他线程持有,该线程将成功获取锁并进入临界区执行参数处理代码;如果锁已经被其他线程持有,该线程将被阻塞,直到锁被释放。这样就保证了在多线程环境下对参数处理的线程安全,避免了数据竞争和不一致的问题。

可重入锁的可重入性在某些复杂的代码逻辑中非常有用。例如,如果在参数处理过程中,某个函数内部又需要再次获取 _ckpt_mutex 锁,使用可重入锁可以避免死锁的发生。因为同一个线程可以多次获取该锁,而不会被阻塞。

_ckpt_mutex = RLock() 这一初始化语句为 MindSpore 权重保存流程中的参数处理提供了重要的线程安全保障。通过使用可重入锁,代码能够在多线程或多进程环境下稳定运行,确保数据的一致性和完整性。

3. 处理 random_op 参数

if param["name"] == "random_op":
    if os.getenv("AITURBO") == "1":
        data_list_np["random_op"] = []
        data_list_np["random_op"].append(param["data"])
        if crc_check:
            bytes_value = bytes(data_list_np[key][0])
            data_list_np[key].append(binascii.crc32(bytes_value))
    else:
        data_list["random_op"] = param["data"]
    continue

当参数名称为 "random_op" 时,根据环境变量 AITURBO 的值进行不同处理:

  • 如果 AITURBO 为 "1",将参数数据存储到 data_list_np 中,并在需要时进行 CRC 校验,将校验结果也存储在列表中。
  • 否则,将参数数据直接存储到 data_list 中。

4. 处理 MapParameter 类型的参数

if isinstance(param["data"], MapParameter):
    data_list[param["name"]].append("mapparameter")
    data_list[param["name"]].append(param["data"])
    continue

如果参数数据是 MapParameter 类型,将 "mapparameter" 标记和参数数据依次添加到 data_list 中。

5. 处理列表类型的参数

if isinstance(param["data"], list):
    if param["data"][0] == "persistent_data":
        _save_param_list_data(data_list, key, param)
    elif param["data"][0] == "offload_parameter":
        data_list[key].append("offload_parameter")
        _save_param_list_data(data_list, key, param)

如果参数数据是列表类型,根据列表的第一个元素判断是 "persistent_data" 还是 "offload_parameter",并调用 _save_param_list_data 函数进行处理。

6. 处理字符串类型的参数

if isinstance(param["data"], str):
    if os.getenv("AITURBO") == "1":
        data_list_np[key].append(np.array(param["data"]))
        if crc_check:
            bytes_value = data_list_np[key][0].tobytes()
            data_list_np[key].append(binascii.crc32(bytes_value))
    else:
        data_list[key].append([0])
        data_list[key].append('str')
        data = np.array(param["data"])
        data_list[key].append(data)

当参数数据是字符串类型时,同样根据 AITURBO 的值进行不同处理:

  • 如果 AITURBO 为 "1",将字符串转换为 numpy 数组存储到 data_list_np 中,并在需要时进行 CRC 校验。
  • 否则,在 data_list 中依次添加 [0]'str' 标记和转换后的 numpy 数组。

7. 处理其他类型的参数

else:
    if isinstance(param["data"], Parameter):
        param["data"].init_data()
    if os.getenv("AITURBO") == "1":
        data_list_np[key].append(param["data"].asnumpy())
        if crc_check:
            bytes_value = data_list_np[key][0].tobytes()
            data_list_np[key].append(binascii.crc32(bytes_value))
    else:
        dims = []
        for dim in param['data'].shape:
            dims.append(dim)
        data_list[key].append(dims)
        tensor_type = str(param["data"].dtype)
        data_list[key].append(tensor_type)
        data = param["data"] if async_save != "process" else param["data"].asnumpy()
        data_list[key].append(data)

对于其他类型的参数,首先如果是 Parameter 类型,调用 init_data() 方法初始化数据。然后根据 AITURBO 的值进行不同处理:

  • 如果 AITURBO 为 "1",将参数数据转换为 numpy 数组存储到 data_list_np 中,并在需要时进行 CRC 校验。
  • 否则,在 data_list 中依次添加参数的维度信息、数据类型和数据本身。

综上所述,这段代码通过对不同类型的参数进行分类处理,将参数数据存储到合适的有序字典中,并在需要时进行 CRC 校验,为后续的权重保存操作做好了数据准备。这种细致的处理方式确保了不同类型的参数都能被正确存储和管理。 其中,AITURBO 是一种特殊的存储方式,MindSpore实现此处时,单独做了一些特殊的处理,不用过于关注。

根据不同模式选择不同的保存策略

在 MindSpore 进行模型权重保存时,会根据不同的条件选择不同的保存策略。这段代码根据环境变量 AITURBO 的值以及 async_save 参数的设置,决定采用特殊存储结构保存、异步保存(包括进程和线程方式)还是同步保存。

if os.getenv("AITURBO") == "1":
    from aiturbo.checkpoint import aiturbo_mindspore as aiturbo
    ckpt_name = os.path.basename(ckpt_file_name)
    aiturbo.save_ckpt(ckpt_name, global_step_num, data_list_np, crc_check)
elif async_save:
    if async_save == "process":
        if sys.platform.startswith("win"):
            logger.warning("The Win platform currently does not support asynchronous process saving of ckpt, "
                           "so serial saving of ckpt is used now.")
            _exec_save(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc, crc_check, format)
        else:
            _wait_async_process_save_ckpt()
            ctx = mp.get_context("fork")
            cond = ctx.Condition()
            process_flag = True
            while process_flag:
                process = ctx.Process(target=_async_process_save,
                                      args=(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc, crc_check,
                                            format, cond), daemon=True, name="asyn_save_ckpt")
                process.start()
                with cond:
                    wait_flag = cond.wait(timeout=5)
                    if not wait_flag:
                        logger.warning("Async save process fails to create. will kill and recreate")
                        process.kill()
                    else:
                        process_flag = False
    else:
        data_copy = copy.deepcopy(data_list)
        _wait_async_thread_save_ckpt()
        thr = Thread(target=_exec_save,
                     args=(ckpt_file_name, data_copy, enc_key, enc_mode, map_param_inc, crc_check, format),
                     name="asyn_save_ckpt")
        thr.start()
else:
    _exec_save(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc, crc_check, format)

1. 使用 AITurbo 特殊存储结构保存

if os.getenv("AITURBO") == "1":
    from aiturbo.checkpoint import aiturbo_mindspore as aiturbo
    ckpt_name = os.path.basename(ckpt_file_name)
    aiturbo.save_ckpt(ckpt_name, global_step_num, data_list_np, crc_check)

当环境变量 AITURBO 的值为 "1" 时,表明需要使用 AITurbo 特殊存储结构进行保存。首先从 aiturbo.checkpoint 模块中导入 aiturbo_mindspore 并别名化为 aiturbo,然后获取检查点文件名的基础名称,最后调用 aiturbo.save_ckpt 函数进行保存操作,该函数会使用 AITurbo 特有的存储逻辑,将 data_list_np 中的数据以及全局步数和 CRC 校验信息保存到指定的文件中,此处我们不再深入进行详细讲解。

2. 异步保存模式

elif async_save:

当 async_save 为 True 时,进入异步保存模式,根据 async_save 的具体值分为进程异步保存和线程异步保存两种情况。

进程异步保存
if async_save == "process":
    if sys.platform.startswith("win"):
        logger.warning("The Win platform currently does not support asynchronous process saving of ckpt, "
                       "so serial saving of ckpt is used now.")
        _exec_save(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc, crc_check, format)
    else:
        _wait_async_process_save_ckpt()
        ctx = mp.get_context("fork")
        cond = ctx.Condition()
        process_flag = True
        while process_flag:
            process = ctx.Process(target=_async_process_save,
                                  args=(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc, crc_check,
                                        format, cond), daemon=True, name="asyn_save_ckpt")
            process.start()
            with cond:
                wait_flag = cond.wait(timeout=5)
                if not wait_flag:
                    logger.warning("Async save process fails to create. will kill and recreate")
                    process.kill()
                else:
                    process_flag = False

如果 async_save 的值为 "process",表示采用进程异步保存方式。但在 Windows 平台上,由于不支持 fork 方式创建进程,会输出警告信息并采用同步保存方式调用 _exec_save 函数。在其他支持 fork 的平台上,首先调用 _wait_async_process_save_ckpt 函数进行一些准备工作,然后使用 multiprocessing 模块创建一个新的进程,指定目标函数为 _async_process_save,并传入相关参数。使用条件变量 cond 进行等待,如果在 5 秒内没有收到信号,说明进程创建失败,会杀死该进程并重新创建,直到成功为止。

线程异步保存
else:
    data_copy = copy.deepcopy(data_list)
    _wait_async_thread_save_ckpt()
    thr = Thread(target=_exec_save,
                 args=(ckpt_file_name, data_copy, enc_key, enc_mode, map_param_inc, crc_check, format),
                 name="asyn_save_ckpt")
    thr.start()

当 async_save 不为 "process" 时,采用线程异步保存方式。首先对 data_list 进行深拷贝,避免多线程操作时的数据竞争问题。然后调用 _wait_async_thread_save_ckpt 函数进行准备工作,接着创建一个新的线程,指定目标函数为 _exec_save,并传入相关参数,最后启动线程进行异步保存。

3. 同步保存模式

else:
    _exec_save(ckpt_file_name, data_list, enc_key, enc_mode, map_param_inc, crc_check, format)

当 async_save 为 False 时,采用同步保存模式,直接调用 _exec_save 函数将 data_list 中的数据保存到指定的检查点文件中。

到此处为止,MindSpore的save_checkpoint整个函数基本上就写到了结尾,后面还剩余几行日志的打印信息,此处就不再额外加以赘述,由于无论是异步保存还是同步保存,最终其实都是调用的_exec_save这个函数完成的保存,因此本文继续分析一下_exec_save这个函数的具体实现。

_exec_save 函数:实现模型权重保存的核心逻辑

在 MindSpore 里,_exec_save 函数承担着将模型权重保存到文件的核心任务。无论是同步还是异步保存操作,最终都会调用此函数来完成具体的保存工作。下面我们就对该函数的具体实现进行详细分析。

def _exec_save(ckpt_file_name, data_list, enc_key=None, enc_mode="AES-GCM", map_param_inc=False, crc_check=False,
               format="ckpt"):
    """Execute the process of saving checkpoint into file."""
    try:
        with _ckpt_mutex:
            file_name_list = list(os.path.splitext(ckpt_file_name))
            file_name_list[1] = file_name_list[1].replace(f".{format}", ".tmp")
            tmp_name = ''.join(file_name_list)
            if os.path.exists(ckpt_file_name):
                os.chmod(ckpt_file_name, stat.S_IWUSR)
                os.remove(ckpt_file_name)
            if os.path.exists(tmp_name):
                os.chmod(tmp_name, stat.S_IWUSR)
                os.remove(tmp_name)

1. 临时文件处理与旧文件清理

  • 首先,函数使用 _ckpt_mutex 加锁,确保同一时间只有一个线程可以执行保存操作,避免数据竞争。
  • 通过 os.path.splitext 函数将检查点文件名拆分为文件名和扩展名,将扩展名替换为 .tmp,得到临时文件名 tmp_name
  • 检查原检查点文件和临时文件是否存在,如果存在则修改其权限为可写(stat.S_IWUSR)并删除,为新的保存操作做准备。
            if format == "ckpt":
                ckpt_total_io_time = 0
                with _ckpt_fs.create(tmp_name, *_ckpt_fs.create_args) as f:
                    plain_data = None
                    if enc_key is not None:
                        plain_data = BytesIO()

                    crc_num = 0
                    for name, value in data_list.items():
                        if name == "random_op":
                            _write_random_seed(name, value, f)
                            continue
                        if value[0] == "mapparameter":
                            _write_mapparameter(name, value, f, map_param_inc)
                            continue
                        if value[0] == "offload_parameter":
                            new_value = value[1:]
                            new_value[2] = value[3]
                            _write_parameter_bytes_data(name, new_value, f, enc_key, plain_data, ckpt_total_io_time)
                            _offload_if_config(value[3])
                            continue
                        if value[1] == "str":
                            crc_num, ckpt_total_io_time = _write_parameter_data(name, value, f, enc_key, plain_data,
                                                                                crc_num, crc_check,
                                                                                ckpt_total_io_time)
                            continue
                        if isinstance(value[2], np.ndarray):
                            crc_num, ckpt_total_io_time = _write_parameter_data(name, value, f, enc_key, plain_data,
                                                                                crc_num, crc_check,
                                                                                ckpt_total_io_time)
                            continue
                        if isinstance(value[2], Tensor) and hasattr(value[2], "slice_num") and value[2].slice_num > 1:
                            _write_hugeparameter(name, value, f)
                            continue

                        crc_num, ckpt_total_io_time = _write_parameter_bytes_data(name, value, f, enc_key, plain_data,
                                                                                  crc_num, crc_check,
                                                                                  ckpt_total_io_time)

                    if enc_key is not None:
                        plain_data.seek(0)
                        max_block_size = ENCRYPT_BLOCK_SIZE * 1024
                        block_data = plain_data.read(max_block_size)
                        while block_data:
                            f.write(_encrypt(block_data, len(block_data), enc_key, len(enc_key), enc_mode))
                            block_data = plain_data.read(max_block_size)
                    if crc_check:
                        f.write('crc_num'.encode() + crc_num.to_bytes(10, byteorder='big'))
                vlog_print("1", "ME", __file__, sys._getframe().f_lineno,
                           f"Save ckpt io cost time:{ckpt_total_io_time}.")

2. 以 CKPT 格式保存的逻辑

  • 当保存格式为 "ckpt" 时,函数会打开临时文件进行写入操作。
  • 若提供了加密密钥 enc_key,会创建一个 BytesIO 对象 plain_data 用于存储未加密的数据。
  • 遍历 data_list 中的每个参数,根据参数的类型调用不同的写入函数:
    • 对于 "random_op" 参数,调用 _write_random_seed 函数写入随机种子信息。
    • 对于 mapparameter 类型的参数,调用 _write_mapparameter 函数写入映射参数信息。
    • 对于 offload_parameter 类型的参数,调用 _write_parameter_bytes_data 函数写入参数数据,并根据配置进行卸载操作。
    • 对于字符串类型或 numpy.ndarray 类型的参数,调用 _write_parameter_data 函数写入参数数据。
    • 对于大参数(slice_num > 1 的 Tensor),调用 _write_hugeparameter 函数进行特殊处理。
    • 其他情况调用 _write_parameter_bytes_data 函数写入参数数据。
  • 如果使用了加密,会对 plain_data 中的数据进行加密并写入文件。
  • 若开启了 CRC 校验,会将 CRC 校验值写入文件末尾。
  • 最后打印保存 CKPT 文件的 I/O 耗时。 这里虽然根据分支执行了很多不同的函数,但是其实每个函数的实现的核心逻辑都是相似的,该函数就是为了实现将参数字节数据写入 Protobuf 文件,我们以 _write_parameter_bytes_data 函数为例,详细剖析该函数的具体实现逻辑。
def _write_parameter_bytes_data(name, value, f, enc_key, plain_data, crc_num=0, crc_check=False, ckpt_total_io_time=0):
    """Write parameter bytes data into protobuf file."""
    bytes_value = value[2].get_bytes()
    chunk_size = 1024 * SLICE_SIZE
1. 数据准备与分块设置
  • 获取字节数据bytes_value = value[2].get_bytes() 这行代码从 value 列表的第三个元素中获取参数的字节数据。
  • 设置分块大小chunk_size = 1024 * SLICE_SIZE 定义了每次写入文件的数据块大小。这里将 1024 与 SLICE_SIZE 相乘,SLICE_SIZE 是一个预定义的常量,值是 512 * 1024, 用于控制分块的大小,这样做可以避免一次性处理大量数据导致内存占用过高。
    for i in range(0, len(bytes_value), chunk_size):
        checkpoint_list = Checkpoint()
        param_value = checkpoint_list.value.add()
        param_value.tag = name
        param_tensor = param_value.tensor
        param_tensor.dims.extend(value[0])
        param_tensor.tensor_type = value[1]
        param_tensor.tensor_content = bytes_value[i:i + chunk_size]
2. 分块处理与 Protobuf 消息构建
  • 分块遍历:使用 for 循环以 chunk_size 为步长遍历 bytes_value。这样可以将整个字节数据分成多个小块进行处理。
  • 创建 Protobuf 消息:在每次循环中,创建一个 Checkpoint 消息对象 checkpoint_list,并向其 value 字段中添加一个新的 Value 消息对象 param_value
  • 填充消息内容
    • param_value.tag = name:将参数的名称赋值给 param_value 的 tag 字段。
    • param_tensor = param_value.tensor:获取 param_value 中的 tensor 子消息对象。
    • param_tensor.dims.extend(value[0]):将 value 列表的第一个元素(参数的维度信息)扩展到 param_tensor 的 dims 字段中。
    • param_tensor.tensor_type = value[1]:将 value 列表的第二个元素(参数的数据类型)赋值给 param_tensor 的 tensor_type 字段。
    • param_tensor.tensor_content = bytes_value[i:i + chunk_size]:将当前数据块的字节数据赋值给 param_tensor 的 tensor_content 字段。
        if enc_key is None:
            output_data = checkpoint_list.SerializeToString()
            if crc_check:
                crc_num = binascii.crc32(output_data, crc_num)
            io_start_time = time.time()
            f.write(output_data)
            io_end_time = time.time()
            io_cost_time = io_end_time - io_start_time
            ckpt_total_io_time += io_cost_time
        else:
            plain_data.write(checkpoint_list.SerializeToString())
3. 数据写入与 CRC 校验
  • 无加密情况:当 enc_key 为 None 时,表示不进行加密操作。首先将 checkpoint_list 消息对象序列化为二进制数据 output_data。如果开启了 CRC 校验(crc_check 为 True),则使用 binascii.crc32 函数更新 crc_num。接着记录写入操作的开始时间和结束时间,计算 I/O 耗时并累加到 ckpt_total_io_time 中,最后将序列化后的数据写入文件 f
  • 加密情况:当 enc_key 不为 None 时,表示需要进行加密操作。此时将 checkpoint_list 消息对象序列化后的数据写入 plain_data 对象,后续会对 plain_data 中的数据进行加密处理。
    return crc_num, ckpt_total_io_time
4. 返回结果

函数最后返回更新后的 crc_num 和 ckpt_total_io_time,以便在后续的处理中使用。crc_num 可用于最终的 CRC 校验结果记录,ckpt_total_io_time 则记录了整个写入操作的总 I/O 耗时。

综上所述,_write_parameter_bytes_data 函数通过分块处理参数的字节数据,将其封装为 Protobuf 消息对象,根据是否加密的情况进行不同的处理,并在需要时进行 CRC 校验,最终将数据写入文件或临时存储对象,同时记录了 I/O 耗时,为后续的保存操作提供了重要的支持。

            elif format == "safetensors":
                save_dict = {}
                crc_num = 0
                for name in sorted(data_list.keys()):
                    value = data_list[name]
                    save_dict[name] = value[2].asnumpy()

                    if crc_check:
                        crc_num = binascii.crc32(bytes(name, encoding='utf-8'), crc_num)
                        crc_num = binascii.crc32(
                            bytes(save_dict[name]), crc_num)
                safetensors_save_time_start = time.time()
                if crc_check:
                    save_file(save_dict, tmp_name, metadata={
                        "crc_num": str(crc_num)})
                else:
                    save_file(save_dict, tmp_name)
                safetensors_save_time_end = time.time()
                cost_time = safetensors_save_time_end - safetensors_save_time_start
                vlog_print("1", "ME", __file__, sys._getframe().f_lineno, f"Save safetensors io cost time:{cost_time}.")

3. 以 Safetensors 格式保存的逻辑

  • 当保存格式为 "safetensors" 时,函数会将 data_list 中的参数转换为 numpy 数组并存储在 save_dict 中。
  • 若开启了 CRC 校验,会对参数名和参数数据进行 CRC 校验并更新 crc_num
  • 调用 save_file 函数将 save_dict 保存到临时文件中,若开启了 CRC 校验,会将校验值作为元数据写入文件。
  • 记录保存 Safetensors 文件的耗时并打印。
            if not os.path.exists(tmp_name):
                logger.warning(f"Rename failed, can't find {tmp_name}, it is possible that multiple processes have "
                               f"simultaneously modified a file.")
            else:
                os.rename(tmp_name, ckpt_file_name)
            os.chmod(ckpt_file_name, stat.S_IRUSR)
    except BaseException as e:
        logger.critical("Failed to save the checkpoint file %s. Maybe don't have the permission to write files, "
                        "or the disk space is insufficient and so on.", ckpt_file_name)
        raise e

4. 文件重命名与权限设置

  • 保存完成后,检查临时文件是否存在。若存在,则将临时文件重命名为最终的检查点文件;若不存在,会输出警告信息。
  • 最后,将检查点文件的权限设置为只读(stat.S_IRUSR),以保护文件内容不被意外修改。

5. 异常处理

  • 若在保存过程中出现异常,函数会记录错误信息并重新抛出异常,提示用户可能是权限不足或磁盘空间不足等问题导致保存失败。

综上所述,_exec_save 函数根据不同的保存格式(CKPT 或 Safetensors),对模型权重数据进行处理和保存,同时支持加密和 CRC 校验功能,确保了保存过程的安全性和数据完整性。在保存完成后,还会进行文件重命名和权限设置,以保证文件的正常使用和安全性。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值