pull/1239/head
parent
1b0c8e7b01
commit
12d9a6d4d2
@ -0,0 +1,8 @@
|
||||
# Timm Bits
|
||||
|
||||
A collection of reusable components and lightweight abstractions for training and evaluating NN.
|
||||
|
||||
This is an early WIP with the primary goal to get up and running on TPUs first. Expect significant changes, rewrites, additions...
|
||||
|
||||
The current train.py and validate.py scipts are evolving to use the timm.bits components, they will also change significantly.
|
||||
|
@ -0,0 +1,10 @@
|
||||
from .device_env_factory import initialize_device, get_device
|
||||
from .device_env import DeviceEnv
|
||||
#from .evaluate import evaluate, eval_step
|
||||
from .logger import Logger
|
||||
#from .task import TaskClassify
|
||||
from .updater import Updater
|
||||
from .updater_factory import create_updater
|
||||
from .tracker import Tracker
|
||||
#from .task_metrics import TaskMetrics, TaskMetricsClassify
|
||||
#from .train import train_one_epoch, TrainServices, TrainState, TrainCfg, Experiment
|
@ -0,0 +1,58 @@
|
||||
import torch
|
||||
import abc
|
||||
|
||||
|
||||
class DeviceEnv(abc.ABC):
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def device(self) -> torch.device:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def local_rank(self) -> int:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def global_rank(self) -> int:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def is_distributed(self) -> bool:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def world_size(self) -> int:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def is_master(self) -> bool:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def type(self) -> str:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def autocast(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def wrap_distributed(self, *modules):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def to_device(self, *modules: torch.nn.Module):
|
||||
pass
|
||||
|
||||
#@abc.abstractmethod
|
||||
def mark_step(self):
|
||||
# FIXME this is for XLA only, make it common to all devices w/ appropriate no-ops?
|
||||
pass
|
@ -0,0 +1,90 @@
|
||||
import os
|
||||
from contextlib import suppress
|
||||
|
||||
import torch
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
|
||||
from .device_env import DeviceEnv
|
||||
|
||||
|
||||
def is_cuda_available():
|
||||
return torch.cuda.is_available()
|
||||
|
||||
|
||||
class DeviceEnvCuda(DeviceEnv):
|
||||
|
||||
def __init__(self, device_idx=None, local_rank=None, amp=False, memory_format=None):
|
||||
assert torch.cuda.device_count()
|
||||
torch.backends.cudnn.benchmark = True
|
||||
self._local_rank = 0
|
||||
self._distributed = False
|
||||
self._world_size = 1
|
||||
self._global_rank = 0
|
||||
if 'WORLD_SIZE' in os.environ:
|
||||
self._distributed = int(os.environ['WORLD_SIZE']) > 1
|
||||
if self._distributed:
|
||||
if local_rank is None:
|
||||
lr = os.environ.get('LOCAL_RANK', None)
|
||||
if lr is None:
|
||||
raise RuntimeError(
|
||||
'At least one of LOCAL_RANK env variable or local_rank arg must be set to valid integer.')
|
||||
self._local_rank = lr
|
||||
else:
|
||||
self._local_rank = int(local_rank)
|
||||
self._device = torch.device('cuda:%d' % self._local_rank)
|
||||
torch.cuda.set_device(self._local_rank)
|
||||
torch.distributed.init_process_group(backend='nccl', init_method='env://')
|
||||
self._world_size = torch.distributed.get_world_size()
|
||||
self._global_rank = torch.distributed.get_rank()
|
||||
else:
|
||||
self._device = torch.device('cuda' if device_idx is None else f'cuda:{device_idx}')
|
||||
self._memory_format = memory_format
|
||||
if amp:
|
||||
self._amp = amp
|
||||
self._autocast = torch.cuda.amp.autocast
|
||||
else:
|
||||
self._amp = amp
|
||||
self._autocast = suppress
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self._device
|
||||
|
||||
@property
|
||||
def local_rank(self):
|
||||
return self._local_rank
|
||||
|
||||
@property
|
||||
def global_rank(self):
|
||||
return self._global_rank
|
||||
|
||||
@property
|
||||
def is_distributed(self):
|
||||
return self._distributed
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
return self._world_size
|
||||
|
||||
@property
|
||||
def is_master(self):
|
||||
return self._local_rank == 0
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return 'cuda'
|
||||
|
||||
@property
|
||||
def amp(self) -> bool:
|
||||
return self._amp
|
||||
|
||||
@property
|
||||
def autocast(self):
|
||||
return self._autocast
|
||||
|
||||
def wrap_distributed(self, *modules, **kwargs):
|
||||
return [DistributedDataParallel(m, device_ids=[self._local_rank], **kwargs) for m in modules]
|
||||
|
||||
def to_device(self, *modules: torch.nn.Module):
|
||||
# FIXME handling dtype / memformat... disable flags, enable flags, diff fn?
|
||||
return [m.to(device=self._device, memory_format=self._memory_format) for m in modules]
|
@ -0,0 +1,34 @@
|
||||
from .device_env_cuda import DeviceEnvCuda, is_cuda_available
|
||||
from .device_env_xla import DeviceEnvXla, is_xla_available
|
||||
|
||||
_device_env = None
|
||||
|
||||
|
||||
def initialize_device(force_cpu: bool = False, xla_device_type=None, **kwargs):
|
||||
global _device_env
|
||||
if _device_env is not None:
|
||||
# warning
|
||||
return _device_env
|
||||
|
||||
denv = None
|
||||
if not force_cpu:
|
||||
if is_xla_available(xla_device_type):
|
||||
# XLA supports more than just TPU, but by default will only look at TPU
|
||||
denv = DeviceEnvXla(**kwargs, xla_device_type=xla_device_type)
|
||||
elif is_cuda_available():
|
||||
denv = DeviceEnvCuda(**kwargs)
|
||||
|
||||
if denv is None:
|
||||
# FIXME implement CPU support
|
||||
raise NotImplementedError()
|
||||
|
||||
_device_env = denv
|
||||
return denv
|
||||
|
||||
|
||||
def get_device():
|
||||
if _device_env is None:
|
||||
raise RuntimeError('Please initialize device environment by calling initialize_device first.')
|
||||
return _device_env
|
||||
|
||||
|
@ -0,0 +1,85 @@
|
||||
import os
|
||||
from contextlib import suppress
|
||||
import torch
|
||||
|
||||
try:
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.amp as xa
|
||||
_HAS_XLA = True
|
||||
except ImportError as e:
|
||||
xm = None
|
||||
_HAS_XLA = False
|
||||
|
||||
from .device_env import DeviceEnv
|
||||
|
||||
|
||||
def is_xla_available(xla_device_type=None):
|
||||
if not _HAS_XLA:
|
||||
return False
|
||||
supported_devs = xm.get_xla_supported_devices(devkind=xla_device_type)
|
||||
print(supported_devs)
|
||||
return len(supported_devs) >= 1
|
||||
|
||||
|
||||
class DeviceEnvXla(DeviceEnv):
|
||||
|
||||
def __init__(self, xla_device_type=None, device_idx=None, local_rank=0, amp=False):
|
||||
self._device = xm.xla_device(n=device_idx, devkind=xla_device_type)
|
||||
print(self._device)
|
||||
self._local_rank = xm.get_local_ordinal(local_rank)
|
||||
self._world_size = xm.xrt_world_size()
|
||||
self._distributed = self._world_size > 1
|
||||
self._global_rank = 0
|
||||
if self._distributed:
|
||||
self._global_rank = xm.get_ordinal()
|
||||
if amp:
|
||||
self._autocast = xa.autocast
|
||||
else:
|
||||
self._autocast = suppress
|
||||
self._memory_format = None
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self._device
|
||||
|
||||
@property
|
||||
def local_rank(self):
|
||||
return self._local_rank
|
||||
|
||||
@property
|
||||
def global_rank(self):
|
||||
return self._global_rank
|
||||
|
||||
@property
|
||||
def is_distributed(self):
|
||||
return self._distributed
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
return self._world_size
|
||||
|
||||
@property
|
||||
def is_master(self):
|
||||
return self._global_rank == 0
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return 'xla'
|
||||
|
||||
@property
|
||||
def amp(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def autocast(self):
|
||||
return self._autocast
|
||||
|
||||
def wrap_distributed(self, *modules):
|
||||
# NO-OP
|
||||
return tuple([m for m in modules])
|
||||
|
||||
def to_device(self, *modules: torch.nn.Module):
|
||||
return [m.to(device=self._device, memory_format=self._memory_format) for m in modules]
|
||||
|
||||
def mark_step(self):
|
||||
xm.mark_step()
|
@ -0,0 +1,36 @@
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
|
||||
from timm.utils.agc import adaptive_clip_grad
|
||||
|
||||
|
||||
def get_clip_grad_fn(mode: str = 'norm', norm_type: float = 2.0):
|
||||
if mode == 'norm':
|
||||
return partial(torch.nn.utils.clip_grad_norm_, norm_type=norm_type)
|
||||
elif mode == 'value':
|
||||
return torch.nn.utils.clip_grad_value_
|
||||
elif mode == 'agc':
|
||||
return partial(adaptive_clip_grad, norm_type=norm_type)
|
||||
else:
|
||||
assert False, f"Unknown clip mode ({mode})."
|
||||
|
||||
|
||||
def get_clip_parameters(model):
|
||||
if hasattr(model, 'get_clip_parameters'):
|
||||
return model.get_clip_parameters()
|
||||
else:
|
||||
return model.parameters()
|
||||
|
||||
|
||||
class GradClipper:
|
||||
|
||||
def __init__(self, model, clip_value, clip_mode='norm'):
|
||||
self.model = model
|
||||
self.clip_fn = get_clip_grad_fn(clip_mode)
|
||||
self.clip_value = clip_value
|
||||
self.enabled = True
|
||||
|
||||
def __call__(self):
|
||||
if self.enabled:
|
||||
self.clip_fn(get_clip_parameters(self.model), self.clip_value)
|
@ -0,0 +1,223 @@
|
||||
import csv
|
||||
import logging
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import Optional, Tuple, Dict, Union
|
||||
|
||||
import torch
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
HAS_TB = True
|
||||
except ImportError as e:
|
||||
HAS_TB = False
|
||||
|
||||
try:
|
||||
import wandb
|
||||
HAS_WANDB = True
|
||||
except ImportError:
|
||||
HAS_WANDB = False
|
||||
|
||||
|
||||
# FIXME old formatting for reference, to remove
|
||||
#
|
||||
# def log_eval(batch_idx, last_idx, batch_time, loss, top1, top5, log_suffix=''):
|
||||
# log_name = 'Test' + log_suffix
|
||||
# logging.info(
|
||||
# f'{log_name}: [{batch_idx:>4d}/{last_idx}] '
|
||||
# f'Time: {batch_time.smooth_val:.3f} ({batch_time.avg:.3f}) '
|
||||
# f'Loss: {loss.smooth_val:>7.4f} ({loss.avg:>6.4f}) '
|
||||
# f'Acc@1: {top1.smooth_val:>7.4f} ({top1.avg:>7.4f}) '
|
||||
# f'Acc@5: {top5.smooth_val:>7.4f} ({top5.avg:>7.4f})'
|
||||
# )
|
||||
#
|
||||
#
|
||||
# def log_train(epoch, step, num_steps, loss, batch_size, batch_time, data_time, lr, world_size=1):
|
||||
# last_step = max(0, num_steps - 1)
|
||||
# progress = 100. * step / last_step if last_step else 0.
|
||||
# log_str = f'Train: {epoch} [{step:>4d}/{num_steps} ({progress:>3.0f}%)]' \
|
||||
# f' Time: {batch_time.smooth_val:.3f}s, {batch_size * world_size / batch_time.smooth_val:>7.2f}/s' \
|
||||
# f' ({batch_time.avg:.3f}s, {batch_size * world_size / batch_time.avg:>7.2f}/s)' \
|
||||
# f' Data: {data_time.smooth_val:.3f} ({data_time.avg:.3f})'
|
||||
# log_str += f' Loss: {loss.smooth_val:>9.6f} ({loss.avg:>6.4f}) '
|
||||
# log_str += f' LR: {lr:.3e} '
|
||||
#
|
||||
# if args.save_images and output_dir:
|
||||
# torchvision.utils.save_image(
|
||||
# input,
|
||||
# os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
|
||||
# padding=0,
|
||||
# normalize=True)
|
||||
|
||||
|
||||
def summary_row_dict(results, index=None, index_name='epoch'):
|
||||
assert isinstance(results, dict)
|
||||
row_dict = OrderedDict()
|
||||
if index is not None:
|
||||
row_dict[index_name] = index
|
||||
if not results:
|
||||
return row_dict
|
||||
if isinstance(next(iter(results.values())), dict):
|
||||
# each key in results is a per-phase results dict, flatten by prefixing with phase name
|
||||
for p, pr in results.keys():
|
||||
assert isinstance(dict, pr)
|
||||
row_dict.update([('_'.join([p, k]), v) for k, v in pr.items()])
|
||||
else:
|
||||
row_dict.update(results)
|
||||
return row_dict
|
||||
|
||||
|
||||
class SummaryCsv:
|
||||
def __init__(self, output_dir, filename='summary.csv'):
|
||||
self.output_dir = output_dir
|
||||
self.filename = os.path.join(output_dir, filename)
|
||||
self.needs_header = not os.path.exists(self.filename)
|
||||
|
||||
def update(self, row_dict):
|
||||
with open(self.filename, mode='a') as cf:
|
||||
dw = csv.DictWriter(cf, fieldnames=row_dict.keys())
|
||||
if self.needs_header: # first iteration (epoch == 1 can't be used)
|
||||
dw.writeheader()
|
||||
self.needs_header = False
|
||||
dw.writerow(row_dict)
|
||||
|
||||
|
||||
def _add_kwargs(text_update, name_map=None, **kwargs):
|
||||
def _to_str(key, val):
|
||||
if isinstance(val, float):
|
||||
return f'{key}: {val:.4f}'
|
||||
else:
|
||||
return f'{key}: {val}'
|
||||
|
||||
def _map_name(key, name_map, capitalize=True):
|
||||
if name_map is None:
|
||||
if capitalize:
|
||||
return key.capitalize() if not key.isupper() else key
|
||||
else:
|
||||
return key
|
||||
return name_map.get(key, None)
|
||||
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, dict):
|
||||
# log each k, v of a dict kwarg as separate items
|
||||
for kk, vv in v.items():
|
||||
name = _map_name(kk, name_map)
|
||||
if not name:
|
||||
continue
|
||||
text_update += [_to_str(kk, vv)]
|
||||
else:
|
||||
name = _map_name(k, name_map, capitalize=True)
|
||||
if not name:
|
||||
continue
|
||||
text_update += [_to_str(name, v)]
|
||||
|
||||
|
||||
class Logger:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
experiment_name=None,
|
||||
output_dir=None,
|
||||
logger=None,
|
||||
log_wandb=False,
|
||||
hparams=None):
|
||||
|
||||
self.output_dir = output_dir # for tensorboard, csv, console logging to file?
|
||||
self.logger = logger or logging.getLogger('log')
|
||||
hparams = hparams or {}
|
||||
|
||||
# Setup CSV writer(s)
|
||||
if output_dir is not None:
|
||||
self.csv_writer = SummaryCsv(output_dir=output_dir)
|
||||
else:
|
||||
self.csv_writer = None
|
||||
|
||||
# Setup Tensorboard
|
||||
self.summary_writer = None # FIXME tensorboard
|
||||
|
||||
# Setup W&B
|
||||
self.wandb_run = None
|
||||
if log_wandb:
|
||||
if HAS_WANDB:
|
||||
self.wandb_run = wandb.init(project=experiment_name, config=hparams)
|
||||
else:
|
||||
_logger.warning("You've requested to log metrics to wandb but package not found. "
|
||||
"Metrics not being logged to wandb, try `pip install wandb`")
|
||||
|
||||
# FIXME image save
|
||||
|
||||
def log_step(
|
||||
self,
|
||||
phase: str,
|
||||
step: int,
|
||||
end_step: Optional[int] = None,
|
||||
loss: Optional[float] = None,
|
||||
rate: Optional[float] = None,
|
||||
epoch: Optional[int] = None,
|
||||
phase_suffix: str = '',
|
||||
**kwargs,
|
||||
):
|
||||
""" log train/eval step
|
||||
"""
|
||||
phase_title = f'{phase.capitalize()} ({phase_suffix})' if phase_suffix else f'{phase.capitalize()}'
|
||||
progress = 100. * step / end_step if end_step else 0.
|
||||
text_update = [
|
||||
phase_title,
|
||||
f'Epoch: {epoch}' if epoch is not None else None,
|
||||
f'Step: {step}' if end_step is None else None,
|
||||
f'Step: [{step}/{end_step} ({progress:>3.0f}%)]' if end_step is not None else None,
|
||||
f'Rate: {rate:.2f}/s' if rate is not None else None,
|
||||
f'Loss: {loss:.5f}' if loss is not None else None,
|
||||
]
|
||||
_add_kwargs(text_update, **kwargs)
|
||||
log_str = ' '.join(item for item in text_update if item)
|
||||
self.logger.info(log_str)
|
||||
if self.summary_writer is not None:
|
||||
# FIXME log step values to tensorboard
|
||||
pass
|
||||
|
||||
def log_phase(
|
||||
self,
|
||||
phase: str = 'eval',
|
||||
epoch: Optional[int] = None,
|
||||
name_map: Optional[dict] = None,
|
||||
**kwargs
|
||||
):
|
||||
"""log completion of evaluation or training phase
|
||||
"""
|
||||
title = [
|
||||
f'{phase.capitalize()}',
|
||||
f'epoch: {epoch}' if epoch is not None else None,
|
||||
'completed. ',
|
||||
]
|
||||
title_str = ' '.join(i for i in title if i)
|
||||
results = []
|
||||
_add_kwargs(results, name_map=name_map, **kwargs)
|
||||
log_str = title_str + ', '.join(item for item in results if item)
|
||||
self.logger.info(log_str)
|
||||
|
||||
def write_summary(
|
||||
self,
|
||||
results: Dict, # Dict or Dict of Dict where first level keys are treated as per-phase results
|
||||
index: Optional[Union[int, str]] = None,
|
||||
index_name: str = 'epoch',
|
||||
):
|
||||
""" Log complete results for all phases (typically called at end of epoch)
|
||||
|
||||
Args:
|
||||
results (dict or dict[dict]): dict of results to write, or multiple dicts where first level
|
||||
key is the name of results dict for each phase
|
||||
index: value for row index (typically epoch #)
|
||||
index_name: name for row index header (typically 'epoch')
|
||||
"""
|
||||
|
||||
row_dict = summary_row_dict(index=index, index_name=index_name, results=results)
|
||||
if self.csv_writer:
|
||||
self.csv_writer.update(row_dict)
|
||||
if self.wandb_run is not None:
|
||||
wandb.log(row_dict)
|
||||
if self.summary_writer:
|
||||
# FIXME log epoch summaries to tensorboard
|
||||
pass
|
@ -0,0 +1,50 @@
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from timm.metrics import ScalarAvgMinMax
|
||||
|
||||
|
||||
class Tracker:
|
||||
|
||||
def __init__(self):
|
||||
self.data_time = ScalarAvgMinMax() # time for data loader to produce batch of samples
|
||||
self.step_time = ScalarAvgMinMax() # time for model step
|
||||
self.iter_time = ScalarAvgMinMax() # full iteration time incl. data, step, and book-keeping
|
||||
self.epoch_time = ScalarAvgMinMax()
|
||||
|
||||
self.iter_timestamp: Optional[float] = None
|
||||
self.prev_timestamp: Optional[float] = None
|
||||
self.epoch_timestamp: Optional[float] = None
|
||||
|
||||
def _measure_iter(self, ref_timestamp=None):
|
||||
timestamp = time.perf_counter()
|
||||
self.prev_timestamp = timestamp
|
||||
|
||||
def mark_iter(self):
|
||||
timestamp = time.perf_counter()
|
||||
if self.iter_timestamp is not None:
|
||||
iter_time = timestamp - self.iter_timestamp
|
||||
self.iter_time.update(iter_time)
|
||||
self.iter_timestamp = self.prev_timestamp = timestamp
|
||||
|
||||
def mark_iter_data_end(self):
|
||||
assert self.prev_timestamp is not None
|
||||
timestamp = time.perf_counter()
|
||||
data_time = timestamp - self.prev_timestamp
|
||||
self.data_time.update(data_time)
|
||||
self.prev_timestamp = timestamp
|
||||
|
||||
def mark_iter_step_end(self):
|
||||
assert self.prev_timestamp is not None
|
||||
timestamp = time.perf_counter()
|
||||
step_time = timestamp - self.prev_timestamp
|
||||
self.step_time.update(step_time)
|
||||
self.prev_timestamp = timestamp
|
||||
|
||||
def mark_epoch(self):
|
||||
timestamp = time.perf_counter()
|
||||
if self.epoch_timestamp is not None:
|
||||
epoch_time = timestamp - self.epoch_timestamp
|
||||
self.epoch_time.update(epoch_time)
|
||||
self.epoch_timestamp = timestamp
|
||||
|
@ -0,0 +1,54 @@
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from .grad_clipper import GradClipper
|
||||
|
||||
|
||||
class Updater:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
clip_value: Optional[Union[Callable, float]] = None,
|
||||
clip_mode: str = 'norm'):
|
||||
|
||||
self.optimizer = optimizer
|
||||
self.clipper: Optional[GradClipper] = None
|
||||
if clip_value is not None:
|
||||
if isinstance(clip_value, Callable):
|
||||
self.clipper = clip_value
|
||||
else:
|
||||
GradClipper(clip_value, clip_mode)
|
||||
self.scaler = None
|
||||
self.create_graph = getattr(self.optimizer, 'second_order', False)
|
||||
self.num_accumulated = 0
|
||||
self.after_step_closure = False
|
||||
|
||||
def apply(self, loss: torch.Tensor, accumulate=False):
|
||||
loss.backward(create_graph=self.create_graph)
|
||||
if self.clipper is not None:
|
||||
self.clipper()
|
||||
if not accumulate:
|
||||
self.optimizer.step()
|
||||
self.reset()
|
||||
else:
|
||||
self.num_accumulated += 1
|
||||
|
||||
def reset(self):
|
||||
self.optimizer.zero_grad()
|
||||
self.num_accumulated = 0
|
||||
|
||||
def state_dict(self):
|
||||
state_dict = dict(optimizer=self.optimizer.state_dict())
|
||||
if self.scaler is not None:
|
||||
state_dict['scaler'] = self.scaler.state_dict()
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
if 'optimizer' in state_dict:
|
||||
self.optimizer.load_state_dict(state_dict['optimizer'])
|
||||
if 'scaler' in state_dict and self.scaler is not None:
|
||||
self.scaler.load_state_dict(state_dict['scaler'])
|
||||
|
||||
|
||||
|
@ -0,0 +1,36 @@
|
||||
from typing import Callable, Optional, Union, Any
|
||||
|
||||
import torch
|
||||
|
||||
from .updater import Updater
|
||||
|
||||
|
||||
class UpdaterCuda(Updater):
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
clip_value: Optional[Union[Callable, float]] = None,
|
||||
clip_mode: str = 'norm',
|
||||
use_scaler: bool = False,
|
||||
scaler_kwargs: Any = None,
|
||||
):
|
||||
super().__init__(optimizer=optimizer, clip_value=clip_value, clip_mode=clip_mode)
|
||||
scaler_kwargs = scaler_kwargs or {}
|
||||
if use_scaler:
|
||||
self.scaler = torch.cuda.amp.GradScaler(**scaler_kwargs)
|
||||
|
||||
def apply(self, loss: torch.Tensor, accumulate=False):
|
||||
if self.scaler is not None:
|
||||
self.scaler.scale(loss).backward(create_graph=self.create_graph)
|
||||
if self.clipper is not None:
|
||||
self.scaler.unscale_(self.optimizer) # unscale the gradients of optimizer's assigned params in-place
|
||||
self.clipper()
|
||||
if not accumulate:
|
||||
self.scaler.step(self.optimizer)
|
||||
self.reset()
|
||||
else:
|
||||
self.num_accumulated += 1
|
||||
self.scaler.update()
|
||||
else:
|
||||
Updater.apply(self, loss, accumulate)
|
||||
|
@ -0,0 +1,30 @@
|
||||
from typing import Callable, Optional, Union, Any
|
||||
|
||||
import torch
|
||||
|
||||
from .device_env import DeviceEnv
|
||||
from .device_env_factory import get_device
|
||||
from .updater import Updater
|
||||
from .updater_cuda import UpdaterCuda
|
||||
from .updater_xla import UpdaterXla
|
||||
|
||||
|
||||
def create_updater(
|
||||
optimizer: torch.optim.Optimizer,
|
||||
dev_env: Optional[DeviceEnv] = None,
|
||||
clip_value: Optional[Union[Callable, float]] = None,
|
||||
clip_mode: str = 'norm',
|
||||
scaler_kwargs: Any = None) -> Updater:
|
||||
|
||||
if not dev_env:
|
||||
dev_env = get_device()
|
||||
|
||||
updater_kwargs = dict(
|
||||
optimizer=optimizer, clip_value=clip_value, clip_mode=clip_mode, scaler_kwargs=scaler_kwargs)
|
||||
if dev_env.type == 'xla':
|
||||
return UpdaterXla(**updater_kwargs, use_scaler=dev_env.amp)
|
||||
elif dev_env.type == 'cuda':
|
||||
return UpdaterCuda(**updater_kwargs, use_scaler=dev_env.amp)
|
||||
else:
|
||||
updater_kwargs.pop('scaler_kwargs', None)
|
||||
return Updater(**updater_kwargs)
|
@ -0,0 +1,52 @@
|
||||
from typing import Callable, Optional, Union, Any
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.amp as xa
|
||||
_HAS_XLA = True
|
||||
except ImportError as e:
|
||||
xm = None
|
||||
_HAS_XLA = False
|
||||
|
||||
from .updater import Updater
|
||||
|
||||
|
||||
class UpdaterXla(Updater):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
clip_value: Optional[Union[Callable, float]] = None,
|
||||
clip_mode: str = 'norm',
|
||||
use_scaler: bool = False,
|
||||
scaler_kwargs: Any = None,
|
||||
):
|
||||
super().__init__(optimizer=optimizer, clip_value=clip_value, clip_mode=clip_mode)
|
||||
self.after_step_closure = True
|
||||
if use_scaler:
|
||||
self.scaler = xa.GradScaler(**scaler_kwargs)
|
||||
|
||||
def apply(self, loss: torch.Tensor, accumulate: bool = False):
|
||||
if self.scaler is None:
|
||||
loss.backward(create_graph=self.create_graph)
|
||||
gradients = xm._fetch_gradients(self.optimizer)
|
||||
xm.all_reduce('sum', gradients, scale=1.0 / xm.xrt_world_size())
|
||||
if self.clipper is not None:
|
||||
self.clipper()
|
||||
if not accumulate:
|
||||
xm.optimizer_step(self.optimizer)
|
||||
else:
|
||||
self.scaler.scale(loss).backward(create_graph=self.create_graph)
|
||||
if self.clipper is not None:
|
||||
self.scaler.unscale_(self.optimizer) # unscale the gradients of optimizer's assigned params in-place
|
||||
self.clipper()
|
||||
if not accumulate:
|
||||
self.scaler.step(self.optimizer)
|
||||
self.reset()
|
||||
self.scaler.update()
|
||||
|
||||
def after_step(self, after_step_fn, *args):
|
||||
xm.add_step_closure(after_step_fn, *args)
|
||||
|
@ -0,0 +1,38 @@
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def fast_collate(batch):
|
||||
""" A fast collation function optimized for uint8 images (np array or torch) and int64 targets (labels)"""
|
||||
assert isinstance(batch[0], tuple)
|
||||
batch_size = len(batch)
|
||||
if isinstance(batch[0][0], tuple):
|
||||
# This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position
|
||||
# such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position
|
||||
inner_tuple_size = len(batch[0][0])
|
||||
flattened_batch_size = batch_size * inner_tuple_size
|
||||
targets = torch.zeros(flattened_batch_size, dtype=torch.int64)
|
||||
tensor = torch.zeros((flattened_batch_size, *batch[0][0][0].shape), dtype=torch.uint8)
|
||||
for i in range(batch_size):
|
||||
assert len(batch[i][0]) == inner_tuple_size # all input tensor tuples must be same length
|
||||
for j in range(inner_tuple_size):
|
||||
targets[i + j * batch_size] = batch[i][1]
|
||||
tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j])
|
||||
return tensor, targets
|
||||
elif isinstance(batch[0][0], np.ndarray):
|
||||
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
|
||||
assert len(targets) == batch_size
|
||||
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
|
||||
for i in range(batch_size):
|
||||
tensor[i] += torch.from_numpy(batch[i][0])
|
||||
return tensor, targets
|
||||
elif isinstance(batch[0][0], torch.Tensor):
|
||||
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
|
||||
assert len(targets) == batch_size
|
||||
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
|
||||
for i in range(batch_size):
|
||||
tensor[i].copy_(batch[i][0])
|
||||
return tensor, targets
|
||||
else:
|
||||
assert False
|
@ -0,0 +1,69 @@
|
||||
import torch
|
||||
|
||||
from .constants import *
|
||||
from .random_erasing import RandomErasing
|
||||
from. mixup import FastCollateMixup
|
||||
|
||||
|
||||
class FetcherXla:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
class Fetcher:
|
||||
|
||||
def __init__(self,
|
||||
loader,
|
||||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD,
|
||||
device=None,
|
||||
dtype=None,
|
||||
re_prob=0.,
|
||||
re_mode='const',
|
||||
re_count=1,
|
||||
re_num_splits=0):
|
||||
self.loader = loader
|
||||
self.mean = torch.tensor([x * 255 for x in mean]).view(1, 3, 1, 1)
|
||||
self.std = torch.tensor([x * 255 for x in std]).view(1, 3, 1, 1)
|
||||
self.device = torch.device(device)
|
||||
self.dtype = dtype or torch.float32
|
||||
if device:
|
||||
self.mean.to(device=device, dtype=self.dtype)
|
||||
self.std.to(device=device, dtype=self.dtype)
|
||||
if re_prob > 0.:
|
||||
self.random_erasing = RandomErasing(
|
||||
probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits)
|
||||
else:
|
||||
self.random_erasing = None
|
||||
|
||||
def __iter__(self):
|
||||
for sample, target in self.loader:
|
||||
sample = sample.to(device=self.device)
|
||||
target = target.to(device=self.device)
|
||||
sample = sample.to(dtype=self.dtype).sub_(self.mean).div_(self.std)
|
||||
if self.random_erasing is not None:
|
||||
sample = self.random_erasing(sample)
|
||||
yield sample, target
|
||||
|
||||
def __len__(self):
|
||||
return len(self.loader)
|
||||
|
||||
@property
|
||||
def sampler(self):
|
||||
return self.loader.sampler
|
||||
|
||||
@property
|
||||
def dataset(self):
|
||||
return self.loader.dataset
|
||||
|
||||
@property
|
||||
def mixup_enabled(self):
|
||||
if isinstance(self.loader.collate_fn, FastCollateMixup):
|
||||
return self.loader.collate_fn.mixup_enabled
|
||||
else:
|
||||
return False
|
||||
|
||||
@mixup_enabled.setter
|
||||
def mixup_enabled(self, x):
|
||||
if isinstance(self.loader.collate_fn, FastCollateMixup):
|
||||
self.loader.collate_fn.mixup_enabled = x
|
@ -0,0 +1,79 @@
|
||||
import torch.cuda
|
||||
|
||||
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .mixup import FastCollateMixup
|
||||
from .random_erasing import RandomErasing
|
||||
|
||||
|
||||
class PrefetcherCuda:
|
||||
|
||||
def __init__(self,
|
||||
loader,
|
||||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD,
|
||||
fp16=False,
|
||||
re_prob=0.,
|
||||
re_mode='const',
|
||||
re_count=1,
|
||||
re_num_splits=0):
|
||||
self.loader = loader
|
||||
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1)
|
||||
self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1)
|
||||
self.fp16 = fp16
|
||||
if fp16:
|
||||
self.mean = self.mean.half()
|
||||
self.std = self.std.half()
|
||||
if re_prob > 0.:
|
||||
self.random_erasing = RandomErasing(
|
||||
probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits)
|
||||
else:
|
||||
self.random_erasing = None
|
||||
|
||||
def __iter__(self):
|
||||
stream = torch.cuda.Stream()
|
||||
first = True
|
||||
|
||||
for next_input, next_target in self.loader:
|
||||
with torch.cuda.stream(stream):
|
||||
next_input = next_input.cuda(non_blocking=True)
|
||||
next_target = next_target.cuda(non_blocking=True)
|
||||
if self.fp16:
|
||||
next_input = next_input.half().sub_(self.mean).div_(self.std)
|
||||
else:
|
||||
next_input = next_input.float().sub_(self.mean).div_(self.std)
|
||||
if self.random_erasing is not None:
|
||||
next_input = self.random_erasing(next_input)
|
||||
|
||||
if not first:
|
||||
yield input, target
|
||||
else:
|
||||
first = False
|
||||
|
||||
torch.cuda.current_stream().wait_stream(stream)
|
||||
input = next_input
|
||||
target = next_target
|
||||
|
||||
yield input, target
|
||||
|
||||
def __len__(self):
|
||||
return len(self.loader)
|
||||
|
||||
@property
|
||||
def sampler(self):
|
||||
return self.loader.sampler
|
||||
|
||||
@property
|
||||
def dataset(self):
|
||||
return self.loader.dataset
|
||||
|
||||
@property
|
||||
def mixup_enabled(self):
|
||||
if isinstance(self.loader.collate_fn, FastCollateMixup):
|
||||
return self.loader.collate_fn.mixup_enabled
|
||||
else:
|
||||
return False
|
||||
|
||||
@mixup_enabled.setter
|
||||
def mixup_enabled(self, x):
|
||||
if isinstance(self.loader.collate_fn, FastCollateMixup):
|
||||
self.loader.collate_fn.mixup_enabled = x
|
@ -0,0 +1,4 @@
|
||||
from .accuracy import Accuracy, AccuracyTopK
|
||||
from .precision_recall import PrecisionRecall
|
||||
from .scalar_avg import ScalarAvgMinMax
|
||||
from .tensor_avg import TensorAvg, TensorEma
|
@ -0,0 +1,112 @@
|
||||
import torch
|
||||
from typing import Optional, Tuple, Dict
|
||||
|
||||
|
||||
class Accuracy(torch.nn.Module):
|
||||
|
||||
def __init__(self, threshold=0.5, multi_label=False):
|
||||
self.threshold = threshold
|
||||
self.eps = 1e-8
|
||||
self.multi_label = multi_label
|
||||
|
||||
# statistics / counts
|
||||
self._correct_sum = torch.tensor(0, dtype=torch.long)
|
||||
self._total_sum = torch.tensor(0, dtype=torch.long)
|
||||
|
||||
def update(self, predictions, target):
|
||||
raise NotImplemented()
|
||||
|
||||
def reset(self):
|
||||
self._correct_sum = 0
|
||||
self._total_sum = 0
|
||||
|
||||
@property
|
||||
def counts(self):
|
||||
pass
|
||||
|
||||
def compute(self):
|
||||
raise NotImplemented()
|
||||
|
||||
|
||||
class AccuracyTopK(torch.nn.Module):
|
||||
|
||||
def __init__(self, topk=(1, 5), device=None):
|
||||
super().__init__()
|
||||
self.eps = 1e-8
|
||||
self.device = device
|
||||
self.topk = topk
|
||||
self.maxk = max(topk)
|
||||
|
||||
# statistics / counts
|
||||
self.reset()
|
||||
|
||||
def update(self, predictions: torch.Tensor, target: torch.Tensor):
|
||||
sorted_indices = predictions.topk(self.maxk, dim=1)[1]
|
||||
sorted_indices.t_()
|
||||
correct = sorted_indices.eq(target.reshape(1, -1).expand_as(sorted_indices))
|
||||
|
||||
batch_size = target.shape[0]
|
||||
correct_k = {k: correct[:k].reshape(-1).float().sum(0) for k in self.topk}
|
||||
for k, v in correct_k.items():
|
||||
attr = f'_correct_top{k}'
|
||||
old_v = getattr(self, attr)
|
||||
setattr(self, attr, old_v + v)
|
||||
self._total_sum += batch_size
|
||||
|
||||
def reset(self):
|
||||
for k in self.topk:
|
||||
setattr(self, f'_correct_top{k}', torch.tensor(0, dtype=torch.float32))
|
||||
self._total_sum = torch.tensor(0, dtype=torch.float32)
|
||||
|
||||
@property
|
||||
def counts(self):
|
||||
pass
|
||||
|
||||
def compute(self) -> Dict[str, torch.Tensor]:
|
||||
return {f'top{k}': 100 * getattr(self, f'_correct_top{k}') / self._total_sum for k in self.topk}
|
||||
|
||||
|
||||
#
|
||||
# class AccuracyTopK:
|
||||
#
|
||||
# def __init__(self, topk=(1, 5), device=None):
|
||||
# self.eps = 1e-8
|
||||
# self.device = device
|
||||
# self.topk = topk
|
||||
# self.maxk = max(topk)
|
||||
#
|
||||
# # statistics / counts
|
||||
# self._correct_sum = None
|
||||
# self._total_sum = None
|
||||
#
|
||||
# def _check_init(self, device):
|
||||
# to_device = self.device if self.device else device
|
||||
# if self._correct_sum is None:
|
||||
# self._correct_sum = {f'top{k}': torch.tensor(0., device=to_device) for k in self.topk}
|
||||
# if self._total_sum is None:
|
||||
# self._total_sum = torch.tensor(0, dtype=torch.long, device=to_device)
|
||||
#
|
||||
# def update(self, predictions: torch.Tensor, target: torch.Tensor):
|
||||
# sorted_indices = predictions.topk(self.maxk, dim=1)[1]
|
||||
# sorted_indices.t_()
|
||||
# correct = sorted_indices.eq(target.reshape(1, -1).expand_as(sorted_indices))
|
||||
#
|
||||
# batch_size = target.shape[0]
|
||||
# correct_k = {f'top{k}': correct[:k].reshape(-1).float().sum(0) for k in self.topk}
|
||||
# self._check_init(device=predictions.device)
|
||||
# for k, v in correct_k.items():
|
||||
# old_v = self._correct_sum[k]
|
||||
# self._correct_sum[k] = old_v + v
|
||||
# self._total_sum += batch_size
|
||||
#
|
||||
# def reset(self):
|
||||
# self._correct_sum = None
|
||||
# self._total_sum = None
|
||||
#
|
||||
# @property
|
||||
# def counts(self):
|
||||
# pass
|
||||
#
|
||||
# def compute(self) -> Dict[str, torch.Tensor]:
|
||||
# assert self._correct_sum is not None and self._total_sum is not None
|
||||
# return {k: 100 * v / self._total_sum for k, v in self._correct_sum.items()}
|
@ -0,0 +1,117 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class PrecisionRecall:
|
||||
|
||||
def __init__(self, threshold=0.5, multi_label=False, device=None):
|
||||
self.threshold = threshold
|
||||
self.device = device
|
||||
self.multi_label = multi_label
|
||||
|
||||
# statistics
|
||||
|
||||
# the total number of true positive instances under each class
|
||||
# Shape: (num_classes, )
|
||||
self._tp_sum = None
|
||||
|
||||
# the total number of instances
|
||||
# Shape: (num_classes, )
|
||||
self._total_sum = None
|
||||
|
||||
# the total number of instances under each _predicted_ class,
|
||||
# including true positives and false positives
|
||||
# Shape: (num_classes, )
|
||||
self._pred_sum = None
|
||||
|
||||
# the total number of instances under each _true_ class,
|
||||
# including true positives and false negatives
|
||||
# Shape: (num_classes, )
|
||||
self._true_sum = None
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self._tp_sum = None
|
||||
self._total_sum = None
|
||||
self._pred_sum = None
|
||||
self._true_sum = None
|
||||
|
||||
def update(self, predictions, target):
|
||||
output_type = predictions.type()
|
||||
num_classes = predictions.size(-1)
|
||||
if self.multi_label:
|
||||
if self.threshold is not None:
|
||||
predictions = (predictions > self.threshold).type(output_type)
|
||||
predictions = predictions.t().reshape(num_classes, -1)
|
||||
target = target.t().reshape(num_classes, -1)
|
||||
else:
|
||||
target = F.one_hot(target.view(-1), num_classes=num_classes)
|
||||
indices = torch.argmax(predictions, dim=1).view(-1)
|
||||
predictions = F.one_hot(indices, num_classes=num_classes)
|
||||
# FIXME make sure binary case works
|
||||
|
||||
target = target.type(output_type)
|
||||
correct = (target * predictions > 0).type(output_type)
|
||||
pred_positives = predictions.sum(dim=0)
|
||||
target_positives = target.sum(dim=0)
|
||||
if correct.sum() == 0:
|
||||
true_positives = torch.zeros_like(pred_positives)
|
||||
else:
|
||||
true_positives = correct.sum(dim=0)
|
||||
|
||||
if self._tp_sum is None:
|
||||
self._tp_sum = torch.zeros(num_classes, device=self.device)
|
||||
self._true_sum = torch.zeros(num_classes, device=self.device)
|
||||
self._pred_sum = torch.zeros(num_classes, device=self.device)
|
||||
self._total_sum = torch.tensor(0, device=self.device)
|
||||
|
||||
self._tp_sum += true_positives
|
||||
self._pred_sum += pred_positives
|
||||
self._true_sum += target_positives
|
||||
self._total_sum += target.shape[0]
|
||||
|
||||
def counts_as_tuple(self, reduce=False):
|
||||
tp_sum = self._tp_sum
|
||||
pred_sum = self._pred_sum
|
||||
true_sum = self._true_sum
|
||||
total_sum = self._total_sum
|
||||
if reduce:
|
||||
tp_sum = reduce_tensor_sum(tp_sum)
|
||||
pred_sum = reduce_tensor_sum(pred_sum)
|
||||
true_sum = reduce_tensor_sum(true_sum)
|
||||
total_sum = reduce_tensor_sum(total_sum)
|
||||
return tp_sum, pred_sum, true_sum, total_sum
|
||||
|
||||
def counts(self, reduce=False):
|
||||
tp_sum, pred_sum, true_sum, total_sum = self.counts_as_tuple(reduce=reduce)
|
||||
return dict(tp_sum=tp_sum, pred_sum=pred_sum, true_sum=true_sum, total_sum=total_sum)
|
||||
|
||||
def confusion(self, reduce=False):
|
||||
tp_sum, pred_sum, true_sum, total_sum = self.counts_as_tuple(reduce=reduce)
|
||||
fp = pred_sum - tp_sum
|
||||
fn = true_sum - tp_sum
|
||||
tp = tp_sum
|
||||
tn = total_sum - tp - fp - fn
|
||||
return dict(tp=tp, fp=fp, fn=fn, tn=tn)
|
||||
|
||||
def compute(self, fscore_beta=1, average='micro', no_reduce=False, distributed=False):
|
||||
tp_sum, pred_sum, true_sum, total_sum = self.counts_as_tuple(reduce=distributed)
|
||||
if average == 'micro':
|
||||
tp_sum = tp_sum.sum()
|
||||
pred_sum = pred_sum.sum()
|
||||
true_sum = true_sum.sum()
|
||||
|
||||
precision = tp_sum / pred_sum
|
||||
recall = tp_sum / true_sum
|
||||
beta_sq = fscore_beta ** 2
|
||||
f1_denom = beta_sq * precision + recall
|
||||
fscore = (1 + beta_sq) * precision * recall / f1_denom
|
||||
|
||||
if average == 'macro' and not no_reduce:
|
||||
precision = precision.mean()
|
||||
recall = recall.mean()
|
||||
fscore = fscore.mean()
|
||||
return dict(fscore=fscore, precision=precision, recall=recall)
|
||||
|
||||
return dict(fscore=fscore, precision=precision, recall=recall)
|
@ -0,0 +1,30 @@
|
||||
class ScalarAvgMinMax:
|
||||
|
||||
"""Computes and stores the average and current value"""
|
||||
def __init__(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.min = None
|
||||
self.max = None
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.min = None
|
||||
self.max = None
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.min = val if self.min is None else min(self.min, val)
|
||||
self.max = val if self.max is None else max(self.max, val)
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
|
@ -0,0 +1,42 @@
|
||||
import torch
|
||||
|
||||
|
||||
class TensorAvg:
|
||||
|
||||
"""Computes and stores the average and current value"""
|
||||
def __init__(self):
|
||||
self.sum = None
|
||||
self.count = None
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.sum = None
|
||||
self.count = None
|
||||
|
||||
def update(self, val: torch.Tensor, n=1):
|
||||
if self.sum is None:
|
||||
self.sum = torch.zeros_like(val)
|
||||
self.count = torch.tensor(0, dtype=torch.long, device=val.device)
|
||||
self.sum += (val * n)
|
||||
self.count += n
|
||||
|
||||
def compute(self):
|
||||
return self.sum / self.count
|
||||
|
||||
|
||||
class TensorEma:
|
||||
|
||||
"""Computes and stores the average and current value"""
|
||||
def __init__(self, smoothing_factor=0.9, init_zero=False):
|
||||
self.smoothing_factor = smoothing_factor
|
||||
self.init_zero = init_zero
|
||||
self.val = None
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = None
|
||||
|
||||
def update(self, val):
|
||||
if self.val is None:
|
||||
self.val = torch.zeros_like(val) if self.init_zero else val.clone()
|
||||
self.val = (1. - self.smoothing_factor) * val + self.smoothing_factor * self.val
|
Loading…
Reference in new issue