First timm.bits commit, add initial abstractions, WIP updates to train, val... some of it working

pull/1239/head
Ross Wightman 3 years ago
parent 1b0c8e7b01
commit 12d9a6d4d2

@ -0,0 +1,8 @@
# Timm Bits
A collection of reusable components and lightweight abstractions for training and evaluating NN.
This is an early WIP with the primary goal to get up and running on TPUs first. Expect significant changes, rewrites, additions...
The current train.py and validate.py scipts are evolving to use the timm.bits components, they will also change significantly.

@ -0,0 +1,10 @@
from .device_env_factory import initialize_device, get_device
from .device_env import DeviceEnv
#from .evaluate import evaluate, eval_step
from .logger import Logger
#from .task import TaskClassify
from .updater import Updater
from .updater_factory import create_updater
from .tracker import Tracker
#from .task_metrics import TaskMetrics, TaskMetricsClassify
#from .train import train_one_epoch, TrainServices, TrainState, TrainCfg, Experiment

@ -0,0 +1,58 @@
import torch
import abc
class DeviceEnv(abc.ABC):
@property
@abc.abstractmethod
def device(self) -> torch.device:
pass
@property
@abc.abstractmethod
def local_rank(self) -> int:
pass
@property
@abc.abstractmethod
def global_rank(self) -> int:
pass
@property
@abc.abstractmethod
def is_distributed(self) -> bool:
pass
@property
@abc.abstractmethod
def world_size(self) -> int:
pass
@property
@abc.abstractmethod
def is_master(self) -> bool:
pass
@property
@abc.abstractmethod
def type(self) -> str:
pass
@property
@abc.abstractmethod
def autocast(self):
pass
@abc.abstractmethod
def wrap_distributed(self, *modules):
pass
@abc.abstractmethod
def to_device(self, *modules: torch.nn.Module):
pass
#@abc.abstractmethod
def mark_step(self):
# FIXME this is for XLA only, make it common to all devices w/ appropriate no-ops?
pass

@ -0,0 +1,90 @@
import os
from contextlib import suppress
import torch
from torch.nn.parallel import DistributedDataParallel
from .device_env import DeviceEnv
def is_cuda_available():
return torch.cuda.is_available()
class DeviceEnvCuda(DeviceEnv):
def __init__(self, device_idx=None, local_rank=None, amp=False, memory_format=None):
assert torch.cuda.device_count()
torch.backends.cudnn.benchmark = True
self._local_rank = 0
self._distributed = False
self._world_size = 1
self._global_rank = 0
if 'WORLD_SIZE' in os.environ:
self._distributed = int(os.environ['WORLD_SIZE']) > 1
if self._distributed:
if local_rank is None:
lr = os.environ.get('LOCAL_RANK', None)
if lr is None:
raise RuntimeError(
'At least one of LOCAL_RANK env variable or local_rank arg must be set to valid integer.')
self._local_rank = lr
else:
self._local_rank = int(local_rank)
self._device = torch.device('cuda:%d' % self._local_rank)
torch.cuda.set_device(self._local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
self._world_size = torch.distributed.get_world_size()
self._global_rank = torch.distributed.get_rank()
else:
self._device = torch.device('cuda' if device_idx is None else f'cuda:{device_idx}')
self._memory_format = memory_format
if amp:
self._amp = amp
self._autocast = torch.cuda.amp.autocast
else:
self._amp = amp
self._autocast = suppress
@property
def device(self):
return self._device
@property
def local_rank(self):
return self._local_rank
@property
def global_rank(self):
return self._global_rank
@property
def is_distributed(self):
return self._distributed
@property
def world_size(self):
return self._world_size
@property
def is_master(self):
return self._local_rank == 0
@property
def type(self) -> str:
return 'cuda'
@property
def amp(self) -> bool:
return self._amp
@property
def autocast(self):
return self._autocast
def wrap_distributed(self, *modules, **kwargs):
return [DistributedDataParallel(m, device_ids=[self._local_rank], **kwargs) for m in modules]
def to_device(self, *modules: torch.nn.Module):
# FIXME handling dtype / memformat... disable flags, enable flags, diff fn?
return [m.to(device=self._device, memory_format=self._memory_format) for m in modules]

@ -0,0 +1,34 @@
from .device_env_cuda import DeviceEnvCuda, is_cuda_available
from .device_env_xla import DeviceEnvXla, is_xla_available
_device_env = None
def initialize_device(force_cpu: bool = False, xla_device_type=None, **kwargs):
global _device_env
if _device_env is not None:
# warning
return _device_env
denv = None
if not force_cpu:
if is_xla_available(xla_device_type):
# XLA supports more than just TPU, but by default will only look at TPU
denv = DeviceEnvXla(**kwargs, xla_device_type=xla_device_type)
elif is_cuda_available():
denv = DeviceEnvCuda(**kwargs)
if denv is None:
# FIXME implement CPU support
raise NotImplementedError()
_device_env = denv
return denv
def get_device():
if _device_env is None:
raise RuntimeError('Please initialize device environment by calling initialize_device first.')
return _device_env

@ -0,0 +1,85 @@
import os
from contextlib import suppress
import torch
try:
import torch_xla.core.xla_model as xm
import torch_xla.amp as xa
_HAS_XLA = True
except ImportError as e:
xm = None
_HAS_XLA = False
from .device_env import DeviceEnv
def is_xla_available(xla_device_type=None):
if not _HAS_XLA:
return False
supported_devs = xm.get_xla_supported_devices(devkind=xla_device_type)
print(supported_devs)
return len(supported_devs) >= 1
class DeviceEnvXla(DeviceEnv):
def __init__(self, xla_device_type=None, device_idx=None, local_rank=0, amp=False):
self._device = xm.xla_device(n=device_idx, devkind=xla_device_type)
print(self._device)
self._local_rank = xm.get_local_ordinal(local_rank)
self._world_size = xm.xrt_world_size()
self._distributed = self._world_size > 1
self._global_rank = 0
if self._distributed:
self._global_rank = xm.get_ordinal()
if amp:
self._autocast = xa.autocast
else:
self._autocast = suppress
self._memory_format = None
@property
def device(self):
return self._device
@property
def local_rank(self):
return self._local_rank
@property
def global_rank(self):
return self._global_rank
@property
def is_distributed(self):
return self._distributed
@property
def world_size(self):
return self._world_size
@property
def is_master(self):
return self._global_rank == 0
@property
def type(self) -> str:
return 'xla'
@property
def amp(self) -> bool:
return False
@property
def autocast(self):
return self._autocast
def wrap_distributed(self, *modules):
# NO-OP
return tuple([m for m in modules])
def to_device(self, *modules: torch.nn.Module):
return [m.to(device=self._device, memory_format=self._memory_format) for m in modules]
def mark_step(self):
xm.mark_step()

@ -0,0 +1,36 @@
from functools import partial
import torch
from timm.utils.agc import adaptive_clip_grad
def get_clip_grad_fn(mode: str = 'norm', norm_type: float = 2.0):
if mode == 'norm':
return partial(torch.nn.utils.clip_grad_norm_, norm_type=norm_type)
elif mode == 'value':
return torch.nn.utils.clip_grad_value_
elif mode == 'agc':
return partial(adaptive_clip_grad, norm_type=norm_type)
else:
assert False, f"Unknown clip mode ({mode})."
def get_clip_parameters(model):
if hasattr(model, 'get_clip_parameters'):
return model.get_clip_parameters()
else:
return model.parameters()
class GradClipper:
def __init__(self, model, clip_value, clip_mode='norm'):
self.model = model
self.clip_fn = get_clip_grad_fn(clip_mode)
self.clip_value = clip_value
self.enabled = True
def __call__(self):
if self.enabled:
self.clip_fn(get_clip_parameters(self.model), self.clip_value)

@ -0,0 +1,223 @@
import csv
import logging
import os
from collections import OrderedDict
from typing import Optional, Tuple, Dict, Union
import torch
_logger = logging.getLogger(__name__)
try:
from torch.utils.tensorboard import SummaryWriter
HAS_TB = True
except ImportError as e:
HAS_TB = False
try:
import wandb
HAS_WANDB = True
except ImportError:
HAS_WANDB = False
# FIXME old formatting for reference, to remove
#
# def log_eval(batch_idx, last_idx, batch_time, loss, top1, top5, log_suffix=''):
# log_name = 'Test' + log_suffix
# logging.info(
# f'{log_name}: [{batch_idx:>4d}/{last_idx}] '
# f'Time: {batch_time.smooth_val:.3f} ({batch_time.avg:.3f}) '
# f'Loss: {loss.smooth_val:>7.4f} ({loss.avg:>6.4f}) '
# f'Acc@1: {top1.smooth_val:>7.4f} ({top1.avg:>7.4f}) '
# f'Acc@5: {top5.smooth_val:>7.4f} ({top5.avg:>7.4f})'
# )
#
#
# def log_train(epoch, step, num_steps, loss, batch_size, batch_time, data_time, lr, world_size=1):
# last_step = max(0, num_steps - 1)
# progress = 100. * step / last_step if last_step else 0.
# log_str = f'Train: {epoch} [{step:>4d}/{num_steps} ({progress:>3.0f}%)]' \
# f' Time: {batch_time.smooth_val:.3f}s, {batch_size * world_size / batch_time.smooth_val:>7.2f}/s' \
# f' ({batch_time.avg:.3f}s, {batch_size * world_size / batch_time.avg:>7.2f}/s)' \
# f' Data: {data_time.smooth_val:.3f} ({data_time.avg:.3f})'
# log_str += f' Loss: {loss.smooth_val:>9.6f} ({loss.avg:>6.4f}) '
# log_str += f' LR: {lr:.3e} '
#
# if args.save_images and output_dir:
# torchvision.utils.save_image(
# input,
# os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
# padding=0,
# normalize=True)
def summary_row_dict(results, index=None, index_name='epoch'):
assert isinstance(results, dict)
row_dict = OrderedDict()
if index is not None:
row_dict[index_name] = index
if not results:
return row_dict
if isinstance(next(iter(results.values())), dict):
# each key in results is a per-phase results dict, flatten by prefixing with phase name
for p, pr in results.keys():
assert isinstance(dict, pr)
row_dict.update([('_'.join([p, k]), v) for k, v in pr.items()])
else:
row_dict.update(results)
return row_dict
class SummaryCsv:
def __init__(self, output_dir, filename='summary.csv'):
self.output_dir = output_dir
self.filename = os.path.join(output_dir, filename)
self.needs_header = not os.path.exists(self.filename)
def update(self, row_dict):
with open(self.filename, mode='a') as cf:
dw = csv.DictWriter(cf, fieldnames=row_dict.keys())
if self.needs_header: # first iteration (epoch == 1 can't be used)
dw.writeheader()
self.needs_header = False
dw.writerow(row_dict)
def _add_kwargs(text_update, name_map=None, **kwargs):
def _to_str(key, val):
if isinstance(val, float):
return f'{key}: {val:.4f}'
else:
return f'{key}: {val}'
def _map_name(key, name_map, capitalize=True):
if name_map is None:
if capitalize:
return key.capitalize() if not key.isupper() else key
else:
return key
return name_map.get(key, None)
for k, v in kwargs.items():
if isinstance(v, dict):
# log each k, v of a dict kwarg as separate items
for kk, vv in v.items():
name = _map_name(kk, name_map)
if not name:
continue
text_update += [_to_str(kk, vv)]
else:
name = _map_name(k, name_map, capitalize=True)
if not name:
continue
text_update += [_to_str(name, v)]
class Logger:
def __init__(
self,
experiment_name=None,
output_dir=None,
logger=None,
log_wandb=False,
hparams=None):
self.output_dir = output_dir # for tensorboard, csv, console logging to file?
self.logger = logger or logging.getLogger('log')
hparams = hparams or {}
# Setup CSV writer(s)
if output_dir is not None:
self.csv_writer = SummaryCsv(output_dir=output_dir)
else:
self.csv_writer = None
# Setup Tensorboard
self.summary_writer = None # FIXME tensorboard
# Setup W&B
self.wandb_run = None
if log_wandb:
if HAS_WANDB:
self.wandb_run = wandb.init(project=experiment_name, config=hparams)
else:
_logger.warning("You've requested to log metrics to wandb but package not found. "
"Metrics not being logged to wandb, try `pip install wandb`")
# FIXME image save
def log_step(
self,
phase: str,
step: int,
end_step: Optional[int] = None,
loss: Optional[float] = None,
rate: Optional[float] = None,
epoch: Optional[int] = None,
phase_suffix: str = '',
**kwargs,
):
""" log train/eval step
"""
phase_title = f'{phase.capitalize()} ({phase_suffix})' if phase_suffix else f'{phase.capitalize()}'
progress = 100. * step / end_step if end_step else 0.
text_update = [
phase_title,
f'Epoch: {epoch}' if epoch is not None else None,
f'Step: {step}' if end_step is None else None,
f'Step: [{step}/{end_step} ({progress:>3.0f}%)]' if end_step is not None else None,
f'Rate: {rate:.2f}/s' if rate is not None else None,
f'Loss: {loss:.5f}' if loss is not None else None,
]
_add_kwargs(text_update, **kwargs)
log_str = ' '.join(item for item in text_update if item)
self.logger.info(log_str)
if self.summary_writer is not None:
# FIXME log step values to tensorboard
pass
def log_phase(
self,
phase: str = 'eval',
epoch: Optional[int] = None,
name_map: Optional[dict] = None,
**kwargs
):
"""log completion of evaluation or training phase
"""
title = [
f'{phase.capitalize()}',
f'epoch: {epoch}' if epoch is not None else None,
'completed. ',
]
title_str = ' '.join(i for i in title if i)
results = []
_add_kwargs(results, name_map=name_map, **kwargs)
log_str = title_str + ', '.join(item for item in results if item)
self.logger.info(log_str)
def write_summary(
self,
results: Dict, # Dict or Dict of Dict where first level keys are treated as per-phase results
index: Optional[Union[int, str]] = None,
index_name: str = 'epoch',
):
""" Log complete results for all phases (typically called at end of epoch)
Args:
results (dict or dict[dict]): dict of results to write, or multiple dicts where first level
key is the name of results dict for each phase
index: value for row index (typically epoch #)
index_name: name for row index header (typically 'epoch')
"""
row_dict = summary_row_dict(index=index, index_name=index_name, results=results)
if self.csv_writer:
self.csv_writer.update(row_dict)
if self.wandb_run is not None:
wandb.log(row_dict)
if self.summary_writer:
# FIXME log epoch summaries to tensorboard
pass

@ -0,0 +1,50 @@
import time
from typing import Optional
from timm.metrics import ScalarAvgMinMax
class Tracker:
def __init__(self):
self.data_time = ScalarAvgMinMax() # time for data loader to produce batch of samples
self.step_time = ScalarAvgMinMax() # time for model step
self.iter_time = ScalarAvgMinMax() # full iteration time incl. data, step, and book-keeping
self.epoch_time = ScalarAvgMinMax()
self.iter_timestamp: Optional[float] = None
self.prev_timestamp: Optional[float] = None
self.epoch_timestamp: Optional[float] = None
def _measure_iter(self, ref_timestamp=None):
timestamp = time.perf_counter()
self.prev_timestamp = timestamp
def mark_iter(self):
timestamp = time.perf_counter()
if self.iter_timestamp is not None:
iter_time = timestamp - self.iter_timestamp
self.iter_time.update(iter_time)
self.iter_timestamp = self.prev_timestamp = timestamp
def mark_iter_data_end(self):
assert self.prev_timestamp is not None
timestamp = time.perf_counter()
data_time = timestamp - self.prev_timestamp
self.data_time.update(data_time)
self.prev_timestamp = timestamp
def mark_iter_step_end(self):
assert self.prev_timestamp is not None
timestamp = time.perf_counter()
step_time = timestamp - self.prev_timestamp
self.step_time.update(step_time)
self.prev_timestamp = timestamp
def mark_epoch(self):
timestamp = time.perf_counter()
if self.epoch_timestamp is not None:
epoch_time = timestamp - self.epoch_timestamp
self.epoch_time.update(epoch_time)
self.epoch_timestamp = timestamp

@ -0,0 +1,54 @@
from typing import Callable, Optional, Union
import torch
from .grad_clipper import GradClipper
class Updater:
def __init__(
self,
optimizer: torch.optim.Optimizer,
clip_value: Optional[Union[Callable, float]] = None,
clip_mode: str = 'norm'):
self.optimizer = optimizer
self.clipper: Optional[GradClipper] = None
if clip_value is not None:
if isinstance(clip_value, Callable):
self.clipper = clip_value
else:
GradClipper(clip_value, clip_mode)
self.scaler = None
self.create_graph = getattr(self.optimizer, 'second_order', False)
self.num_accumulated = 0
self.after_step_closure = False
def apply(self, loss: torch.Tensor, accumulate=False):
loss.backward(create_graph=self.create_graph)
if self.clipper is not None:
self.clipper()
if not accumulate:
self.optimizer.step()
self.reset()
else:
self.num_accumulated += 1
def reset(self):
self.optimizer.zero_grad()
self.num_accumulated = 0
def state_dict(self):
state_dict = dict(optimizer=self.optimizer.state_dict())
if self.scaler is not None:
state_dict['scaler'] = self.scaler.state_dict()
def load_state_dict(self, state_dict):
if 'optimizer' in state_dict:
self.optimizer.load_state_dict(state_dict['optimizer'])
if 'scaler' in state_dict and self.scaler is not None:
self.scaler.load_state_dict(state_dict['scaler'])

@ -0,0 +1,36 @@
from typing import Callable, Optional, Union, Any
import torch
from .updater import Updater
class UpdaterCuda(Updater):
def __init__(
self,
optimizer: torch.optim.Optimizer,
clip_value: Optional[Union[Callable, float]] = None,
clip_mode: str = 'norm',
use_scaler: bool = False,
scaler_kwargs: Any = None,
):
super().__init__(optimizer=optimizer, clip_value=clip_value, clip_mode=clip_mode)
scaler_kwargs = scaler_kwargs or {}
if use_scaler:
self.scaler = torch.cuda.amp.GradScaler(**scaler_kwargs)
def apply(self, loss: torch.Tensor, accumulate=False):
if self.scaler is not None:
self.scaler.scale(loss).backward(create_graph=self.create_graph)
if self.clipper is not None:
self.scaler.unscale_(self.optimizer) # unscale the gradients of optimizer's assigned params in-place
self.clipper()
if not accumulate:
self.scaler.step(self.optimizer)
self.reset()
else:
self.num_accumulated += 1
self.scaler.update()
else:
Updater.apply(self, loss, accumulate)

@ -0,0 +1,30 @@
from typing import Callable, Optional, Union, Any
import torch
from .device_env import DeviceEnv
from .device_env_factory import get_device
from .updater import Updater
from .updater_cuda import UpdaterCuda
from .updater_xla import UpdaterXla
def create_updater(
optimizer: torch.optim.Optimizer,
dev_env: Optional[DeviceEnv] = None,
clip_value: Optional[Union[Callable, float]] = None,
clip_mode: str = 'norm',
scaler_kwargs: Any = None) -> Updater:
if not dev_env:
dev_env = get_device()
updater_kwargs = dict(
optimizer=optimizer, clip_value=clip_value, clip_mode=clip_mode, scaler_kwargs=scaler_kwargs)
if dev_env.type == 'xla':
return UpdaterXla(**updater_kwargs, use_scaler=dev_env.amp)
elif dev_env.type == 'cuda':
return UpdaterCuda(**updater_kwargs, use_scaler=dev_env.amp)
else:
updater_kwargs.pop('scaler_kwargs', None)
return Updater(**updater_kwargs)

@ -0,0 +1,52 @@
from typing import Callable, Optional, Union, Any
import torch
try:
import torch_xla.core.xla_model as xm
import torch_xla.amp as xa
_HAS_XLA = True
except ImportError as e:
xm = None
_HAS_XLA = False
from .updater import Updater
class UpdaterXla(Updater):
def __init__(
self,
optimizer: torch.optim.Optimizer,
clip_value: Optional[Union[Callable, float]] = None,
clip_mode: str = 'norm',
use_scaler: bool = False,
scaler_kwargs: Any = None,
):
super().__init__(optimizer=optimizer, clip_value=clip_value, clip_mode=clip_mode)
self.after_step_closure = True
if use_scaler:
self.scaler = xa.GradScaler(**scaler_kwargs)
def apply(self, loss: torch.Tensor, accumulate: bool = False):
if self.scaler is None:
loss.backward(create_graph=self.create_graph)
gradients = xm._fetch_gradients(self.optimizer)
xm.all_reduce('sum', gradients, scale=1.0 / xm.xrt_world_size())
if self.clipper is not None:
self.clipper()
if not accumulate:
xm.optimizer_step(self.optimizer)
else:
self.scaler.scale(loss).backward(create_graph=self.create_graph)
if self.clipper is not None:
self.scaler.unscale_(self.optimizer) # unscale the gradients of optimizer's assigned params in-place
self.clipper()
if not accumulate:
self.scaler.step(self.optimizer)
self.reset()
self.scaler.update()
def after_step(self, after_step_fn, *args):
xm.add_step_closure(after_step_fn, *args)

@ -0,0 +1,38 @@
import numpy as np
import torch
def fast_collate(batch):
""" A fast collation function optimized for uint8 images (np array or torch) and int64 targets (labels)"""
assert isinstance(batch[0], tuple)
batch_size = len(batch)
if isinstance(batch[0][0], tuple):
# This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position
# such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position
inner_tuple_size = len(batch[0][0])
flattened_batch_size = batch_size * inner_tuple_size
targets = torch.zeros(flattened_batch_size, dtype=torch.int64)
tensor = torch.zeros((flattened_batch_size, *batch[0][0][0].shape), dtype=torch.uint8)
for i in range(batch_size):
assert len(batch[i][0]) == inner_tuple_size # all input tensor tuples must be same length
for j in range(inner_tuple_size):
targets[i + j * batch_size] = batch[i][1]
tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j])
return tensor, targets
elif isinstance(batch[0][0], np.ndarray):
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
assert len(targets) == batch_size
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
for i in range(batch_size):
tensor[i] += torch.from_numpy(batch[i][0])
return tensor, targets
elif isinstance(batch[0][0], torch.Tensor):
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
assert len(targets) == batch_size
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
for i in range(batch_size):
tensor[i].copy_(batch[i][0])
return tensor, targets
else:
assert False

@ -70,6 +70,14 @@ def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, v
elif 'crop_pct' in default_cfg:
new_config['crop_pct'] = default_cfg['crop_pct']
if getattr(args, 'mixup', 0) > 0 \
or getattr(args, 'cutmix', 0) > 0. \
or getattr(args, 'cutmix_minmax', None) is not None:
new_config['mixup'] = dict(
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
label_smoothing=args.smoothing, num_classes=args.num_classes)
if verbose:
_logger.info('Data processing configuration for current model + dataset:')
for n, v in new_config.items():

@ -0,0 +1,69 @@
import torch
from .constants import *
from .random_erasing import RandomErasing
from. mixup import FastCollateMixup
class FetcherXla:
def __init__(self):
pass
class Fetcher:
def __init__(self,
loader,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
device=None,
dtype=None,
re_prob=0.,
re_mode='const',
re_count=1,
re_num_splits=0):
self.loader = loader
self.mean = torch.tensor([x * 255 for x in mean]).view(1, 3, 1, 1)
self.std = torch.tensor([x * 255 for x in std]).view(1, 3, 1, 1)
self.device = torch.device(device)
self.dtype = dtype or torch.float32
if device:
self.mean.to(device=device, dtype=self.dtype)
self.std.to(device=device, dtype=self.dtype)
if re_prob > 0.:
self.random_erasing = RandomErasing(
probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits)
else:
self.random_erasing = None
def __iter__(self):
for sample, target in self.loader:
sample = sample.to(device=self.device)
target = target.to(device=self.device)
sample = sample.to(dtype=self.dtype).sub_(self.mean).div_(self.std)
if self.random_erasing is not None:
sample = self.random_erasing(sample)
yield sample, target
def __len__(self):
return len(self.loader)
@property
def sampler(self):
return self.loader.sampler
@property
def dataset(self):
return self.loader.dataset
@property
def mixup_enabled(self):
if isinstance(self.loader.collate_fn, FastCollateMixup):
return self.loader.collate_fn.mixup_enabled
else:
return False
@mixup_enabled.setter
def mixup_enabled(self, x):
if isinstance(self.loader.collate_fn, FastCollateMixup):
self.loader.collate_fn.mixup_enabled = x

@ -7,122 +7,15 @@ Hacked together by / Copyright 2020 Ross Wightman
"""
import torch.utils.data
import numpy as np
from timm.bits import get_device
from .fetcher import Fetcher
from .prefetcher_cuda import PrefetcherCuda
from .collate import fast_collate
from .transforms_factory import create_transform
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .distributed_sampler import OrderedDistributedSampler
from .random_erasing import RandomErasing
from .mixup import FastCollateMixup
def fast_collate(batch):
""" A fast collation function optimized for uint8 images (np array or torch) and int64 targets (labels)"""
assert isinstance(batch[0], tuple)
batch_size = len(batch)
if isinstance(batch[0][0], tuple):
# This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position
# such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position
inner_tuple_size = len(batch[0][0])
flattened_batch_size = batch_size * inner_tuple_size
targets = torch.zeros(flattened_batch_size, dtype=torch.int64)
tensor = torch.zeros((flattened_batch_size, *batch[0][0][0].shape), dtype=torch.uint8)
for i in range(batch_size):
assert len(batch[i][0]) == inner_tuple_size # all input tensor tuples must be same length
for j in range(inner_tuple_size):
targets[i + j * batch_size] = batch[i][1]
tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j])
return tensor, targets
elif isinstance(batch[0][0], np.ndarray):
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
assert len(targets) == batch_size
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
for i in range(batch_size):
tensor[i] += torch.from_numpy(batch[i][0])
return tensor, targets
elif isinstance(batch[0][0], torch.Tensor):
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
assert len(targets) == batch_size
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
for i in range(batch_size):
tensor[i].copy_(batch[i][0])
return tensor, targets
else:
assert False
class PrefetchLoader:
def __init__(self,
loader,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
fp16=False,
re_prob=0.,
re_mode='const',
re_count=1,
re_num_splits=0):
self.loader = loader
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1)
self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1)
self.fp16 = fp16
if fp16:
self.mean = self.mean.half()
self.std = self.std.half()
if re_prob > 0.:
self.random_erasing = RandomErasing(
probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits)
else:
self.random_erasing = None
def __iter__(self):
stream = torch.cuda.Stream()
first = True
for next_input, next_target in self.loader:
with torch.cuda.stream(stream):
next_input = next_input.cuda(non_blocking=True)
next_target = next_target.cuda(non_blocking=True)
if self.fp16:
next_input = next_input.half().sub_(self.mean).div_(self.std)
else:
next_input = next_input.float().sub_(self.mean).div_(self.std)
if self.random_erasing is not None:
next_input = self.random_erasing(next_input)
if not first:
yield input, target
else:
first = False
torch.cuda.current_stream().wait_stream(stream)
input = next_input
target = next_target
yield input, target
def __len__(self):
return len(self.loader)
@property
def sampler(self):
return self.loader.sampler
@property
def dataset(self):
return self.loader.dataset
@property
def mixup_enabled(self):
if isinstance(self.loader.collate_fn, FastCollateMixup):
return self.loader.collate_fn.mixup_enabled
else:
return False
@mixup_enabled.setter
def mixup_enabled(self, x):
if isinstance(self.loader.collate_fn, FastCollateMixup):
self.loader.collate_fn.mixup_enabled = x
def create_loader(
@ -130,7 +23,7 @@ def create_loader(
input_size,
batch_size,
is_training=False,
use_prefetcher=True,
dev_env=None,
no_aug=False,
re_prob=0.,
re_mode='const',
@ -163,7 +56,7 @@ def create_loader(
dataset.transform = create_transform(
input_size,
is_training=is_training,
use_prefetcher=use_prefetcher,
use_fetcher=True,
no_aug=no_aug,
scale=scale,
ratio=ratio,
@ -183,6 +76,9 @@ def create_loader(
separate=num_aug_splits > 0,
)
if dev_env is None:
dev_env = get_device()
sampler = None
if distributed and not isinstance(dataset, torch.utils.data.IterableDataset):
if is_training:
@ -193,10 +89,9 @@ def create_loader(
sampler = OrderedDistributedSampler(dataset)
if collate_fn is None:
collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate
collate_fn = fast_collate
loader_class = torch.utils.data.DataLoader
if use_multi_epochs_loader:
loader_class = MultiEpochsDataLoader
@ -214,18 +109,19 @@ def create_loader(
except TypeError as e:
loader_args.pop('persistent_workers') # only in Pytorch 1.7+
loader = loader_class(dataset, **loader_args)
if use_prefetcher:
prefetch_re_prob = re_prob if is_training and not no_aug else 0.
loader = PrefetchLoader(
loader,
mean=mean,
std=std,
fp16=fp16,
re_prob=prefetch_re_prob,
re_mode=re_mode,
re_count=re_count,
re_num_splits=re_num_splits
)
fetcher_kwargs = dict(
mean=mean,
std=std,
re_prob=re_prob if is_training and not no_aug else 0.,
re_mode=re_mode,
re_count=re_count,
re_num_splits=re_num_splits
)
if dev_env.type == 'cuda':
loader = PrefetcherCuda(loader, **fetcher_kwargs)
else:
loader = Fetcher(loader, device=dev_env.device, **fetcher_kwargs)
return loader

@ -0,0 +1,79 @@
import torch.cuda
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .mixup import FastCollateMixup
from .random_erasing import RandomErasing
class PrefetcherCuda:
def __init__(self,
loader,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
fp16=False,
re_prob=0.,
re_mode='const',
re_count=1,
re_num_splits=0):
self.loader = loader
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1)
self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1)
self.fp16 = fp16
if fp16:
self.mean = self.mean.half()
self.std = self.std.half()
if re_prob > 0.:
self.random_erasing = RandomErasing(
probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits)
else:
self.random_erasing = None
def __iter__(self):
stream = torch.cuda.Stream()
first = True
for next_input, next_target in self.loader:
with torch.cuda.stream(stream):
next_input = next_input.cuda(non_blocking=True)
next_target = next_target.cuda(non_blocking=True)
if self.fp16:
next_input = next_input.half().sub_(self.mean).div_(self.std)
else:
next_input = next_input.float().sub_(self.mean).div_(self.std)
if self.random_erasing is not None:
next_input = self.random_erasing(next_input)
if not first:
yield input, target
else:
first = False
torch.cuda.current_stream().wait_stream(stream)
input = next_input
target = next_target
yield input, target
def __len__(self):
return len(self.loader)
@property
def sampler(self):
return self.loader.sampler
@property
def dataset(self):
return self.loader.dataset
@property
def mixup_enabled(self):
if isinstance(self.loader.collate_fn, FastCollateMixup):
return self.loader.collate_fn.mixup_enabled
else:
return False
@mixup_enabled.setter
def mixup_enabled(self, x):
if isinstance(self.loader.collate_fn, FastCollateMixup):
self.loader.collate_fn.mixup_enabled = x

@ -22,7 +22,10 @@ Hacked together by / Copyright 2020 Ross Wightman
# limitations under the License.
# ==============================================================================
"""ImageNet preprocessing for MnasNet."""
import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
tf.compat.v1.disable_eager_execution()
import numpy as np
IMAGE_SIZE = 224
@ -131,6 +134,39 @@ def _decode_and_center_crop(image_bytes, image_size, resize_method):
return image
def crop(image_bytes, crop_window):
"""Helper function to crop a jpeg or a decoded image."""
if image_bytes.dtype == tf.dtypes.string:
image = tf.image.decode_and_crop_jpeg(image_bytes,
tf.stack(crop_window),
channels=3)
else:
image = tf.image.crop_to_bounding_box(image_bytes, *crop_window)
return image
def _decode_and_resize_then_crop(
image_bytes: tf.Tensor,
image_size,
crop_pct: float = 32,
) -> tf.Tensor:
"""Rescales an image to image_size / crop_pct, then center crops."""
image = tf.image.decode_jpeg(image_bytes, channels=3)
# Scale image to "scaled size" before taking a center crop
if crop_pct > 1.0: # If crop_pct is >1, treat it as num pad pixels (like VGG)
scale_size = tuple([int(x + crop_pct) for x in image_size])
else:
scale_size = tuple([int(float(x) / crop_pct) for x in image_size])
image = tf.image.resize(image, scale_size, tf.image.ResizeMethod.BICUBIC)
crop_height = tf.cast(image_size[0], tf.int32)
crop_width = tf.cast(image_size[1], tf.int32)
offset_height = ((scale_size[0] - crop_height) + 1) // 2
offset_width = ((scale_size[1] - crop_width) + 1) // 2
crop_window = [offset_height, offset_width, crop_height, crop_width]
image = crop(image, crop_window)
return image
def _flip(image):
"""Random horizontal image flip."""
image = tf.image.random_flip_left_right(image)
@ -172,6 +208,7 @@ def preprocess_for_eval(image_bytes, use_bfloat16, image_size=IMAGE_SIZE, interp
"""
resize_method = tf.image.ResizeMethod.BICUBIC if interpolation == 'bicubic' else tf.image.ResizeMethod.BILINEAR
image = _decode_and_center_crop(image_bytes, image_size, resize_method)
#image = _decode_and_resize_then_crop(image_bytes, (image_size, image_size), resize_method)
image = tf.reshape(image, [image_size, image_size, 3])
image = tf.image.convert_image_dtype(
image, dtype=tf.bfloat16 if use_bfloat16 else tf.float32)

@ -167,7 +167,7 @@ def transforms_imagenet_eval(
def create_transform(
input_size,
is_training=False,
use_prefetcher=False,
use_fetcher=False,
no_aug=False,
scale=None,
ratio=None,
@ -191,7 +191,7 @@ def create_transform(
else:
img_size = input_size
if tf_preprocessing and use_prefetcher:
if tf_preprocessing and use_fetcher:
assert not separate, "Separate transforms not supported for TF preprocessing"
from timm.data.tf_preprocessing import TfPreprocessTransform
transform = TfPreprocessTransform(
@ -202,7 +202,7 @@ def create_transform(
transform = transforms_noaug_train(
img_size,
interpolation=interpolation,
use_prefetcher=use_prefetcher,
use_prefetcher=use_fetcher,
mean=mean,
std=std)
elif is_training:
@ -215,7 +215,7 @@ def create_transform(
color_jitter=color_jitter,
auto_augment=auto_augment,
interpolation=interpolation,
use_prefetcher=use_prefetcher,
use_prefetcher=use_fetcher,
mean=mean,
std=std,
re_prob=re_prob,
@ -228,7 +228,7 @@ def create_transform(
transform = transforms_imagenet_eval(
img_size,
interpolation=interpolation,
use_prefetcher=use_prefetcher,
use_prefetcher=use_fetcher,
mean=mean,
std=std,
crop_pct=crop_pct)

@ -0,0 +1,4 @@
from .accuracy import Accuracy, AccuracyTopK
from .precision_recall import PrecisionRecall
from .scalar_avg import ScalarAvgMinMax
from .tensor_avg import TensorAvg, TensorEma

@ -0,0 +1,112 @@
import torch
from typing import Optional, Tuple, Dict
class Accuracy(torch.nn.Module):
def __init__(self, threshold=0.5, multi_label=False):
self.threshold = threshold
self.eps = 1e-8
self.multi_label = multi_label
# statistics / counts
self._correct_sum = torch.tensor(0, dtype=torch.long)
self._total_sum = torch.tensor(0, dtype=torch.long)
def update(self, predictions, target):
raise NotImplemented()
def reset(self):
self._correct_sum = 0
self._total_sum = 0
@property
def counts(self):
pass
def compute(self):
raise NotImplemented()
class AccuracyTopK(torch.nn.Module):
def __init__(self, topk=(1, 5), device=None):
super().__init__()
self.eps = 1e-8
self.device = device
self.topk = topk
self.maxk = max(topk)
# statistics / counts
self.reset()
def update(self, predictions: torch.Tensor, target: torch.Tensor):
sorted_indices = predictions.topk(self.maxk, dim=1)[1]
sorted_indices.t_()
correct = sorted_indices.eq(target.reshape(1, -1).expand_as(sorted_indices))
batch_size = target.shape[0]
correct_k = {k: correct[:k].reshape(-1).float().sum(0) for k in self.topk}
for k, v in correct_k.items():
attr = f'_correct_top{k}'
old_v = getattr(self, attr)
setattr(self, attr, old_v + v)
self._total_sum += batch_size
def reset(self):
for k in self.topk:
setattr(self, f'_correct_top{k}', torch.tensor(0, dtype=torch.float32))
self._total_sum = torch.tensor(0, dtype=torch.float32)
@property
def counts(self):
pass
def compute(self) -> Dict[str, torch.Tensor]:
return {f'top{k}': 100 * getattr(self, f'_correct_top{k}') / self._total_sum for k in self.topk}
#
# class AccuracyTopK:
#
# def __init__(self, topk=(1, 5), device=None):
# self.eps = 1e-8
# self.device = device
# self.topk = topk
# self.maxk = max(topk)
#
# # statistics / counts
# self._correct_sum = None
# self._total_sum = None
#
# def _check_init(self, device):
# to_device = self.device if self.device else device
# if self._correct_sum is None:
# self._correct_sum = {f'top{k}': torch.tensor(0., device=to_device) for k in self.topk}
# if self._total_sum is None:
# self._total_sum = torch.tensor(0, dtype=torch.long, device=to_device)
#
# def update(self, predictions: torch.Tensor, target: torch.Tensor):
# sorted_indices = predictions.topk(self.maxk, dim=1)[1]
# sorted_indices.t_()
# correct = sorted_indices.eq(target.reshape(1, -1).expand_as(sorted_indices))
#
# batch_size = target.shape[0]
# correct_k = {f'top{k}': correct[:k].reshape(-1).float().sum(0) for k in self.topk}
# self._check_init(device=predictions.device)
# for k, v in correct_k.items():
# old_v = self._correct_sum[k]
# self._correct_sum[k] = old_v + v
# self._total_sum += batch_size
#
# def reset(self):
# self._correct_sum = None
# self._total_sum = None
#
# @property
# def counts(self):
# pass
#
# def compute(self) -> Dict[str, torch.Tensor]:
# assert self._correct_sum is not None and self._total_sum is not None
# return {k: 100 * v / self._total_sum for k, v in self._correct_sum.items()}

@ -0,0 +1,117 @@
import torch
import torch.nn.functional as F
class PrecisionRecall:
def __init__(self, threshold=0.5, multi_label=False, device=None):
self.threshold = threshold
self.device = device
self.multi_label = multi_label
# statistics
# the total number of true positive instances under each class
# Shape: (num_classes, )
self._tp_sum = None
# the total number of instances
# Shape: (num_classes, )
self._total_sum = None
# the total number of instances under each _predicted_ class,
# including true positives and false positives
# Shape: (num_classes, )
self._pred_sum = None
# the total number of instances under each _true_ class,
# including true positives and false negatives
# Shape: (num_classes, )
self._true_sum = None
self.reset()
def reset(self):
self._tp_sum = None
self._total_sum = None
self._pred_sum = None
self._true_sum = None
def update(self, predictions, target):
output_type = predictions.type()
num_classes = predictions.size(-1)
if self.multi_label:
if self.threshold is not None:
predictions = (predictions > self.threshold).type(output_type)
predictions = predictions.t().reshape(num_classes, -1)
target = target.t().reshape(num_classes, -1)
else:
target = F.one_hot(target.view(-1), num_classes=num_classes)
indices = torch.argmax(predictions, dim=1).view(-1)
predictions = F.one_hot(indices, num_classes=num_classes)
# FIXME make sure binary case works
target = target.type(output_type)
correct = (target * predictions > 0).type(output_type)
pred_positives = predictions.sum(dim=0)
target_positives = target.sum(dim=0)
if correct.sum() == 0:
true_positives = torch.zeros_like(pred_positives)
else:
true_positives = correct.sum(dim=0)
if self._tp_sum is None:
self._tp_sum = torch.zeros(num_classes, device=self.device)
self._true_sum = torch.zeros(num_classes, device=self.device)
self._pred_sum = torch.zeros(num_classes, device=self.device)
self._total_sum = torch.tensor(0, device=self.device)
self._tp_sum += true_positives
self._pred_sum += pred_positives
self._true_sum += target_positives
self._total_sum += target.shape[0]
def counts_as_tuple(self, reduce=False):
tp_sum = self._tp_sum
pred_sum = self._pred_sum
true_sum = self._true_sum
total_sum = self._total_sum
if reduce:
tp_sum = reduce_tensor_sum(tp_sum)
pred_sum = reduce_tensor_sum(pred_sum)
true_sum = reduce_tensor_sum(true_sum)
total_sum = reduce_tensor_sum(total_sum)
return tp_sum, pred_sum, true_sum, total_sum
def counts(self, reduce=False):
tp_sum, pred_sum, true_sum, total_sum = self.counts_as_tuple(reduce=reduce)
return dict(tp_sum=tp_sum, pred_sum=pred_sum, true_sum=true_sum, total_sum=total_sum)
def confusion(self, reduce=False):
tp_sum, pred_sum, true_sum, total_sum = self.counts_as_tuple(reduce=reduce)
fp = pred_sum - tp_sum
fn = true_sum - tp_sum
tp = tp_sum
tn = total_sum - tp - fp - fn
return dict(tp=tp, fp=fp, fn=fn, tn=tn)
def compute(self, fscore_beta=1, average='micro', no_reduce=False, distributed=False):
tp_sum, pred_sum, true_sum, total_sum = self.counts_as_tuple(reduce=distributed)
if average == 'micro':
tp_sum = tp_sum.sum()
pred_sum = pred_sum.sum()
true_sum = true_sum.sum()
precision = tp_sum / pred_sum
recall = tp_sum / true_sum
beta_sq = fscore_beta ** 2
f1_denom = beta_sq * precision + recall
fscore = (1 + beta_sq) * precision * recall / f1_denom
if average == 'macro' and not no_reduce:
precision = precision.mean()
recall = recall.mean()
fscore = fscore.mean()
return dict(fscore=fscore, precision=precision, recall=recall)
return dict(fscore=fscore, precision=precision, recall=recall)

@ -0,0 +1,30 @@
class ScalarAvgMinMax:
"""Computes and stores the average and current value"""
def __init__(self):
self.val = 0
self.avg = 0
self.min = None
self.max = None
self.sum = 0
self.count = 0
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.min = None
self.max = None
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.min = val if self.min is None else min(self.min, val)
self.max = val if self.max is None else max(self.max, val)
self.sum += val * n
self.count += n
self.avg = self.sum / self.count

@ -0,0 +1,42 @@
import torch
class TensorAvg:
"""Computes and stores the average and current value"""
def __init__(self):
self.sum = None
self.count = None
self.reset()
def reset(self):
self.sum = None
self.count = None
def update(self, val: torch.Tensor, n=1):
if self.sum is None:
self.sum = torch.zeros_like(val)
self.count = torch.tensor(0, dtype=torch.long, device=val.device)
self.sum += (val * n)
self.count += n
def compute(self):
return self.sum / self.count
class TensorEma:
"""Computes and stores the average and current value"""
def __init__(self, smoothing_factor=0.9, init_zero=False):
self.smoothing_factor = smoothing_factor
self.init_zero = init_zero
self.val = None
self.reset()
def reset(self):
self.val = None
def update(self, val):
if self.val is None:
self.val = torch.zeros_like(val) if self.init_zero else val.clone()
self.val = (1. - self.smoothing_factor) * val + self.smoothing_factor * self.val

@ -65,14 +65,16 @@ class Scheduler:
return None
def step(self, epoch: int, metric: float = None) -> None:
self.metric = metric
if metric is not None:
self.metric = metric
values = self.get_epoch_values(epoch)
if values is not None:
values = self._add_noise(values, epoch)
self.update_groups(values)
def step_update(self, num_updates: int, metric: float = None):
self.metric = metric
if metric is not None:
self.metric = metric
values = self.get_update_values(num_updates)
if values is not None:
values = self._add_noise(values, num_updates)

@ -20,14 +20,14 @@ import yaml
import os
import logging
from collections import OrderedDict
from contextlib import suppress
from datetime import datetime
import torch
import torch.nn as nn
import torchvision.utils
from torch.nn.parallel import DistributedDataParallel as NativeDDP
from timm.bits import initialize_device, DeviceEnv, create_updater, Updater, Logger, Tracker
from timm.metrics import TensorAvg, AccuracyTopK
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint,\
convert_splitbn_model, model_parameters
@ -35,32 +35,11 @@ from timm.utils import *
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
from timm.optim import create_optimizer_v2, optimizer_kwargs
from timm.scheduler import create_scheduler
from timm.utils import ApexScaler, NativeScaler
try:
from apex import amp
from apex.parallel import DistributedDataParallel as ApexDDP
from apex.parallel import convert_syncbn_model
has_apex = True
except ImportError:
has_apex = False
has_native_amp = False
try:
if getattr(torch.cuda.amp, 'autocast') is not None:
has_native_amp = True
except AttributeError:
pass
try:
import wandb
has_wandb = True
except ImportError:
has_wandb = False
torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('train')
# The first arg parser parses out only the --config argument, this argument is used to
# load a yaml file containing key-values that override the defaults for the main parser below
config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
@ -254,16 +233,10 @@ parser.add_argument('--save-images', action='store_true', default=False,
help='save images of input bathes every log interval for debugging')
parser.add_argument('--amp', action='store_true', default=False,
help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
parser.add_argument('--apex-amp', action='store_true', default=False,
help='Use NVIDIA Apex AMP mixed precision')
parser.add_argument('--native-amp', action='store_true', default=False,
help='Use Native Torch AMP mixed precision')
parser.add_argument('--channels-last', action='store_true', default=False,
help='Use channels_last memory layout')
parser.add_argument('--pin-mem', action='store_true', default=False,
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--no-prefetcher', action='store_true', default=False,
help='disable fast prefetcher')
parser.add_argument('--output', default='', type=str, metavar='PATH',
help='path to output folder (default: none, current dir)')
parser.add_argument('--experiment', default='', type=str, metavar='NAME',
@ -301,50 +274,15 @@ def _parse_args():
def main():
setup_default_logging()
args, args_text = _parse_args()
if args.log_wandb:
if has_wandb:
wandb.init(project=args.experiment, config=args)
else:
_logger.warning("You've requested to log metrics to wandb but package not found. "
"Metrics not being logged to wandb, try `pip install wandb`")
args.prefetcher = not args.no_prefetcher
args.distributed = False
if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1
args.device = 'cuda:0'
args.world_size = 1
args.rank = 0 # global rank
if args.distributed:
args.device = 'cuda:%d' % args.local_rank
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
args.world_size = torch.distributed.get_world_size()
args.rank = torch.distributed.get_rank()
_logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
% (args.rank, args.world_size))
dev_env = initialize_device(amp=args.amp)
if dev_env.is_distributed:
_logger.info('Training in distributed mode with multiple processes, 1 device per process. Process %d, total %d.'
% (dev_env.global_rank, dev_env.world_size))
else:
_logger.info('Training with a single process on 1 GPUs.')
assert args.rank >= 0
# resolve AMP arguments based on PyTorch / Apex availability
use_amp = None
if args.amp:
# `--amp` chooses native amp before apex (APEX ver not actively maintained)
if has_native_amp:
args.native_amp = True
elif has_apex:
args.apex_amp = True
if args.apex_amp and has_apex:
use_amp = 'apex'
elif args.native_amp and has_native_amp:
use_amp = 'native'
elif args.apex_amp or args.native_amp:
_logger.warning("Neither APEX or native Torch AMP is available, using float32. "
"Install NVIDA apex or upgrade to PyTorch 1.6")
random_seed(args.seed, args.rank)
_logger.info('Training with a single process on 1 device.')
random_seed(args.seed, dev_env.global_rank)
model = create_model(
args.model,
@ -364,11 +302,11 @@ def main():
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly
if args.local_rank == 0:
if dev_env.is_master:
_logger.info(
f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}')
data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
data_config = resolve_data_config(vars(args), model=model, verbose=dev_env.is_master)
# setup augmentation batch splits for contrastive loss or split bn
num_aug_splits = 0
@ -382,55 +320,33 @@ def main():
model = convert_splitbn_model(model, max(num_aug_splits, 2))
# move model to GPU, enable channels last layout if set
model.cuda()
if args.channels_last:
model = model.to(memory_format=torch.channels_last)
dev_env.to_device(model)
# setup synchronized BatchNorm for distributed training
if args.distributed and args.sync_bn:
if dev_env.is_distributed and args.sync_bn:
assert not args.split_bn
if has_apex and use_amp != 'native':
# Apex SyncBN preferred unless native amp is activated
model = convert_syncbn_model(model)
else:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if args.local_rank == 0:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if dev_env.is_master:
_logger.info(
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
if args.torchscript:
assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
model = torch.jit.script(model)
optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args))
# setup automatic mixed-precision (AMP) loss scaling and op casting
amp_autocast = suppress # do nothing
loss_scaler = None
if use_amp == 'apex':
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
loss_scaler = ApexScaler()
if args.local_rank == 0:
_logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
elif use_amp == 'native':
amp_autocast = torch.cuda.amp.autocast
loss_scaler = NativeScaler()
if args.local_rank == 0:
_logger.info('Using native Torch AMP. Training in mixed precision.')
else:
if args.local_rank == 0:
_logger.info('AMP not enabled. Training in float32.')
updater = create_updater(
create_optimizer_v2(model, **optimizer_kwargs(cfg=args)),
clip_value=args.clip_grad, clip_mode=args.clip_mode)
# optionally resume from a checkpoint
resume_epoch = None
if args.resume:
resume_epoch = resume_checkpoint(
model, args.resume,
optimizer=None if args.no_resume_opt else optimizer,
loss_scaler=None if args.no_resume_opt else loss_scaler,
log_info=args.local_rank == 0)
optimizer=None if args.no_resume_opt else updater.optimizer,
loss_scaler=None if args.no_resume_opt else updater.scaler,
log_info=dev_env.is_master)
# setup exponential moving average of model weights, SWA could be used here too
model_ema = None
@ -442,20 +358,14 @@ def main():
load_checkpoint(model_ema.module, args.resume, use_ema=True)
# setup distributed training
if args.distributed:
if has_apex and use_amp != 'native':
# Apex DDP preferred unless native amp is activated
if args.local_rank == 0:
_logger.info("Using NVIDIA APEX DistributedDataParallel.")
model = ApexDDP(model, delay_allreduce=True)
else:
if args.local_rank == 0:
_logger.info("Using native Torch DistributedDataParallel.")
model = NativeDDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1
# NOTE: EMA model does not need to be wrapped by DDP
if dev_env.is_distributed:
if dev_env.is_master:
_logger.info("Distributing model.")
model = dev_env.wrap_distributed(model)
# NOTE: EMA model does not need to be wrapped by DDP
# setup learning rate schedule and starting epoch
lr_scheduler, num_epochs = create_scheduler(args, optimizer)
lr_scheduler, num_epochs = create_scheduler(args, updater.optimizer)
start_epoch = 0
if args.start_epoch is not None:
# a specified start_epoch will always override the resume epoch
@ -465,7 +375,7 @@ def main():
if lr_scheduler is not None and start_epoch > 0:
lr_scheduler.step(start_epoch)
if args.local_rank == 0:
if dev_env.is_master:
_logger.info('Scheduled epochs: {}'.format(num_epochs))
# create the train and eval datasets
@ -478,18 +388,14 @@ def main():
# setup mixup / cutmix
collate_fn = None
mixup_fn = None
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
if mixup_active:
mixup_args = dict(
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
label_smoothing=args.smoothing, num_classes=args.num_classes)
if args.prefetcher:
assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup)
collate_fn = FastCollateMixup(**mixup_args)
else:
mixup_fn = Mixup(**mixup_args)
assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup)
collate_fn = FastCollateMixup(**mixup_args)
# wrap dataset in AugMix helper
if num_aug_splits > 1:
@ -504,7 +410,6 @@ def main():
input_size=data_config['input_size'],
batch_size=args.batch_size,
is_training=True,
use_prefetcher=args.prefetcher,
no_aug=args.no_aug,
re_prob=args.reprob,
re_mode=args.remode,
@ -521,7 +426,7 @@ def main():
mean=data_config['mean'],
std=data_config['std'],
num_workers=args.workers,
distributed=args.distributed,
distributed=dev_env.is_distributed,
collate_fn=collate_fn,
pin_memory=args.pin_mem,
use_multi_epochs_loader=args.use_multi_epochs_loader
@ -532,12 +437,11 @@ def main():
input_size=data_config['input_size'],
batch_size=args.validation_batch_size_multiplier * args.batch_size,
is_training=False,
use_prefetcher=args.prefetcher,
interpolation=data_config['interpolation'],
mean=data_config['mean'],
std=data_config['std'],
num_workers=args.workers,
distributed=args.distributed,
distributed=dev_env.is_distributed,
crop_pct=data_config['crop_pct'],
pin_memory=args.pin_mem,
)
@ -545,23 +449,24 @@ def main():
# setup loss function
if args.jsd:
assert num_aug_splits > 1 # JSD only valid with aug splits set
train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda()
train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing)
elif mixup_active:
# smoothing is handled with mixup target transform
train_loss_fn = SoftTargetCrossEntropy().cuda()
train_loss_fn = SoftTargetCrossEntropy()
elif args.smoothing:
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
else:
train_loss_fn = nn.CrossEntropyLoss().cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda()
train_loss_fn = nn.CrossEntropyLoss()
validate_loss_fn = nn.CrossEntropyLoss()
dev_env.to_device(train_loss_fn, validate_loss_fn)
# setup checkpoint saver and eval metric tracking
eval_metric = args.eval_metric
best_metric = None
best_epoch = None
saver = None
output_dir = ''
if args.local_rank == 0:
output_dir = None
if dev_env.is_master:
if args.experiment:
exp_name = args.experiment
else:
@ -573,42 +478,48 @@ def main():
output_dir = get_outdir(args.output if args.output else './output/train', exp_name)
decreasing = True if eval_metric == 'loss' else False
saver = CheckpointSaver(
model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,
model=model, optimizer=updater.optimizer, args=args, model_ema=model_ema, amp_scaler=updater.scaler,
checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist)
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
f.write(args_text)
logger = Logger(output_dir=output_dir, logger=_logger, hparams=vars(args))
try:
for epoch in range(start_epoch, num_epochs):
if args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
if dev_env.is_distributed and hasattr(loader_train.sampler, 'set_epoch'):
loader_train.sampler.set_epoch(epoch)
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
if loader_train.mixup_enabled:
loader_train.mixup_enabled = False
train_metrics = train_one_epoch(
epoch, model, loader_train, optimizer, train_loss_fn, args,
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn)
epoch, model, loader_train, updater, train_loss_fn, dev_env,
lr_scheduler=lr_scheduler, saver=saver, logger=logger, model_ema=model_ema,
log_interval=args.log_interval, recovery_interval=args.recovery_interval)
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
if args.local_rank == 0:
if dev_env.is_distributed and args.dist_bn in ('broadcast', 'reduce'):
if dev_env.is_master:
_logger.info("Distributing BatchNorm running means and vars")
distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
distribute_bn(model, dev_env.world_size, args.dist_bn == 'reduce')
eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast)
eval_metrics = evaluate(model, loader_eval, validate_loss_fn, dev_env, logger=logger)
if model_ema is not None and not args.model_ema_force_cpu:
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
ema_eval_metrics = validate(
model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)')
if dev_env.is_distributed and args.dist_bn in ('broadcast', 'reduce'):
distribute_bn(model_ema, dev_env.world_size, args.dist_bn == 'reduce')
ema_eval_metrics = evaluate(
model_ema.module, loader_eval, validate_loss_fn, dev_env,
logger=logger, phase_suffix='EMA')
eval_metrics = ema_eval_metrics
if lr_scheduler is not None:
# step LR for next epoch
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
update_summary(
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb)
if logger is not None:
logger.write_summary(index=epoch, results=dict(train=train_metrics, eval=eval_metric))
if saver is not None:
# save proper checkpoint with eval metric
@ -622,175 +533,128 @@ def main():
def train_one_epoch(
epoch, model, loader, optimizer, loss_fn, args,
lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress,
loss_scaler=None, model_ema=None, mixup_fn=None):
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
if args.prefetcher and loader.mixup_enabled:
loader.mixup_enabled = False
elif mixup_fn is not None:
mixup_fn.mixup_enabled = False
second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
batch_time_m = AverageMeter()
data_time_m = AverageMeter()
losses_m = AverageMeter()
epoch: int,
model: nn.Module,
loader,
updater: Updater,
loss_fn: nn.Module,
dev_env: DeviceEnv,
lr_scheduler=None,
saver: CheckpointSaver = None,
logger: Logger = None,
model_ema: nn.Module = None,
log_interval: int = 50,
recovery_interval: int = 0,
):
tracker = Tracker()
losses_m = TensorAvg()
model.train()
end = time.time()
last_idx = len(loader) - 1
end_idx = len(loader) - 1
num_updates = epoch * len(loader)
for batch_idx, (input, target) in enumerate(loader):
last_batch = batch_idx == last_idx
data_time_m.update(time.time() - end)
if not args.prefetcher:
input, target = input.cuda(), target.cuda()
if mixup_fn is not None:
input, target = mixup_fn(input, target)
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
with amp_autocast():
output = model(input)
batch_size = 0
tracker.mark_iter()
for step_idx, (sample, target) in enumerate(loader):
tracker.mark_iter_data_end()
last_step = step_idx == end_idx
batch_size = max(batch_size, sample.shape[0])
with dev_env.autocast():
output = model(sample)
loss = loss_fn(output, target)
if not args.distributed:
losses_m.update(loss.item(), input.size(0))
optimizer.zero_grad()
if loss_scaler is not None:
loss_scaler(
loss, optimizer,
clip_grad=args.clip_grad, clip_mode=args.clip_mode,
parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
create_graph=second_order)
else:
loss.backward(create_graph=second_order)
if args.clip_grad is not None:
dispatch_clip_grad(
model_parameters(model, exclude_head='agc' in args.clip_mode),
value=args.clip_grad, mode=args.clip_mode)
optimizer.step()
updater.reset()
updater.apply(loss)
dev_env.mark_step() # FIXME
tracker.mark_iter_step_end()
losses_m.update(loss, sample.size(0))
if model_ema is not None:
model_ema.update(model)
torch.cuda.synchronize()
num_updates += 1
batch_time_m.update(time.time() - end)
if last_batch or batch_idx % args.log_interval == 0:
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
if last_step or (step_idx + 1) % log_interval == 0:
lrl = [param_group['lr'] for param_group in updater.optimizer.param_groups]
lr = sum(lrl) / len(lrl)
if args.distributed:
reduced_loss = reduce_tensor(loss.data, args.world_size)
losses_m.update(reduced_loss.item(), input.size(0))
if args.local_rank == 0:
_logger.info(
'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) '
'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
'({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
'LR: {lr:.3e} '
'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
epoch,
batch_idx, len(loader),
100. * batch_idx / last_idx,
loss=losses_m,
batch_time=batch_time_m,
rate=input.size(0) * args.world_size / batch_time_m.val,
rate_avg=input.size(0) * args.world_size / batch_time_m.avg,
lr=lr,
data_time=data_time_m))
if args.save_images and output_dir:
torchvision.utils.save_image(
input,
os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
padding=0,
normalize=True)
if saver is not None and args.recovery_interval and (
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
saver.save_recovery(epoch, batch_idx=batch_idx)
if dev_env.is_master and logger is not None:
loss_avg = losses_m.compute()
logger.log_step(
'Train',
step=step_idx,
end_step=end_idx,
loss=loss_avg.item(),
rate=(dev_env.world_size * batch_size) / tracker.iter_time.avg,
lr=lr,
)
if saver is not None and recovery_interval and (last_step or (step_idx + 1) % recovery_interval == 0):
saver.save_recovery(epoch, batch_idx=step_idx)
if lr_scheduler is not None:
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
lr_scheduler.step_update(num_updates=num_updates)
end = time.time()
tracker.mark_iter()
# end for
if hasattr(optimizer, 'sync_lookahead'):
optimizer.sync_lookahead()
if hasattr(updater.optimizer, 'sync_lookahead'):
updater.optimizer.sync_lookahead()
return OrderedDict([('loss', losses_m.compute().item())])
return OrderedDict([('loss', losses_m.avg)])
def evaluate(
model: nn.Module,
loader,
loss_fn: nn.Module,
dev_env: DeviceEnv,
logger: Logger,
phase_suffix: str = '',
log_interval: int = 10,
):
def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''):
batch_time_m = AverageMeter()
losses_m = AverageMeter()
top1_m = AverageMeter()
top5_m = AverageMeter()
tracker = Tracker()
losses_m = TensorAvg()
accuracy_m = AccuracyTopK()
model.eval()
end = time.time()
last_idx = len(loader) - 1
end_idx = len(loader) - 1
tracker.mark_iter()
with torch.no_grad():
for batch_idx, (input, target) in enumerate(loader):
last_batch = batch_idx == last_idx
if not args.prefetcher:
input = input.cuda()
target = target.cuda()
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
with amp_autocast():
output = model(input)
if isinstance(output, (tuple, list)):
output = output[0]
# augmentation reduction
reduce_factor = args.tta
if reduce_factor > 1:
output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
target = target[0:target.size(0):reduce_factor]
loss = loss_fn(output, target)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
if args.distributed:
reduced_loss = reduce_tensor(loss.data, args.world_size)
acc1 = reduce_tensor(acc1, args.world_size)
acc5 = reduce_tensor(acc5, args.world_size)
else:
reduced_loss = loss.data
torch.cuda.synchronize()
losses_m.update(reduced_loss.item(), input.size(0))
top1_m.update(acc1.item(), output.size(0))
top5_m.update(acc5.item(), output.size(0))
batch_time_m.update(time.time() - end)
end = time.time()
if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
log_name = 'Test' + log_suffix
_logger.info(
'{0}: [{1:>4d}/{2}] '
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
log_name, batch_idx, last_idx, batch_time=batch_time_m,
loss=losses_m, top1=top1_m, top5=top5_m))
metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])
return metrics
for step_idx, (sample, target) in enumerate(loader):
tracker.mark_iter_data_end()
last_step = step_idx == end_idx
with dev_env.autocast():
output = model(sample)
if isinstance(output, (tuple, list)):
output = output[0]
loss = loss_fn(output, target)
dev_env.mark_step() # FIXME
tracker.mark_iter_step_end()
losses_m.update(loss, output.size(0))
accuracy_m.update(output, target)
if dev_env.is_master and (last_step or step_idx % log_interval == 0):
top1, top5 = accuracy_m.compute().values()
loss_avg = losses_m.compute()
logger.log_step(
'Eval',
step=step_idx,
num_steps=end_idx,
loss=loss_avg.item(),
top1=top1.item(),
top5=top5.item(),
phase_suffix=phase_suffix,
)
tracker.mark_iter()
top1, top5 = accuracy_m.compute().values()
results = OrderedDict([('loss', losses_m.compute().item()), ('top1', top1.item()), ('top5', top5.item())])
return results
if __name__ == '__main__':

@ -17,27 +17,14 @@ import torch
import torch.nn as nn
import torch.nn.parallel
from collections import OrderedDict
from contextlib import suppress
from timm.bits import initialize_device, Tracker, Logger
from timm.metrics import AccuracyTopK, TensorAvg
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_legacy
has_apex = False
try:
from apex import amp
has_apex = True
except ImportError:
pass
has_native_amp = False
try:
if getattr(torch.cuda.amp, 'autocast') is not None:
has_native_amp = True
except AttributeError:
pass
torch.backends.cudnn.benchmark = True
from timm.utils import natural_key, setup_default_logging
_logger = logging.getLogger('validate')
@ -72,36 +59,28 @@ parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',
help='path to class to idx mapping file (default: "")')
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
parser.add_argument('--log-freq', default=10, type=int,
parser.add_argument('--log-freq', default=20, type=int,
metavar='N', help='batch logging frequency (default: 10)')
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='use pre-trained model')
parser.add_argument('--num-gpu', type=int, default=1,
help='Number of GPUS to use')
parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true',
help='disable test time pool')
parser.add_argument('--no-prefetcher', action='store_true', default=False,
help='disable fast prefetcher')
# parser.add_argument('--num-gpu', type=int, default=1,
# help='Number of GPUS to use')
parser.add_argument('--test-pool', dest='test_pool', action='store_true',
help='enable test time pool')
parser.add_argument('--pin-mem', action='store_true', default=False,
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--channels-last', action='store_true', default=False,
help='Use channels_last memory layout')
parser.add_argument('--amp', action='store_true', default=False,
help='Use AMP mixed precision. Defaults to Apex, fallback to native Torch AMP.')
parser.add_argument('--apex-amp', action='store_true', default=False,
help='Use NVIDIA Apex AMP mixed precision')
parser.add_argument('--native-amp', action='store_true', default=False,
help='Use Native Torch AMP mixed precision')
parser.add_argument('--tf-preprocessing', action='store_true', default=False,
help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
help='use ema version of weights if present')
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
help='convert model torchscript for inference')
parser.add_argument('--legacy-jit', dest='legacy_jit', action='store_true',
help='use legacy jit mode for pytorch 1.5/1.5.1/1.6 to get back fusion performance')
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
help='Output csv file for validation results (summary)')
parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME',
@ -113,26 +92,8 @@ parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME',
def validate(args):
# might as well try to validate something
args.pretrained = args.pretrained or not args.checkpoint
args.prefetcher = not args.no_prefetcher
amp_autocast = suppress # do nothing
if args.amp:
if has_native_amp:
args.native_amp = True
elif has_apex:
args.apex_amp = True
else:
_logger.warning("Neither APEX or Native Torch AMP is available.")
assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set."
if args.native_amp:
amp_autocast = torch.cuda.amp.autocast
_logger.info('Validating in mixed precision with native PyTorch AMP.')
elif args.apex_amp:
_logger.info('Validating in mixed precision with NVIDIA APEX AMP.')
else:
_logger.info('Validating in float32. AMP not enabled.')
if args.legacy_jit:
set_jit_legacy()
dev_env = initialize_device(amp=args.amp)
# create model
model = create_model(
@ -154,24 +115,16 @@ def validate(args):
data_config = resolve_data_config(vars(args), model=model, use_test_size=True, verbose=True)
test_time_pool = False
if not args.no_test_pool:
if args.test_pool:
model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True)
if args.torchscript:
torch.jit.optimized_execution(True)
model = torch.jit.script(model)
model = model.cuda()
if args.apex_amp:
model = amp.initialize(model, opt_level='O1')
if args.channels_last:
model = model.to(memory_format=torch.channels_last)
if args.num_gpu > 1:
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))
criterion = nn.CrossEntropyLoss().cuda()
# FIXME device
model, criterion = dev_env.to_device(model, nn.CrossEntropyLoss())
model.to(dev_env.device)
dataset = create_dataset(
root=args.data, name=args.dataset, split=args.split,
@ -194,7 +147,6 @@ def validate(args):
dataset,
input_size=data_config['input_size'],
batch_size=args.batch_size,
use_prefetcher=args.prefetcher,
interpolation=data_config['interpolation'],
mean=data_config['mean'],
std=data_config['std'],
@ -203,63 +155,61 @@ def validate(args):
pin_memory=args.pin_mem,
tf_preprocessing=args.tf_preprocessing)
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
logger = Logger(logger=_logger)
tracker = Tracker()
losses = TensorAvg()
accuracy = AccuracyTopK().to(dev_env.device)
model.eval()
num_steps = len(loader)
with torch.no_grad():
# warmup, reduce variability of first batch time, especially for comparing torchscript vs non
input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).cuda()
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
model(input)
end = time.time()
for batch_idx, (input, target) in enumerate(loader):
if args.no_prefetcher:
target = target.cuda()
input = input.cuda()
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
tracker.mark_iter()
for step_idx, (sample, target) in enumerate(loader):
tracker.mark_iter_data_end()
# compute output
with amp_autocast():
output = model(input)
with dev_env.autocast():
output = model(sample)
if valid_labels is not None:
output = output[:, valid_labels]
loss = criterion(output, target)
if dev_env.type == 'cuda':
torch.cuda.synchronize()
#elif dev_env.type == 'xla':
# dev_env.mark_step()
tracker.mark_iter_step_end()
losses.update(loss.detach(), sample.size(0))
if real_labels is not None:
real_labels.add_result(output)
# measure accuracy and record loss
acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5))
losses.update(loss.item(), input.size(0))
top1.update(acc1.item(), input.size(0))
top5.update(acc5.item(), input.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if batch_idx % args.log_freq == 0:
_logger.info(
'Test: [{0:>4d}/{1}] '
'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) '
'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
batch_idx, len(loader), batch_time=batch_time,
rate_avg=input.size(0) / batch_time.avg,
loss=losses, top1=top1, top5=top5))
accuracy.update(output.detach(), target)
if dev_env.type == 'xla':
dev_env.mark_step()
tracker.mark_iter()
if step_idx % args.log_freq == 0:
top1, top5 = accuracy.compute().values()
loss_avg = losses.compute()
logger.log_step(
phase='eval',
step=step_idx,
num_steps=num_steps,
rate=args.batch_size / tracker.iter_time.avg,
loss=loss_avg.item(),
top1=top1.item(),
top5=top5.item(),
)
if real_labels is not None:
# real labels mode replaces topk values at the end
top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5)
else:
top1a, top5a = top1.avg, top5.avg
top1a, top5a = accuracy.compute().values()
top1a, top5a = top1a.item(), top5a.item()
results = OrderedDict(
top1=round(top1a, 4), top1_err=round(100 - top1a, 4),
top5=round(top5a, 4), top5_err=round(100 - top5a, 4),
@ -267,9 +217,7 @@ def validate(args):
img_size=data_config['input_size'][-1],
cropt_pct=crop_pct,
interpolation=data_config['interpolation'])
_logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
results['top1'], results['top1_err'], results['top5'], results['top5_err']))
logger.log_phase(phase='eval', name_map={'top1': 'Acc@1', 'top5': 'Acc@5'}, **results)
return results
@ -309,7 +257,6 @@ def main():
result = OrderedDict(model=args.model)
r = {}
while not r and batch_size >= args.num_gpu:
torch.cuda.empty_cache()
try:
args.batch_size = batch_size
print('Validating with batch size: %d' % args.batch_size)

Loading…
Cancel
Save