Major timm.bits update. Updater and DeviceEnv now dataclasses, after_step closure used, metrics base impl w/ distributed reduce, many tweaks/fixes.

pull/1239/head
Ross Wightman 3 years ago
parent 938716c753
commit aa92d7b1c5

@ -1,10 +1,25 @@
from .avg_scalar import AvgMinMaxScalar
from .avg_tensor import AvgTensor
from .device_env import DeviceEnv, DeviceEnvType
from .device_env_cuda import DeviceEnvCuda
from .device_env_factory import initialize_device, get_device from .device_env_factory import initialize_device, get_device
from .device_env import DeviceEnv from .device_env_xla import DeviceEnvXla
#from .evaluate import evaluate, eval_step from .distributed import distribute_bn, all_gather_recursive, all_reduce_recursive, broadcast_recursive,\
all_reduce_sequence, all_gather_sequence
# from .evaluate import evaluate, eval_step
from .logger import Logger from .logger import Logger
#from .task import TaskClassify from .metric import Metric, MetricValue
from .metric_accuracy import AccuracyTopK
from .tracker import Tracker
# from .task_metrics import TaskMetrics, TaskMetricsClassify
from .train_cfg import TrainCfg
from .train_services import TrainServices
from .train_setup import setup_model_and_optimizer
from .train_state import TrainState
# from .task import TaskClassify
from .updater import Updater from .updater import Updater
from .updater_cuda import UpdaterCudaWithScaler
from .updater_deepspeed import UpdaterDeepSpeed
from .updater_factory import create_updater from .updater_factory import create_updater
from .tracker import Tracker from .updater_xla import UpdaterXla, UpdaterXlaWithScaler
#from .task_metrics import TaskMetrics, TaskMetricsClassify # from .train import train_one_epoch, Experiment
#from .train import train_one_epoch, TrainServices, TrainState, TrainCfg, Experiment

@ -1,4 +1,4 @@
class ScalarAvgMinMax: class AvgMinMaxScalar:
"""Computes and stores the average and current value""" """Computes and stores the average and current value"""
def __init__(self): def __init__(self):

@ -1,7 +1,7 @@
import torch import torch
class TensorAvg: class AvgTensor:
"""Computes and stores the average and current value""" """Computes and stores the average and current value"""
def __init__(self): def __init__(self):

@ -0,0 +1,58 @@
import logging
import os
from collections import OrderedDict
import torch
from .train_state import TrainState, serialize_train_state, deserialize_train_state
_logger = logging.getLogger(__name__)
def resume_train_checkpoint(
train_state,
checkpoint_path,
resume_opt=True,
deserialize_fn=deserialize_train_state,
log_info=True):
raise NotImplementedError
# resume_epoch = None
# if os.path.isfile(checkpoint_path):
# checkpoint = torch.load(checkpoint_path, map_location='cpu')
#
# if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
# if log_info:
# _logger.info('Restoring model state from checkpoint...')
# new_state_dict = OrderedDict()
# for k, v in checkpoint['state_dict'].items():
# name = k[7:] if k.startswith('module') else k
# new_state_dict[name] = v
# model.load_state_dict(new_state_dict)
#
# if optimizer is not None and 'optimizer' in checkpoint:
# if log_info:
# _logger.info('Restoring optimizer state from checkpoint...')
# optimizer.load_state_dict(checkpoint['optimizer'])
#
# if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint:
# if log_info:
# _logger.info('Restoring AMP loss scaler state from checkpoint...')
# loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key])
#
# if 'epoch' in checkpoint:
# resume_epoch = checkpoint['epoch']
# if 'version' in checkpoint and checkpoint['version'] > 1:
# resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
#
# if log_info:
# _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
# else:
# model.load_state_dict(checkpoint)
# if log_info:
# _logger.info("Loaded checkpoint '{}'".format(checkpoint_path))
# return resume_epoch
# else:
# _logger.error("No checkpoint found at '{}'".format(checkpoint_path))
# raise FileNotFoundError()

@ -1,58 +1,130 @@
import torch
import abc import abc
from contextlib import suppress
from enum import Enum
from typing import Callable, Union, Optional, List, Tuple
from dataclasses import dataclass, field, InitVar
import torch
import torch.distributed as dist
class DeviceEnv(abc.ABC): TensorList = Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]
@property
@abc.abstractmethod
def device(self) -> torch.device:
pass
@property class DeviceEnvType(Enum):
@abc.abstractmethod """ Device Environment Types
def local_rank(self) -> int: """
pass CPU = "cpu"
CUDA = "cuda"
XLA = "xla"
@dataclass
class DeviceEnv:
device_type: InitVar[Optional[str]] = None
device_index: InitVar[Optional[int]] = None
device: torch.device = field(init=False) # set from device_type + device_index or post_init logic
world_size: Optional[int] = None # set by post_init from env when None
local_rank: Optional[int] = None # set by post_init from env when None
global_rank: Optional[int] = None # set by post_init from env when None
amp: bool = False
autocast: Optional[Callable] = None # set by post_init from env when None
memory_format: Optional[torch.memory_format] = None
dtype: Optional[torch.dtype] = None
def __post_init__(self, device_type: Optional[str], device_index: Optional[int]):
device_type = device_type or 'cpu'
self.device = torch.device(device_type) if device_index is None \
else torch.device(device_type, device_index)
self.world_size = 1 if self.world_size is None else self.world_size
self.local_rank = 0 if self.local_rank is None else self.local_rank
self.global_rank = 0 if self.global_rank is None else self.global_rank
if self.autocast is None:
self.autocast = suppress
@property @property
@abc.abstractmethod def type(self) -> DeviceEnvType:
def global_rank(self) -> int: if self.device.type == 'cpu':
pass return DeviceEnvType.CPU
elif self.device.type == 'cuda':
return DeviceEnvType.CUDA
elif self.device.type == 'xla':
return DeviceEnvType.XLA
else:
assert False, "Unexpected device type for base DevEnv impl."
@property @property
@abc.abstractmethod def type_cuda(self):
def is_distributed(self) -> bool: # shortcut for common cuda device type
pass return self.type == DeviceEnvType.CUDA
@property @property
@abc.abstractmethod def type_xla(self):
def world_size(self) -> int: # shortcut for common xla device type
pass return self.type == DeviceEnvType.XLA
@property @property
@abc.abstractmethod def distributed(self):
def is_master(self) -> bool: return self.world_size > 1
pass
@property @property
@abc.abstractmethod def primary(self):
def type(self) -> str: return self.local_rank == 0
pass
@property @property
@abc.abstractmethod def global_primary(self):
def autocast(self): return self.global_rank == 0
pass
@abc.abstractmethod
def wrap_distributed(self, *modules): def wrap_distributed(self, *modules):
pass pass
@abc.abstractmethod def wrap_parallel(self, *modules):
def to_device(self, *modules: torch.nn.Module):
pass pass
#@abc.abstractmethod def to_device(self, *modules: torch.nn.Module):
# FIXME handling dtype / memformat... disable flags, enable flags, diff fn?
moved = [m.to(device=self.device, memory_format=self.memory_format) for m in modules]
return moved[0] if len(moved) == 1 else moved
def mark_step(self): def mark_step(self):
# FIXME this is for XLA only, make it common to all devices w/ appropriate no-ops? pass # NO-OP for non-XLA devices
pass
def all_reduce_(self, tensor: TensorList, op=dist.ReduceOp.SUM, average=False):
print(len(tensor), type(tensor))
print(tensor.shape)
dist.all_reduce(tensor, op=op)
if average:
tensor.div_(self.world_size)
return tensor
def all_reduce(self, tensor: torch.Tensor, op=dist.ReduceOp.SUM, average=False):
reduce_tensor = tensor.clone()
dist.all_reduce(reduce_tensor, op=op)
if average:
reduce_tensor = reduce_tensor / self.world_size
return reduce_tensor
def all_gather(self, tensor: torch.Tensor, cat_dim=0):
output_tensors = [torch.empty_like(tensor) for _ in range(self.world_size)]
dist.all_gather(output_tensors, tensor)
return torch.cat(output_tensors, cat_dim)
def all_to_all(self, tensor: torch.Tensor, num_splits, split_dim, cat_dim=0):
input_tensors = torch.chunk(tensor, num_splits, split_dim)
output_tensors = [torch.empty_like(input_tensors[0]) for _ in range(self.world_size)]
dist.all_to_all(output_tensors, input_tensors)
return torch.cat(output_tensors, cat_dim)
def broadcast_(self, tensor: torch.Tensor, src_rank=0):
dist.broadcast(tensor, src=src_rank)
return tensor
def broadcast(self, tensor: Optional[torch.Tensor] = None, src_rank=0):
if self.global_rank != src_rank:
tensor = torch.empty_like(tensor)
assert tensor is not None
dist.broadcast(tensor, src=src_rank)
return tensor
def barrier(self):
dist.barrier()

@ -1,92 +1,58 @@
import os import os
from contextlib import suppress from contextlib import suppress
from dataclasses import dataclass, field, InitVar
from typing import Optional
import torch import torch
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel, DataParallel
from .device_env import DeviceEnv from .device_env import DeviceEnv, DeviceEnvType
def is_cuda_available(): def is_cuda_available():
return torch.cuda.is_available() return torch.cuda.is_available()
@dataclass
class DeviceEnvCuda(DeviceEnv): class DeviceEnvCuda(DeviceEnv):
def __init__(self, device_idx=None, local_rank=None, amp=False, memory_format=None): def __post_init__(self, device_type: str, device_index: Optional[int]):
assert torch.cuda.device_count() assert torch.cuda.device_count()
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
self._local_rank = 0 setup_world_size = self.world_size or int(os.environ.get('WORLD_SIZE', 1))
self._distributed = False assert setup_world_size
self._world_size = 1 if setup_world_size > 1:
self._global_rank = 0 # setup distributed
if 'WORLD_SIZE' in os.environ: assert device_index is None
self._distributed = int(os.environ['WORLD_SIZE']) > 1 if self.local_rank is None:
if self._distributed:
if local_rank is None:
lr = os.environ.get('LOCAL_RANK', None) lr = os.environ.get('LOCAL_RANK', None)
if lr is None: if lr is None:
raise RuntimeError( raise RuntimeError(
'At least one of LOCAL_RANK env variable or local_rank arg must be set to valid integer.') 'At least one of LOCAL_RANK env variable or local_rank arg must be set to valid integer.')
self._local_rank = lr self.local_rank = int(lr)
else: self.device = torch.device('cuda:%d' % self.local_rank)
self._local_rank = int(local_rank) torch.cuda.set_device(self.local_rank)
self._device = torch.device('cuda:%d' % self._local_rank)
torch.cuda.set_device(self._local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://') torch.distributed.init_process_group(backend='nccl', init_method='env://')
self._world_size = torch.distributed.get_world_size() self.world_size = torch.distributed.get_world_size()
self._global_rank = torch.distributed.get_rank() assert self.world_size == setup_world_size
self.global_rank = torch.distributed.get_rank()
else: else:
self._device = torch.device('cuda' if device_idx is None else f'cuda:{device_idx}') self.device = torch.device('cuda' if device_index is None else f'cuda:{device_index}')
self._memory_format = memory_format self.local_rank = 0
if amp: self.world_size = 1
self._amp = amp self.global_rank = 0
self._autocast = torch.cuda.amp.autocast if self.autocast is None:
else: self.autocast = torch.cuda.amp.autocast if self.amp else suppress
self._amp = amp
self._autocast = suppress
@property
def device(self):
return self._device
@property
def local_rank(self):
return self._local_rank
@property
def global_rank(self):
return self._global_rank
@property
def is_distributed(self):
return self._distributed
@property @property
def world_size(self): def type(self) -> DeviceEnvType:
return self._world_size return DeviceEnvType.CUDA
@property
def is_master(self):
return self._local_rank == 0
@property
def type(self) -> str:
return 'cuda'
@property
def amp(self) -> bool:
return self._amp
@property
def autocast(self):
return self._autocast
def wrap_distributed(self, *modules, **kwargs): def wrap_distributed(self, *modules, **kwargs):
wrapped = [DistributedDataParallel(m, device_ids=[self._local_rank], **kwargs) for m in modules] wrapped = [DistributedDataParallel(m, device_ids=[self.local_rank], **kwargs) for m in modules]
return wrapped[0] if len(wrapped) == 1 else wrapped return wrapped[0] if len(wrapped) == 1 else wrapped
def to_device(self, *modules: torch.nn.Module): def wrap_parallel(self, *modules, **kwargs):
# FIXME handling dtype / memformat... disable flags, enable flags, diff fn? assert not self.distributed
moved = [m.to(device=self._device, memory_format=self._memory_format) for m in modules] wrapped = [DataParallel(m, **kwargs) for m in modules]
return moved[0] if len(moved) == 1 else moved return wrapped[0] if len(wrapped) == 1 else wrapped

@ -1,10 +1,11 @@
from .device_env import DeviceEnv
from .device_env_cuda import DeviceEnvCuda, is_cuda_available from .device_env_cuda import DeviceEnvCuda, is_cuda_available
from .device_env_xla import DeviceEnvXla, is_xla_available from .device_env_xla import DeviceEnvXla, is_xla_available
_device_env = None _device_env = None
def initialize_device(force_cpu: bool = False, xla_device_type=None, **kwargs): def initialize_device(force_cpu: bool = False, **kwargs) -> DeviceEnv:
global _device_env global _device_env
if _device_env is not None: if _device_env is not None:
# warning # warning
@ -12,21 +13,22 @@ def initialize_device(force_cpu: bool = False, xla_device_type=None, **kwargs):
denv = None denv = None
if not force_cpu: if not force_cpu:
xla_device_type = kwargs.get('xla_device_type', None)
if is_xla_available(xla_device_type): if is_xla_available(xla_device_type):
# XLA supports more than just TPU, but by default will only look at TPU # XLA supports more than just TPU, will search in order TPU, GPU, CPU
denv = DeviceEnvXla(**kwargs, xla_device_type=xla_device_type) denv = DeviceEnvXla(**kwargs)
elif is_cuda_available(): elif is_cuda_available():
denv = DeviceEnvCuda(**kwargs) denv = DeviceEnvCuda(**kwargs)
if denv is None: if denv is None:
# FIXME implement CPU support denv = DeviceEnv()
raise NotImplementedError()
print(denv) # FIXME DEBUG
_device_env = denv _device_env = denv
return denv return denv
def get_device(): def get_device() -> DeviceEnv:
if _device_env is None: if _device_env is None:
raise RuntimeError('Please initialize device environment by calling initialize_device first.') raise RuntimeError('Please initialize device environment by calling initialize_device first.')
return _device_env return _device_env

@ -1,6 +1,10 @@
import os import os
from contextlib import suppress from contextlib import suppress
from dataclasses import dataclass, field, InitVar
from typing import Optional
import torch import torch
from torch.distributed import ReduceOp
try: try:
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
@ -15,78 +19,102 @@ try:
except ImportError as e: except ImportError as e:
xa = None xa = None
from .device_env import DeviceEnv from .device_env import DeviceEnv, DeviceEnvType, TensorList
_PT_TO_XM_OP = {
ReduceOp.SUM: 'sum',
ReduceOp.PRODUCT: 'prod',
ReduceOp.MIN: 'min',
ReduceOp.MAX: 'max',
ReduceOp.BAND: 'and',
ReduceOp.BOR: 'or',
}
def is_xla_available(xla_device_type=None): def is_xla_available(xla_device_type=None):
if not _HAS_XLA: if not _HAS_XLA:
return False return False
supported_devs = xm.get_xla_supported_devices(devkind=xla_device_type) supported_devs = xm.get_xla_supported_devices(devkind=xla_device_type)
print(supported_devs)
return len(supported_devs) >= 1 return len(supported_devs) >= 1
@dataclass
class DeviceEnvXla(DeviceEnv): class DeviceEnvXla(DeviceEnv):
def __init__(self, xla_device_type=None, device_idx=None, local_rank=0, amp=False): def __post_init__(self, device_type: Optional[str], device_idx: Optional[int]):
self._device = xm.xla_device(n=device_idx, devkind=xla_device_type) if device_type is not None:
self._local_rank = xm.get_local_ordinal(local_rank) device_type = device_type.upper()
self._world_size = xm.xrt_world_size() assert device_type in ('TPU', 'GPU', 'CPU'), "XLA device type must be one of ('TPU', 'GPU', 'CPU')"
self._distributed = self._world_size > 1 self.device = xm.xla_device(n=device_idx, devkind=device_type)
self._global_rank = 0 self.world_size = xm.xrt_world_size()
if self._distributed: if self.distributed:
self._global_rank = xm.get_ordinal() assert device_idx is None, "device_index is based on local rank for distributed XLA mode"
if amp: self.local_rank = xm.get_local_ordinal()
assert xa is not None, 'XLA AMP is not present on this build' self.global_rank = xm.get_ordinal()
self._autocast = xa.autocast
else: else:
self._autocast = suppress self.local_rank = 0
self._memory_format = None self.global_rank = 0
if self.amp:
@property assert xa is not None, 'XLA AMP is not present on this build'
def device(self): if self.autocast is None:
return self._device self.autocast = xa.autocast if self.amp else suppress
@property
def local_rank(self):
return self._local_rank
@property
def global_rank(self):
return self._global_rank
@property
def is_distributed(self):
return self._distributed
@property
def world_size(self):
return self._world_size
@property
def is_master(self):
return self._global_rank == 0
@property
def type(self) -> str:
return 'xla'
@property
def amp(self) -> bool:
return False
@property @property
def autocast(self): def type(self) -> DeviceEnvType:
return self._autocast return DeviceEnvType.XLA
def wrap_distributed(self, *modules): def wrap_distributed(self, *modules):
# NO-OP wrapped = [m for m in modules] # NO-OP
wrapped = [m for m in modules]
return wrapped[0] if len(wrapped) == 1 else wrapped return wrapped[0] if len(wrapped) == 1 else wrapped
def to_device(self, *modules: torch.nn.Module): def wrap_parallel(self, *modules):
moved = [m.to(device=self._device, memory_format=self._memory_format) for m in modules] assert False, "Not implemented"
return moved[0] if len(moved) == 1 else moved
def mark_step(self): def mark_step(self):
xm.mark_step() xm.mark_step()
def all_reduce(self, tensor: torch.Tensor, op=ReduceOp.SUM, average=False):
assert isinstance(tensor, torch.Tensor) # unlike in-place variant, lists/tuples not allowed
op = _PT_TO_XM_OP[op]
scale = 1.0
if average:
scale /= self.world_size
return xm.all_reduce(op, tensor, scale=scale)
def all_reduce_(self, tensor: TensorList, op=ReduceOp.SUM, average=False):
op = _PT_TO_XM_OP[op]
scale = 1.0
wrapped = False
if isinstance(tensor, torch.Tensor):
tensor = [tensor] # bare tensors are not operated on in-place
wrapped = True
if average:
scale /= self.world_size
xm.all_reduce(op, tensor, scale=scale)
if wrapped:
tensor = tensor[0]
return tensor
def all_gather(self, tensor: torch.Tensor, cat_dim=0):
output = xm.all_gather(tensor, cat_dim)
return output
def all_to_all(self, tensor, num_splits, split_dim, cat_dim=0):
output = xm.all_to_all(tensor, split_dim, cat_dim, num_splits)
return output
def broadcast(self, tensor: torch.Tensor, src_rank=0):
if self.global_rank != src_rank:
reduce_tensor = torch.zeros_like(tensor)
xm.all_reduce('sum', reduce_tensor)
else:
xm.all_reduce('sum', tensor)
return tensor
def broadcast_(self, tensor: torch.Tensor, src_rank=0):
out_tensor = self.broadcast(tensor, src_rank)
return tensor.copy_(out_tensor)
def barrier(self):
xm.rendezvous('timm.bits.dist_barrier')

@ -0,0 +1,151 @@
from typing import Dict, Tuple, List, Union, Any, Callable
import torch
from torch.distributed import ReduceOp
from timm.utils import unwrap_model
from .device_env import DeviceEnv, DeviceEnvType
from .device_env_factory import get_device
TensorSeq = Union[torch.Tensor, Tuple[torch.Tensor, ...], List[torch.Tensor], Dict[Any, torch.Tensor]]
def _validate_type(tensor: TensorSeq):
if isinstance(tensor, (dict, list, tuple)):
if not tensor:
return
else:
assert isinstance(tensor, torch.Tensor)
def distribute_bn(model: torch.nn.Module, reduce: bool = False, dev_env: DeviceEnv = None):
if dev_env is None:
dev_env = get_device()
# ensure every node has the same running bn stats
for bn_name, bn_buf in unwrap_model(model).named_buffers(recurse=True):
if ('running_mean' in bn_name) or ('running_var' in bn_name):
if reduce:
# average bn stats across whole group
dev_env.all_reduce_(bn_buf, average=True)
else:
# broadcast bn stats from rank 0 to whole group
dev_env.broadcast_(bn_buf, 0)
def all_gather_recursive(tensor: TensorSeq, cat_dim=0, dev_env: DeviceEnv = None):
""" Recursive all gather via DeviceEnv distributed primitives
FIXME add group support
"""
_validate_type(tensor)
if dev_env is None:
dev_env = get_device()
if isinstance(tensor, torch.Tensor):
return dev_env.all_gather(tensor, cat_dim=cat_dim)
elif isinstance(tensor, dict):
return {k: all_gather_recursive(v, dev_env=dev_env) for k, v in tensor.items()}
elif isinstance(tensor, (tuple, list)):
return type(tensor)(all_gather_recursive(v, dev_env=dev_env) for v in tensor)
def all_reduce_recursive(tensor: TensorSeq, op=ReduceOp.SUM, average=False, dev_env: DeviceEnv = None):
""" Recursive all reduce via DeviceEnv distributed primitives
FIXME add group support
"""
_validate_type(tensor)
if dev_env is None:
dev_env = get_device()
if isinstance(tensor, torch.Tensor):
return dev_env.all_reduce_(tensor, op=op, average=average)
elif isinstance(tensor, dict):
return {k: all_reduce_recursive(v, op=op, average=average, dev_env=dev_env) for k, v in tensor.items()}
elif isinstance(tensor, (tuple, list)):
return type(tensor)(all_reduce_recursive(v, op=op, average=average, dev_env=dev_env) for v in tensor)
def broadcast_recursive(tensor: TensorSeq, src_rank: int, dev_env: DeviceEnv = None):
""" Recursive broadcast via DeviceEnv distributed primitives
FIXME add group support
"""
_validate_type(tensor)
if dev_env is None:
dev_env = get_device()
if isinstance(tensor, torch.Tensor):
return dev_env.broadcast_(tensor, src_rank=src_rank)
elif isinstance(tensor, dict):
return {k: broadcast_recursive(v, src_rank=src_rank, dev_env=dev_env) for k, v in tensor.items()}
elif isinstance(tensor, (tuple, list)):
return type(tensor)(broadcast_recursive(v, src_rank=src_rank, dev_env=dev_env) for v in tensor)
def all_gather_sequence(tensor: TensorSeq, cat_dim: int = 0, dev_env: DeviceEnv = None):
""" All gather a flat Tensor sequence (dict, list, tuple) of same shape
"""
_validate_type(tensor)
if dev_env is None:
dev_env = get_device()
with torch.no_grad():
names = None
# merge values into one tensor for reduction
if isinstance(tensor, dict):
names = tensor.keys()
gather_values = tuple(tensor.values())
elif isinstance(tensor, (tuple, list)):
gather_values = tensor
else:
gather_values = (tensor,)
gather_values = torch.stack(gather_values, dim=0)
gather_values = dev_env.all_gather(gather_values, cat_dim=cat_dim + 1).unbind(dim=0)
# separate reduced values into original structure
if isinstance(tensor, dict):
gather_values = {k: v for k, v in zip(names, gather_values)}
elif isinstance(tensor, (tuple, list)):
gather_values = type(tensor)(v for v in gather_values)
else:
gather_values = gather_values[0]
return gather_values
def all_reduce_sequence(tensor: TensorSeq, op=ReduceOp.SUM, average=False, dev_env: DeviceEnv = None):
"""
All reduce the tensors in a flat Tensor sequence (dict, list, tuple) of same tensor shape
Args:
tensor (dict): inputs to be reduced. All the values must be scalar Tensor.
average (bool): whether to do average or sum
Returns:
a sequence with the same type as input (dict, list, tuple)
"""
_validate_type(tensor)
if dev_env is None:
dev_env = get_device()
with torch.no_grad():
names = None
# merge values into one tensor for reduction
if isinstance(tensor, dict):
names = tensor.keys()
reduce_values = tuple(tensor.values())
elif isinstance(tensor, (tuple, list)):
reduce_values = tensor
else:
reduce_values = (tensor,)
reduce_values = torch.stack(reduce_values, dim=0)
dev_env.all_reduce_(reduce_values, op=op, average=average)
reduce_values = reduce_values.unbind(dim=0)
# separate reduced values into original structure
if isinstance(tensor, dict):
reduce_values = {k: v for k, v in zip(names, reduce_values)}
elif isinstance(tensor, (tuple, list)):
reduce_values = type(tensor)(v for v in reduce_values)
else:
reduce_values = reduce_values[0]
return reduce_values

@ -0,0 +1,190 @@
""" PyTorch distributed helpers
Some of this lifted from Detectron2 with other fns added by myself.
FIXME many functions remain unfinished/untested
"""
from typing import Dict, Tuple, List, Union, Any, Callable
import torch
import torch.distributed as dist
from torch.distributed import ReduceOp
TensorSeq = Union[torch.Tensor, Tuple[torch.Tensor, ...], List[torch.Tensor], Dict[Any, torch.Tensor]]
def synchronize_torch():
"""
Helper function to synchronize (barrier) among all processes when
using distributed training
"""
if not dist.is_available():
return
if not dist.is_initialized():
return
world_size = dist.get_world_size()
if world_size == 1:
return
dist.barrier()
def all_reduce_sequence_torch(values: TensorSeq, op=ReduceOp.SUM, average=False, group=None):
"""
All reduce the tensors in a sequence (dict, list, tuple)
Args:
values (dict): inputs to be reduced. All the values must be scalar Tensor.
average (bool): whether to do average or sum
Returns:
a sequence with the same type as input (dict, list, tuple)
"""
world_size = dist.get_world_size(group)
if world_size <= 1:
return values
with torch.no_grad():
names = None
if isinstance(values, dict):
names = values.keys()
reduce_values = torch.stack(tuple(values.values()), dim=0)
elif isinstance(values, (tuple, list)):
reduce_values = torch.stack(values, dim=0)
else:
reduce_values = values
dist.all_reduce(reduce_values, op=op, group=group)
if average:
reduce_values /= world_size
if isinstance(values, dict):
reduce_values = {k: v for k, v in zip(names, reduce_values)}
elif isinstance(values, (tuple, list)):
reduce_values = type(values)(v for v in reduce_values)
return reduce_values
def reduce_sequence_torch(values: TensorSeq, dst_rank=0, op=ReduceOp.SUM, average=False, group=None):
"""
All reduce the tensors in a sequence (dict, list, tuple)
Args:
values (dict): inputs to be reduced. All the values must be scalar Tensor.
average (bool): whether to do average or sum
Returns:
a sequence with the same type as input (dict, list, tuple)
"""
world_size = dist.get_world_size(group)
this_rank = dist.get_rank()
if world_size <= 1:
return values
with torch.no_grad():
names = None
if isinstance(values, dict):
names = values.keys()
reduce_values = torch.stack(tuple(values.values()), dim=0)
elif isinstance(values, (tuple, list)):
reduce_values = torch.stack(values, dim=0)
else:
reduce_values = values
reduce_values = torch.stack(reduce_values, dim=0)
dist.reduce(reduce_values, dst=dst_rank, op=op, group=group)
if average and this_rank == dst_rank:
reduce_values /= world_size
if isinstance(values, dict):
reduce_values = {k: v for k, v in zip(names, reduce_values)}
elif isinstance(values, (tuple, list)):
reduce_values = type(values)(v for v in reduce_values)
return reduce_values
def all_gather_sequence_torch(values: TensorSeq, group=None, join_fn=torch.cat, join_dim=0):
world_size = dist.get_world_size(group)
def _do_gather(tensor):
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
dist.all_gather(tensor_list, tensor, group=group)
return join_fn(tensor_list, dim=join_dim)
if isinstance(values, dict):
gathered = {k: _do_gather(v) for k, v in values.items()}
return gathered
elif isinstance(values, (list, tuple)):
gathered = type(values)(_do_gather(v) for v in values)
return gathered
else:
# if not a dict, list, tuple, expect a singular tensor
assert isinstance(values, torch.Tensor)
return _do_gather(values)
def gather_sequence_torch(values: TensorSeq, dst_rank, group=None, join_fn=torch.cat, join_dim=0):
world_size = dist.get_world_size(group)
this_rank = dist.get_rank(group)
def _do_gather(tensor):
tensor_list = None
if this_rank == dst_rank:
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
dist.gather(tensor, tensor_list, dst=dst_rank, group=group)
return join_fn(tensor_list, dim=join_dim)
if isinstance(values, dict):
gathered = {k: _do_gather(v) for k, v in values.items()}
return gathered
elif isinstance(values, (list, tuple)):
gathered = type(values)(_do_gather(v) for v in values)
return gathered
else:
# if not a dict, list, tuple, expect a singular tensor
assert isinstance(values, torch.Tensor)
return _do_gather(values)
def all_gather_torch(value: TensorSeq, group=None, join_fn: Callable = None, join_dim=0):
if isinstance(value, torch.Tensor):
world_size = dist.get_world_size(group)
out_tensors = [torch.empty_like(value) for _ in range(world_size)]
dist.all_gather(out_tensors, value, group=group)
if join_fn is not None:
out_tensors = join_fn(out_tensors, dim=join_dim)
return out_tensors
elif isinstance(value, dict):
return {k: all_gather_torch(v, group, join_fn, join_dim) for k, v in value.items()}
elif isinstance(value, (tuple, list)):
return type(value)(all_gather_torch(v, group, join_fn, join_dim) for v in value)
def gather_torch(value: TensorSeq, dst_rank=0, group=None, join_fn: Callable = None, join_dim=0):
if isinstance(value, torch.Tensor):
world_size = dist.get_world_size(group)
this_rank = dist.get_rank()
out_tensors = None
if this_rank == dst_rank:
out_tensors = [torch.empty_like(value) for _ in range(world_size)]
dist.gather(value, out_tensors, dst=dst_rank, group=group)
if join_fn is not None:
out_tensors = join_fn(out_tensors, dim=join_dim)
return out_tensors
elif isinstance(value, dict):
return {k: gather_torch(v, dst_rank, group, join_fn, join_dim) for k, v in value.items()}
elif isinstance(value, (tuple, list)):
return type(value)(gather_torch(v, dst_rank, group, join_fn, join_dim) for v in value)
def all_reduce_torch(value: TensorSeq, op=ReduceOp.SUM, average=False, group=None):
if isinstance(value, torch.Tensor):
dist.all_reduce(value, op=op, group=group)
if average:
value /= dist.get_world_size(group)
elif isinstance(value, dict):
return {k: all_reduce_torch(v, op=op, average=average, group=group) for k, v in value.items()}
elif isinstance(value, (tuple, list)):
return type(value)(all_reduce_torch(v, op=op, average=average, group=group) for v in value)
def broadcast_torch(value: TensorSeq, src_rank: int = 0, group=None):
if isinstance(value, torch.Tensor):
return dist.broadcast(value, src=src_rank, group=group)
elif isinstance(value, dict):
return {k: broadcast_torch(v, src_rank=src_rank, group=group) for k, v in value.items()}
elif isinstance(value, (tuple, list)):
return type(value)(broadcast_torch(v, src_rank=src_rank, group=group) for v in value)

@ -16,21 +16,11 @@ def get_clip_grad_fn(mode: str = 'norm', norm_type: float = 2.0):
assert False, f"Unknown clip mode ({mode})." assert False, f"Unknown clip mode ({mode})."
def get_clip_parameters(model): def get_clip_parameters(model, skip_last=0):
if hasattr(model, 'get_clip_parameters'): if hasattr(model, 'get_clip_parameters'):
return model.get_clip_parameters() return model.get_clip_parameters()
else: else:
return model.parameters() if skip_last:
return list(model.parameters())[::-skip_last]
else:
class GradClipper: return model.parameters()
def __init__(self, model, clip_value, clip_mode='norm'):
self.model = model
self.clip_fn = get_clip_grad_fn(clip_mode)
self.clip_value = clip_value
self.enabled = True
def __call__(self):
if self.enabled:
self.clip_fn(get_clip_parameters(self.model), self.clip_value)

@ -21,6 +21,8 @@ except ImportError:
HAS_WANDB = False HAS_WANDB = False
from .device_env_factory import get_device
# FIXME old formatting for reference, to remove # FIXME old formatting for reference, to remove
# #
# def log_eval(batch_idx, last_idx, batch_time, loss, top1, top5, log_suffix=''): # def log_eval(batch_idx, last_idx, batch_time, loss, top1, top5, log_suffix=''):
@ -84,10 +86,16 @@ class SummaryCsv:
dw.writerow(row_dict) dw.writerow(row_dict)
_sci_keys = {'lr'}
def _add_kwargs(text_update, name_map=None, **kwargs): def _add_kwargs(text_update, name_map=None, **kwargs):
def _to_str(key, val): def _to_str(key, val):
if isinstance(val, float): if isinstance(val, float):
return f'{key}: {val:.4f}' if key.lower() in _sci_keys:
return f'{key}: {val:.3e} '
else:
return f'{key}: {val:.4f}'
else: else:
return f'{key}: {val}' return f'{key}: {val}'
@ -120,12 +128,13 @@ class Logger:
self, self,
experiment_name=None, experiment_name=None,
output_dir=None, output_dir=None,
logger=None, python_logger=None,
hparams=None,
log_wandb=False, log_wandb=False,
hparams=None): output_enabled=True,
):
self.output_dir = output_dir # for tensorboard, csv, console logging to file? self.output_dir = output_dir # for tensorboard, csv, text file (TODO) logging
self.logger = logger or logging.getLogger('log') self.logger = python_logger or logging.getLogger('log')
hparams = hparams or {} hparams = hparams or {}
# Setup CSV writer(s) # Setup CSV writer(s)
@ -146,28 +155,32 @@ class Logger:
_logger.warning("You've requested to log metrics to wandb but package not found. " _logger.warning("You've requested to log metrics to wandb but package not found. "
"Metrics not being logged to wandb, try `pip install wandb`") "Metrics not being logged to wandb, try `pip install wandb`")
self.output_enabled = output_enabled
# FIXME image save # FIXME image save
def log_step( def log_step(
self, self,
phase: str, phase: str,
step: int, step: int,
end_step: Optional[int] = None, step_end: Optional[int] = None,
epoch: Optional[int] = None,
loss: Optional[float] = None, loss: Optional[float] = None,
rate: Optional[float] = None, rate: Optional[float] = None,
epoch: Optional[int] = None,
phase_suffix: str = '', phase_suffix: str = '',
**kwargs, **kwargs,
): ):
""" log train/eval step """ log train/eval step
""" """
phase_title = f'{phase.capitalize()} ({phase_suffix})' if phase_suffix else f'{phase.capitalize()}' if not self.output_enabled:
progress = 100. * step / end_step if end_step else 0. return
phase_title = f'{phase.capitalize()} ({phase_suffix})' if phase_suffix else f'{phase.capitalize()}:'
progress = 100. * step / step_end if step_end else 0.
text_update = [ text_update = [
phase_title, phase_title,
f'Epoch: {epoch}' if epoch is not None else None, f'{epoch}' if epoch is not None else None,
f'Step: {step}' if end_step is None else None, f'[{step}]' if step_end is None else None,
f'Step: [{step}/{end_step} ({progress:>3.0f}%)]' if end_step is not None else None, f'[{step}/{step_end} ({progress:>3.0f}%)]' if step_end is not None else None,
f'Rate: {rate:.2f}/s' if rate is not None else None, f'Rate: {rate:.2f}/s' if rate is not None else None,
f'Loss: {loss:.5f}' if loss is not None else None, f'Loss: {loss:.5f}' if loss is not None else None,
] ]
@ -187,6 +200,9 @@ class Logger:
): ):
"""log completion of evaluation or training phase """log completion of evaluation or training phase
""" """
if not self.output_enabled:
return
title = [ title = [
f'{phase.capitalize()}', f'{phase.capitalize()}',
f'epoch: {epoch}' if epoch is not None else None, f'epoch: {epoch}' if epoch is not None else None,
@ -212,6 +228,9 @@ class Logger:
index: value for row index (typically epoch #) index: value for row index (typically epoch #)
index_name: name for row index header (typically 'epoch') index_name: name for row index header (typically 'epoch')
""" """
if not self.output_enabled:
return
row_dict = summary_row_dict(index=index, index_name=index_name, results=results) row_dict = summary_row_dict(index=index, index_name=index_name, results=results)
if self.csv_writer: if self.csv_writer:
self.csv_writer.update(row_dict) self.csv_writer.update(row_dict)

@ -0,0 +1,142 @@
import abc
from typing import Callable, Union, Optional, List, Tuple, Dict
from dataclasses import dataclass
import torch
from torch.distributed import ReduceOp
from .device_env import DeviceEnv
from .device_env_factory import get_device
from .distributed import all_gather_sequence, all_reduce_sequence
MetricValue = Union[float, torch.Tensor, List[float], List[torch.Tensor]]
@dataclass
class ValueInfo:
initial: Optional[MetricValue] = 0.
dtype: torch.dtype = torch.float32
dist_reduce: str = 'sum'
dist_average: bool = False
class Metric(abc.ABC):
def __init__(self, dev_env: DeviceEnv = None):
self._infos: Dict[str, ValueInfo] = {}
self._values: Dict[str, Optional[MetricValue]] = {}
self._values_dist: Dict[str, Optional[MetricValue]] = {}
if dev_env is None:
dev_env = get_device()
self._dev_env = dev_env
def _register_value(self, name: str, info: Optional[ValueInfo] = None):
info = info or ValueInfo()
self._infos[name] = info
# def get_value(self, name: str, use_dist=True):
# if use_dist:
# return self._values_dist.get(name, self._values.get(name))
# else:
# return self._values.get(name)
def __getattr__(self, item):
if item not in self._infos:
raise AttributeError
value = self._values_dist.get(item, self._values.get(item, None))
return value
def __setattr__(self, key, value):
if '_infos' in self.__dict__ and key in self._infos:
self._values[key] = value
else:
super().__setattr__(key, value)
def update(
self,
predictions: Union[torch.Tensor, Dict[str, torch.Tensor]],
target: Union[torch.Tensor, Dict[str, torch.Tensor]]):
self._update(predictions, target)
def _update(
self,
predictions: Union[torch.Tensor, Dict[str, torch.Tensor]],
target: Union[torch.Tensor, Dict[str, torch.Tensor]]):
pass
def reset(self):
self._values = {}
self._values_dist = {}
for name, info in self._infos.items():
# if info specifies an initial value, we reset here, otherwise set to None and leave it to child class
if info.initial is not None:
if isinstance(info.initial, torch.Tensor):
tensor = info.initial.detach().clone()
else:
tensor = torch.ones([], dtype=info.dtype) * info.initial # scalar
self._values[name] = tensor.to(device=self._dev_env.device, dtype=info.dtype)
else:
self._values[name] = None
self._reset()
def _reset(self):
pass
def compute(self) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[str, torch.Tensor]]:
if self._dev_env.distributed:
self._distribute_values()
results = self._compute()
self._values_dist = {}
return results
@abc.abstractmethod
def _compute(self) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[str, torch.Tensor]]:
pass
def _distribute_values(self):
if not self._infos or not self._values:
return
def _args(op: str):
if op == 'cat':
return True, dict(cat_dim=0)
else:
return False, dict(op=ReduceOp.SUM)
prev_dsr = None
same_dsr = True
names = []
values = []
reductions = []
for name, value in self._values.items():
if value is not None:
info = self._infos[name]
dsr = (value.dtype, value.shape, info.dist_reduce)
if prev_dsr is not None and prev_dsr != dsr:
same_dsr = False
prev_dsr = dsr
names.append(name)
values.append(value)
reductions.append(_args(info.dist_reduce))
same_dsr = False
if same_dsr:
do_gather, reduce_kwargs = reductions[0]
if do_gather:
reduced_values = all_gather_sequence(values, dev_env=self._dev_env, **reduce_kwargs)
else:
reduced_values = all_reduce_sequence(values, dev_env=self._dev_env, **reduce_kwargs)
for name, reduced_value in zip(names, reduced_values):
info = self._infos[name]
if info.dist_average:
reduced_value /= self._dev_env.world_size
self._values_dist[name] = reduced_value
else:
for n, v, r in zip(names, values, reductions):
info = self._infos[n]
do_gather, reduce_kwargs = r
if do_gather:
reduced_value = self._dev_env.all_gather(v, **reduce_kwargs)
else:
reduced_value = self._dev_env.all_reduce(v, **reduce_kwargs)
if info.dist_average:
reduced_value /= self._dev_env.world_size
self._values_dist[n] = reduced_value

@ -0,0 +1,98 @@
import torch
from typing import Optional, Tuple, Dict
from .device_env import DeviceEnv
from .metric import Metric, ValueInfo
class Accuracy(Metric):
def __init__(self, threshold=0.5, multi_label=False, dev_env=None):
super().__init__(dev_env=dev_env)
self.threshold = threshold
self.eps = 1e-8
self.multi_label = multi_label
# statistics / counts
self._register_value('correct')
self._register_value('total')
def _update(self, predictions, target):
raise NotImplemented()
def _compute(self):
raise NotImplemented()
# class AccuracyTopK(torch.nn.Module):
#
# def __init__(self, topk=(1, 5), device=None):
# super().__init__()
# self.eps = 1e-8
# self.device = device
# self.topk = topk
# self.maxk = max(topk)
# # FIXME handle distributed operation
#
# # statistics / counts
# self.reset()
#
# def update(self, predictions: torch.Tensor, target: torch.Tensor):
# sorted_indices = predictions.topk(self.maxk, dim=1)[1]
# sorted_indices.t_()
# correct = sorted_indices.eq(target.reshape(1, -1).expand_as(sorted_indices))
#
# batch_size = target.shape[0]
# correct_k = {k: correct[:k].reshape(-1).float().sum(0) for k in self.topk}
# for k, v in correct_k.items():
# attr = f'_correct_top{k}'
# old_v = getattr(self, attr)
# setattr(self, attr, old_v + v)
# self._total_sum += batch_size
#
# def reset(self):
# for k in self.topk:
# setattr(self, f'_correct_top{k}', torch.tensor(0, dtype=torch.float32))
# self._total_sum = torch.tensor(0, dtype=torch.float32)
#
# @property
# def counts(self):
# pass
#
# def compute(self) -> Dict[str, torch.Tensor]:
# # FIXME handle distributed reduction
# return {f'top{k}': 100 * getattr(self, f'_correct_top{k}') / self._total_sum for k in self.topk}
class AccuracyTopK(Metric):
def __init__(self, topk=(1, 5), dev_env: DeviceEnv = None):
super().__init__(dev_env=dev_env)
self.eps = 1e-8
self.topk = topk
self.maxk = max(topk)
# statistics / counts
for k in self.topk:
self._register_value(f'top{k}')
self._register_value('total')
self.reset()
def _update(self, predictions: torch.Tensor, target: torch.Tensor):
batch_size = predictions.shape[0]
sorted_indices = predictions.topk(self.maxk, dim=1)[1]
target_reshape = target.reshape(-1, 1).expand_as(sorted_indices)
correct = sorted_indices.eq(target_reshape).float().sum(0)
for k in self.topk:
attr_name = f'top{k}'
correct_at_k = correct[:k].sum()
setattr(self, attr_name, getattr(self, attr_name) + correct_at_k)
self.total += batch_size
def _compute(self) -> Dict[str, torch.Tensor]:
assert self.total is not None
output = {}
for k in self.topk:
attr_name = f'top{k}'
output[attr_name] = 100 * getattr(self, attr_name) / self.total
return output

@ -1,16 +1,16 @@
import time import time
from typing import Optional from typing import Optional
from timm.metrics import ScalarAvgMinMax from .avg_scalar import AvgMinMaxScalar
class Tracker: class Tracker:
def __init__(self): def __init__(self):
self.data_time = ScalarAvgMinMax() # time for data loader to produce batch of samples self.data_time = AvgMinMaxScalar() # time for data loader to produce batch of samples
self.step_time = ScalarAvgMinMax() # time for model step self.step_time = AvgMinMaxScalar() # time for model step
self.iter_time = ScalarAvgMinMax() # full iteration time incl. data, step, and book-keeping self.iter_time = AvgMinMaxScalar() # full iteration time incl. data, step, and book-keeping
self.epoch_time = ScalarAvgMinMax() self.epoch_time = AvgMinMaxScalar()
self.iter_timestamp: Optional[float] = None self.iter_timestamp: Optional[float] = None
self.prev_timestamp: Optional[float] = None self.prev_timestamp: Optional[float] = None
@ -48,3 +48,12 @@ class Tracker:
self.epoch_time.update(epoch_time) self.epoch_time.update(epoch_time)
self.epoch_timestamp = timestamp self.epoch_timestamp = timestamp
def get_avg_iter_rate(self, num_per_iter: int):
if num_per_iter == 0 or self.iter_time.avg == 0:
return 0
return num_per_iter / self.iter_time.avg
def get_last_iter_rate(self, num_per_iter: int):
if num_per_iter == 0 or self.iter_time.val == 0:
return 0
return num_per_iter / self.iter_time.val

@ -0,0 +1,12 @@
from dataclasses import dataclass
@dataclass
class TrainCfg:
""" Train Loop Configuration
Dataclass to propagate training configuration values
"""
num_epochs: int = 0
log_interval: int = 50
recovery_interval: int = 0
accumulate_steps: int = 0

@ -0,0 +1,13 @@
from dataclasses import dataclass
from .logger import Logger
from timm.utils.checkpoint_saver import CheckpointSaver
@dataclass
class TrainServices:
""" Train Loop Services
"""
logger: Logger = None
saver: CheckpointSaver = None

@ -0,0 +1,153 @@
import dataclasses
from typing import Callable, Union, Optional
import logging
import torch
import torch.nn as nn
from timm.optim import create_optimizer_v2
from timm.utils import ModelEmaV2
try:
import deepspeed as ds
except ImportError:
ds = None
from .checkpoint import resume_train_checkpoint
from .device_env import DeviceEnv
from .train_cfg import TrainCfg
from .train_state import TrainState
from .updater_factory import create_updater
_logger = logging.getLogger(__name__)
def setup_model_and_optimizer(
dev_env: DeviceEnv,
model: nn.Module,
optimizer: Union[Callable, str],
optimizer_cfg,
clip_fn: Optional[Union[Callable, str]] = None,
clip_value: Optional[float] = None,
model_ema: bool = False,
model_ema_decay: float = 0.9999,
use_syncbn: bool = False,
resume_path: str = '',
resume_opt: bool = True,
deepspeed: bool = False,
):
"""
Args:
dev_env:
model:
optimizer:
optimizer_cfg:
clip_value:
clip_fn:
model_ema:
model_ema_decay:
use_syncbn:
resume_path:
resume_opt:
Returns:
"""
if deepspeed:
return setup_model_and_optimizer_deepspeed(
dev_env=dev_env, model=model, optimizer=optimizer, optimizer_cfg=optimizer_cfg,
clip_fn=clip_fn, clip_value=clip_value, model_ema=model_ema, model_ema_decay=model_ema_decay,
resume_path=resume_path, resume_opt=resume_opt,
)
dev_env.to_device(model)
if use_syncbn and dev_env.distributed:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if dev_env.primary:
_logger.info(
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
if isinstance(optimizer, Callable):
optimizer = optimizer(model=model, **optimizer_cfg)
else:
optimizer = create_optimizer_v2(model=model, **optimizer_cfg)
updater = create_updater(
model=model,
optimizer=optimizer,
clip_fn=clip_fn,
clip_value=clip_value,
)
# ema model
model_ema = ModelEmaV2(model, decay=model_ema_decay) if model_ema else None
train_state = TrainState(model=model, updater=updater, model_ema=model_ema)
if resume_path:
resume_train_checkpoint(
train_state,
resume_path,
resume_opt=resume_opt,
log_info=dev_env.primary)
if dev_env.distributed:
train_state = dataclasses.replace(
train_state, model=dev_env.wrap_distributed(train_state.model))
return train_state
def setup_model_and_optimizer_deepspeed(
dev_env: DeviceEnv,
model: nn.Module,
optimizer: Union[Callable, str],
optimizer_cfg,
clip_fn: Optional[Union[Callable, str]] = None,
clip_value: Optional[float] = None,
model_ema: bool = False,
model_ema_decay: float = 0.9999,
use_syncbn: bool = False,
resume_path: str = '',
resume_opt: bool = True,
):
dev_env.to_device(model)
if isinstance(optimizer, Callable):
optimizer = optimizer(model=model, **optimizer_cfg)
else:
optimizer = create_optimizer_v2(model=model, **optimizer_cfg)
model = ds.initialize(model=model, optimizer=optimizer, dist_init_required=False)
updater = create_updater(
model=model,
optimizer=optimizer,
clip_fn=clip_fn,
clip_value=clip_value,
deepspeed=True,
)
# ema model
# FIXME how to do EMA w/ deepspeed?
model_ema = ModelEmaV2(model, decay=model_ema_decay) if model_ema else None
train_state = TrainState(model=model, updater=updater, model_ema=model_ema)
if resume_path:
# FIXME deepspeed resumes differently
resume_train_checkpoint(
train_state,
resume_path,
resume_opt=resume_opt,
log_info=dev_env.primary)
if dev_env.distributed:
train_state = dataclasses.replace(
train_state, model=dev_env.wrap_distributed(train_state.model))
return train_state

@ -0,0 +1,33 @@
from dataclasses import dataclass
from typing import Dict, Any
from torch import nn as nn
from timm.scheduler import Scheduler
from .updater import Updater
@dataclass
class TrainState:
model: nn.Module = None
train_loss: nn.Module = None
eval_loss: nn.Module = None
updater: Updater = None
lr_scheduler: Scheduler = None
model_ema: nn.Module = None
step_count_epoch: int = 0
step_count_global: int = 0
epoch: int = 0
def __post_init__(self):
assert self.model is not None
assert self.updater is not None
def serialize_train_state(train_state: TrainState):
pass
def deserialize_train_state(checkpoint: Dict[str, Any]):
pass

@ -1,54 +1,68 @@
from dataclasses import dataclass, field, InitVar
from functools import partial
from typing import Callable, Optional, Union from typing import Callable, Optional, Union
import torch import torch
import torch.nn as nn
from .grad_clipper import GradClipper from .grad_clip import get_clip_grad_fn, get_clip_parameters
@dataclass
class Updater: class Updater:
def __init__( model: nn.Module = None
self, optimizer: torch.optim.Optimizer = None # FIXME handle multiple optimizers per-model
optimizer: torch.optim.Optimizer, clip_fn: Optional[Union[Callable, str]] = None
clip_value: Optional[Union[Callable, float]] = None, clip_value: Optional[float] = None
clip_mode: str = 'norm'): clip_params_fn: Optional[Callable] = None
grad_scaler: Optional[Callable] = None
self.optimizer = optimizer create_graph: Optional[bool] = None
self.clipper: Optional[GradClipper] = None after_step_closure: bool = False
if clip_value is not None:
if isinstance(clip_value, Callable): def __post_init__(self):
self.clipper = clip_value assert self.model is not None
assert self.optimizer is not None
if self.clip_fn is not None:
if isinstance(self.clip_fn, Callable):
skip_last = 0
else: else:
GradClipper(clip_value, clip_mode) assert isinstance(self.clip_fn, str)
self.scaler = None skip_last = 2 if 'agc' in self.clip_fn else 0
self.create_graph = getattr(self.optimizer, 'second_order', False) self.clip_fn = get_clip_grad_fn(self.clip_fn)
self.num_accumulated = 0 assert self.clip_value is not None
self.clip_params_fn = partial(get_clip_parameters, model=self.model, skip_last=skip_last)
if self.create_graph is None:
self.create_graph = getattr(self.optimizer, 'second_order', False)
self.after_step_closure = False self.after_step_closure = False
def reset(self):
self.optimizer.zero_grad()
def apply(self, loss: torch.Tensor, accumulate=False): def apply(self, loss: torch.Tensor, accumulate=False):
loss.backward(create_graph=self.create_graph) loss.backward(create_graph=self.create_graph)
if self.clipper is not None: if accumulate:
self.clipper() return
if not accumulate: if self.clip_fn is not None:
self.optimizer.step() self.clip_fn(self.clip_params_fn(), self.clip_value)
self.reset() self.optimizer.step()
else: self.reset()
self.num_accumulated += 1
def reset(self): def get_average_lr(self):
self.optimizer.zero_grad() lrl = [param_group['lr'] for param_group in self.optimizer.param_groups if param_group['lr'] > 0]
self.num_accumulated = 0 return sum(lrl) / len(lrl)
def state_dict(self): def state_dict(self):
state_dict = dict(optimizer=self.optimizer.state_dict()) state_dict = dict(optimizer=self.optimizer.state_dict())
if self.scaler is not None: if self.grad_scaler is not None:
state_dict['scaler'] = self.scaler.state_dict() state_dict['grad_scaler'] = self.grad_scaler.state_dict()
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
if 'optimizer' in state_dict: if 'optimizer' in state_dict:
self.optimizer.load_state_dict(state_dict['optimizer']) self.optimizer.load_state_dict(state_dict['optimizer'])
if 'scaler' in state_dict and self.scaler is not None: if 'grad_scaler' in state_dict and self.grad_scaler is not None:
self.scaler.load_state_dict(state_dict['scaler']) self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
def after_step(self, after_step_fn, *args):
after_step_fn(*args)

@ -1,36 +1,30 @@
from typing import Callable, Optional, Union, Any from dataclasses import dataclass, field, InitVar
from typing import Dict, Any
import torch import torch
from .updater import Updater from .updater import Updater
class UpdaterCuda(Updater): @dataclass
def __init__( class UpdaterCudaWithScaler(Updater):
self,
optimizer: torch.optim.Optimizer, scaler_kwargs: InitVar[Dict[str, Any]] = None
clip_value: Optional[Union[Callable, float]] = None,
clip_mode: str = 'norm', def __post_init__(self, scaler_kwargs: Dict[str, Any]):
use_scaler: bool = False, super().__post_init__()
scaler_kwargs: Any = None,
):
super().__init__(optimizer=optimizer, clip_value=clip_value, clip_mode=clip_mode)
scaler_kwargs = scaler_kwargs or {} scaler_kwargs = scaler_kwargs or {}
if use_scaler: self.grad_scaler = torch.cuda.amp.GradScaler(**scaler_kwargs)
self.scaler = torch.cuda.amp.GradScaler(**scaler_kwargs)
def apply(self, loss: torch.Tensor, accumulate=False): def apply(self, loss: torch.Tensor, accumulate=False):
if self.scaler is not None: self.grad_scaler.scale(loss).backward(create_graph=self.create_graph)
self.scaler.scale(loss).backward(create_graph=self.create_graph) if accumulate:
if self.clipper is not None: # unscale first?
self.scaler.unscale_(self.optimizer) # unscale the gradients of optimizer's assigned params in-place return
self.clipper() if self.clip_fn is not None:
if not accumulate: # unscale the gradients of optimizer's assigned params in-place
self.scaler.step(self.optimizer) self.grad_scaler.unscale_(self.optimizer)
self.reset() self.clip_fn(self.clip_params_fn(), self.clip_value)
else: self.grad_scaler.step(self.optimizer)
self.num_accumulated += 1 self.grad_scaler.update()
self.scaler.update() self.reset()
else:
Updater.apply(self, loss, accumulate)

@ -0,0 +1,26 @@
from dataclasses import dataclass, field, InitVar
import torch
try:
import deepspeed as ds
except ImportError as e:
ds = None
from .updater import Updater
@dataclass
class UpdaterDeepSpeed(Updater):
def __post_init__(self):
super().__post_init__()
# FIXME not sure how to deal with model.module / grad clipping w/ DS engine interface
assert isinstance(self.model, ds.DeepSpeedEngine)
def reset(self):
self.model.zero_grad()
def apply(self, loss: torch.Tensor, accumulate=False):
self.model.backward(loss)
self.model.step()
self.reset()

@ -2,29 +2,38 @@ from typing import Callable, Optional, Union, Any
import torch import torch
from .device_env import DeviceEnv from .device_env import DeviceEnv, DeviceEnvType
from .device_env_factory import get_device from .device_env_factory import get_device
from .updater import Updater from .updater import Updater
from .updater_cuda import UpdaterCuda from .updater_cuda import UpdaterCudaWithScaler
from .updater_xla import UpdaterXla from .updater_deepspeed import UpdaterDeepSpeed
from .updater_xla import UpdaterXla, UpdaterXlaWithScaler
def create_updater( def create_updater(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
clip_fn: Optional[Union[Callable, str]] = None,
clip_value: Optional[float] = None,
scaler_kwargs: Any = None,
dev_env: Optional[DeviceEnv] = None, dev_env: Optional[DeviceEnv] = None,
clip_value: Optional[Union[Callable, float]] = None, deepspeed: bool = False,
clip_mode: str = 'norm', ) -> Updater:
scaler_kwargs: Any = None) -> Updater:
if not dev_env: if not dev_env:
dev_env = get_device() dev_env = get_device()
updater_kwargs = dict( updater_kwargs = dict(model=model, optimizer=optimizer, clip_fn=clip_fn, clip_value=clip_value)
optimizer=optimizer, clip_value=clip_value, clip_mode=clip_mode, scaler_kwargs=scaler_kwargs) use_scaler = dev_env.amp
if dev_env.type == 'xla': if use_scaler:
return UpdaterXla(**updater_kwargs, use_scaler=dev_env.amp) updater_kwargs['scaler_kwargs'] = scaler_kwargs
elif dev_env.type == 'cuda': updater_cls = Updater
return UpdaterCuda(**updater_kwargs, use_scaler=dev_env.amp) if dev_env.type == DeviceEnvType.XLA:
else: updater_cls = UpdaterXlaWithScaler if use_scaler else UpdaterXla
updater_kwargs.pop('scaler_kwargs', None) elif dev_env.type == DeviceEnvType.CUDA and use_scaler:
return Updater(**updater_kwargs) updater_cls = UpdaterCudaWithScaler
elif deepspeed:
del updater_kwargs['scaler_kwargs']
updater_cls = UpdaterDeepSpeed
return updater_cls(**updater_kwargs)

@ -1,6 +1,8 @@
from typing import Callable, Optional, Union, Any from dataclasses import dataclass, field, InitVar
from typing import Any, Dict
import torch import torch
import torch.nn as nn
try: try:
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
@ -18,41 +20,49 @@ except ImportError as e:
from .updater import Updater from .updater import Updater
@dataclass
class UpdaterXla(Updater): class UpdaterXla(Updater):
def __init__( def __post_init__(self):
self, super().__post_init__()
optimizer: torch.optim.Optimizer,
clip_value: Optional[Union[Callable, float]] = None,
clip_mode: str = 'norm',
use_scaler: bool = False,
scaler_kwargs: Any = None,
):
super().__init__(optimizer=optimizer, clip_value=clip_value, clip_mode=clip_mode)
self.after_step_closure = True self.after_step_closure = True
if use_scaler:
assert xa is not None, 'XLA AMP not present in this build'
self.scaler = xa.GradScaler(**scaler_kwargs)
def apply(self, loss: torch.Tensor, accumulate: bool = False): def apply(self, loss: torch.Tensor, accumulate: bool = False):
if self.scaler is None: loss.backward(create_graph=self.create_graph)
loss.backward(create_graph=self.create_graph) if accumulate:
gradients = xm._fetch_gradients(self.optimizer) return
xm.all_reduce('sum', gradients, scale=1.0 / xm.xrt_world_size()) xm.reduce_gradients(self.optimizer)
if self.clipper is not None: if self.clip_fn is not None:
self.clipper() self.clip_fn(self.clip_params_fn(), self.clip_value)
if not accumulate: self.optimizer.step()
xm.optimizer_step(self.optimizer) xm.mark_step()
else: self.reset()
self.scaler.scale(loss).backward(create_graph=self.create_graph)
if self.clipper is not None:
self.scaler.unscale_(self.optimizer) # unscale the gradients of optimizer's assigned params in-place
self.clipper()
if not accumulate:
self.scaler.step(self.optimizer)
self.reset()
self.scaler.update()
def after_step(self, after_step_fn, *args): def after_step(self, after_step_fn, *args):
xm.add_step_closure(after_step_fn, *args) xm.add_step_closure(after_step_fn, args)
@dataclass
class UpdaterXlaWithScaler(UpdaterXla):
scaler_kwargs: InitVar[Dict[str, Any]] = None
def __post_init__(self, scaler_kwargs: Dict[str, Any]):
super().__post_init__()
scaler_kwargs = scaler_kwargs or {}
assert xa is not None, 'XLA AMP not present in this build'
self.scaler = xa.GradScaler(**scaler_kwargs)
def apply(self, loss: torch.Tensor, accumulate: bool = False):
self.scaler.scale(loss).backward(create_graph=self.create_graph)
if accumulate:
# unscale first?
return
xm.reduce_gradients(self.optimizer)
if self.clip_fn is not None:
self.scaler.unscale_(self.optimizer) # unscale the gradients of optimizer's assigned params in-place
self.clip_fn(self.clip_params_fn(), self.clip_value)
self.scaler.step(self.optimizer)
self.scaler.update()
xm.mark_step()
self.reset()

@ -23,24 +23,20 @@ class Fetcher:
re_count=1, re_count=1,
re_num_splits=0): re_num_splits=0):
self.loader = loader self.loader = loader
self.mean = torch.tensor([x * 255 for x in mean]).view(1, 3, 1, 1)
self.std = torch.tensor([x * 255 for x in std]).view(1, 3, 1, 1)
self.device = torch.device(device) self.device = torch.device(device)
self.dtype = dtype or torch.float32 self.dtype = dtype or torch.float32
if device: self.mean = torch.tensor([x * 255 for x in mean], dtype=self.dtype, device=self.device).view(1, 3, 1, 1)
self.mean.to(device=device, dtype=self.dtype) self.std = torch.tensor([x * 255 for x in std], dtype=self.dtype, device=self.device).view(1, 3, 1, 1)
self.std.to(device=device, dtype=self.dtype)
if re_prob > 0.: if re_prob > 0.:
self.random_erasing = RandomErasing( self.random_erasing = RandomErasing(
probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits) probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits, device=device)
else: else:
self.random_erasing = None self.random_erasing = None
def __iter__(self): def __iter__(self):
for sample, target in self.loader: for sample, target in self.loader:
sample = sample.to(device=self.device) sample = sample.to(device=self.device, dtype=self.dtype).sub_(self.mean).div_(self.std)
target = target.to(device=self.device) target = target.to(device=self.device)
sample = sample.to(dtype=self.dtype).sub_(self.mean).div_(self.std)
if self.random_erasing is not None: if self.random_erasing is not None:
sample = self.random_erasing(sample) sample = self.random_erasing(sample)
yield sample, target yield sample, target

@ -8,7 +8,7 @@ Hacked together by / Copyright 2020 Ross Wightman
import torch.utils.data import torch.utils.data
from timm.bits import get_device from timm.bits import get_device, DeviceEnvType
from .fetcher import Fetcher from .fetcher import Fetcher
from .prefetcher_cuda import PrefetcherCuda from .prefetcher_cuda import PrefetcherCuda
@ -78,7 +78,7 @@ def create_loader(
dev_env = get_device() dev_env = get_device()
sampler = None sampler = None
if dev_env.is_distributed and not isinstance(dataset, torch.utils.data.IterableDataset): if dev_env.distributed and not isinstance(dataset, torch.utils.data.IterableDataset):
if is_training: if is_training:
sampler = torch.utils.data.distributed.DistributedSampler( sampler = torch.utils.data.distributed.DistributedSampler(
dataset, num_replicas=dev_env.world_size, rank=dev_env.global_rank) dataset, num_replicas=dev_env.world_size, rank=dev_env.global_rank)
@ -117,7 +117,7 @@ def create_loader(
re_count=re_count, re_count=re_count,
re_num_splits=re_num_splits re_num_splits=re_num_splits
) )
if dev_env.type == 'cuda': if dev_env.type_cuda:
loader = PrefetcherCuda(loader, **fetcher_kwargs) loader = PrefetcherCuda(loader, **fetcher_kwargs)
else: else:
loader = Fetcher(loader, device=dev_env.device, **fetcher_kwargs) loader = Fetcher(loader, device=dev_env.device, **fetcher_kwargs)

@ -82,7 +82,7 @@ class ParserTfds(Parser):
self.dist_num_replicas = 1 self.dist_num_replicas = 1
dev_env = get_device() dev_env = get_device()
# FIXME allow to work without devenv usage? # FIXME allow to work without devenv usage?
if dev_env.is_distributed and dev_env.world_size > 1: if dev_env.distributed and dev_env.world_size > 1:
self.dist_rank = dev_env.global_rank self.dist_rank = dev_env.global_rank
self.dist_num_replicas = dev_env.world_size self.dist_num_replicas = dev_env.world_size
# if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1: # if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1:
@ -150,8 +150,10 @@ class ParserTfds(Parser):
ds = self.builder.as_dataset( ds = self.builder.as_dataset(
split=self.subsplit or self.split, shuffle_files=self.shuffle, read_config=read_config) split=self.subsplit or self.split, shuffle_files=self.shuffle, read_config=read_config)
# avoid overloading threading w/ combo fo TF ds threads + PyTorch workers # avoid overloading threading w/ combo fo TF ds threads + PyTorch workers
ds.options().experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers) options = tf.data.Options()
ds.options().experimental_threading.max_intra_op_parallelism = 1 options.experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers)
options.experimental_threading.max_intra_op_parallelism = 1
ds = ds.with_options(options)
if self.is_training or self.repeats > 1: if self.is_training or self.repeats > 1:
# to prevent excessive drop_last batch behaviour w/ IterableDatasets # to prevent excessive drop_last batch behaviour w/ IterableDatasets
# see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading # see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading

@ -1,4 +0,0 @@
from .accuracy import Accuracy, AccuracyTopK
from .precision_recall import PrecisionRecall
from .scalar_avg import ScalarAvgMinMax
from .tensor_avg import TensorAvg, TensorEma

@ -1,114 +0,0 @@
import torch
from typing import Optional, Tuple, Dict
class Accuracy(torch.nn.Module):
def __init__(self, threshold=0.5, multi_label=False):
self.threshold = threshold
self.eps = 1e-8
self.multi_label = multi_label
# statistics / counts
self._correct_sum = torch.tensor(0, dtype=torch.long)
self._total_sum = torch.tensor(0, dtype=torch.long)
def update(self, predictions, target):
raise NotImplemented()
def reset(self):
self._correct_sum = 0
self._total_sum = 0
@property
def counts(self):
pass
def compute(self):
raise NotImplemented()
class AccuracyTopK(torch.nn.Module):
def __init__(self, topk=(1, 5), device=None):
super().__init__()
self.eps = 1e-8
self.device = device
self.topk = topk
self.maxk = max(topk)
# FIXME handle distributed operation
# statistics / counts
self.reset()
def update(self, predictions: torch.Tensor, target: torch.Tensor):
sorted_indices = predictions.topk(self.maxk, dim=1)[1]
sorted_indices.t_()
correct = sorted_indices.eq(target.reshape(1, -1).expand_as(sorted_indices))
batch_size = target.shape[0]
correct_k = {k: correct[:k].reshape(-1).float().sum(0) for k in self.topk}
for k, v in correct_k.items():
attr = f'_correct_top{k}'
old_v = getattr(self, attr)
setattr(self, attr, old_v + v)
self._total_sum += batch_size
def reset(self):
for k in self.topk:
setattr(self, f'_correct_top{k}', torch.tensor(0, dtype=torch.float32))
self._total_sum = torch.tensor(0, dtype=torch.float32)
@property
def counts(self):
pass
def compute(self) -> Dict[str, torch.Tensor]:
# FIXME handle distributed reduction
return {f'top{k}': 100 * getattr(self, f'_correct_top{k}') / self._total_sum for k in self.topk}
#
# class AccuracyTopK:
#
# def __init__(self, topk=(1, 5), device=None):
# self.eps = 1e-8
# self.device = device
# self.topk = topk
# self.maxk = max(topk)
#
# # statistics / counts
# self._correct_sum = None
# self._total_sum = None
#
# def _check_init(self, device):
# to_device = self.device if self.device else device
# if self._correct_sum is None:
# self._correct_sum = {f'top{k}': torch.tensor(0., device=to_device) for k in self.topk}
# if self._total_sum is None:
# self._total_sum = torch.tensor(0, dtype=torch.long, device=to_device)
#
# def update(self, predictions: torch.Tensor, target: torch.Tensor):
# sorted_indices = predictions.topk(self.maxk, dim=1)[1]
# sorted_indices.t_()
# correct = sorted_indices.eq(target.reshape(1, -1).expand_as(sorted_indices))
#
# batch_size = target.shape[0]
# correct_k = {f'top{k}': correct[:k].reshape(-1).float().sum(0) for k in self.topk}
# self._check_init(device=predictions.device)
# for k, v in correct_k.items():
# old_v = self._correct_sum[k]
# self._correct_sum[k] = old_v + v
# self._total_sum += batch_size
#
# def reset(self):
# self._correct_sum = None
# self._total_sum = None
#
# @property
# def counts(self):
# pass
#
# def compute(self) -> Dict[str, torch.Tensor]:
# assert self._correct_sum is not None and self._total_sum is not None
# return {k: 100 * v / self._total_sum for k, v in self._correct_sum.items()}

@ -3,3 +3,4 @@ from .plateau_lr import PlateauLRScheduler
from .step_lr import StepLRScheduler from .step_lr import StepLRScheduler
from .tanh_lr import TanhLRScheduler from .tanh_lr import TanhLRScheduler
from .scheduler_factory import create_scheduler from .scheduler_factory import create_scheduler
from .scheduler import Scheduler

@ -108,7 +108,8 @@ class CheckpointSaver:
save_state['arch'] = self.args.model save_state['arch'] = self.args.model
save_state['args'] = self.args save_state['args'] = self.args
if self.amp_scaler is not None: if self.amp_scaler is not None:
save_state[self.amp_scaler.state_dict_key] = self.amp_scaler.state_dict() amp_key = getattr(self.amp_scaler, 'state_dict_key', 'amp_scaler')
save_state[amp_key] = self.amp_scaler.state_dict()
if self.model_ema is not None: if self.model_ema is not None:
save_state['state_dict_ema'] = get_state_dict(self.model_ema, self.unwrap_fn) save_state['state_dict_ema'] = get_state_dict(self.model_ema, self.unwrap_fn)
if metric is not None: if metric is not None:

@ -3,7 +3,11 @@ import torch
from timm.utils.agc import adaptive_clip_grad from timm.utils.agc import adaptive_clip_grad
def dispatch_clip_grad(parameters, value: float, mode: str = 'norm', norm_type: float = 2.0): def dispatch_clip_grad(
parameters,
value: float,
mode: str = 'norm',
norm_type: float = 2.0):
""" Dispatch to gradient clipping method """ Dispatch to gradient clipping method
Args: Args:

@ -21,8 +21,8 @@ def distribute_bn(model, world_size, reduce=False):
if ('running_mean' in bn_name) or ('running_var' in bn_name): if ('running_mean' in bn_name) or ('running_var' in bn_name):
if reduce: if reduce:
# average bn stats across whole group # average bn stats across whole group
torch.distributed.all_reduce(bn_buf, op=dist.ReduceOp.SUM) torch.distributed.all_reduce_recursive(bn_buf, op=dist.ReduceOp.SUM)
bn_buf /= float(world_size) bn_buf /= float(world_size)
else: else:
# broadcast bn stats from rank 0 to whole group # broadcast bn stats from rank 0 to whole group
torch.distributed.broadcast(bn_buf, 0) torch.distributed.broadcast_recursive(bn_buf, 0)

@ -21,13 +21,15 @@ import os
import logging import logging
from collections import OrderedDict from collections import OrderedDict
from datetime import datetime from datetime import datetime
from dataclasses import replace
from typing import Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchvision.utils import torchvision.utils
from timm.bits import initialize_device, DeviceEnv, create_updater, Updater, Logger, Tracker from timm.bits import initialize_device, setup_model_and_optimizer, DeviceEnv, Logger, Tracker,\
from timm.metrics import TensorAvg, AccuracyTopK TrainState, TrainServices, TrainCfg, AccuracyTopK, AvgTensor
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint,\ from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint,\
convert_splitbn_model, model_parameters convert_splitbn_model, model_parameters
@ -276,7 +278,7 @@ def main():
args, args_text = _parse_args() args, args_text = _parse_args()
dev_env = initialize_device(amp=args.amp) dev_env = initialize_device(amp=args.amp)
if dev_env.is_distributed: if dev_env.distributed:
_logger.info('Training in distributed mode with multiple processes, 1 device per process. Process %d, total %d.' _logger.info('Training in distributed mode with multiple processes, 1 device per process. Process %d, total %d.'
% (dev_env.global_rank, dev_env.world_size)) % (dev_env.global_rank, dev_env.world_size))
else: else:
@ -284,6 +286,111 @@ def main():
random_seed(args.seed, dev_env.global_rank) random_seed(args.seed, dev_env.global_rank)
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
train_state, train_cfg = setup_train_task(args, dev_env, mixup_active)
data_config, loader_eval, loader_train = setup_data(args, dev_env, mixup_active)
# setup checkpoint saver
eval_metric = args.eval_metric
best_metric = None
best_epoch = None
saver = None
output_dir = None
if dev_env.primary:
if args.experiment:
exp_name = args.experiment
else:
exp_name = '-'.join([
datetime.now().strftime("%Y%m%d-%H%M%S"),
safe_model_name(args.model),
str(data_config['input_size'][-1])
])
output_dir = get_outdir(args.output if args.output else './output/train', exp_name)
decreasing = True if eval_metric == 'loss' else False
saver = CheckpointSaver( # TODO CheckpointSaverV2
model=train_state.model,
optimizer=train_state.updater.optimizer,
args=args,
model_ema=train_state.model_ema,
amp_scaler=train_state.updater.grad_scaler,
checkpoint_dir=output_dir,
recovery_dir=output_dir,
decreasing=decreasing,
max_history=args.checkpoint_hist)
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
f.write(args_text)
services = TrainServices(
logger=Logger(
output_dir=output_dir, python_logger=_logger, hparams=vars(args), output_enabled=dev_env.primary),
saver=saver,
)
try:
for epoch in range(train_state.epoch, train_cfg.num_epochs):
if dev_env.distributed and hasattr(loader_train.sampler, 'set_epoch'):
loader_train.sampler.set_epoch(epoch)
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
if loader_train.mixup_enabled:
loader_train.mixup_enabled = False
train_metrics = train_one_epoch(
dev_env=dev_env,
state=train_state,
services=services,
cfg=train_cfg,
loader=loader_train
)
if dev_env.distributed and args.dist_bn in ('broadcast', 'reduce'):
if dev_env.primary:
_logger.info("Distributing BatchNorm running means and vars")
distribute_bn(model, dev_env.world_size, args.dist_bn == 'reduce')
eval_metrics = evaluate(
train_state.model,
train_state.eval_loss,
loader_eval,
dev_env,
logger=services.logger)
if train_state.model_ema is not None and not args.model_ema_force_cpu:
if dev_env.distributed and args.dist_bn in ('broadcast', 'reduce'):
distribute_bn(train_state.model_ema, dev_env.world_size, args.dist_bn == 'reduce')
ema_eval_metrics = evaluate(
train_state.model_ema.module,
train_state.eval_loss,
loader_eval,
dev_env,
logger=services.logger,
phase_suffix='EMA')
eval_metrics = ema_eval_metrics
if train_state.lr_scheduler is not None:
# step LR for next epoch
train_state.lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
if services.logger is not None:
services.logger.write_summary(index=epoch, results=dict(train=train_metrics, eval=eval_metrics))
if saver is not None:
# save proper checkpoint with eval metric
save_metric = eval_metrics[eval_metric]
best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)
train_state = replace(train_state, epoch=epoch + 1)
except KeyboardInterrupt:
pass
if best_metric is not None:
_logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
def setup_train_task(args, dev_env: DeviceEnv, mixup_active: bool):
model = create_model( model = create_model(
args.model, args.model,
pretrained=args.pretrained, pretrained=args.pretrained,
@ -302,82 +409,69 @@ def main():
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly
if dev_env.is_master: if dev_env.primary:
_logger.info( _logger.info(
f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}') f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}')
data_config = resolve_data_config(vars(args), model=model, verbose=dev_env.is_master)
# setup augmentation batch splits for contrastive loss or split bn # setup augmentation batch splits for contrastive loss or split bn
num_aug_splits = 0 assert args.aug_splits == 0 or args.aug_splits > 1, 'A split of 1 makes no sense'
if args.aug_splits > 0:
assert args.aug_splits > 1, 'A split of 1 makes no sense'
num_aug_splits = args.aug_splits
# enable split bn (separate bn stats per batch-portion) # enable split bn (separate bn stats per batch-portion)
if args.split_bn: if args.split_bn:
assert num_aug_splits > 1 or args.resplit assert args.aug_splits > 1 or args.resplit
model = convert_splitbn_model(model, max(num_aug_splits, 2)) model = convert_splitbn_model(model, max(args.aug_splits, 2))
# move model to GPU, enable channels last layout if set train_state = setup_model_and_optimizer(
dev_env.to_device(model) dev_env=dev_env,
model=model,
# setup synchronized BatchNorm for distributed training optimizer=args.opt,
if dev_env.is_distributed and args.sync_bn: optimizer_cfg=optimizer_kwargs(cfg=args),
assert not args.split_bn clip_fn=args.clip_mode if args.clip_grad is not None else None,
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) clip_value=args.clip_grad,
if dev_env.is_master: model_ema=args.model_ema,
_logger.info( model_ema_decay=args.model_ema_decay,
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' use_syncbn=args.sync_bn,
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') )
if args.torchscript:
assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
model = torch.jit.script(model)
updater = create_updater(
create_optimizer_v2(model, **optimizer_kwargs(cfg=args)),
clip_value=args.clip_grad, clip_mode=args.clip_mode)
# optionally resume from a checkpoint
resume_epoch = None
if args.resume:
resume_epoch = resume_checkpoint(
model, args.resume,
optimizer=None if args.no_resume_opt else updater.optimizer,
loss_scaler=None if args.no_resume_opt else updater.scaler,
log_info=dev_env.is_master)
# setup exponential moving average of model weights, SWA could be used here too
model_ema = None
if args.model_ema:
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
model_ema = ModelEmaV2(
model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None)
if args.resume:
load_checkpoint(model_ema.module, args.resume, use_ema=True)
# setup distributed training
if dev_env.is_distributed:
if dev_env.is_master:
_logger.info("Distributing model.")
model = dev_env.wrap_distributed(model)
# NOTE: EMA model does not need to be wrapped by DDP
# setup learning rate schedule and starting epoch # setup learning rate schedule and starting epoch
lr_scheduler, num_epochs = create_scheduler(args, updater.optimizer) # FIXME move into updater?
start_epoch = 0 lr_scheduler, num_epochs = create_scheduler(args, train_state.updater.optimizer)
if args.start_epoch is not None: if lr_scheduler is not None and train_state.epoch > 0:
# a specified start_epoch will always override the resume epoch lr_scheduler.step(train_state.epoch)
start_epoch = args.start_epoch
elif resume_epoch is not None: # setup loss function
start_epoch = resume_epoch if args.jsd:
if lr_scheduler is not None and start_epoch > 0: assert args.aug_splits > 1 # JSD only valid with aug splits set
lr_scheduler.step(start_epoch) train_loss_fn = JsdCrossEntropy(num_splits=args.aug_splits, smoothing=args.smoothing)
elif mixup_active:
if dev_env.is_master: # smoothing is handled with mixup target transform
train_loss_fn = SoftTargetCrossEntropy()
elif args.smoothing:
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
else:
train_loss_fn = nn.CrossEntropyLoss()
eval_loss_fn = nn.CrossEntropyLoss()
dev_env.to_device(train_loss_fn, eval_loss_fn)
if dev_env.primary:
_logger.info('Scheduled epochs: {}'.format(num_epochs)) _logger.info('Scheduled epochs: {}'.format(num_epochs))
train_state = replace(
train_state,
lr_scheduler=lr_scheduler,
train_loss=train_loss_fn,
eval_loss=eval_loss_fn)
train_cfg = TrainCfg(
num_epochs=num_epochs,
log_interval=args.log_interval,
recovery_interval=args.recovery_interval)
return train_state, train_cfg
def setup_data(args, dev_env, mixup_active):
data_config = resolve_data_config(vars(args), model=model, verbose=dev_env.primary)
# create the train and eval datasets # create the train and eval datasets
dataset_train = create_dataset( dataset_train = create_dataset(
args.dataset, args.dataset,
@ -388,18 +482,17 @@ def main():
# setup mixup / cutmix # setup mixup / cutmix
collate_fn = None collate_fn = None
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
if mixup_active: if mixup_active:
mixup_args = dict( mixup_args = dict(
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
label_smoothing=args.smoothing, num_classes=args.num_classes) label_smoothing=args.smoothing, num_classes=args.num_classes)
assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) assert not args.aug_splits # collate conflict (need to support deinterleaving in collate mixup)
collate_fn = FastCollateMixup(**mixup_args) collate_fn = FastCollateMixup(**mixup_args)
# wrap dataset in AugMix helper # wrap dataset in AugMix helper
if num_aug_splits > 1: if args.aug_splits > 1:
dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) dataset_train = AugMixDataset(dataset_train, num_splits=args.aug_splits)
# create data loaders w/ augmentation pipeiine # create data loaders w/ augmentation pipeiine
train_interpolation = args.train_interpolation train_interpolation = args.train_interpolation
@ -421,7 +514,7 @@ def main():
vflip=args.vflip, vflip=args.vflip,
color_jitter=args.color_jitter, color_jitter=args.color_jitter,
auto_augment=args.aa, auto_augment=args.aa,
num_aug_splits=num_aug_splits, num_aug_splits=args.aug_splits,
interpolation=train_interpolation, interpolation=train_interpolation,
mean=data_config['mean'], mean=data_config['mean'],
std=data_config['std'], std=data_config['std'],
@ -443,169 +536,107 @@ def main():
crop_pct=data_config['crop_pct'], crop_pct=data_config['crop_pct'],
pin_memory=args.pin_mem, pin_memory=args.pin_mem,
) )
return data_config, loader_eval, loader_train
# setup loss function
if args.jsd:
assert num_aug_splits > 1 # JSD only valid with aug splits set
train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing)
elif mixup_active:
# smoothing is handled with mixup target transform
train_loss_fn = SoftTargetCrossEntropy()
elif args.smoothing:
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
else:
train_loss_fn = nn.CrossEntropyLoss()
validate_loss_fn = nn.CrossEntropyLoss()
dev_env.to_device(train_loss_fn, validate_loss_fn)
# setup checkpoint saver and eval metric tracking
eval_metric = args.eval_metric
best_metric = None
best_epoch = None
saver = None
output_dir = None
if dev_env.is_master:
if args.experiment:
exp_name = args.experiment
else:
exp_name = '-'.join([
datetime.now().strftime("%Y%m%d-%H%M%S"),
safe_model_name(args.model),
str(data_config['input_size'][-1])
])
output_dir = get_outdir(args.output if args.output else './output/train', exp_name)
decreasing = True if eval_metric == 'loss' else False
saver = CheckpointSaver(
model=model, optimizer=updater.optimizer, args=args, model_ema=model_ema, amp_scaler=updater.scaler,
checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist)
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
f.write(args_text)
logger = Logger(output_dir=output_dir, logger=_logger, hparams=vars(args))
try: def train_one_epoch(
for epoch in range(start_epoch, num_epochs): dev_env: DeviceEnv,
if dev_env.is_distributed and hasattr(loader_train.sampler, 'set_epoch'): state: TrainState,
loader_train.sampler.set_epoch(epoch) cfg: TrainCfg,
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: services: TrainServices,
if loader_train.mixup_enabled: loader,
loader_train.mixup_enabled = False ):
tracker = Tracker()
loss_meter = AvgTensor()
train_metrics = train_one_epoch( state.model.train()
epoch, model, loader_train, updater, train_loss_fn, dev_env, state.updater.reset() # zero-grad
lr_scheduler=lr_scheduler, saver=saver, logger=logger, model_ema=model_ema,
log_interval=args.log_interval, recovery_interval=args.recovery_interval)
if dev_env.is_distributed and args.dist_bn in ('broadcast', 'reduce'): step_end_idx = len(loader) - 1
if dev_env.is_master: tracker.mark_iter()
_logger.info("Distributing BatchNorm running means and vars") for step_idx, (sample, target) in enumerate(loader):
distribute_bn(model, dev_env.world_size, args.dist_bn == 'reduce') tracker.mark_iter_data_end()
eval_metrics = evaluate(model, loader_eval, validate_loss_fn, dev_env, logger=logger) # FIXME move forward + loss into model 'task' wrapper
with dev_env.autocast():
output = state.model(sample)
loss = state.train_loss(output, target)
if model_ema is not None and not args.model_ema_force_cpu: state.updater.apply(loss)
if dev_env.is_distributed and args.dist_bn in ('broadcast', 'reduce'):
distribute_bn(model_ema, dev_env.world_size, args.dist_bn == 'reduce')
ema_eval_metrics = evaluate( tracker.mark_iter_step_end()
model_ema.module, loader_eval, validate_loss_fn, dev_env,
logger=logger, phase_suffix='EMA')
eval_metrics = ema_eval_metrics
if lr_scheduler is not None: state.updater.after_step(
# step LR for next epoch after_train_step,
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) dev_env,
state,
services,
cfg,
step_idx,
step_end_idx,
tracker,
loss_meter,
(output, target, loss),
)
if logger is not None: tracker.mark_iter()
logger.write_summary(index=epoch, results=dict(train=train_metrics, eval=eval_metrics)) # end for
if saver is not None: if hasattr(state.updater.optimizer, 'sync_lookahead'):
# save proper checkpoint with eval metric state.updater.optimizer.sync_lookahead()
save_metric = eval_metrics[eval_metric]
best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)
except KeyboardInterrupt: return OrderedDict([('loss', loss_meter.compute().item())])
pass
if best_metric is not None:
_logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
def train_one_epoch( def after_train_step(
epoch: int,
model: nn.Module,
loader,
updater: Updater,
loss_fn: nn.Module,
dev_env: DeviceEnv, dev_env: DeviceEnv,
lr_scheduler=None, state: TrainState,
saver: CheckpointSaver = None, services: TrainServices,
logger: Logger = None, cfg: TrainCfg,
model_ema: nn.Module = None, step_idx: int,
log_interval: int = 50, step_end_idx: int,
recovery_interval: int = 0, tracker: Tracker,
loss_meter: AvgTensor,
tensors: Tuple[torch.Tensor, ...],
): ):
tracker = Tracker() end_step = step_idx == step_end_idx
losses_m = TensorAvg()
model.train() with torch.no_grad():
output, target, loss = tensors
end_idx = len(loader) - 1 loss_meter.update(loss, output.shape[0])
num_updates = epoch * len(loader)
batch_size = 0
tracker.mark_iter()
for step_idx, (sample, target) in enumerate(loader):
tracker.mark_iter_data_end()
last_step = step_idx == end_idx
batch_size = max(batch_size, sample.shape[0])
with dev_env.autocast():
output = model(sample)
loss = loss_fn(output, target)
updater.reset()
updater.apply(loss)
dev_env.mark_step() # FIXME if state.model_ema is not None:
tracker.mark_iter_step_end() state.model_ema.update(model)
losses_m.update(loss, sample.size(0))
if model_ema is not None:
model_ema.update(model)
num_updates += 1 state = replace(state, step_count_global=state.step_count_global + 1)
if last_step or (step_idx + 1) % log_interval == 0:
lrl = [param_group['lr'] for param_group in updater.optimizer.param_groups]
lr = sum(lrl) / len(lrl)
if dev_env.is_master and logger is not None: if services.logger is not None and end_step or (step_idx + 1) % cfg.log_interval == 0:
loss_avg = losses_m.compute() global_batch_size = dev_env.world_size * output.shape[0]
logger.log_step( loss_avg = loss_meter.compute()
if services.logger is not None:
lr_avg = state.updater.get_average_lr()
services.logger.log_step(
'Train', 'Train',
step=step_idx, step=step_idx,
end_step=end_idx, step_end=step_end_idx,
epoch=state.epoch,
loss=loss_avg.item(), loss=loss_avg.item(),
rate=(dev_env.world_size * batch_size) / tracker.iter_time.avg, rate=tracker.get_avg_iter_rate(global_batch_size),
lr=lr, lr=lr_avg,
) )
if saver is not None and recovery_interval and (last_step or (step_idx + 1) % recovery_interval == 0): if services.saver is not None and cfg.recovery_interval and (
saver.save_recovery(epoch, batch_idx=step_idx) end_step or (step_idx + 1) % cfg.recovery_interval == 0):
services.saver.save_recovery(state.epoch, batch_idx=step_idx)
if lr_scheduler is not None:
lr_scheduler.step_update(num_updates=num_updates)
tracker.mark_iter() if state.lr_scheduler is not None:
# end for state.lr_scheduler.step_update(num_updates=state.step_count_global)
if hasattr(updater.optimizer, 'sync_lookahead'):
updater.optimizer.sync_lookahead()
return OrderedDict([('loss', losses_m.compute().item())])
def evaluate( def evaluate(
model: nn.Module, model: nn.Module,
loader,
loss_fn: nn.Module, loss_fn: nn.Module,
loader,
dev_env: DeviceEnv, dev_env: DeviceEnv,
logger: Logger, logger: Logger,
phase_suffix: str = '', phase_suffix: str = '',
@ -613,7 +644,7 @@ def evaluate(
): ):
tracker = Tracker() tracker = Tracker()
losses_m = TensorAvg() losses_m = AvgTensor()
accuracy_m = AccuracyTopK() accuracy_m = AccuracyTopK()
model.eval() model.eval()
@ -636,13 +667,13 @@ def evaluate(
losses_m.update(loss, output.size(0)) losses_m.update(loss, output.size(0))
accuracy_m.update(output, target) accuracy_m.update(output, target)
if dev_env.is_master and (last_step or step_idx % log_interval == 0): if last_step or step_idx % log_interval == 0:
top1, top5 = accuracy_m.compute().values() top1, top5 = accuracy_m.compute().values()
loss_avg = losses_m.compute() loss_avg = losses_m.compute()
logger.log_step( logger.log_step(
'Eval', 'Eval',
step=step_idx, step=step_idx,
num_steps=end_idx, step_end=end_idx,
loss=loss_avg.item(), loss=loss_avg.item(),
top1=top1.item(), top1=top1.item(),
top5=top5.item(), top5=top5.item(),

@ -18,8 +18,7 @@ import torch.nn as nn
import torch.nn.parallel import torch.nn.parallel
from collections import OrderedDict from collections import OrderedDict
from timm.bits import initialize_device, Tracker, Logger from timm.bits import initialize_device, Tracker, Logger, AccuracyTopK, AvgTensor
from timm.metrics import AccuracyTopK, TensorAvg
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet
from timm.utils import natural_key, setup_default_logging from timm.utils import natural_key, setup_default_logging
@ -155,10 +154,10 @@ def validate(args):
pin_memory=args.pin_mem, pin_memory=args.pin_mem,
tf_preprocessing=args.tf_preprocessing) tf_preprocessing=args.tf_preprocessing)
logger = Logger(logger=_logger) logger = Logger(python_logger=_logger)
tracker = Tracker() tracker = Tracker()
losses = TensorAvg() losses = AvgTensor()
accuracy = AccuracyTopK().to(dev_env.device) accuracy = AccuracyTopK(dev_env=dev_env)
model.eval() model.eval()
num_steps = len(loader) num_steps = len(loader)
@ -175,10 +174,8 @@ def validate(args):
output = output[:, valid_labels] output = output[:, valid_labels]
loss = criterion(output, target) loss = criterion(output, target)
if dev_env.type == 'cuda': if dev_env.type_cuda:
torch.cuda.synchronize() torch.cuda.synchronize()
#elif dev_env.type == 'xla':
# dev_env.mark_step()
tracker.mark_iter_step_end() tracker.mark_iter_step_end()
losses.update(loss.detach(), sample.size(0)) losses.update(loss.detach(), sample.size(0))
@ -186,7 +183,7 @@ def validate(args):
real_labels.add_result(output) real_labels.add_result(output)
accuracy.update(output.detach(), target) accuracy.update(output.detach(), target)
if dev_env.type == 'xla': if dev_env.type_xla:
dev_env.mark_step() dev_env.mark_step()
tracker.mark_iter() tracker.mark_iter()

Loading…
Cancel
Save