1、位置
metricLogger 度量记录器,一般放在utils文件里面
是一个在训练过程中输出的类
它先接触数据,使用yield迭代传送进去训练。
2、class MetricLogger
class MetricLogger(object):
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(
"{}: {:.4f} ({:.4f})".format(name, meter.median, meter.global_avg)
)
return self.delimiter.join(loss_str)
初始化时利用了SmoothValue这个类,主要包含以下属性:
self.deque | 利用队列来获取的数值 |
---|---|
self.total | 记录累计的数值的总和 |
self.count | 记录累计的个数的总和 |
self.fmt |
MetricLogger类就两个属性,一个是self.meters,另一个是self.delimiter。self.meters初始化利用的是SmoothValue这个类,所以可以使用SmoothValue的属性和方法。self.meters里面是是一个字典:{name:meter}meter是字典,主要信息存储在列表中。
self.delimiter = delimeter:是一个字符串类型
MetricLogger类的方法,主要是update,用来更新数值。还包括__getattr__、__str__方法