Pytorh checkpoint系统深入解析以及优化(02)

load

流程分析

用户可以通过下面的代码读取checkpoint文件

 checkpoint = torch.load("ckpt.pt")
        model.load_state_dict(checkpoint['model_state_dict'])
        model.to(device)

pytorch默认读取压缩文件,unzip之后用pickle做反序列化,但是对其中的tensor会特殊处理。torch.load是load逻辑的入口。

def load(f, map_location=None, pickle_module=pickle, **pickle_load_args):
    _check_dill_version(pickle_module)
    with _open_file_like(f, 'rb') as opened_file:
        if _is_zipfile(opened_file):
            orig_position = opened_file.tell()
            with _open_zipfile_reader(opened_file) as opened_zipfile:
                return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
        return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)

和序列化的时候对应,对zip文件调用_load,非zip文件(_legacy_save的那种)调用_legacy_load。此外pytorch支持各种定制化的逻辑,在注释里写的很清楚了

  1. If :attr:map_location is a callable, it will be called once for each serialized storage with two arguments: storage and location. The storage argument will be the initial deserialization of the storage, residing on the CPU.
  2. Each serialized storage has a location tag associated with it which identifies the device it was saved from, and this tag is the second argument passed to :attr:map_location. The builtin location tags are ‘cpu’ for CPU tensors and ‘cuda:device_id’ (e.g. ‘cuda:2’) for CUDA tensors.
  3. :attr:map_location should return either None`` or a storage. If :attr:map_location returns a storage, it will be used as the final deserialized object, already moved to the right device. Otherwise, :func:torch.load will fall back to the default behavior, as if :attr:map_location wasn’t specified.

map_location可以用来自定义什么内容加载到CPU或者哪块GPU上。

def _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', **pickle_load_args):
    restore_location = _get_restore_location(map_location)

    loaded_storages = {}

    def load_tensor(data_type, size, key, location): # 省略
    def persistent_load(saved_id): # 省略

    load_module_mapping: Dict[str, str] = {
        'torch.tensor': 'torch._tensor'
    }

    class UnpicklerWrapper(pickle_module.Unpickler):  # type: ignore[name-defined] 省略

    data_file = io.BytesIO(zip_file.get_record(pickle_file))

    unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
    unpickler.persistent_load = persistent_load
    result = unpickler.load()

    torch._utils._validate_loaded_sparse_tensors()

    return result

_load的逻辑为

  1. 决定restore时数据放在哪儿,通过_get_restore_location决定,它的逻辑是一堆if else,我们关注默认返回值default_restore_location是
def default_restore_location(storage, location):
    for _, _, fn in _package_registry:
        result = fn(storage, location)
        if result is not None:
            return result

_package_registry默认是空的,能看到它的append方法被调用,但恐怕要了解pytorch整个流程才知道它运行时有哪些内容

  1. 初始化unpickler,继承pickle_module.Unpickler,override了find_classpersistent_load方法。之后通过unpickler.load()从文件里读取数据进行反序列化
    find_class的逻辑只是overwrite了’torch.tensor’变成’torch._tensor’,意味着对tensor来说,序列化和反序列化的行为不一样了?

persistent_load允许用户自定义额外的反序列化行为。

def persistent_load(saved_id):
        assert isinstance(saved_id, tuple)
        typename = _maybe_decode_ascii(saved_id[0])
        data = saved_id[1:]

        assert typename == 'storage', \
            f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
        data_type, key, location, size = data
        if key not in loaded_storages:
            load_tensor(data_type, size, key, _maybe_decode_ascii(location))
        storage = loaded_storages[key]
        return storage

_maybe_decode_ascii是为了py2和py3的兼容性而做的额外转换步骤,不管。序列化的时候tensor额外被dump了,dump的时候类型都是storage,data是具体的tensor的数据。通过load_tensor加载张量。

 def load_tensor(data_type, size, key, location):
        name = f'data/{key}'
        dtype = data_type(0).dtype

        storage = zip_file.get_storage_from_record(name, size, dtype).storage()
        loaded_storages[key] = restore_location(storage, location)

'data/{key}'是tensor序列化后的文件名,在 这里提到过。读取数据之后放回原来的位置(通常显存里)。调试过程可以看到_package_registry有两个元素。

分别对应两个函数_cpu_deserialize和_cuda_deserialize。调用方result = fn(storage, location)中,location是’cuda:0’,所以会进入_cuda_deserialize,最核心的是return obj.cuda(device)这个调用,里面会做cudaMemCpy

调用torch._utils._validate_loaded_sparse_tensors()对tensor进行校验

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值