|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
|
|
class AvgTensor:
|
|
|
|
|
|
|
|
"""Computes and stores the average and current value"""
|
|
|
|
def __init__(self, accumulate_dtype=torch.float32):
|
|
|
|
self.accumulate_dtype = accumulate_dtype
|
|
|
|
self.sum = None
|
|
|
|
self.count = None
|
|
|
|
self.reset()
|
|
|
|
# FIXME handle distributed operation
|
|
|
|
|
|
|
|
def reset(self):
|
|
|
|
self.sum = None
|
|
|
|
self.count = None
|
|
|
|
|
|
|
|
def update(self, val: torch.Tensor, n=1):
|
|
|
|
if self.sum is None:
|
|
|
|
self.sum = torch.zeros_like(val, dtype=self.accumulate_dtype)
|
|
|
|
self.count = torch.tensor(0, dtype=torch.long, device=val.device)
|
|
|
|
self.sum += (val * n)
|
|
|
|
self.count += n
|
|
|
|
|
|
|
|
def compute(self):
|
|
|
|
return self.sum / self.count
|
|
|
|
|
|
|
|
|
|
|
|
class TensorEma:
|
|
|
|
|
|
|
|
"""Computes and stores the average and current value"""
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
smoothing_factor=0.9,
|
|
|
|
init_zero=False,
|
|
|
|
accumulate_dtype=torch.float32
|
|
|
|
):
|
|
|
|
self.accumulate_dtype = accumulate_dtype
|
|
|
|
self.smoothing_factor = smoothing_factor
|
|
|
|
self.init_zero = init_zero
|
|
|
|
self.val = None
|
|
|
|
self.reset()
|
|
|
|
# FIXME handle distributed operation
|
|
|
|
|
|
|
|
def reset(self):
|
|
|
|
self.val = None
|
|
|
|
|
|
|
|
def update(self, val):
|
|
|
|
if self.val is None:
|
|
|
|
if self.init_zero:
|
|
|
|
self.val = torch.zeros_like(val, dtype=self.accumulate_dtype)
|
|
|
|
else:
|
|
|
|
self.val = val.clone().to(dtype=self.accumulate_dtype)
|
|
|
|
self.val = (1. - self.smoothing_factor) * val + self.smoothing_factor * self.val
|