class ModelEMA:
""" Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
Keep a moving average of everything in the model state_dict (parameters and buffers).
This is intended to allow functionality like
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
A smoothed version of the weights is necessary for some training schemes to perform well.
This class is sensitive where it is initialized in the sequence of model init,
GPU assignment and distributed training wrappers.
"""
def __init__(self, model, decay=0.9999, updates=0):
# Create EMA
self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA
# if next(model.parameters()).device.type != 'cpu':
# self.ema.half() # FP16 EMA
self.updates = updates # number of EMA updates
self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) # decay exponential ramp (to help early epochs)
for p in self.ema.parameters():
p.requires_grad_(False)
def update(self, model):
# Update EMA parameters
with torch.no_grad():
self.updates += 1
d = self.decay(self.updates)
msd = model.module.state_dict() if is_parallel(model) else model.state_dict() # model state_dict
for k, v in self.ema.state_dict().items():
if v.dtype.is_floating_point:
v *= d
v += (1 - d) * msd[k].detach()
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
# Update EMA attributes
copy_attr(self.ema, model, include, exclude)
滑动平均的公式为
mvt=decay∗mvt−1+(1−decay)∗variablemv_{t} = decay*mv_{t-1}+(1-decay)*variablemvt=decay∗mvt−1+(1−decay)∗variable
对应上述代码中的
v *= d
v += (1 - d) * msd[k].detach()
具体细节再次不赘述,ema是用来在test或val时才使用,用来更新参数。
本文介绍了如何在PyTorch中实现Model Exponential Moving Average (EMA),一种用于平滑模型权重的技巧,常用于测试或验证阶段。它通过时间衰减调整模型参数,有助于稳定训练过程。
3万+

被折叠的 条评论
为什么被折叠?



