You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/bits/tracker.py

51 lines
1.7 KiB

import time
from typing import Optional
from timm.metrics import ScalarAvgMinMax
class Tracker:
def __init__(self):
self.data_time = ScalarAvgMinMax() # time for data loader to produce batch of samples
self.step_time = ScalarAvgMinMax() # time for model step
self.iter_time = ScalarAvgMinMax() # full iteration time incl. data, step, and book-keeping
self.epoch_time = ScalarAvgMinMax()
self.iter_timestamp: Optional[float] = None
self.prev_timestamp: Optional[float] = None
self.epoch_timestamp: Optional[float] = None
def _measure_iter(self, ref_timestamp=None):
timestamp = time.perf_counter()
self.prev_timestamp = timestamp
def mark_iter(self):
timestamp = time.perf_counter()
if self.iter_timestamp is not None:
iter_time = timestamp - self.iter_timestamp
self.iter_time.update(iter_time)
self.iter_timestamp = self.prev_timestamp = timestamp
def mark_iter_data_end(self):
assert self.prev_timestamp is not None
timestamp = time.perf_counter()
data_time = timestamp - self.prev_timestamp
self.data_time.update(data_time)
self.prev_timestamp = timestamp
def mark_iter_step_end(self):
assert self.prev_timestamp is not None
timestamp = time.perf_counter()
step_time = timestamp - self.prev_timestamp
self.step_time.update(step_time)
self.prev_timestamp = timestamp
def mark_epoch(self):
timestamp = time.perf_counter()
if self.epoch_timestamp is not None:
epoch_time = timestamp - self.epoch_timestamp
self.epoch_time.update(epoch_time)
self.epoch_timestamp = timestamp