Pytorh checkpoint系统深入解析以及优化
checkpoint包含啥?
参考 Saving and loading a general checkpoint in PyTorch ‒ PyTorch Tutorials 2.0.0+cu117 documentation,至少包含模型model,最好包含优化器optimizer,可能包含loss信息、epoch信息、以及取决于各种算法本身的杂七杂八的信息。
用户可以这样做checkpoint:
for epoch in range(1, args.epochs + 1):
train(args, model, device, train_loader, optimizer, epoch)
model_state_dict = model.state_dict()
optimizer_state_dict = optimizer.state_dict()
torch.save({
'epoch': epoch,
'model_state_dict': model_state_dict,
'optimizer_state_dict': optimizer_state_dict,
}, "model.pt")
流程分析
def save(obj, f: Union[str, os.PathLike, BinaryIO, IO[bytes]],
pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True) -> None:
"""
Args:
obj: saved object
f: a file-like object (has to implement write and flush) or a string or
os.PathLike object containing a file name
pickle_module: module used for pickling metadata and objects
pickle_protocol: can be specified to override the default protocol
"""
with _open_file_like(f, 'wb') as opened_file:
if _use_new_zipfile_serialization:
with _open_zipfile_writer(opened_file) as opened_zipfile:
_save(obj, opened_zipfile, pickle_module, pickle_protocol)
return
_legacy_save(obj, opened_file, pickle_module, pickle_protocol)
pickle_module=pickle代表默认用pickle模块做序列化,pickle是python的序列化模块,可以序列化任意的python object,参考 pickle — Python object serialization。
协议用默认协议DEFAULT_PROTOCOL(2),pickle定义了多种协议,上面的文档里有介绍,但注意即使是0,也就是
Protocol version 0 is the original “human-readable” protocol and is backwards compatible with earlier versions of Python.
依然是不可读的。
_use_new_zipfile_serialization带来俩分支,默认是true,用zip来压缩文件。
这段代码的含义是根据用户的输入,选择不同的压缩算法和序列化方式,然后调用真正的序列化函数_save或者_legacy_save(目前先只看save)。
def _save(obj, zip_file, pickle_module, pickle_protocol):
serialized_storages = {}
id_map: Dict[int, str] = {}
storage_dtypes: Dict[int, torch.dtype] = {}
def persistent_id(obj):
# 省略,稍后再看
# Write the pickle data for `obj`
data_buf = io.BytesIO()
pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol)
pickler.persistent_id = persistent_id
pickler.dump(obj)
data_value = data_buf.getvalue()
zip_file.write_record('data.pkl', data_value, len(data_value))
# Write each tensor to a file named tensor/the_tensor_key in the zip archive
for key in sorted(serialized_storages.keys()):
name = f'data/{key}'
storage = serialized_storages[key]
# given that we copy things around anyway, we might use storage.cpu()
# this means to that to get tensors serialized, you need to implement
# .cpu() on the underlying Storage
if storage.device.type != 'cpu':
storage = storage.cpu()
# Now that it is on the CPU we can directly copy it into the zip file
num_bytes = storage.nbytes()
zip_file.write_record(name, storage.data_ptr(), num_bytes)
有两部分数据要序列化:
- 用户数据,定义了persistent_id方法,然后用pickle模块直接dump
pickle的文档 pickle — Python object serialization介绍了persistent_id
Do nothing by default. This exists so a subclass can override it.
If [persistent_id()](https://docs.python.org/3/library/pickle.html#pickle.Pickler.persistent_id) returns None, obj is pickled as usual. Any other value causes [Pickler](https://docs.python.org/3/library/pickle.html#pickle.Pickler) to emit the returned value as a persistent ID for obj. The meaning of this persistent ID should be defined by [Unpickler.persistent_load()](https://docs.python.org/3/library/pickle.html#pickle.Unpickler.persistent_load). Note that the value returned by [persistent_id()](https://docs.python.org/3/library/pickle.html#pickle.Pickler.persistent_id) cannot itself have a persistent ID.
简单来说就是给序列化后的数据一个UID,在反序列化时可以识别,并做一些校验。
然后看persistent_id方法:
def persistent_id(obj):
if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
if isinstance(obj, torch.storage.TypedStorage):
# TODO: Once we decide to break serialization FC, this case
# can be deleted
storage = obj._storage
storage_dtype = obj.dtype
storage_type_str = obj.pickle_storage_type()
storage_type = getattr(torch, storage_type_str)
storage_numel = obj.size()
else:
storage = obj
storage_dtype = torch.uint8
storage_type = normalize_storage_type(type(obj))
storage_numel = storage.nbytes()
# If storage is allocated, ensure that any other saved storages
# pointing to the same data all have the same dtype. If storage is
# not allocated, don't perform this check
if storage.data_ptr() != 0:
if storage.data_ptr() in storage_dtypes:
if storage_dtype != storage_dtypes[storage.data_ptr()]:
raise RuntimeError(
'Cannot save multiple tensors or storages that '
'view the same data as different types')
else:
storage_dtypes[storage.data_ptr()] = storage_dtype
storage_key = id_map.setdefault(storage._cdata, str(len(id_map)))
location = location_tag(storage)
serialized_storages[storage_key] = storage
return ('storage',
storage_type,
storage_key,
location,
storage_numel)
return None
只有输入是torch.storage.TypedStorage类,或者是torch的storage类型
UntypedStorage, DoubleStorage, FloatStorage, LongStorage, IntStorage,
ShortStorage, CharStorage, ByteStorage, HalfStorage, BoolStorage,
QUInt8Storage, QInt8Storage, QInt32Storage, BFloat16Storage,
ComplexFloatStorage, ComplexDoubleStorage, QUInt4x2Storage, QUInt2x4Storage,
TypedStorage
才会进入逻辑,否则返回None,啥也不干。
如果object是一个结构体,persistent_id会对每个key和每个value递归的调用,比如
{
'epoch': epoch,
'model_state_dict': model_state_dict,
'optimizer_state_dict': optimizer_state_dict,
}
会对’epoch’和epoch,'model_state_dict’和model_state_dict等递归的调用……
然后注意这一行:storage_key = id_map.setdefault(storage._cdata, str(len(id_map))) 如果键不存在于字典中,将会添加键并将值设为默认值。由于id_map一开始是空的,等同于添加key storage._cdata,value为递增的整数,storage_key也是递增的整数。 最后serialized_storages[storage_key] = storage,serialized_storages的key也是整数
- persistent_id方法中把另外需要dump的内容写进了serialized_storages这个map中,稍后看具体是什么。遍历map中每个元素,把指针指向的内容写进data/{key}(key是这个元素的标识符)
根据注释大致能理解
Write each tensor to a file named tensor/the_tensor_key in the zip archive
根据前面的分析,key是整数,所以最后写成一个个data/1,data/2…这样的文件。
最后说一下,pytorch最后只会留一个文件,是用户指定的model.pt,我们可以用unzip命令解压,看到里面的具体内容:
data/
: 存放所有的tensor(模型数据)pickle序列化的数据,完全不可读data.pkl
: 存放用户的输入结构体(模型结构+优化器)。这是pickle序列化的数据,完全不可读version.zip
: 是zip自行维护的版本文件,与用户无关