Source code for common.utils.meter
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
from typing import Optional, List
[docs]class AverageMeter(object):
r"""Computes and stores the average and current value.
Examples::
>>> # Initialize a meter to record loss
>>> losses = AverageMeter()
>>> # Update meter after every minibatch update
>>> losses.update(loss_value, batch_size)
"""
def __init__(self, name: str, fmt: Optional[str] = ':f'):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
if self.count > 0:
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
class AverageMeterDict(object):
def __init__(self, names: List, fmt: Optional[str] = ':f'):
self.dict = {
name: AverageMeter(name, fmt) for name in names
}
def reset(self):
for meter in self.dict.values():
meter.reset()
def update(self, accuracies, n=1):
for name, acc in accuracies.items():
self.dict[name].update(acc, n)
def average(self):
return {
name: meter.avg for name, meter in self.dict.items()
}
def __getitem__(self, item):
return self.dict[item]
[docs]class Meter(object):
"""Computes and stores the current value."""
def __init__(self, name: str, fmt: Optional[str] = ':f'):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
def update(self, val):
self.val = val
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '}'
return fmtstr.format(**self.__dict__)
[docs]class ProgressMeter(object):
def __init__(self, num_batches, meters, prefix=""):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = prefix
def display(self, batch):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
print('\t'.join(entries))
def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
fmt = '{:' + str(num_digits) + 'd}'
return '[' + fmt + '/' + fmt.format(num_batches) + ']'