Pytorh checkpoint系统深入解析以及优化(01)

Pytorh checkpoint系统深入解析以及优化

checkpoint包含啥?

参考 Saving and loading a general checkpoint in PyTorch ‒ PyTorch Tutorials 2.0.0+cu117 documentation,至少包含模型model,最好包含优化器optimizer,可能包含loss信息、epoch信息、以及取决于各种算法本身的杂七杂八的信息。

用户可以这样做checkpoint:

    for epoch in range(1, args.epochs + 1):
        train(args, model, device, train_loader, optimizer, epoch)
        model_state_dict = model.state_dict()
        optimizer_state_dict = optimizer.state_dict()
        torch.save({
            'epoch': epoch,
            'model_state_dict': model_state_dict,
            'optimizer_state_dict': optimizer_state_dict,
        }, "model.pt")

流程分析

def save(obj, f: Union[str, os.PathLike, BinaryIO, IO[bytes]],
         pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True) -> None:

    """
    Args:
        obj: saved object
        f: a file-like object (has to implement write and flush) or a string or
           os.PathLike object containing a file name
        pickle_module: module used for pickling metadata and objects
        pickle_protocol: can be specified to override the default protocol
    """
    with _open_file_like(f, 'wb') as opened_file:
        if _use_new_zipfile_serialization:
            with _open_zipfile_writer(opened_file) as opened_zipfile:
                _save(obj, opened_zipfile, pickle_module, pickle_protocol)
                return
        _legacy_save(obj, opened_file, pickle_module, pickle_protocol)

pickle_module=pickle代表默认用pickle模块做序列化,pickle是python的序列化模块,可以序列化任意的python object,参考 pickle — Python object serialization。

协议用默认协议DEFAULT_PROTOCOL(2),pickle定义了多种协议,上面的文档里有介绍,但注意即使是0,也就是

Protocol version 0 is the original “human-readable” protocol and is backwards compatible with earlier versions of Python.

依然是不可读的。

_use_new_zipfile_serialization带来俩分支,默认是true,用zip来压缩文件。

这段代码的含义是根据用户的输入,选择不同的压缩算法和序列化方式,然后调用真正的序列化函数_save或者_legacy_save(目前先只看save)。

def _save(obj, zip_file, pickle_module, pickle_protocol):
    serialized_storages = {}
    id_map: Dict[int, str] = {}

    storage_dtypes: Dict[int, torch.dtype] = {}

    def persistent_id(obj):
        # 省略,稍后再看

    # Write the pickle data for `obj`
    data_buf = io.BytesIO()
    pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol)
    pickler.persistent_id = persistent_id
    pickler.dump(obj)
    data_value = data_buf.getvalue()
    zip_file.write_record('data.pkl', data_value, len(data_value))

    # Write each tensor to a file named tensor/the_tensor_key in the zip archive
    for key in sorted(serialized_storages.keys()):
        name = f'data/{key}'
        storage = serialized_storages[key]
        # given that we copy things around anyway, we might use storage.cpu()
        # this means to that to get tensors serialized, you need to implement
        # .cpu() on the underlying Storage
        if storage.device.type != 'cpu':
            storage = storage.cpu()
        # Now that it is on the CPU we can directly copy it into the zip file
        num_bytes = storage.nbytes()
        zip_file.write_record(name, storage.data_ptr(), num_bytes)

有两部分数据要序列化:

  1. 用户数据,定义了persistent_id方法,然后用pickle模块直接dump
    pickle的文档 pickle — Python object serialization介绍了persistent_id
Do nothing by default. This exists so a subclass can override it.
If [persistent_id()](https://docs.python.org/3/library/pickle.html#pickle.Pickler.persistent_id) returns None, obj is pickled as usual. Any other value causes [Pickler](https://docs.python.org/3/library/pickle.html#pickle.Pickler) to emit the returned value as a persistent ID for obj. The meaning of this persistent ID should be defined by [Unpickler.persistent_load()](https://docs.python.org/3/library/pickle.html#pickle.Unpickler.persistent_load). Note that the value returned by [persistent_id()](https://docs.python.org/3/library/pickle.html#pickle.Pickler.persistent_id) cannot itself have a persistent ID.

简单来说就是给序列化后的数据一个UID,在反序列化时可以识别,并做一些校验。

然后看persistent_id方法:

  def persistent_id(obj):
        if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):

            if isinstance(obj, torch.storage.TypedStorage):
                # TODO: Once we decide to break serialization FC, this case
                # can be deleted
                storage = obj._storage
                storage_dtype = obj.dtype
                storage_type_str = obj.pickle_storage_type()
                storage_type = getattr(torch, storage_type_str)
                storage_numel = obj.size()

            else:
                storage = obj
                storage_dtype = torch.uint8
                storage_type = normalize_storage_type(type(obj))
                storage_numel = storage.nbytes()

            # If storage is allocated, ensure that any other saved storages
            # pointing to the same data all have the same dtype. If storage is
            # not allocated, don't perform this check
            if storage.data_ptr() != 0:
                if storage.data_ptr() in storage_dtypes:
                    if storage_dtype != storage_dtypes[storage.data_ptr()]:
                        raise RuntimeError(
                            'Cannot save multiple tensors or storages that '
                            'view the same data as different types')
                else:
                    storage_dtypes[storage.data_ptr()] = storage_dtype

            storage_key = id_map.setdefault(storage._cdata, str(len(id_map)))
            location = location_tag(storage)
            serialized_storages[storage_key] = storage

            return ('storage',
                    storage_type,
                    storage_key,
                    location,
                    storage_numel)

        return None

只有输入是torch.storage.TypedStorage类,或者是torch的storage类型

UntypedStorage, DoubleStorage, FloatStorage, LongStorage, IntStorage,
    ShortStorage, CharStorage, ByteStorage, HalfStorage, BoolStorage,
    QUInt8Storage, QInt8Storage, QInt32Storage, BFloat16Storage,
    ComplexFloatStorage, ComplexDoubleStorage, QUInt4x2Storage, QUInt2x4Storage,
    TypedStorage

才会进入逻辑,否则返回None,啥也不干。

如果object是一个结构体,persistent_id会对每个key和每个value递归的调用,比如

 {
            'epoch': epoch,
            'model_state_dict': model_state_dict,
            'optimizer_state_dict': optimizer_state_dict,
        }

会对’epoch’和epoch,'model_state_dict’和model_state_dict等递归的调用……

然后注意这一行:storage_key = id_map.setdefault(storage._cdata, str(len(id_map))) 如果键不存在于字典中,将会添加键并将值设为默认值。由于id_map一开始是空的,等同于添加key storage._cdata,value为递增的整数,storage_key也是递增的整数。 最后serialized_storages[storage_key] = storage,serialized_storages的key也是整数

  1. persistent_id方法中把另外需要dump的内容写进了serialized_storages这个map中,稍后看具体是什么。遍历map中每个元素,把指针指向的内容写进data/{key}(key是这个元素的标识符)
    根据注释大致能理解
Write each tensor to a file named tensor/the_tensor_key in the zip archive

根据前面的分析,key是整数,所以最后写成一个个data/1,data/2…这样的文件。

最后说一下,pytorch最后只会留一个文件,是用户指定的model.pt,我们可以用unzip命令解压,看到里面的具体内容:

  • data/: 存放所有的tensor(模型数据)pickle序列化的数据,完全不可读
  • data.pkl: 存放用户的输入结构体(模型结构+优化器)。这是pickle序列化的数据,完全不可读
  • version.zip: 是zip自行维护的版本文件,与用户无关
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值