diff --git a/timm/bits/README.md b/timm/bits/README.md new file mode 100644 index 00000000..02ba6dc6 --- /dev/null +++ b/timm/bits/README.md @@ -0,0 +1,8 @@ +# Timm Bits + +A collection of reusable components and lightweight abstractions for training and evaluating NN. + +This is an early WIP with the primary goal to get up and running on TPUs first. Expect significant changes, rewrites, additions... + +The current train.py and validate.py scipts are evolving to use the timm.bits components, they will also change significantly. + diff --git a/timm/bits/__init__.py b/timm/bits/__init__.py new file mode 100644 index 00000000..33080c73 --- /dev/null +++ b/timm/bits/__init__.py @@ -0,0 +1,10 @@ +from .device_env_factory import initialize_device, get_device +from .device_env import DeviceEnv +#from .evaluate import evaluate, eval_step +from .logger import Logger +#from .task import TaskClassify +from .updater import Updater +from .updater_factory import create_updater +from .tracker import Tracker +#from .task_metrics import TaskMetrics, TaskMetricsClassify +#from .train import train_one_epoch, TrainServices, TrainState, TrainCfg, Experiment \ No newline at end of file diff --git a/timm/bits/device_env.py b/timm/bits/device_env.py new file mode 100644 index 00000000..646d64f4 --- /dev/null +++ b/timm/bits/device_env.py @@ -0,0 +1,58 @@ +import torch +import abc + + +class DeviceEnv(abc.ABC): + + @property + @abc.abstractmethod + def device(self) -> torch.device: + pass + + @property + @abc.abstractmethod + def local_rank(self) -> int: + pass + + @property + @abc.abstractmethod + def global_rank(self) -> int: + pass + + @property + @abc.abstractmethod + def is_distributed(self) -> bool: + pass + + @property + @abc.abstractmethod + def world_size(self) -> int: + pass + + @property + @abc.abstractmethod + def is_master(self) -> bool: + pass + + @property + @abc.abstractmethod + def type(self) -> str: + pass + + @property + @abc.abstractmethod + def autocast(self): + pass + + @abc.abstractmethod + def wrap_distributed(self, *modules): + pass + + @abc.abstractmethod + def to_device(self, *modules: torch.nn.Module): + pass + + #@abc.abstractmethod + def mark_step(self): + # FIXME this is for XLA only, make it common to all devices w/ appropriate no-ops? + pass \ No newline at end of file diff --git a/timm/bits/device_env_cuda.py b/timm/bits/device_env_cuda.py new file mode 100644 index 00000000..29c4d8f6 --- /dev/null +++ b/timm/bits/device_env_cuda.py @@ -0,0 +1,90 @@ +import os +from contextlib import suppress + +import torch +from torch.nn.parallel import DistributedDataParallel + +from .device_env import DeviceEnv + + +def is_cuda_available(): + return torch.cuda.is_available() + + +class DeviceEnvCuda(DeviceEnv): + + def __init__(self, device_idx=None, local_rank=None, amp=False, memory_format=None): + assert torch.cuda.device_count() + torch.backends.cudnn.benchmark = True + self._local_rank = 0 + self._distributed = False + self._world_size = 1 + self._global_rank = 0 + if 'WORLD_SIZE' in os.environ: + self._distributed = int(os.environ['WORLD_SIZE']) > 1 + if self._distributed: + if local_rank is None: + lr = os.environ.get('LOCAL_RANK', None) + if lr is None: + raise RuntimeError( + 'At least one of LOCAL_RANK env variable or local_rank arg must be set to valid integer.') + self._local_rank = lr + else: + self._local_rank = int(local_rank) + self._device = torch.device('cuda:%d' % self._local_rank) + torch.cuda.set_device(self._local_rank) + torch.distributed.init_process_group(backend='nccl', init_method='env://') + self._world_size = torch.distributed.get_world_size() + self._global_rank = torch.distributed.get_rank() + else: + self._device = torch.device('cuda' if device_idx is None else f'cuda:{device_idx}') + self._memory_format = memory_format + if amp: + self._amp = amp + self._autocast = torch.cuda.amp.autocast + else: + self._amp = amp + self._autocast = suppress + + @property + def device(self): + return self._device + + @property + def local_rank(self): + return self._local_rank + + @property + def global_rank(self): + return self._global_rank + + @property + def is_distributed(self): + return self._distributed + + @property + def world_size(self): + return self._world_size + + @property + def is_master(self): + return self._local_rank == 0 + + @property + def type(self) -> str: + return 'cuda' + + @property + def amp(self) -> bool: + return self._amp + + @property + def autocast(self): + return self._autocast + + def wrap_distributed(self, *modules, **kwargs): + return [DistributedDataParallel(m, device_ids=[self._local_rank], **kwargs) for m in modules] + + def to_device(self, *modules: torch.nn.Module): + # FIXME handling dtype / memformat... disable flags, enable flags, diff fn? + return [m.to(device=self._device, memory_format=self._memory_format) for m in modules] diff --git a/timm/bits/device_env_factory.py b/timm/bits/device_env_factory.py new file mode 100644 index 00000000..f6dc14f3 --- /dev/null +++ b/timm/bits/device_env_factory.py @@ -0,0 +1,34 @@ +from .device_env_cuda import DeviceEnvCuda, is_cuda_available +from .device_env_xla import DeviceEnvXla, is_xla_available + +_device_env = None + + +def initialize_device(force_cpu: bool = False, xla_device_type=None, **kwargs): + global _device_env + if _device_env is not None: + # warning + return _device_env + + denv = None + if not force_cpu: + if is_xla_available(xla_device_type): + # XLA supports more than just TPU, but by default will only look at TPU + denv = DeviceEnvXla(**kwargs, xla_device_type=xla_device_type) + elif is_cuda_available(): + denv = DeviceEnvCuda(**kwargs) + + if denv is None: + # FIXME implement CPU support + raise NotImplementedError() + + _device_env = denv + return denv + + +def get_device(): + if _device_env is None: + raise RuntimeError('Please initialize device environment by calling initialize_device first.') + return _device_env + + diff --git a/timm/bits/device_env_xla.py b/timm/bits/device_env_xla.py new file mode 100644 index 00000000..385b8626 --- /dev/null +++ b/timm/bits/device_env_xla.py @@ -0,0 +1,85 @@ +import os +from contextlib import suppress +import torch + +try: + import torch_xla.core.xla_model as xm + import torch_xla.amp as xa + _HAS_XLA = True +except ImportError as e: + xm = None + _HAS_XLA = False + +from .device_env import DeviceEnv + + +def is_xla_available(xla_device_type=None): + if not _HAS_XLA: + return False + supported_devs = xm.get_xla_supported_devices(devkind=xla_device_type) + print(supported_devs) + return len(supported_devs) >= 1 + + +class DeviceEnvXla(DeviceEnv): + + def __init__(self, xla_device_type=None, device_idx=None, local_rank=0, amp=False): + self._device = xm.xla_device(n=device_idx, devkind=xla_device_type) + print(self._device) + self._local_rank = xm.get_local_ordinal(local_rank) + self._world_size = xm.xrt_world_size() + self._distributed = self._world_size > 1 + self._global_rank = 0 + if self._distributed: + self._global_rank = xm.get_ordinal() + if amp: + self._autocast = xa.autocast + else: + self._autocast = suppress + self._memory_format = None + + @property + def device(self): + return self._device + + @property + def local_rank(self): + return self._local_rank + + @property + def global_rank(self): + return self._global_rank + + @property + def is_distributed(self): + return self._distributed + + @property + def world_size(self): + return self._world_size + + @property + def is_master(self): + return self._global_rank == 0 + + @property + def type(self) -> str: + return 'xla' + + @property + def amp(self) -> bool: + return False + + @property + def autocast(self): + return self._autocast + + def wrap_distributed(self, *modules): + # NO-OP + return tuple([m for m in modules]) + + def to_device(self, *modules: torch.nn.Module): + return [m.to(device=self._device, memory_format=self._memory_format) for m in modules] + + def mark_step(self): + xm.mark_step() diff --git a/timm/bits/grad_clipper.py b/timm/bits/grad_clipper.py new file mode 100644 index 00000000..232f5fc0 --- /dev/null +++ b/timm/bits/grad_clipper.py @@ -0,0 +1,36 @@ +from functools import partial + +import torch + +from timm.utils.agc import adaptive_clip_grad + + +def get_clip_grad_fn(mode: str = 'norm', norm_type: float = 2.0): + if mode == 'norm': + return partial(torch.nn.utils.clip_grad_norm_, norm_type=norm_type) + elif mode == 'value': + return torch.nn.utils.clip_grad_value_ + elif mode == 'agc': + return partial(adaptive_clip_grad, norm_type=norm_type) + else: + assert False, f"Unknown clip mode ({mode})." + + +def get_clip_parameters(model): + if hasattr(model, 'get_clip_parameters'): + return model.get_clip_parameters() + else: + return model.parameters() + + +class GradClipper: + + def __init__(self, model, clip_value, clip_mode='norm'): + self.model = model + self.clip_fn = get_clip_grad_fn(clip_mode) + self.clip_value = clip_value + self.enabled = True + + def __call__(self): + if self.enabled: + self.clip_fn(get_clip_parameters(self.model), self.clip_value) \ No newline at end of file diff --git a/timm/bits/logger.py b/timm/bits/logger.py new file mode 100644 index 00000000..2e2cd9da --- /dev/null +++ b/timm/bits/logger.py @@ -0,0 +1,223 @@ +import csv +import logging +import os +from collections import OrderedDict +from typing import Optional, Tuple, Dict, Union + +import torch + +_logger = logging.getLogger(__name__) + +try: + from torch.utils.tensorboard import SummaryWriter + HAS_TB = True +except ImportError as e: + HAS_TB = False + +try: + import wandb + HAS_WANDB = True +except ImportError: + HAS_WANDB = False + + +# FIXME old formatting for reference, to remove +# +# def log_eval(batch_idx, last_idx, batch_time, loss, top1, top5, log_suffix=''): +# log_name = 'Test' + log_suffix +# logging.info( +# f'{log_name}: [{batch_idx:>4d}/{last_idx}] ' +# f'Time: {batch_time.smooth_val:.3f} ({batch_time.avg:.3f}) ' +# f'Loss: {loss.smooth_val:>7.4f} ({loss.avg:>6.4f}) ' +# f'Acc@1: {top1.smooth_val:>7.4f} ({top1.avg:>7.4f}) ' +# f'Acc@5: {top5.smooth_val:>7.4f} ({top5.avg:>7.4f})' +# ) +# +# +# def log_train(epoch, step, num_steps, loss, batch_size, batch_time, data_time, lr, world_size=1): +# last_step = max(0, num_steps - 1) +# progress = 100. * step / last_step if last_step else 0. +# log_str = f'Train: {epoch} [{step:>4d}/{num_steps} ({progress:>3.0f}%)]' \ +# f' Time: {batch_time.smooth_val:.3f}s, {batch_size * world_size / batch_time.smooth_val:>7.2f}/s' \ +# f' ({batch_time.avg:.3f}s, {batch_size * world_size / batch_time.avg:>7.2f}/s)' \ +# f' Data: {data_time.smooth_val:.3f} ({data_time.avg:.3f})' +# log_str += f' Loss: {loss.smooth_val:>9.6f} ({loss.avg:>6.4f}) ' +# log_str += f' LR: {lr:.3e} ' +# +# if args.save_images and output_dir: +# torchvision.utils.save_image( +# input, +# os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx), +# padding=0, +# normalize=True) + + +def summary_row_dict(results, index=None, index_name='epoch'): + assert isinstance(results, dict) + row_dict = OrderedDict() + if index is not None: + row_dict[index_name] = index + if not results: + return row_dict + if isinstance(next(iter(results.values())), dict): + # each key in results is a per-phase results dict, flatten by prefixing with phase name + for p, pr in results.keys(): + assert isinstance(dict, pr) + row_dict.update([('_'.join([p, k]), v) for k, v in pr.items()]) + else: + row_dict.update(results) + return row_dict + + +class SummaryCsv: + def __init__(self, output_dir, filename='summary.csv'): + self.output_dir = output_dir + self.filename = os.path.join(output_dir, filename) + self.needs_header = not os.path.exists(self.filename) + + def update(self, row_dict): + with open(self.filename, mode='a') as cf: + dw = csv.DictWriter(cf, fieldnames=row_dict.keys()) + if self.needs_header: # first iteration (epoch == 1 can't be used) + dw.writeheader() + self.needs_header = False + dw.writerow(row_dict) + + +def _add_kwargs(text_update, name_map=None, **kwargs): + def _to_str(key, val): + if isinstance(val, float): + return f'{key}: {val:.4f}' + else: + return f'{key}: {val}' + + def _map_name(key, name_map, capitalize=True): + if name_map is None: + if capitalize: + return key.capitalize() if not key.isupper() else key + else: + return key + return name_map.get(key, None) + + for k, v in kwargs.items(): + if isinstance(v, dict): + # log each k, v of a dict kwarg as separate items + for kk, vv in v.items(): + name = _map_name(kk, name_map) + if not name: + continue + text_update += [_to_str(kk, vv)] + else: + name = _map_name(k, name_map, capitalize=True) + if not name: + continue + text_update += [_to_str(name, v)] + + +class Logger: + + def __init__( + self, + experiment_name=None, + output_dir=None, + logger=None, + log_wandb=False, + hparams=None): + + self.output_dir = output_dir # for tensorboard, csv, console logging to file? + self.logger = logger or logging.getLogger('log') + hparams = hparams or {} + + # Setup CSV writer(s) + if output_dir is not None: + self.csv_writer = SummaryCsv(output_dir=output_dir) + else: + self.csv_writer = None + + # Setup Tensorboard + self.summary_writer = None # FIXME tensorboard + + # Setup W&B + self.wandb_run = None + if log_wandb: + if HAS_WANDB: + self.wandb_run = wandb.init(project=experiment_name, config=hparams) + else: + _logger.warning("You've requested to log metrics to wandb but package not found. " + "Metrics not being logged to wandb, try `pip install wandb`") + + # FIXME image save + + def log_step( + self, + phase: str, + step: int, + end_step: Optional[int] = None, + loss: Optional[float] = None, + rate: Optional[float] = None, + epoch: Optional[int] = None, + phase_suffix: str = '', + **kwargs, + ): + """ log train/eval step + """ + phase_title = f'{phase.capitalize()} ({phase_suffix})' if phase_suffix else f'{phase.capitalize()}' + progress = 100. * step / end_step if end_step else 0. + text_update = [ + phase_title, + f'Epoch: {epoch}' if epoch is not None else None, + f'Step: {step}' if end_step is None else None, + f'Step: [{step}/{end_step} ({progress:>3.0f}%)]' if end_step is not None else None, + f'Rate: {rate:.2f}/s' if rate is not None else None, + f'Loss: {loss:.5f}' if loss is not None else None, + ] + _add_kwargs(text_update, **kwargs) + log_str = ' '.join(item for item in text_update if item) + self.logger.info(log_str) + if self.summary_writer is not None: + # FIXME log step values to tensorboard + pass + + def log_phase( + self, + phase: str = 'eval', + epoch: Optional[int] = None, + name_map: Optional[dict] = None, + **kwargs + ): + """log completion of evaluation or training phase + """ + title = [ + f'{phase.capitalize()}', + f'epoch: {epoch}' if epoch is not None else None, + 'completed. ', + ] + title_str = ' '.join(i for i in title if i) + results = [] + _add_kwargs(results, name_map=name_map, **kwargs) + log_str = title_str + ', '.join(item for item in results if item) + self.logger.info(log_str) + + def write_summary( + self, + results: Dict, # Dict or Dict of Dict where first level keys are treated as per-phase results + index: Optional[Union[int, str]] = None, + index_name: str = 'epoch', + ): + """ Log complete results for all phases (typically called at end of epoch) + + Args: + results (dict or dict[dict]): dict of results to write, or multiple dicts where first level + key is the name of results dict for each phase + index: value for row index (typically epoch #) + index_name: name for row index header (typically 'epoch') + """ + + row_dict = summary_row_dict(index=index, index_name=index_name, results=results) + if self.csv_writer: + self.csv_writer.update(row_dict) + if self.wandb_run is not None: + wandb.log(row_dict) + if self.summary_writer: + # FIXME log epoch summaries to tensorboard + pass diff --git a/timm/bits/tracker.py b/timm/bits/tracker.py new file mode 100644 index 00000000..12e0106b --- /dev/null +++ b/timm/bits/tracker.py @@ -0,0 +1,50 @@ +import time +from typing import Optional + +from timm.metrics import ScalarAvgMinMax + + +class Tracker: + + def __init__(self): + self.data_time = ScalarAvgMinMax() # time for data loader to produce batch of samples + self.step_time = ScalarAvgMinMax() # time for model step + self.iter_time = ScalarAvgMinMax() # full iteration time incl. data, step, and book-keeping + self.epoch_time = ScalarAvgMinMax() + + self.iter_timestamp: Optional[float] = None + self.prev_timestamp: Optional[float] = None + self.epoch_timestamp: Optional[float] = None + + def _measure_iter(self, ref_timestamp=None): + timestamp = time.perf_counter() + self.prev_timestamp = timestamp + + def mark_iter(self): + timestamp = time.perf_counter() + if self.iter_timestamp is not None: + iter_time = timestamp - self.iter_timestamp + self.iter_time.update(iter_time) + self.iter_timestamp = self.prev_timestamp = timestamp + + def mark_iter_data_end(self): + assert self.prev_timestamp is not None + timestamp = time.perf_counter() + data_time = timestamp - self.prev_timestamp + self.data_time.update(data_time) + self.prev_timestamp = timestamp + + def mark_iter_step_end(self): + assert self.prev_timestamp is not None + timestamp = time.perf_counter() + step_time = timestamp - self.prev_timestamp + self.step_time.update(step_time) + self.prev_timestamp = timestamp + + def mark_epoch(self): + timestamp = time.perf_counter() + if self.epoch_timestamp is not None: + epoch_time = timestamp - self.epoch_timestamp + self.epoch_time.update(epoch_time) + self.epoch_timestamp = timestamp + diff --git a/timm/bits/updater.py b/timm/bits/updater.py new file mode 100644 index 00000000..6612c8ea --- /dev/null +++ b/timm/bits/updater.py @@ -0,0 +1,54 @@ +from typing import Callable, Optional, Union + +import torch + +from .grad_clipper import GradClipper + + +class Updater: + + def __init__( + self, + optimizer: torch.optim.Optimizer, + clip_value: Optional[Union[Callable, float]] = None, + clip_mode: str = 'norm'): + + self.optimizer = optimizer + self.clipper: Optional[GradClipper] = None + if clip_value is not None: + if isinstance(clip_value, Callable): + self.clipper = clip_value + else: + GradClipper(clip_value, clip_mode) + self.scaler = None + self.create_graph = getattr(self.optimizer, 'second_order', False) + self.num_accumulated = 0 + self.after_step_closure = False + + def apply(self, loss: torch.Tensor, accumulate=False): + loss.backward(create_graph=self.create_graph) + if self.clipper is not None: + self.clipper() + if not accumulate: + self.optimizer.step() + self.reset() + else: + self.num_accumulated += 1 + + def reset(self): + self.optimizer.zero_grad() + self.num_accumulated = 0 + + def state_dict(self): + state_dict = dict(optimizer=self.optimizer.state_dict()) + if self.scaler is not None: + state_dict['scaler'] = self.scaler.state_dict() + + def load_state_dict(self, state_dict): + if 'optimizer' in state_dict: + self.optimizer.load_state_dict(state_dict['optimizer']) + if 'scaler' in state_dict and self.scaler is not None: + self.scaler.load_state_dict(state_dict['scaler']) + + + diff --git a/timm/bits/updater_cuda.py b/timm/bits/updater_cuda.py new file mode 100644 index 00000000..799aef00 --- /dev/null +++ b/timm/bits/updater_cuda.py @@ -0,0 +1,36 @@ +from typing import Callable, Optional, Union, Any + +import torch + +from .updater import Updater + + +class UpdaterCuda(Updater): + def __init__( + self, + optimizer: torch.optim.Optimizer, + clip_value: Optional[Union[Callable, float]] = None, + clip_mode: str = 'norm', + use_scaler: bool = False, + scaler_kwargs: Any = None, + ): + super().__init__(optimizer=optimizer, clip_value=clip_value, clip_mode=clip_mode) + scaler_kwargs = scaler_kwargs or {} + if use_scaler: + self.scaler = torch.cuda.amp.GradScaler(**scaler_kwargs) + + def apply(self, loss: torch.Tensor, accumulate=False): + if self.scaler is not None: + self.scaler.scale(loss).backward(create_graph=self.create_graph) + if self.clipper is not None: + self.scaler.unscale_(self.optimizer) # unscale the gradients of optimizer's assigned params in-place + self.clipper() + if not accumulate: + self.scaler.step(self.optimizer) + self.reset() + else: + self.num_accumulated += 1 + self.scaler.update() + else: + Updater.apply(self, loss, accumulate) + diff --git a/timm/bits/updater_factory.py b/timm/bits/updater_factory.py new file mode 100644 index 00000000..aba008d2 --- /dev/null +++ b/timm/bits/updater_factory.py @@ -0,0 +1,30 @@ +from typing import Callable, Optional, Union, Any + +import torch + +from .device_env import DeviceEnv +from .device_env_factory import get_device +from .updater import Updater +from .updater_cuda import UpdaterCuda +from .updater_xla import UpdaterXla + + +def create_updater( + optimizer: torch.optim.Optimizer, + dev_env: Optional[DeviceEnv] = None, + clip_value: Optional[Union[Callable, float]] = None, + clip_mode: str = 'norm', + scaler_kwargs: Any = None) -> Updater: + + if not dev_env: + dev_env = get_device() + + updater_kwargs = dict( + optimizer=optimizer, clip_value=clip_value, clip_mode=clip_mode, scaler_kwargs=scaler_kwargs) + if dev_env.type == 'xla': + return UpdaterXla(**updater_kwargs, use_scaler=dev_env.amp) + elif dev_env.type == 'cuda': + return UpdaterCuda(**updater_kwargs, use_scaler=dev_env.amp) + else: + updater_kwargs.pop('scaler_kwargs', None) + return Updater(**updater_kwargs) diff --git a/timm/bits/updater_xla.py b/timm/bits/updater_xla.py new file mode 100644 index 00000000..0789f06f --- /dev/null +++ b/timm/bits/updater_xla.py @@ -0,0 +1,52 @@ +from typing import Callable, Optional, Union, Any + +import torch + +try: + import torch_xla.core.xla_model as xm + import torch_xla.amp as xa + _HAS_XLA = True +except ImportError as e: + xm = None + _HAS_XLA = False + +from .updater import Updater + + +class UpdaterXla(Updater): + + def __init__( + self, + optimizer: torch.optim.Optimizer, + clip_value: Optional[Union[Callable, float]] = None, + clip_mode: str = 'norm', + use_scaler: bool = False, + scaler_kwargs: Any = None, + ): + super().__init__(optimizer=optimizer, clip_value=clip_value, clip_mode=clip_mode) + self.after_step_closure = True + if use_scaler: + self.scaler = xa.GradScaler(**scaler_kwargs) + + def apply(self, loss: torch.Tensor, accumulate: bool = False): + if self.scaler is None: + loss.backward(create_graph=self.create_graph) + gradients = xm._fetch_gradients(self.optimizer) + xm.all_reduce('sum', gradients, scale=1.0 / xm.xrt_world_size()) + if self.clipper is not None: + self.clipper() + if not accumulate: + xm.optimizer_step(self.optimizer) + else: + self.scaler.scale(loss).backward(create_graph=self.create_graph) + if self.clipper is not None: + self.scaler.unscale_(self.optimizer) # unscale the gradients of optimizer's assigned params in-place + self.clipper() + if not accumulate: + self.scaler.step(self.optimizer) + self.reset() + self.scaler.update() + + def after_step(self, after_step_fn, *args): + xm.add_step_closure(after_step_fn, *args) + diff --git a/timm/data/collate.py b/timm/data/collate.py new file mode 100644 index 00000000..a1e37e1f --- /dev/null +++ b/timm/data/collate.py @@ -0,0 +1,38 @@ +import numpy as np + +import torch + + +def fast_collate(batch): + """ A fast collation function optimized for uint8 images (np array or torch) and int64 targets (labels)""" + assert isinstance(batch[0], tuple) + batch_size = len(batch) + if isinstance(batch[0][0], tuple): + # This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position + # such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position + inner_tuple_size = len(batch[0][0]) + flattened_batch_size = batch_size * inner_tuple_size + targets = torch.zeros(flattened_batch_size, dtype=torch.int64) + tensor = torch.zeros((flattened_batch_size, *batch[0][0][0].shape), dtype=torch.uint8) + for i in range(batch_size): + assert len(batch[i][0]) == inner_tuple_size # all input tensor tuples must be same length + for j in range(inner_tuple_size): + targets[i + j * batch_size] = batch[i][1] + tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j]) + return tensor, targets + elif isinstance(batch[0][0], np.ndarray): + targets = torch.tensor([b[1] for b in batch], dtype=torch.int64) + assert len(targets) == batch_size + tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) + for i in range(batch_size): + tensor[i] += torch.from_numpy(batch[i][0]) + return tensor, targets + elif isinstance(batch[0][0], torch.Tensor): + targets = torch.tensor([b[1] for b in batch], dtype=torch.int64) + assert len(targets) == batch_size + tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) + for i in range(batch_size): + tensor[i].copy_(batch[i][0]) + return tensor, targets + else: + assert False \ No newline at end of file diff --git a/timm/data/config.py b/timm/data/config.py index 38f5689a..06920d7d 100644 --- a/timm/data/config.py +++ b/timm/data/config.py @@ -70,6 +70,14 @@ def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, v elif 'crop_pct' in default_cfg: new_config['crop_pct'] = default_cfg['crop_pct'] + if getattr(args, 'mixup', 0) > 0 \ + or getattr(args, 'cutmix', 0) > 0. \ + or getattr(args, 'cutmix_minmax', None) is not None: + new_config['mixup'] = dict( + mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, + prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, + label_smoothing=args.smoothing, num_classes=args.num_classes) + if verbose: _logger.info('Data processing configuration for current model + dataset:') for n, v in new_config.items(): diff --git a/timm/data/fetcher.py b/timm/data/fetcher.py new file mode 100644 index 00000000..1cbc3fe5 --- /dev/null +++ b/timm/data/fetcher.py @@ -0,0 +1,69 @@ +import torch + +from .constants import * +from .random_erasing import RandomErasing +from. mixup import FastCollateMixup + + +class FetcherXla: + def __init__(self): + pass + + +class Fetcher: + + def __init__(self, + loader, + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + device=None, + dtype=None, + re_prob=0., + re_mode='const', + re_count=1, + re_num_splits=0): + self.loader = loader + self.mean = torch.tensor([x * 255 for x in mean]).view(1, 3, 1, 1) + self.std = torch.tensor([x * 255 for x in std]).view(1, 3, 1, 1) + self.device = torch.device(device) + self.dtype = dtype or torch.float32 + if device: + self.mean.to(device=device, dtype=self.dtype) + self.std.to(device=device, dtype=self.dtype) + if re_prob > 0.: + self.random_erasing = RandomErasing( + probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits) + else: + self.random_erasing = None + + def __iter__(self): + for sample, target in self.loader: + sample = sample.to(device=self.device) + target = target.to(device=self.device) + sample = sample.to(dtype=self.dtype).sub_(self.mean).div_(self.std) + if self.random_erasing is not None: + sample = self.random_erasing(sample) + yield sample, target + + def __len__(self): + return len(self.loader) + + @property + def sampler(self): + return self.loader.sampler + + @property + def dataset(self): + return self.loader.dataset + + @property + def mixup_enabled(self): + if isinstance(self.loader.collate_fn, FastCollateMixup): + return self.loader.collate_fn.mixup_enabled + else: + return False + + @mixup_enabled.setter + def mixup_enabled(self, x): + if isinstance(self.loader.collate_fn, FastCollateMixup): + self.loader.collate_fn.mixup_enabled = x \ No newline at end of file diff --git a/timm/data/loader.py b/timm/data/loader.py index 76144669..9b15eb02 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -7,122 +7,15 @@ Hacked together by / Copyright 2020 Ross Wightman """ import torch.utils.data -import numpy as np +from timm.bits import get_device + +from .fetcher import Fetcher +from .prefetcher_cuda import PrefetcherCuda +from .collate import fast_collate from .transforms_factory import create_transform from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .distributed_sampler import OrderedDistributedSampler -from .random_erasing import RandomErasing -from .mixup import FastCollateMixup - - -def fast_collate(batch): - """ A fast collation function optimized for uint8 images (np array or torch) and int64 targets (labels)""" - assert isinstance(batch[0], tuple) - batch_size = len(batch) - if isinstance(batch[0][0], tuple): - # This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position - # such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position - inner_tuple_size = len(batch[0][0]) - flattened_batch_size = batch_size * inner_tuple_size - targets = torch.zeros(flattened_batch_size, dtype=torch.int64) - tensor = torch.zeros((flattened_batch_size, *batch[0][0][0].shape), dtype=torch.uint8) - for i in range(batch_size): - assert len(batch[i][0]) == inner_tuple_size # all input tensor tuples must be same length - for j in range(inner_tuple_size): - targets[i + j * batch_size] = batch[i][1] - tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j]) - return tensor, targets - elif isinstance(batch[0][0], np.ndarray): - targets = torch.tensor([b[1] for b in batch], dtype=torch.int64) - assert len(targets) == batch_size - tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) - for i in range(batch_size): - tensor[i] += torch.from_numpy(batch[i][0]) - return tensor, targets - elif isinstance(batch[0][0], torch.Tensor): - targets = torch.tensor([b[1] for b in batch], dtype=torch.int64) - assert len(targets) == batch_size - tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) - for i in range(batch_size): - tensor[i].copy_(batch[i][0]) - return tensor, targets - else: - assert False - - -class PrefetchLoader: - - def __init__(self, - loader, - mean=IMAGENET_DEFAULT_MEAN, - std=IMAGENET_DEFAULT_STD, - fp16=False, - re_prob=0., - re_mode='const', - re_count=1, - re_num_splits=0): - self.loader = loader - self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1) - self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1) - self.fp16 = fp16 - if fp16: - self.mean = self.mean.half() - self.std = self.std.half() - if re_prob > 0.: - self.random_erasing = RandomErasing( - probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits) - else: - self.random_erasing = None - - def __iter__(self): - stream = torch.cuda.Stream() - first = True - - for next_input, next_target in self.loader: - with torch.cuda.stream(stream): - next_input = next_input.cuda(non_blocking=True) - next_target = next_target.cuda(non_blocking=True) - if self.fp16: - next_input = next_input.half().sub_(self.mean).div_(self.std) - else: - next_input = next_input.float().sub_(self.mean).div_(self.std) - if self.random_erasing is not None: - next_input = self.random_erasing(next_input) - - if not first: - yield input, target - else: - first = False - - torch.cuda.current_stream().wait_stream(stream) - input = next_input - target = next_target - - yield input, target - - def __len__(self): - return len(self.loader) - - @property - def sampler(self): - return self.loader.sampler - - @property - def dataset(self): - return self.loader.dataset - - @property - def mixup_enabled(self): - if isinstance(self.loader.collate_fn, FastCollateMixup): - return self.loader.collate_fn.mixup_enabled - else: - return False - - @mixup_enabled.setter - def mixup_enabled(self, x): - if isinstance(self.loader.collate_fn, FastCollateMixup): - self.loader.collate_fn.mixup_enabled = x def create_loader( @@ -130,7 +23,7 @@ def create_loader( input_size, batch_size, is_training=False, - use_prefetcher=True, + dev_env=None, no_aug=False, re_prob=0., re_mode='const', @@ -163,7 +56,7 @@ def create_loader( dataset.transform = create_transform( input_size, is_training=is_training, - use_prefetcher=use_prefetcher, + use_fetcher=True, no_aug=no_aug, scale=scale, ratio=ratio, @@ -183,6 +76,9 @@ def create_loader( separate=num_aug_splits > 0, ) + if dev_env is None: + dev_env = get_device() + sampler = None if distributed and not isinstance(dataset, torch.utils.data.IterableDataset): if is_training: @@ -193,10 +89,9 @@ def create_loader( sampler = OrderedDistributedSampler(dataset) if collate_fn is None: - collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate + collate_fn = fast_collate loader_class = torch.utils.data.DataLoader - if use_multi_epochs_loader: loader_class = MultiEpochsDataLoader @@ -214,18 +109,19 @@ def create_loader( except TypeError as e: loader_args.pop('persistent_workers') # only in Pytorch 1.7+ loader = loader_class(dataset, **loader_args) - if use_prefetcher: - prefetch_re_prob = re_prob if is_training and not no_aug else 0. - loader = PrefetchLoader( - loader, - mean=mean, - std=std, - fp16=fp16, - re_prob=prefetch_re_prob, - re_mode=re_mode, - re_count=re_count, - re_num_splits=re_num_splits - ) + + fetcher_kwargs = dict( + mean=mean, + std=std, + re_prob=re_prob if is_training and not no_aug else 0., + re_mode=re_mode, + re_count=re_count, + re_num_splits=re_num_splits + ) + if dev_env.type == 'cuda': + loader = PrefetcherCuda(loader, **fetcher_kwargs) + else: + loader = Fetcher(loader, device=dev_env.device, **fetcher_kwargs) return loader diff --git a/timm/data/prefetcher_cuda.py b/timm/data/prefetcher_cuda.py new file mode 100644 index 00000000..4f1c4e10 --- /dev/null +++ b/timm/data/prefetcher_cuda.py @@ -0,0 +1,79 @@ +import torch.cuda + +from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .mixup import FastCollateMixup +from .random_erasing import RandomErasing + + +class PrefetcherCuda: + + def __init__(self, + loader, + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + fp16=False, + re_prob=0., + re_mode='const', + re_count=1, + re_num_splits=0): + self.loader = loader + self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1) + self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1) + self.fp16 = fp16 + if fp16: + self.mean = self.mean.half() + self.std = self.std.half() + if re_prob > 0.: + self.random_erasing = RandomErasing( + probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits) + else: + self.random_erasing = None + + def __iter__(self): + stream = torch.cuda.Stream() + first = True + + for next_input, next_target in self.loader: + with torch.cuda.stream(stream): + next_input = next_input.cuda(non_blocking=True) + next_target = next_target.cuda(non_blocking=True) + if self.fp16: + next_input = next_input.half().sub_(self.mean).div_(self.std) + else: + next_input = next_input.float().sub_(self.mean).div_(self.std) + if self.random_erasing is not None: + next_input = self.random_erasing(next_input) + + if not first: + yield input, target + else: + first = False + + torch.cuda.current_stream().wait_stream(stream) + input = next_input + target = next_target + + yield input, target + + def __len__(self): + return len(self.loader) + + @property + def sampler(self): + return self.loader.sampler + + @property + def dataset(self): + return self.loader.dataset + + @property + def mixup_enabled(self): + if isinstance(self.loader.collate_fn, FastCollateMixup): + return self.loader.collate_fn.mixup_enabled + else: + return False + + @mixup_enabled.setter + def mixup_enabled(self, x): + if isinstance(self.loader.collate_fn, FastCollateMixup): + self.loader.collate_fn.mixup_enabled = x \ No newline at end of file diff --git a/timm/data/tf_preprocessing.py b/timm/data/tf_preprocessing.py index 44b4a3af..0e657a9a 100644 --- a/timm/data/tf_preprocessing.py +++ b/timm/data/tf_preprocessing.py @@ -22,7 +22,10 @@ Hacked together by / Copyright 2020 Ross Wightman # limitations under the License. # ============================================================================== """ImageNet preprocessing for MnasNet.""" -import tensorflow as tf +import tensorflow.compat.v1 as tf +tf.disable_v2_behavior() +tf.compat.v1.disable_eager_execution() + import numpy as np IMAGE_SIZE = 224 @@ -131,6 +134,39 @@ def _decode_and_center_crop(image_bytes, image_size, resize_method): return image +def crop(image_bytes, crop_window): + """Helper function to crop a jpeg or a decoded image.""" + if image_bytes.dtype == tf.dtypes.string: + image = tf.image.decode_and_crop_jpeg(image_bytes, + tf.stack(crop_window), + channels=3) + else: + image = tf.image.crop_to_bounding_box(image_bytes, *crop_window) + return image + + +def _decode_and_resize_then_crop( + image_bytes: tf.Tensor, + image_size, + crop_pct: float = 32, +) -> tf.Tensor: + """Rescales an image to image_size / crop_pct, then center crops.""" + image = tf.image.decode_jpeg(image_bytes, channels=3) + # Scale image to "scaled size" before taking a center crop + if crop_pct > 1.0: # If crop_pct is >1, treat it as num pad pixels (like VGG) + scale_size = tuple([int(x + crop_pct) for x in image_size]) + else: + scale_size = tuple([int(float(x) / crop_pct) for x in image_size]) + image = tf.image.resize(image, scale_size, tf.image.ResizeMethod.BICUBIC) + crop_height = tf.cast(image_size[0], tf.int32) + crop_width = tf.cast(image_size[1], tf.int32) + offset_height = ((scale_size[0] - crop_height) + 1) // 2 + offset_width = ((scale_size[1] - crop_width) + 1) // 2 + crop_window = [offset_height, offset_width, crop_height, crop_width] + image = crop(image, crop_window) + return image + + def _flip(image): """Random horizontal image flip.""" image = tf.image.random_flip_left_right(image) @@ -172,6 +208,7 @@ def preprocess_for_eval(image_bytes, use_bfloat16, image_size=IMAGE_SIZE, interp """ resize_method = tf.image.ResizeMethod.BICUBIC if interpolation == 'bicubic' else tf.image.ResizeMethod.BILINEAR image = _decode_and_center_crop(image_bytes, image_size, resize_method) + #image = _decode_and_resize_then_crop(image_bytes, (image_size, image_size), resize_method) image = tf.reshape(image, [image_size, image_size, 3]) image = tf.image.convert_image_dtype( image, dtype=tf.bfloat16 if use_bfloat16 else tf.float32) diff --git a/timm/data/transforms_factory.py b/timm/data/transforms_factory.py index df6e0de0..16e08a39 100644 --- a/timm/data/transforms_factory.py +++ b/timm/data/transforms_factory.py @@ -167,7 +167,7 @@ def transforms_imagenet_eval( def create_transform( input_size, is_training=False, - use_prefetcher=False, + use_fetcher=False, no_aug=False, scale=None, ratio=None, @@ -191,7 +191,7 @@ def create_transform( else: img_size = input_size - if tf_preprocessing and use_prefetcher: + if tf_preprocessing and use_fetcher: assert not separate, "Separate transforms not supported for TF preprocessing" from timm.data.tf_preprocessing import TfPreprocessTransform transform = TfPreprocessTransform( @@ -202,7 +202,7 @@ def create_transform( transform = transforms_noaug_train( img_size, interpolation=interpolation, - use_prefetcher=use_prefetcher, + use_prefetcher=use_fetcher, mean=mean, std=std) elif is_training: @@ -215,7 +215,7 @@ def create_transform( color_jitter=color_jitter, auto_augment=auto_augment, interpolation=interpolation, - use_prefetcher=use_prefetcher, + use_prefetcher=use_fetcher, mean=mean, std=std, re_prob=re_prob, @@ -228,7 +228,7 @@ def create_transform( transform = transforms_imagenet_eval( img_size, interpolation=interpolation, - use_prefetcher=use_prefetcher, + use_prefetcher=use_fetcher, mean=mean, std=std, crop_pct=crop_pct) diff --git a/timm/metrics/__init__.py b/timm/metrics/__init__.py new file mode 100644 index 00000000..93a2773e --- /dev/null +++ b/timm/metrics/__init__.py @@ -0,0 +1,4 @@ +from .accuracy import Accuracy, AccuracyTopK +from .precision_recall import PrecisionRecall +from .scalar_avg import ScalarAvgMinMax +from .tensor_avg import TensorAvg, TensorEma diff --git a/timm/metrics/accuracy.py b/timm/metrics/accuracy.py new file mode 100644 index 00000000..98aa59eb --- /dev/null +++ b/timm/metrics/accuracy.py @@ -0,0 +1,112 @@ +import torch +from typing import Optional, Tuple, Dict + + +class Accuracy(torch.nn.Module): + + def __init__(self, threshold=0.5, multi_label=False): + self.threshold = threshold + self.eps = 1e-8 + self.multi_label = multi_label + + # statistics / counts + self._correct_sum = torch.tensor(0, dtype=torch.long) + self._total_sum = torch.tensor(0, dtype=torch.long) + + def update(self, predictions, target): + raise NotImplemented() + + def reset(self): + self._correct_sum = 0 + self._total_sum = 0 + + @property + def counts(self): + pass + + def compute(self): + raise NotImplemented() + + +class AccuracyTopK(torch.nn.Module): + + def __init__(self, topk=(1, 5), device=None): + super().__init__() + self.eps = 1e-8 + self.device = device + self.topk = topk + self.maxk = max(topk) + + # statistics / counts + self.reset() + + def update(self, predictions: torch.Tensor, target: torch.Tensor): + sorted_indices = predictions.topk(self.maxk, dim=1)[1] + sorted_indices.t_() + correct = sorted_indices.eq(target.reshape(1, -1).expand_as(sorted_indices)) + + batch_size = target.shape[0] + correct_k = {k: correct[:k].reshape(-1).float().sum(0) for k in self.topk} + for k, v in correct_k.items(): + attr = f'_correct_top{k}' + old_v = getattr(self, attr) + setattr(self, attr, old_v + v) + self._total_sum += batch_size + + def reset(self): + for k in self.topk: + setattr(self, f'_correct_top{k}', torch.tensor(0, dtype=torch.float32)) + self._total_sum = torch.tensor(0, dtype=torch.float32) + + @property + def counts(self): + pass + + def compute(self) -> Dict[str, torch.Tensor]: + return {f'top{k}': 100 * getattr(self, f'_correct_top{k}') / self._total_sum for k in self.topk} + + +# +# class AccuracyTopK: +# +# def __init__(self, topk=(1, 5), device=None): +# self.eps = 1e-8 +# self.device = device +# self.topk = topk +# self.maxk = max(topk) +# +# # statistics / counts +# self._correct_sum = None +# self._total_sum = None +# +# def _check_init(self, device): +# to_device = self.device if self.device else device +# if self._correct_sum is None: +# self._correct_sum = {f'top{k}': torch.tensor(0., device=to_device) for k in self.topk} +# if self._total_sum is None: +# self._total_sum = torch.tensor(0, dtype=torch.long, device=to_device) +# +# def update(self, predictions: torch.Tensor, target: torch.Tensor): +# sorted_indices = predictions.topk(self.maxk, dim=1)[1] +# sorted_indices.t_() +# correct = sorted_indices.eq(target.reshape(1, -1).expand_as(sorted_indices)) +# +# batch_size = target.shape[0] +# correct_k = {f'top{k}': correct[:k].reshape(-1).float().sum(0) for k in self.topk} +# self._check_init(device=predictions.device) +# for k, v in correct_k.items(): +# old_v = self._correct_sum[k] +# self._correct_sum[k] = old_v + v +# self._total_sum += batch_size +# +# def reset(self): +# self._correct_sum = None +# self._total_sum = None +# +# @property +# def counts(self): +# pass +# +# def compute(self) -> Dict[str, torch.Tensor]: +# assert self._correct_sum is not None and self._total_sum is not None +# return {k: 100 * v / self._total_sum for k, v in self._correct_sum.items()} diff --git a/timm/metrics/precision_recall.py b/timm/metrics/precision_recall.py new file mode 100644 index 00000000..a5a38f91 --- /dev/null +++ b/timm/metrics/precision_recall.py @@ -0,0 +1,117 @@ +import torch +import torch.nn.functional as F + + +class PrecisionRecall: + + def __init__(self, threshold=0.5, multi_label=False, device=None): + self.threshold = threshold + self.device = device + self.multi_label = multi_label + + # statistics + + # the total number of true positive instances under each class + # Shape: (num_classes, ) + self._tp_sum = None + + # the total number of instances + # Shape: (num_classes, ) + self._total_sum = None + + # the total number of instances under each _predicted_ class, + # including true positives and false positives + # Shape: (num_classes, ) + self._pred_sum = None + + # the total number of instances under each _true_ class, + # including true positives and false negatives + # Shape: (num_classes, ) + self._true_sum = None + + self.reset() + + def reset(self): + self._tp_sum = None + self._total_sum = None + self._pred_sum = None + self._true_sum = None + + def update(self, predictions, target): + output_type = predictions.type() + num_classes = predictions.size(-1) + if self.multi_label: + if self.threshold is not None: + predictions = (predictions > self.threshold).type(output_type) + predictions = predictions.t().reshape(num_classes, -1) + target = target.t().reshape(num_classes, -1) + else: + target = F.one_hot(target.view(-1), num_classes=num_classes) + indices = torch.argmax(predictions, dim=1).view(-1) + predictions = F.one_hot(indices, num_classes=num_classes) + # FIXME make sure binary case works + + target = target.type(output_type) + correct = (target * predictions > 0).type(output_type) + pred_positives = predictions.sum(dim=0) + target_positives = target.sum(dim=0) + if correct.sum() == 0: + true_positives = torch.zeros_like(pred_positives) + else: + true_positives = correct.sum(dim=0) + + if self._tp_sum is None: + self._tp_sum = torch.zeros(num_classes, device=self.device) + self._true_sum = torch.zeros(num_classes, device=self.device) + self._pred_sum = torch.zeros(num_classes, device=self.device) + self._total_sum = torch.tensor(0, device=self.device) + + self._tp_sum += true_positives + self._pred_sum += pred_positives + self._true_sum += target_positives + self._total_sum += target.shape[0] + + def counts_as_tuple(self, reduce=False): + tp_sum = self._tp_sum + pred_sum = self._pred_sum + true_sum = self._true_sum + total_sum = self._total_sum + if reduce: + tp_sum = reduce_tensor_sum(tp_sum) + pred_sum = reduce_tensor_sum(pred_sum) + true_sum = reduce_tensor_sum(true_sum) + total_sum = reduce_tensor_sum(total_sum) + return tp_sum, pred_sum, true_sum, total_sum + + def counts(self, reduce=False): + tp_sum, pred_sum, true_sum, total_sum = self.counts_as_tuple(reduce=reduce) + return dict(tp_sum=tp_sum, pred_sum=pred_sum, true_sum=true_sum, total_sum=total_sum) + + def confusion(self, reduce=False): + tp_sum, pred_sum, true_sum, total_sum = self.counts_as_tuple(reduce=reduce) + fp = pred_sum - tp_sum + fn = true_sum - tp_sum + tp = tp_sum + tn = total_sum - tp - fp - fn + return dict(tp=tp, fp=fp, fn=fn, tn=tn) + + def compute(self, fscore_beta=1, average='micro', no_reduce=False, distributed=False): + tp_sum, pred_sum, true_sum, total_sum = self.counts_as_tuple(reduce=distributed) + if average == 'micro': + tp_sum = tp_sum.sum() + pred_sum = pred_sum.sum() + true_sum = true_sum.sum() + + precision = tp_sum / pred_sum + recall = tp_sum / true_sum + beta_sq = fscore_beta ** 2 + f1_denom = beta_sq * precision + recall + fscore = (1 + beta_sq) * precision * recall / f1_denom + + if average == 'macro' and not no_reduce: + precision = precision.mean() + recall = recall.mean() + fscore = fscore.mean() + return dict(fscore=fscore, precision=precision, recall=recall) + + return dict(fscore=fscore, precision=precision, recall=recall) diff --git a/timm/metrics/scalar_avg.py b/timm/metrics/scalar_avg.py new file mode 100644 index 00000000..f5d95807 --- /dev/null +++ b/timm/metrics/scalar_avg.py @@ -0,0 +1,30 @@ +class ScalarAvgMinMax: + + """Computes and stores the average and current value""" + def __init__(self): + self.val = 0 + self.avg = 0 + self.min = None + self.max = None + self.sum = 0 + self.count = 0 + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.min = None + self.max = None + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.min = val if self.min is None else min(self.min, val) + self.max = val if self.max is None else max(self.max, val) + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + + diff --git a/timm/metrics/tensor_avg.py b/timm/metrics/tensor_avg.py new file mode 100644 index 00000000..ac2fb6ed --- /dev/null +++ b/timm/metrics/tensor_avg.py @@ -0,0 +1,42 @@ +import torch + + +class TensorAvg: + + """Computes and stores the average and current value""" + def __init__(self): + self.sum = None + self.count = None + self.reset() + + def reset(self): + self.sum = None + self.count = None + + def update(self, val: torch.Tensor, n=1): + if self.sum is None: + self.sum = torch.zeros_like(val) + self.count = torch.tensor(0, dtype=torch.long, device=val.device) + self.sum += (val * n) + self.count += n + + def compute(self): + return self.sum / self.count + + +class TensorEma: + + """Computes and stores the average and current value""" + def __init__(self, smoothing_factor=0.9, init_zero=False): + self.smoothing_factor = smoothing_factor + self.init_zero = init_zero + self.val = None + self.reset() + + def reset(self): + self.val = None + + def update(self, val): + if self.val is None: + self.val = torch.zeros_like(val) if self.init_zero else val.clone() + self.val = (1. - self.smoothing_factor) * val + self.smoothing_factor * self.val diff --git a/timm/scheduler/scheduler.py b/timm/scheduler/scheduler.py index 21d51509..79e9a5e9 100644 --- a/timm/scheduler/scheduler.py +++ b/timm/scheduler/scheduler.py @@ -65,14 +65,16 @@ class Scheduler: return None def step(self, epoch: int, metric: float = None) -> None: - self.metric = metric + if metric is not None: + self.metric = metric values = self.get_epoch_values(epoch) if values is not None: values = self._add_noise(values, epoch) self.update_groups(values) def step_update(self, num_updates: int, metric: float = None): - self.metric = metric + if metric is not None: + self.metric = metric values = self.get_update_values(num_updates) if values is not None: values = self._add_noise(values, num_updates) diff --git a/train.py b/train.py index 85829fc1..f105e525 100755 --- a/train.py +++ b/train.py @@ -20,14 +20,14 @@ import yaml import os import logging from collections import OrderedDict -from contextlib import suppress from datetime import datetime import torch import torch.nn as nn import torchvision.utils -from torch.nn.parallel import DistributedDataParallel as NativeDDP +from timm.bits import initialize_device, DeviceEnv, create_updater, Updater, Logger, Tracker +from timm.metrics import TensorAvg, AccuracyTopK from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint,\ convert_splitbn_model, model_parameters @@ -35,32 +35,11 @@ from timm.utils import * from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy from timm.optim import create_optimizer_v2, optimizer_kwargs from timm.scheduler import create_scheduler -from timm.utils import ApexScaler, NativeScaler - -try: - from apex import amp - from apex.parallel import DistributedDataParallel as ApexDDP - from apex.parallel import convert_syncbn_model - has_apex = True -except ImportError: - has_apex = False - -has_native_amp = False -try: - if getattr(torch.cuda.amp, 'autocast') is not None: - has_native_amp = True -except AttributeError: - pass - -try: - import wandb - has_wandb = True -except ImportError: - has_wandb = False - -torch.backends.cudnn.benchmark = True + + _logger = logging.getLogger('train') + # The first arg parser parses out only the --config argument, this argument is used to # load a yaml file containing key-values that override the defaults for the main parser below config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False) @@ -254,16 +233,10 @@ parser.add_argument('--save-images', action='store_true', default=False, help='save images of input bathes every log interval for debugging') parser.add_argument('--amp', action='store_true', default=False, help='use NVIDIA Apex AMP or Native AMP for mixed precision training') -parser.add_argument('--apex-amp', action='store_true', default=False, - help='Use NVIDIA Apex AMP mixed precision') -parser.add_argument('--native-amp', action='store_true', default=False, - help='Use Native Torch AMP mixed precision') parser.add_argument('--channels-last', action='store_true', default=False, help='Use channels_last memory layout') parser.add_argument('--pin-mem', action='store_true', default=False, help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') -parser.add_argument('--no-prefetcher', action='store_true', default=False, - help='disable fast prefetcher') parser.add_argument('--output', default='', type=str, metavar='PATH', help='path to output folder (default: none, current dir)') parser.add_argument('--experiment', default='', type=str, metavar='NAME', @@ -301,50 +274,15 @@ def _parse_args(): def main(): setup_default_logging() args, args_text = _parse_args() - - if args.log_wandb: - if has_wandb: - wandb.init(project=args.experiment, config=args) - else: - _logger.warning("You've requested to log metrics to wandb but package not found. " - "Metrics not being logged to wandb, try `pip install wandb`") - - args.prefetcher = not args.no_prefetcher - args.distributed = False - if 'WORLD_SIZE' in os.environ: - args.distributed = int(os.environ['WORLD_SIZE']) > 1 - args.device = 'cuda:0' - args.world_size = 1 - args.rank = 0 # global rank - if args.distributed: - args.device = 'cuda:%d' % args.local_rank - torch.cuda.set_device(args.local_rank) - torch.distributed.init_process_group(backend='nccl', init_method='env://') - args.world_size = torch.distributed.get_world_size() - args.rank = torch.distributed.get_rank() - _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' - % (args.rank, args.world_size)) + + dev_env = initialize_device(amp=args.amp) + if dev_env.is_distributed: + _logger.info('Training in distributed mode with multiple processes, 1 device per process. Process %d, total %d.' + % (dev_env.global_rank, dev_env.world_size)) else: - _logger.info('Training with a single process on 1 GPUs.') - assert args.rank >= 0 - - # resolve AMP arguments based on PyTorch / Apex availability - use_amp = None - if args.amp: - # `--amp` chooses native amp before apex (APEX ver not actively maintained) - if has_native_amp: - args.native_amp = True - elif has_apex: - args.apex_amp = True - if args.apex_amp and has_apex: - use_amp = 'apex' - elif args.native_amp and has_native_amp: - use_amp = 'native' - elif args.apex_amp or args.native_amp: - _logger.warning("Neither APEX or native Torch AMP is available, using float32. " - "Install NVIDA apex or upgrade to PyTorch 1.6") - - random_seed(args.seed, args.rank) + _logger.info('Training with a single process on 1 device.') + + random_seed(args.seed, dev_env.global_rank) model = create_model( args.model, @@ -364,11 +302,11 @@ def main(): assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly - if args.local_rank == 0: + if dev_env.is_master: _logger.info( f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}') - data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) + data_config = resolve_data_config(vars(args), model=model, verbose=dev_env.is_master) # setup augmentation batch splits for contrastive loss or split bn num_aug_splits = 0 @@ -382,55 +320,33 @@ def main(): model = convert_splitbn_model(model, max(num_aug_splits, 2)) # move model to GPU, enable channels last layout if set - model.cuda() - if args.channels_last: - model = model.to(memory_format=torch.channels_last) + dev_env.to_device(model) # setup synchronized BatchNorm for distributed training - if args.distributed and args.sync_bn: + if dev_env.is_distributed and args.sync_bn: assert not args.split_bn - if has_apex and use_amp != 'native': - # Apex SyncBN preferred unless native amp is activated - model = convert_syncbn_model(model) - else: - model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) - if args.local_rank == 0: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + if dev_env.is_master: _logger.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') if args.torchscript: - assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' model = torch.jit.script(model) - optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args)) - - # setup automatic mixed-precision (AMP) loss scaling and op casting - amp_autocast = suppress # do nothing - loss_scaler = None - if use_amp == 'apex': - model, optimizer = amp.initialize(model, optimizer, opt_level='O1') - loss_scaler = ApexScaler() - if args.local_rank == 0: - _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') - elif use_amp == 'native': - amp_autocast = torch.cuda.amp.autocast - loss_scaler = NativeScaler() - if args.local_rank == 0: - _logger.info('Using native Torch AMP. Training in mixed precision.') - else: - if args.local_rank == 0: - _logger.info('AMP not enabled. Training in float32.') + updater = create_updater( + create_optimizer_v2(model, **optimizer_kwargs(cfg=args)), + clip_value=args.clip_grad, clip_mode=args.clip_mode) # optionally resume from a checkpoint resume_epoch = None if args.resume: resume_epoch = resume_checkpoint( model, args.resume, - optimizer=None if args.no_resume_opt else optimizer, - loss_scaler=None if args.no_resume_opt else loss_scaler, - log_info=args.local_rank == 0) + optimizer=None if args.no_resume_opt else updater.optimizer, + loss_scaler=None if args.no_resume_opt else updater.scaler, + log_info=dev_env.is_master) # setup exponential moving average of model weights, SWA could be used here too model_ema = None @@ -442,20 +358,14 @@ def main(): load_checkpoint(model_ema.module, args.resume, use_ema=True) # setup distributed training - if args.distributed: - if has_apex and use_amp != 'native': - # Apex DDP preferred unless native amp is activated - if args.local_rank == 0: - _logger.info("Using NVIDIA APEX DistributedDataParallel.") - model = ApexDDP(model, delay_allreduce=True) - else: - if args.local_rank == 0: - _logger.info("Using native Torch DistributedDataParallel.") - model = NativeDDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1 - # NOTE: EMA model does not need to be wrapped by DDP + if dev_env.is_distributed: + if dev_env.is_master: + _logger.info("Distributing model.") + model = dev_env.wrap_distributed(model) + # NOTE: EMA model does not need to be wrapped by DDP # setup learning rate schedule and starting epoch - lr_scheduler, num_epochs = create_scheduler(args, optimizer) + lr_scheduler, num_epochs = create_scheduler(args, updater.optimizer) start_epoch = 0 if args.start_epoch is not None: # a specified start_epoch will always override the resume epoch @@ -465,7 +375,7 @@ def main(): if lr_scheduler is not None and start_epoch > 0: lr_scheduler.step(start_epoch) - if args.local_rank == 0: + if dev_env.is_master: _logger.info('Scheduled epochs: {}'.format(num_epochs)) # create the train and eval datasets @@ -478,18 +388,14 @@ def main(): # setup mixup / cutmix collate_fn = None - mixup_fn = None mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None if mixup_active: mixup_args = dict( mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.num_classes) - if args.prefetcher: - assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) - collate_fn = FastCollateMixup(**mixup_args) - else: - mixup_fn = Mixup(**mixup_args) + assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) + collate_fn = FastCollateMixup(**mixup_args) # wrap dataset in AugMix helper if num_aug_splits > 1: @@ -504,7 +410,6 @@ def main(): input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, - use_prefetcher=args.prefetcher, no_aug=args.no_aug, re_prob=args.reprob, re_mode=args.remode, @@ -521,7 +426,7 @@ def main(): mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, - distributed=args.distributed, + distributed=dev_env.is_distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, use_multi_epochs_loader=args.use_multi_epochs_loader @@ -532,12 +437,11 @@ def main(): input_size=data_config['input_size'], batch_size=args.validation_batch_size_multiplier * args.batch_size, is_training=False, - use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, - distributed=args.distributed, + distributed=dev_env.is_distributed, crop_pct=data_config['crop_pct'], pin_memory=args.pin_mem, ) @@ -545,23 +449,24 @@ def main(): # setup loss function if args.jsd: assert num_aug_splits > 1 # JSD only valid with aug splits set - train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda() + train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing) elif mixup_active: # smoothing is handled with mixup target transform - train_loss_fn = SoftTargetCrossEntropy().cuda() + train_loss_fn = SoftTargetCrossEntropy() elif args.smoothing: - train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda() + train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing) else: - train_loss_fn = nn.CrossEntropyLoss().cuda() - validate_loss_fn = nn.CrossEntropyLoss().cuda() + train_loss_fn = nn.CrossEntropyLoss() + validate_loss_fn = nn.CrossEntropyLoss() + dev_env.to_device(train_loss_fn, validate_loss_fn) # setup checkpoint saver and eval metric tracking eval_metric = args.eval_metric best_metric = None best_epoch = None saver = None - output_dir = '' - if args.local_rank == 0: + output_dir = None + if dev_env.is_master: if args.experiment: exp_name = args.experiment else: @@ -573,42 +478,48 @@ def main(): output_dir = get_outdir(args.output if args.output else './output/train', exp_name) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver( - model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, + model=model, optimizer=updater.optimizer, args=args, model_ema=model_ema, amp_scaler=updater.scaler, checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist) with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) + logger = Logger(output_dir=output_dir, logger=_logger, hparams=vars(args)) + try: for epoch in range(start_epoch, num_epochs): - if args.distributed and hasattr(loader_train.sampler, 'set_epoch'): + if dev_env.is_distributed and hasattr(loader_train.sampler, 'set_epoch'): loader_train.sampler.set_epoch(epoch) + if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: + if loader_train.mixup_enabled: + loader_train.mixup_enabled = False train_metrics = train_one_epoch( - epoch, model, loader_train, optimizer, train_loss_fn, args, - lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, - amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn) + epoch, model, loader_train, updater, train_loss_fn, dev_env, + lr_scheduler=lr_scheduler, saver=saver, logger=logger, model_ema=model_ema, + log_interval=args.log_interval, recovery_interval=args.recovery_interval) - if args.distributed and args.dist_bn in ('broadcast', 'reduce'): - if args.local_rank == 0: + if dev_env.is_distributed and args.dist_bn in ('broadcast', 'reduce'): + if dev_env.is_master: _logger.info("Distributing BatchNorm running means and vars") - distribute_bn(model, args.world_size, args.dist_bn == 'reduce') + distribute_bn(model, dev_env.world_size, args.dist_bn == 'reduce') - eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) + eval_metrics = evaluate(model, loader_eval, validate_loss_fn, dev_env, logger=logger) if model_ema is not None and not args.model_ema_force_cpu: - if args.distributed and args.dist_bn in ('broadcast', 'reduce'): - distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') - ema_eval_metrics = validate( - model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)') + if dev_env.is_distributed and args.dist_bn in ('broadcast', 'reduce'): + distribute_bn(model_ema, dev_env.world_size, args.dist_bn == 'reduce') + + ema_eval_metrics = evaluate( + model_ema.module, loader_eval, validate_loss_fn, dev_env, + logger=logger, phase_suffix='EMA') eval_metrics = ema_eval_metrics if lr_scheduler is not None: # step LR for next epoch lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) - update_summary( - epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), - write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb) + if logger is not None: + logger.write_summary(index=epoch, results=dict(train=train_metrics, eval=eval_metric)) if saver is not None: # save proper checkpoint with eval metric @@ -622,175 +533,128 @@ def main(): def train_one_epoch( - epoch, model, loader, optimizer, loss_fn, args, - lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress, - loss_scaler=None, model_ema=None, mixup_fn=None): - - if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: - if args.prefetcher and loader.mixup_enabled: - loader.mixup_enabled = False - elif mixup_fn is not None: - mixup_fn.mixup_enabled = False - - second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order - batch_time_m = AverageMeter() - data_time_m = AverageMeter() - losses_m = AverageMeter() + epoch: int, + model: nn.Module, + loader, + updater: Updater, + loss_fn: nn.Module, + dev_env: DeviceEnv, + lr_scheduler=None, + saver: CheckpointSaver = None, + logger: Logger = None, + model_ema: nn.Module = None, + log_interval: int = 50, + recovery_interval: int = 0, +): + tracker = Tracker() + losses_m = TensorAvg() model.train() - end = time.time() - last_idx = len(loader) - 1 + end_idx = len(loader) - 1 num_updates = epoch * len(loader) - for batch_idx, (input, target) in enumerate(loader): - last_batch = batch_idx == last_idx - data_time_m.update(time.time() - end) - if not args.prefetcher: - input, target = input.cuda(), target.cuda() - if mixup_fn is not None: - input, target = mixup_fn(input, target) - if args.channels_last: - input = input.contiguous(memory_format=torch.channels_last) - - with amp_autocast(): - output = model(input) + batch_size = 0 + tracker.mark_iter() + for step_idx, (sample, target) in enumerate(loader): + tracker.mark_iter_data_end() + last_step = step_idx == end_idx + batch_size = max(batch_size, sample.shape[0]) + + with dev_env.autocast(): + output = model(sample) loss = loss_fn(output, target) - if not args.distributed: - losses_m.update(loss.item(), input.size(0)) - - optimizer.zero_grad() - if loss_scaler is not None: - loss_scaler( - loss, optimizer, - clip_grad=args.clip_grad, clip_mode=args.clip_mode, - parameters=model_parameters(model, exclude_head='agc' in args.clip_mode), - create_graph=second_order) - else: - loss.backward(create_graph=second_order) - if args.clip_grad is not None: - dispatch_clip_grad( - model_parameters(model, exclude_head='agc' in args.clip_mode), - value=args.clip_grad, mode=args.clip_mode) - optimizer.step() + updater.reset() + updater.apply(loss) + dev_env.mark_step() # FIXME + tracker.mark_iter_step_end() + losses_m.update(loss, sample.size(0)) if model_ema is not None: model_ema.update(model) - torch.cuda.synchronize() num_updates += 1 - batch_time_m.update(time.time() - end) - if last_batch or batch_idx % args.log_interval == 0: - lrl = [param_group['lr'] for param_group in optimizer.param_groups] + if last_step or (step_idx + 1) % log_interval == 0: + lrl = [param_group['lr'] for param_group in updater.optimizer.param_groups] lr = sum(lrl) / len(lrl) - if args.distributed: - reduced_loss = reduce_tensor(loss.data, args.world_size) - losses_m.update(reduced_loss.item(), input.size(0)) - - if args.local_rank == 0: - _logger.info( - 'Train: {} [{:>4d}/{} ({:>3.0f}%)] ' - 'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) ' - 'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s ' - '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' - 'LR: {lr:.3e} ' - 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format( - epoch, - batch_idx, len(loader), - 100. * batch_idx / last_idx, - loss=losses_m, - batch_time=batch_time_m, - rate=input.size(0) * args.world_size / batch_time_m.val, - rate_avg=input.size(0) * args.world_size / batch_time_m.avg, - lr=lr, - data_time=data_time_m)) - - if args.save_images and output_dir: - torchvision.utils.save_image( - input, - os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx), - padding=0, - normalize=True) - - if saver is not None and args.recovery_interval and ( - last_batch or (batch_idx + 1) % args.recovery_interval == 0): - saver.save_recovery(epoch, batch_idx=batch_idx) + if dev_env.is_master and logger is not None: + loss_avg = losses_m.compute() + logger.log_step( + 'Train', + step=step_idx, + end_step=end_idx, + loss=loss_avg.item(), + rate=(dev_env.world_size * batch_size) / tracker.iter_time.avg, + lr=lr, + ) + + if saver is not None and recovery_interval and (last_step or (step_idx + 1) % recovery_interval == 0): + saver.save_recovery(epoch, batch_idx=step_idx) if lr_scheduler is not None: - lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) + lr_scheduler.step_update(num_updates=num_updates) - end = time.time() + tracker.mark_iter() # end for - if hasattr(optimizer, 'sync_lookahead'): - optimizer.sync_lookahead() + if hasattr(updater.optimizer, 'sync_lookahead'): + updater.optimizer.sync_lookahead() + + return OrderedDict([('loss', losses_m.compute().item())]) - return OrderedDict([('loss', losses_m.avg)]) +def evaluate( + model: nn.Module, + loader, + loss_fn: nn.Module, + dev_env: DeviceEnv, + logger: Logger, + phase_suffix: str = '', + log_interval: int = 10, +): -def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''): - batch_time_m = AverageMeter() - losses_m = AverageMeter() - top1_m = AverageMeter() - top5_m = AverageMeter() + tracker = Tracker() + losses_m = TensorAvg() + accuracy_m = AccuracyTopK() model.eval() - end = time.time() - last_idx = len(loader) - 1 + end_idx = len(loader) - 1 + tracker.mark_iter() with torch.no_grad(): - for batch_idx, (input, target) in enumerate(loader): - last_batch = batch_idx == last_idx - if not args.prefetcher: - input = input.cuda() - target = target.cuda() - if args.channels_last: - input = input.contiguous(memory_format=torch.channels_last) - - with amp_autocast(): - output = model(input) - if isinstance(output, (tuple, list)): - output = output[0] - - # augmentation reduction - reduce_factor = args.tta - if reduce_factor > 1: - output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2) - target = target[0:target.size(0):reduce_factor] - - loss = loss_fn(output, target) - acc1, acc5 = accuracy(output, target, topk=(1, 5)) - - if args.distributed: - reduced_loss = reduce_tensor(loss.data, args.world_size) - acc1 = reduce_tensor(acc1, args.world_size) - acc5 = reduce_tensor(acc5, args.world_size) - else: - reduced_loss = loss.data - - torch.cuda.synchronize() - - losses_m.update(reduced_loss.item(), input.size(0)) - top1_m.update(acc1.item(), output.size(0)) - top5_m.update(acc5.item(), output.size(0)) - - batch_time_m.update(time.time() - end) - end = time.time() - if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0): - log_name = 'Test' + log_suffix - _logger.info( - '{0}: [{1:>4d}/{2}] ' - 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' - 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' - 'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) ' - 'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format( - log_name, batch_idx, last_idx, batch_time=batch_time_m, - loss=losses_m, top1=top1_m, top5=top5_m)) - - metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)]) - - return metrics + for step_idx, (sample, target) in enumerate(loader): + tracker.mark_iter_data_end() + last_step = step_idx == end_idx + + with dev_env.autocast(): + output = model(sample) + if isinstance(output, (tuple, list)): + output = output[0] + loss = loss_fn(output, target) + + dev_env.mark_step() # FIXME + tracker.mark_iter_step_end() + losses_m.update(loss, output.size(0)) + accuracy_m.update(output, target) + + if dev_env.is_master and (last_step or step_idx % log_interval == 0): + top1, top5 = accuracy_m.compute().values() + loss_avg = losses_m.compute() + logger.log_step( + 'Eval', + step=step_idx, + num_steps=end_idx, + loss=loss_avg.item(), + top1=top1.item(), + top5=top5.item(), + phase_suffix=phase_suffix, + ) + tracker.mark_iter() + + top1, top5 = accuracy_m.compute().values() + results = OrderedDict([('loss', losses_m.compute().item()), ('top1', top1.item()), ('top5', top5.item())]) + return results if __name__ == '__main__': diff --git a/validate.py b/validate.py index 74f8f435..add23469 100755 --- a/validate.py +++ b/validate.py @@ -17,27 +17,14 @@ import torch import torch.nn as nn import torch.nn.parallel from collections import OrderedDict -from contextlib import suppress +from timm.bits import initialize_device, Tracker, Logger +from timm.metrics import AccuracyTopK, TensorAvg from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet -from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_legacy - -has_apex = False -try: - from apex import amp - has_apex = True -except ImportError: - pass - -has_native_amp = False -try: - if getattr(torch.cuda.amp, 'autocast') is not None: - has_native_amp = True -except AttributeError: - pass - -torch.backends.cudnn.benchmark = True +from timm.utils import natural_key, setup_default_logging + + _logger = logging.getLogger('validate') @@ -72,36 +59,28 @@ parser.add_argument('--class-map', default='', type=str, metavar='FILENAME', help='path to class to idx mapping file (default: "")') parser.add_argument('--gp', default=None, type=str, metavar='POOL', help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') -parser.add_argument('--log-freq', default=10, type=int, +parser.add_argument('--log-freq', default=20, type=int, metavar='N', help='batch logging frequency (default: 10)') parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model') -parser.add_argument('--num-gpu', type=int, default=1, - help='Number of GPUS to use') -parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true', - help='disable test time pool') -parser.add_argument('--no-prefetcher', action='store_true', default=False, - help='disable fast prefetcher') +# parser.add_argument('--num-gpu', type=int, default=1, +# help='Number of GPUS to use') +parser.add_argument('--test-pool', dest='test_pool', action='store_true', + help='enable test time pool') parser.add_argument('--pin-mem', action='store_true', default=False, help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') parser.add_argument('--channels-last', action='store_true', default=False, help='Use channels_last memory layout') parser.add_argument('--amp', action='store_true', default=False, help='Use AMP mixed precision. Defaults to Apex, fallback to native Torch AMP.') -parser.add_argument('--apex-amp', action='store_true', default=False, - help='Use NVIDIA Apex AMP mixed precision') -parser.add_argument('--native-amp', action='store_true', default=False, - help='Use Native Torch AMP mixed precision') parser.add_argument('--tf-preprocessing', action='store_true', default=False, help='Use Tensorflow preprocessing pipeline (require CPU TF installed') parser.add_argument('--use-ema', dest='use_ema', action='store_true', help='use ema version of weights if present') parser.add_argument('--torchscript', dest='torchscript', action='store_true', help='convert model torchscript for inference') -parser.add_argument('--legacy-jit', dest='legacy_jit', action='store_true', - help='use legacy jit mode for pytorch 1.5/1.5.1/1.6 to get back fusion performance') parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', help='Output csv file for validation results (summary)') parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME', @@ -113,26 +92,8 @@ parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME', def validate(args): # might as well try to validate something args.pretrained = args.pretrained or not args.checkpoint - args.prefetcher = not args.no_prefetcher - amp_autocast = suppress # do nothing - if args.amp: - if has_native_amp: - args.native_amp = True - elif has_apex: - args.apex_amp = True - else: - _logger.warning("Neither APEX or Native Torch AMP is available.") - assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set." - if args.native_amp: - amp_autocast = torch.cuda.amp.autocast - _logger.info('Validating in mixed precision with native PyTorch AMP.') - elif args.apex_amp: - _logger.info('Validating in mixed precision with NVIDIA APEX AMP.') - else: - _logger.info('Validating in float32. AMP not enabled.') - if args.legacy_jit: - set_jit_legacy() + dev_env = initialize_device(amp=args.amp) # create model model = create_model( @@ -154,24 +115,16 @@ def validate(args): data_config = resolve_data_config(vars(args), model=model, use_test_size=True, verbose=True) test_time_pool = False - if not args.no_test_pool: + if args.test_pool: model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True) if args.torchscript: torch.jit.optimized_execution(True) model = torch.jit.script(model) - model = model.cuda() - if args.apex_amp: - model = amp.initialize(model, opt_level='O1') - - if args.channels_last: - model = model.to(memory_format=torch.channels_last) - - if args.num_gpu > 1: - model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))) - - criterion = nn.CrossEntropyLoss().cuda() + # FIXME device + model, criterion = dev_env.to_device(model, nn.CrossEntropyLoss()) + model.to(dev_env.device) dataset = create_dataset( root=args.data, name=args.dataset, split=args.split, @@ -194,7 +147,6 @@ def validate(args): dataset, input_size=data_config['input_size'], batch_size=args.batch_size, - use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], @@ -203,63 +155,61 @@ def validate(args): pin_memory=args.pin_mem, tf_preprocessing=args.tf_preprocessing) - batch_time = AverageMeter() - losses = AverageMeter() - top1 = AverageMeter() - top5 = AverageMeter() + logger = Logger(logger=_logger) + tracker = Tracker() + losses = TensorAvg() + accuracy = AccuracyTopK().to(dev_env.device) model.eval() + num_steps = len(loader) with torch.no_grad(): - # warmup, reduce variability of first batch time, especially for comparing torchscript vs non - input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).cuda() - if args.channels_last: - input = input.contiguous(memory_format=torch.channels_last) - model(input) - end = time.time() - for batch_idx, (input, target) in enumerate(loader): - if args.no_prefetcher: - target = target.cuda() - input = input.cuda() - if args.channels_last: - input = input.contiguous(memory_format=torch.channels_last) + tracker.mark_iter() + for step_idx, (sample, target) in enumerate(loader): + tracker.mark_iter_data_end() # compute output - with amp_autocast(): - output = model(input) + with dev_env.autocast(): + output = model(sample) if valid_labels is not None: output = output[:, valid_labels] loss = criterion(output, target) + if dev_env.type == 'cuda': + torch.cuda.synchronize() + #elif dev_env.type == 'xla': + # dev_env.mark_step() + tracker.mark_iter_step_end() + + losses.update(loss.detach(), sample.size(0)) if real_labels is not None: real_labels.add_result(output) - - # measure accuracy and record loss - acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5)) - losses.update(loss.item(), input.size(0)) - top1.update(acc1.item(), input.size(0)) - top5.update(acc5.item(), input.size(0)) - - # measure elapsed time - batch_time.update(time.time() - end) - end = time.time() - - if batch_idx % args.log_freq == 0: - _logger.info( - 'Test: [{0:>4d}/{1}] ' - 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' - 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' - 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' - 'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( - batch_idx, len(loader), batch_time=batch_time, - rate_avg=input.size(0) / batch_time.avg, - loss=losses, top1=top1, top5=top5)) + accuracy.update(output.detach(), target) + + if dev_env.type == 'xla': + dev_env.mark_step() + + tracker.mark_iter() + if step_idx % args.log_freq == 0: + top1, top5 = accuracy.compute().values() + loss_avg = losses.compute() + logger.log_step( + phase='eval', + step=step_idx, + num_steps=num_steps, + rate=args.batch_size / tracker.iter_time.avg, + loss=loss_avg.item(), + top1=top1.item(), + top5=top5.item(), + ) if real_labels is not None: # real labels mode replaces topk values at the end top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5) else: - top1a, top5a = top1.avg, top5.avg + top1a, top5a = accuracy.compute().values() + top1a, top5a = top1a.item(), top5a.item() + results = OrderedDict( top1=round(top1a, 4), top1_err=round(100 - top1a, 4), top5=round(top5a, 4), top5_err=round(100 - top5a, 4), @@ -267,9 +217,7 @@ def validate(args): img_size=data_config['input_size'][-1], cropt_pct=crop_pct, interpolation=data_config['interpolation']) - - _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format( - results['top1'], results['top1_err'], results['top5'], results['top5_err'])) + logger.log_phase(phase='eval', name_map={'top1': 'Acc@1', 'top5': 'Acc@5'}, **results) return results @@ -309,7 +257,6 @@ def main(): result = OrderedDict(model=args.model) r = {} while not r and batch_size >= args.num_gpu: - torch.cuda.empty_cache() try: args.batch_size = batch_size print('Validating with batch size: %d' % args.batch_size)