You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
51 lines
1.7 KiB
51 lines
1.7 KiB
import time
|
|
from typing import Optional
|
|
|
|
from timm.metrics import ScalarAvgMinMax
|
|
|
|
|
|
class Tracker:
|
|
|
|
def __init__(self):
|
|
self.data_time = ScalarAvgMinMax() # time for data loader to produce batch of samples
|
|
self.step_time = ScalarAvgMinMax() # time for model step
|
|
self.iter_time = ScalarAvgMinMax() # full iteration time incl. data, step, and book-keeping
|
|
self.epoch_time = ScalarAvgMinMax()
|
|
|
|
self.iter_timestamp: Optional[float] = None
|
|
self.prev_timestamp: Optional[float] = None
|
|
self.epoch_timestamp: Optional[float] = None
|
|
|
|
def _measure_iter(self, ref_timestamp=None):
|
|
timestamp = time.perf_counter()
|
|
self.prev_timestamp = timestamp
|
|
|
|
def mark_iter(self):
|
|
timestamp = time.perf_counter()
|
|
if self.iter_timestamp is not None:
|
|
iter_time = timestamp - self.iter_timestamp
|
|
self.iter_time.update(iter_time)
|
|
self.iter_timestamp = self.prev_timestamp = timestamp
|
|
|
|
def mark_iter_data_end(self):
|
|
assert self.prev_timestamp is not None
|
|
timestamp = time.perf_counter()
|
|
data_time = timestamp - self.prev_timestamp
|
|
self.data_time.update(data_time)
|
|
self.prev_timestamp = timestamp
|
|
|
|
def mark_iter_step_end(self):
|
|
assert self.prev_timestamp is not None
|
|
timestamp = time.perf_counter()
|
|
step_time = timestamp - self.prev_timestamp
|
|
self.step_time.update(step_time)
|
|
self.prev_timestamp = timestamp
|
|
|
|
def mark_epoch(self):
|
|
timestamp = time.perf_counter()
|
|
if self.epoch_timestamp is not None:
|
|
epoch_time = timestamp - self.epoch_timestamp
|
|
self.epoch_time.update(epoch_time)
|
|
self.epoch_timestamp = timestamp
|
|
|