load
流程分析
用户可以通过下面的代码读取checkpoint文件
checkpoint = torch.load("ckpt.pt")
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
pytorch默认读取压缩文件,unzip之后用pickle做反序列化,但是对其中的tensor会特殊处理。torch.load是load逻辑的入口。
def load(f, map_location=None, pickle_module=pickle, **pickle_load_args):
_check_dill_version(pickle_module)
with _open_file_like(f, 'rb') as opened_file:
if _is_zipfile(opened_file):
orig_position = opened_file.tell()
with _open_zipfile_reader(opened_file) as opened_zipfile:
return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
和序列化的时候对应,对zip文件调用_load,非zip文件(_legacy_save的那种)调用_legacy_load。此外pytorch支持各种定制化的逻辑,在注释里写的很清楚了
- If :attr:map_location is a callable, it will be called once for each serialized storage with two arguments: storage and location. The storage argument will be the initial deserialization of the storage, residing on the CPU.
- Each serialized storage has a location tag associated with it which identifies the device it was saved from, and this tag is the second argument passed to :attr:map_location. The builtin location tags are ‘cpu’ for CPU tensors and
- :attr:map_location should return either None`` or a storage. If :attr:map_location returns a storage, it will be used as the final deserialized object, already moved to the right device. Otherwise, :func:torch.load will fall back to the default behavior, as if :attr:map_location wasn’t specified.
map_location可以用来自定义什么内容加载到CPU或者哪块GPU上。
def _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', **pickle_load_args):
restore_location = _get_restore_location(map_location)
loaded_storages = {}
def load_tensor(data_type, size, key, location): # 省略
def persistent_load(saved_id): # 省略
load_module_mapping: Dict[str, str] = {
'torch.tensor': 'torch._tensor'
}
class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined] 省略
data_file = io.BytesIO(zip_file.get_record(pickle_file))
unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
unpickler.persistent_load = persistent_load
result = unpickler.load()
torch._utils._validate_loaded_sparse_tensors()
return result
_load的逻辑为
- 决定restore时数据放在哪儿,通过_get_restore_location决定,它的逻辑是一堆if else,我们关注默认返回值default_restore_location是
def default_restore_location(storage, location):
for _, _, fn in _package_registry:
result = fn(storage, location)
if result is not None:
return result
_package_registry默认是空的,能看到它的append方法被调用,但恐怕要了解pytorch整个流程才知道它运行时有哪些内容
- 初始化unpickler,继承pickle_module.Unpickler,override了find_class和persistent_load方法。之后通过unpickler.load()从文件里读取数据进行反序列化
find_class的逻辑只是overwrite了’torch.tensor’变成’torch._tensor’,意味着对tensor来说,序列化和反序列化的行为不一样了?
persistent_load允许用户自定义额外的反序列化行为。
def persistent_load(saved_id):
assert isinstance(saved_id, tuple)
typename = _maybe_decode_ascii(saved_id[0])
data = saved_id[1:]
assert typename == 'storage', \
f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
data_type, key, location, size = data
if key not in loaded_storages:
load_tensor(data_type, size, key, _maybe_decode_ascii(location))
storage = loaded_storages[key]
return storage
_maybe_decode_ascii是为了py2和py3的兼容性而做的额外转换步骤,不管。序列化的时候tensor额外被dump了,dump的时候类型都是storage,data是具体的tensor的数据。通过load_tensor加载张量。
def load_tensor(data_type, size, key, location):
name = f'data/{key}'
dtype = data_type(0).dtype
storage = zip_file.get_storage_from_record(name, size, dtype).storage()
loaded_storages[key] = restore_location(storage, location)
'data/{key}'是tensor序列化后的文件名,在 这里提到过。读取数据之后放回原来的位置(通常显存里)。调试过程可以看到_package_registry有两个元素。
分别对应两个函数_cpu_deserialize和_cuda_deserialize。调用方result = fn(storage, location)中,location是’cuda:0’,所以会进入_cuda_deserialize,最核心的是return obj.cuda(device)这个调用,里面会做cudaMemCpy
调用torch._utils._validate_loaded_sparse_tensors()对tensor进行校验