Major timm.bits update. Updater and DeviceEnv now dataclasses, after_step closure used, metrics base impl w/ distributed reduce, many tweaks/fixes.

pull/1239/head
Ross Wightman 4 years ago
parent 938716c753
commit aa92d7b1c5

@ -1,10 +1,25 @@
from .avg_scalar import AvgMinMaxScalar
from .avg_tensor import AvgTensor
from .device_env import DeviceEnv, DeviceEnvType
from .device_env_cuda import DeviceEnvCuda
from .device_env_factory import initialize_device, get_device
from .device_env import DeviceEnv
#from .evaluate import evaluate, eval_step
from .device_env_xla import DeviceEnvXla
from .distributed import distribute_bn, all_gather_recursive, all_reduce_recursive, broadcast_recursive,\
all_reduce_sequence, all_gather_sequence
# from .evaluate import evaluate, eval_step
from .logger import Logger
#from .task import TaskClassify
from .metric import Metric, MetricValue
from .metric_accuracy import AccuracyTopK
from .tracker import Tracker
# from .task_metrics import TaskMetrics, TaskMetricsClassify
from .train_cfg import TrainCfg
from .train_services import TrainServices
from .train_setup import setup_model_and_optimizer
from .train_state import TrainState
# from .task import TaskClassify
from .updater import Updater
from .updater_cuda import UpdaterCudaWithScaler
from .updater_deepspeed import UpdaterDeepSpeed
from .updater_factory import create_updater
from .tracker import Tracker
#from .task_metrics import TaskMetrics, TaskMetricsClassify
#from .train import train_one_epoch, TrainServices, TrainState, TrainCfg, Experiment
from .updater_xla import UpdaterXla, UpdaterXlaWithScaler
# from .train import train_one_epoch, Experiment

@ -1,4 +1,4 @@
class ScalarAvgMinMax:
class AvgMinMaxScalar:
"""Computes and stores the average and current value"""
def __init__(self):

@ -1,7 +1,7 @@
import torch
class TensorAvg:
class AvgTensor:
"""Computes and stores the average and current value"""
def __init__(self):

@ -0,0 +1,58 @@
import logging
import os
from collections import OrderedDict
import torch
from .train_state import TrainState, serialize_train_state, deserialize_train_state
_logger = logging.getLogger(__name__)
def resume_train_checkpoint(
train_state,
checkpoint_path,
resume_opt=True,
deserialize_fn=deserialize_train_state,
log_info=True):
raise NotImplementedError
# resume_epoch = None
# if os.path.isfile(checkpoint_path):
# checkpoint = torch.load(checkpoint_path, map_location='cpu')
#
# if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
# if log_info:
# _logger.info('Restoring model state from checkpoint...')
# new_state_dict = OrderedDict()
# for k, v in checkpoint['state_dict'].items():
# name = k[7:] if k.startswith('module') else k
# new_state_dict[name] = v
# model.load_state_dict(new_state_dict)
#
# if optimizer is not None and 'optimizer' in checkpoint:
# if log_info:
# _logger.info('Restoring optimizer state from checkpoint...')
# optimizer.load_state_dict(checkpoint['optimizer'])
#
# if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint:
# if log_info:
# _logger.info('Restoring AMP loss scaler state from checkpoint...')
# loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key])
#
# if 'epoch' in checkpoint:
# resume_epoch = checkpoint['epoch']
# if 'version' in checkpoint and checkpoint['version'] > 1:
# resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
#
# if log_info:
# _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
# else:
# model.load_state_dict(checkpoint)
# if log_info:
# _logger.info("Loaded checkpoint '{}'".format(checkpoint_path))
# return resume_epoch
# else:
# _logger.error("No checkpoint found at '{}'".format(checkpoint_path))
# raise FileNotFoundError()

@ -1,58 +1,130 @@
import torch
import abc
from contextlib import suppress
from enum import Enum
from typing import Callable, Union, Optional, List, Tuple
from dataclasses import dataclass, field, InitVar
import torch
import torch.distributed as dist
class DeviceEnv(abc.ABC):
TensorList = Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]
@property
@abc.abstractmethod
def device(self) -> torch.device:
pass
@property
@abc.abstractmethod
def local_rank(self) -> int:
pass
class DeviceEnvType(Enum):
""" Device Environment Types
"""
CPU = "cpu"
CUDA = "cuda"
XLA = "xla"
@dataclass
class DeviceEnv:
device_type: InitVar[Optional[str]] = None
device_index: InitVar[Optional[int]] = None
device: torch.device = field(init=False) # set from device_type + device_index or post_init logic
world_size: Optional[int] = None # set by post_init from env when None
local_rank: Optional[int] = None # set by post_init from env when None
global_rank: Optional[int] = None # set by post_init from env when None
amp: bool = False
autocast: Optional[Callable] = None # set by post_init from env when None
memory_format: Optional[torch.memory_format] = None
dtype: Optional[torch.dtype] = None
def __post_init__(self, device_type: Optional[str], device_index: Optional[int]):
device_type = device_type or 'cpu'
self.device = torch.device(device_type) if device_index is None \
else torch.device(device_type, device_index)
self.world_size = 1 if self.world_size is None else self.world_size
self.local_rank = 0 if self.local_rank is None else self.local_rank
self.global_rank = 0 if self.global_rank is None else self.global_rank
if self.autocast is None:
self.autocast = suppress
@property
@abc.abstractmethod
def global_rank(self) -> int:
pass
def type(self) -> DeviceEnvType:
if self.device.type == 'cpu':
return DeviceEnvType.CPU
elif self.device.type == 'cuda':
return DeviceEnvType.CUDA
elif self.device.type == 'xla':
return DeviceEnvType.XLA
else:
assert False, "Unexpected device type for base DevEnv impl."
@property
@abc.abstractmethod
def is_distributed(self) -> bool:
pass
def type_cuda(self):
# shortcut for common cuda device type
return self.type == DeviceEnvType.CUDA
@property
@abc.abstractmethod
def world_size(self) -> int:
pass
def type_xla(self):
# shortcut for common xla device type
return self.type == DeviceEnvType.XLA
@property
@abc.abstractmethod
def is_master(self) -> bool:
pass
def distributed(self):
return self.world_size > 1
@property
@abc.abstractmethod
def type(self) -> str:
pass
def primary(self):
return self.local_rank == 0
@property
@abc.abstractmethod
def autocast(self):
pass
def global_primary(self):
return self.global_rank == 0
@abc.abstractmethod
def wrap_distributed(self, *modules):
pass
@abc.abstractmethod
def to_device(self, *modules: torch.nn.Module):
def wrap_parallel(self, *modules):
pass
#@abc.abstractmethod
def to_device(self, *modules: torch.nn.Module):
# FIXME handling dtype / memformat... disable flags, enable flags, diff fn?
moved = [m.to(device=self.device, memory_format=self.memory_format) for m in modules]
return moved[0] if len(moved) == 1 else moved
def mark_step(self):
# FIXME this is for XLA only, make it common to all devices w/ appropriate no-ops?
pass
pass # NO-OP for non-XLA devices
def all_reduce_(self, tensor: TensorList, op=dist.ReduceOp.SUM, average=False):
print(len(tensor), type(tensor))
print(tensor.shape)
dist.all_reduce(tensor, op=op)
if average:
tensor.div_(self.world_size)
return tensor
def all_reduce(self, tensor: torch.Tensor, op=dist.ReduceOp.SUM, average=False):
reduce_tensor = tensor.clone()
dist.all_reduce(reduce_tensor, op=op)
if average:
reduce_tensor = reduce_tensor / self.world_size
return reduce_tensor
def all_gather(self, tensor: torch.Tensor, cat_dim=0):
output_tensors = [torch.empty_like(tensor) for _ in range(self.world_size)]
dist.all_gather(output_tensors, tensor)
return torch.cat(output_tensors, cat_dim)
def all_to_all(self, tensor: torch.Tensor, num_splits, split_dim, cat_dim=0):
input_tensors = torch.chunk(tensor, num_splits, split_dim)
output_tensors = [torch.empty_like(input_tensors[0]) for _ in range(self.world_size)]
dist.all_to_all(output_tensors, input_tensors)
return torch.cat(output_tensors, cat_dim)
def broadcast_(self, tensor: torch.Tensor, src_rank=0):
dist.broadcast(tensor, src=src_rank)
return tensor
def broadcast(self, tensor: Optional[torch.Tensor] = None, src_rank=0):
if self.global_rank != src_rank:
tensor = torch.empty_like(tensor)
assert tensor is not None
dist.broadcast(tensor, src=src_rank)
return tensor
def barrier(self):
dist.barrier()

@ -1,92 +1,58 @@
import os
from contextlib import suppress
from dataclasses import dataclass, field, InitVar
from typing import Optional
import torch
from torch.nn.parallel import DistributedDataParallel
from torch.nn.parallel import DistributedDataParallel, DataParallel
from .device_env import DeviceEnv
from .device_env import DeviceEnv, DeviceEnvType
def is_cuda_available():
return torch.cuda.is_available()
@dataclass
class DeviceEnvCuda(DeviceEnv):
def __init__(self, device_idx=None, local_rank=None, amp=False, memory_format=None):
def __post_init__(self, device_type: str, device_index: Optional[int]):
assert torch.cuda.device_count()
torch.backends.cudnn.benchmark = True
self._local_rank = 0
self._distributed = False
self._world_size = 1
self._global_rank = 0
if 'WORLD_SIZE' in os.environ:
self._distributed = int(os.environ['WORLD_SIZE']) > 1
if self._distributed:
if local_rank is None:
setup_world_size = self.world_size or int(os.environ.get('WORLD_SIZE', 1))
assert setup_world_size
if setup_world_size > 1:
# setup distributed
assert device_index is None
if self.local_rank is None:
lr = os.environ.get('LOCAL_RANK', None)
if lr is None:
raise RuntimeError(
'At least one of LOCAL_RANK env variable or local_rank arg must be set to valid integer.')
self._local_rank = lr
else:
self._local_rank = int(local_rank)
self._device = torch.device('cuda:%d' % self._local_rank)
torch.cuda.set_device(self._local_rank)
self.local_rank = int(lr)
self.device = torch.device('cuda:%d' % self.local_rank)
torch.cuda.set_device(self.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
self._world_size = torch.distributed.get_world_size()
self._global_rank = torch.distributed.get_rank()
self.world_size = torch.distributed.get_world_size()
assert self.world_size == setup_world_size
self.global_rank = torch.distributed.get_rank()
else:
self._device = torch.device('cuda' if device_idx is None else f'cuda:{device_idx}')
self._memory_format = memory_format
if amp:
self._amp = amp
self._autocast = torch.cuda.amp.autocast
else:
self._amp = amp
self._autocast = suppress
@property
def device(self):
return self._device
@property
def local_rank(self):
return self._local_rank
@property
def global_rank(self):
return self._global_rank
@property
def is_distributed(self):
return self._distributed
self.device = torch.device('cuda' if device_index is None else f'cuda:{device_index}')
self.local_rank = 0
self.world_size = 1
self.global_rank = 0
if self.autocast is None:
self.autocast = torch.cuda.amp.autocast if self.amp else suppress
@property
def world_size(self):
return self._world_size
@property
def is_master(self):
return self._local_rank == 0
@property
def type(self) -> str:
return 'cuda'
@property
def amp(self) -> bool:
return self._amp
@property
def autocast(self):
return self._autocast
def type(self) -> DeviceEnvType:
return DeviceEnvType.CUDA
def wrap_distributed(self, *modules, **kwargs):
wrapped = [DistributedDataParallel(m, device_ids=[self._local_rank], **kwargs) for m in modules]
wrapped = [DistributedDataParallel(m, device_ids=[self.local_rank], **kwargs) for m in modules]
return wrapped[0] if len(wrapped) == 1 else wrapped
def to_device(self, *modules: torch.nn.Module):
# FIXME handling dtype / memformat... disable flags, enable flags, diff fn?
moved = [m.to(device=self._device, memory_format=self._memory_format) for m in modules]
return moved[0] if len(moved) == 1 else moved
def wrap_parallel(self, *modules, **kwargs):
assert not self.distributed
wrapped = [DataParallel(m, **kwargs) for m in modules]
return wrapped[0] if len(wrapped) == 1 else wrapped

@ -1,10 +1,11 @@
from .device_env import DeviceEnv
from .device_env_cuda import DeviceEnvCuda, is_cuda_available
from .device_env_xla import DeviceEnvXla, is_xla_available
_device_env = None
def initialize_device(force_cpu: bool = False, xla_device_type=None, **kwargs):
def initialize_device(force_cpu: bool = False, **kwargs) -> DeviceEnv:
global _device_env
if _device_env is not None:
# warning
@ -12,21 +13,22 @@ def initialize_device(force_cpu: bool = False, xla_device_type=None, **kwargs):
denv = None
if not force_cpu:
xla_device_type = kwargs.get('xla_device_type', None)
if is_xla_available(xla_device_type):
# XLA supports more than just TPU, but by default will only look at TPU
denv = DeviceEnvXla(**kwargs, xla_device_type=xla_device_type)
# XLA supports more than just TPU, will search in order TPU, GPU, CPU
denv = DeviceEnvXla(**kwargs)
elif is_cuda_available():
denv = DeviceEnvCuda(**kwargs)
if denv is None:
# FIXME implement CPU support
raise NotImplementedError()
denv = DeviceEnv()
print(denv) # FIXME DEBUG
_device_env = denv
return denv
def get_device():
def get_device() -> DeviceEnv:
if _device_env is None:
raise RuntimeError('Please initialize device environment by calling initialize_device first.')
return _device_env

@ -1,6 +1,10 @@
import os
from contextlib import suppress
from dataclasses import dataclass, field, InitVar
from typing import Optional
import torch
from torch.distributed import ReduceOp
try:
import torch_xla.core.xla_model as xm
@ -15,78 +19,102 @@ try:
except ImportError as e:
xa = None
from .device_env import DeviceEnv
from .device_env import DeviceEnv, DeviceEnvType, TensorList
_PT_TO_XM_OP = {
ReduceOp.SUM: 'sum',
ReduceOp.PRODUCT: 'prod',
ReduceOp.MIN: 'min',
ReduceOp.MAX: 'max',
ReduceOp.BAND: 'and',
ReduceOp.BOR: 'or',
}
def is_xla_available(xla_device_type=None):
if not _HAS_XLA:
return False
supported_devs = xm.get_xla_supported_devices(devkind=xla_device_type)
print(supported_devs)
return len(supported_devs) >= 1
@dataclass
class DeviceEnvXla(DeviceEnv):
def __init__(self, xla_device_type=None, device_idx=None, local_rank=0, amp=False):
self._device = xm.xla_device(n=device_idx, devkind=xla_device_type)
self._local_rank = xm.get_local_ordinal(local_rank)
self._world_size = xm.xrt_world_size()
self._distributed = self._world_size > 1
self._global_rank = 0
if self._distributed:
self._global_rank = xm.get_ordinal()
if amp:
assert xa is not None, 'XLA AMP is not present on this build'
self._autocast = xa.autocast
def __post_init__(self, device_type: Optional[str], device_idx: Optional[int]):
if device_type is not None:
device_type = device_type.upper()
assert device_type in ('TPU', 'GPU', 'CPU'), "XLA device type must be one of ('TPU', 'GPU', 'CPU')"
self.device = xm.xla_device(n=device_idx, devkind=device_type)
self.world_size = xm.xrt_world_size()
if self.distributed:
assert device_idx is None, "device_index is based on local rank for distributed XLA mode"
self.local_rank = xm.get_local_ordinal()
self.global_rank = xm.get_ordinal()
else:
self._autocast = suppress
self._memory_format = None
@property
def device(self):
return self._device
@property
def local_rank(self):
return self._local_rank
@property
def global_rank(self):
return self._global_rank
@property
def is_distributed(self):
return self._distributed
@property
def world_size(self):
return self._world_size
@property
def is_master(self):
return self._global_rank == 0
@property
def type(self) -> str:
return 'xla'
@property
def amp(self) -> bool:
return False
self.local_rank = 0
self.global_rank = 0
if self.amp:
assert xa is not None, 'XLA AMP is not present on this build'
if self.autocast is None:
self.autocast = xa.autocast if self.amp else suppress
@property
def autocast(self):
return self._autocast
def type(self) -> DeviceEnvType:
return DeviceEnvType.XLA
def wrap_distributed(self, *modules):
# NO-OP
wrapped = [m for m in modules]
wrapped = [m for m in modules] # NO-OP
return wrapped[0] if len(wrapped) == 1 else wrapped
def to_device(self, *modules: torch.nn.Module):
moved = [m.to(device=self._device, memory_format=self._memory_format) for m in modules]
return moved[0] if len(moved) == 1 else moved
def wrap_parallel(self, *modules):
assert False, "Not implemented"
def mark_step(self):
xm.mark_step()
def all_reduce(self, tensor: torch.Tensor, op=ReduceOp.SUM, average=False):
assert isinstance(tensor, torch.Tensor) # unlike in-place variant, lists/tuples not allowed
op = _PT_TO_XM_OP[op]
scale = 1.0
if average:
scale /= self.world_size
return xm.all_reduce(op, tensor, scale=scale)
def all_reduce_(self, tensor: TensorList, op=ReduceOp.SUM, average=False):
op = _PT_TO_XM_OP[op]
scale = 1.0
wrapped = False
if isinstance(tensor, torch.Tensor):
tensor = [tensor] # bare tensors are not operated on in-place
wrapped = True
if average:
scale /= self.world_size
xm.all_reduce(op, tensor, scale=scale)
if wrapped:
tensor = tensor[0]
return tensor
def all_gather(self, tensor: torch.Tensor, cat_dim=0):
output = xm.all_gather(tensor, cat_dim)
return output
def all_to_all(self, tensor, num_splits, split_dim, cat_dim=0):
output = xm.all_to_all(tensor, split_dim, cat_dim, num_splits)
return output
def broadcast(self, tensor: torch.Tensor, src_rank=0):
if self.global_rank != src_rank:
reduce_tensor = torch.zeros_like(tensor)
xm.all_reduce('sum', reduce_tensor)
else:
xm.all_reduce('sum', tensor)
return tensor
def broadcast_(self, tensor: torch.Tensor, src_rank=0):
out_tensor = self.broadcast(tensor, src_rank)
return tensor.copy_(out_tensor)
def barrier(self):
xm.rendezvous('timm.bits.dist_barrier')

@ -0,0 +1,151 @@
from typing import Dict, Tuple, List, Union, Any, Callable
import torch
from torch.distributed import ReduceOp
from timm.utils import unwrap_model
from .device_env import DeviceEnv, DeviceEnvType
from .device_env_factory import get_device
TensorSeq = Union[torch.Tensor, Tuple[torch.Tensor, ...], List[torch.Tensor], Dict[Any, torch.Tensor]]
def _validate_type(tensor: TensorSeq):
if isinstance(tensor, (dict, list, tuple)):
if not tensor:
return
else:
assert isinstance(tensor, torch.Tensor)
def distribute_bn(model: torch.nn.Module, reduce: bool = False, dev_env: DeviceEnv = None):
if dev_env is None:
dev_env = get_device()
# ensure every node has the same running bn stats
for bn_name, bn_buf in unwrap_model(model).named_buffers(recurse=True):
if ('running_mean' in bn_name) or ('running_var' in bn_name):
if reduce:
# average bn stats across whole group
dev_env.all_reduce_(bn_buf, average=True)
else:
# broadcast bn stats from rank 0 to whole group
dev_env.broadcast_(bn_buf, 0)
def all_gather_recursive(tensor: TensorSeq, cat_dim=0, dev_env: DeviceEnv = None):
""" Recursive all gather via DeviceEnv distributed primitives
FIXME add group support
"""
_validate_type(tensor)
if dev_env is None:
dev_env = get_device()
if isinstance(tensor, torch.Tensor):
return dev_env.all_gather(tensor, cat_dim=cat_dim)
elif isinstance(tensor, dict):
return {k: all_gather_recursive(v, dev_env=dev_env) for k, v in tensor.items()}
elif isinstance(tensor, (tuple, list)):
return type(tensor)(all_gather_recursive(v, dev_env=dev_env) for v in tensor)
def all_reduce_recursive(tensor: TensorSeq, op=ReduceOp.SUM, average=False, dev_env: DeviceEnv = None):
""" Recursive all reduce via DeviceEnv distributed primitives
FIXME add group support
"""
_validate_type(tensor)
if dev_env is None:
dev_env = get_device()
if isinstance(tensor, torch.Tensor):
return dev_env.all_reduce_(tensor, op=op, average=average)
elif isinstance(tensor, dict):
return {k: all_reduce_recursive(v, op=op, average=average, dev_env=dev_env) for k, v in tensor.items()}
elif isinstance(tensor, (tuple, list)):
return type(tensor)(all_reduce_recursive(v, op=op, average=average, dev_env=dev_env) for v in tensor)
def broadcast_recursive(tensor: TensorSeq, src_rank: int, dev_env: DeviceEnv = None):
""" Recursive broadcast via DeviceEnv distributed primitives
FIXME add group support
"""
_validate_type(tensor)
if dev_env is None:
dev_env = get_device()
if isinstance(tensor, torch.Tensor):
return dev_env.broadcast_(tensor, src_rank=src_rank)
elif isinstance(tensor, dict):
return {k: broadcast_recursive(v, src_rank=src_rank, dev_env=dev_env) for k, v in tensor.items()}
elif isinstance(tensor, (tuple, list)):
return type(tensor)(broadcast_recursive(v, src_rank=src_rank, dev_env=dev_env) for v in tensor)
def all_gather_sequence(tensor: TensorSeq, cat_dim: int = 0, dev_env: DeviceEnv = None):
""" All gather a flat Tensor sequence (dict, list, tuple) of same shape
"""
_validate_type(tensor)
if dev_env is None:
dev_env = get_device()
with torch.no_grad():
names = None
# merge values into one tensor for reduction
if isinstance(tensor, dict):
names = tensor.keys()
gather_values = tuple(tensor.values())
elif isinstance(tensor, (tuple, list)):
gather_values = tensor
else:
gather_values = (tensor,)
gather_values = torch.stack(gather_values, dim=0)
gather_values = dev_env.all_gather(gather_values, cat_dim=cat_dim + 1).unbind(dim=0)
# separate reduced values into original structure
if isinstance(tensor, dict):
gather_values = {k: v for k, v in zip(names, gather_values)}
elif isinstance(tensor, (tuple, list)):
gather_values = type(tensor)(v for v in gather_values)
else:
gather_values = gather_values[0]
return gather_values
def all_reduce_sequence(tensor: TensorSeq, op=ReduceOp.SUM, average=False, dev_env: DeviceEnv = None):
"""
All reduce the tensors in a flat Tensor sequence (dict, list, tuple) of same tensor shape
Args:
tensor (dict): inputs to be reduced. All the values must be scalar Tensor.
average (bool): whether to do average or sum
Returns:
a sequence with the same type as input (dict, list, tuple)
"""
_validate_type(tensor)
if dev_env is None:
dev_env = get_device()
with torch.no_grad():
names = None
# merge values into one tensor for reduction
if isinstance(tensor, dict):
names = tensor.keys()
reduce_values = tuple(tensor.values())
elif isinstance(tensor, (tuple, list)):
reduce_values = tensor
else:
reduce_values = (tensor,)
reduce_values = torch.stack(reduce_values, dim=0)
dev_env.all_reduce_(reduce_values, op=op, average=average)
reduce_values = reduce_values.unbind(dim=0)
# separate reduced values into original structure
if isinstance(tensor, dict):
reduce_values = {k: v for k, v in zip(names, reduce_values)}
elif isinstance(tensor, (tuple, list)):
reduce_values = type(tensor)(v for v in reduce_values)
else:
reduce_values = reduce_values[0]
return reduce_values

@ -0,0 +1,190 @@
""" PyTorch distributed helpers
Some of this lifted from Detectron2 with other fns added by myself.
FIXME many functions remain unfinished/untested
"""
from typing import Dict, Tuple, List, Union, Any, Callable
import torch
import torch.distributed as dist
from torch.distributed import ReduceOp
TensorSeq = Union[torch.Tensor, Tuple[torch.Tensor, ...], List[torch.Tensor], Dict[Any, torch.Tensor]]
def synchronize_torch():
"""
Helper function to synchronize (barrier) among all processes when
using distributed training
"""
if not dist.is_available():
return
if not dist.is_initialized():
return
world_size = dist.get_world_size()
if world_size == 1:
return
dist.barrier()
def all_reduce_sequence_torch(values: TensorSeq, op=ReduceOp.SUM, average=False, group=None):
"""
All reduce the tensors in a sequence (dict, list, tuple)
Args:
values (dict): inputs to be reduced. All the values must be scalar Tensor.
average (bool): whether to do average or sum
Returns:
a sequence with the same type as input (dict, list, tuple)
"""
world_size = dist.get_world_size(group)
if world_size <= 1:
return values
with torch.no_grad():
names = None
if isinstance(values, dict):
names = values.keys()
reduce_values = torch.stack(tuple(values.values()), dim=0)
elif isinstance(values, (tuple, list)):
reduce_values = torch.stack(values, dim=0)
else:
reduce_values = values
dist.all_reduce(reduce_values, op=op, group=group)
if average:
reduce_values /= world_size
if isinstance(values, dict):
reduce_values = {k: v for k, v in zip(names, reduce_values)}
elif isinstance(values, (tuple, list)):
reduce_values = type(values)(v for v in reduce_values)
return reduce_values
def reduce_sequence_torch(values: TensorSeq, dst_rank=0, op=ReduceOp.SUM, average=False, group=None):
"""
All reduce the tensors in a sequence (dict, list, tuple)
Args:
values (dict): inputs to be reduced. All the values must be scalar Tensor.
average (bool): whether to do average or sum
Returns:
a sequence with the same type as input (dict, list, tuple)
"""
world_size = dist.get_world_size(group)
this_rank = dist.get_rank()
if world_size <= 1:
return values
with torch.no_grad():
names = None
if isinstance(values, dict):
names = values.keys()
reduce_values = torch.stack(tuple(values.values()), dim=0)
elif isinstance(values, (tuple, list)):
reduce_values = torch.stack(values, dim=0)
else:
reduce_values = values
reduce_values = torch.stack(reduce_values, dim=0)
dist.reduce(reduce_values, dst=dst_rank, op=op, group=group)
if average and this_rank == dst_rank:
reduce_values /= world_size
if isinstance(values, dict):
reduce_values = {k: v for k, v in zip(names, reduce_values)}
elif isinstance(values, (tuple, list)):
reduce_values = type(values)(v for v in reduce_values)
return reduce_values
def all_gather_sequence_torch(values: TensorSeq, group=None, join_fn=torch.cat, join_dim=0):
world_size = dist.get_world_size(group)
def _do_gather(tensor):
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
dist.all_gather(tensor_list, tensor, group=group)
return join_fn(tensor_list, dim=join_dim)
if isinstance(values, dict):
gathered = {k: _do_gather(v) for k, v in values.items()}
return gathered
elif isinstance(values, (list, tuple)):
gathered = type(values)(_do_gather(v) for v in values)
return gathered
else:
# if not a dict, list, tuple, expect a singular tensor
assert isinstance(values, torch.Tensor)
return _do_gather(values)
def gather_sequence_torch(values: TensorSeq, dst_rank, group=None, join_fn=torch.cat, join_dim=0):
world_size = dist.get_world_size(group)
this_rank = dist.get_rank(group)
def _do_gather(tensor):
tensor_list = None
if this_rank == dst_rank:
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
dist.gather(tensor, tensor_list, dst=dst_rank, group=group)
return join_fn(tensor_list, dim=join_dim)
if isinstance(values, dict):
gathered = {k: _do_gather(v) for k, v in values.items()}
return gathered
elif isinstance(values, (list, tuple)):
gathered = type(values)(_do_gather(v) for v in values)
return gathered
else:
# if not a dict, list, tuple, expect a singular tensor
assert isinstance(values, torch.Tensor)
return _do_gather(values)
def all_gather_torch(value: TensorSeq, group=None, join_fn: Callable = None, join_dim=0):
if isinstance(value, torch.Tensor):
world_size = dist.get_world_size(group)
out_tensors = [torch.empty_like(value) for _ in range(world_size)]
dist.all_gather(out_tensors, value, group=group)
if join_fn is not None:
out_tensors = join_fn(out_tensors, dim=join_dim)
return out_tensors
elif isinstance(value, dict):
return {k: all_gather_torch(v, group, join_fn, join_dim) for k, v in value.items()}
elif isinstance(value, (tuple, list)):
return type(value)(all_gather_torch(v, group, join_fn, join_dim) for v in value)
def gather_torch(value: TensorSeq, dst_rank=0, group=None, join_fn: Callable = None, join_dim=0):
if isinstance(value, torch.Tensor):
world_size = dist.get_world_size(group)
this_rank = dist.get_rank()
out_tensors = None
if this_rank == dst_rank:
out_tensors = [torch.empty_like(value) for _ in range(world_size)]
dist.gather(value, out_tensors, dst=dst_rank, group=group)
if join_fn is not None:
out_tensors = join_fn(out_tensors, dim=join_dim)
return out_tensors
elif isinstance(value, dict):
return {k: gather_torch(v, dst_rank, group, join_fn, join_dim) for k, v in value.items()}
elif isinstance(value, (tuple, list)):
return type(value)(gather_torch(v, dst_rank, group, join_fn, join_dim) for v in value)
def all_reduce_torch(value: TensorSeq, op=ReduceOp.SUM, average=False, group=None):
if isinstance(value, torch.Tensor):
dist.all_reduce(value, op=op, group=group)
if average:
value /= dist.get_world_size(group)
elif isinstance(value, dict):
return {k: all_reduce_torch(v, op=op, average=average, group=group) for k, v in value.items()}
elif isinstance(value, (tuple, list)):
return type(value)(all_reduce_torch(v, op=op, average=average, group=group) for v in value)
def broadcast_torch(value: TensorSeq, src_rank: int = 0, group=None):
if isinstance(value, torch.Tensor):
return dist.broadcast(value, src=src_rank, group=group)
elif isinstance(value, dict):
return {k: broadcast_torch(v, src_rank=src_rank, group=group) for k, v in value.items()}
elif isinstance(value, (tuple, list)):
return type(value)(broadcast_torch(v, src_rank=src_rank, group=group) for v in value)

@ -16,21 +16,11 @@ def get_clip_grad_fn(mode: str = 'norm', norm_type: float = 2.0):
assert False, f"Unknown clip mode ({mode})."
def get_clip_parameters(model):
def get_clip_parameters(model, skip_last=0):
if hasattr(model, 'get_clip_parameters'):
return model.get_clip_parameters()
else:
return model.parameters()
class GradClipper:
def __init__(self, model, clip_value, clip_mode='norm'):
self.model = model
self.clip_fn = get_clip_grad_fn(clip_mode)
self.clip_value = clip_value
self.enabled = True
def __call__(self):
if self.enabled:
self.clip_fn(get_clip_parameters(self.model), self.clip_value)
if skip_last:
return list(model.parameters())[::-skip_last]
else:
return model.parameters()

@ -21,6 +21,8 @@ except ImportError:
HAS_WANDB = False
from .device_env_factory import get_device
# FIXME old formatting for reference, to remove
#
# def log_eval(batch_idx, last_idx, batch_time, loss, top1, top5, log_suffix=''):
@ -84,10 +86,16 @@ class SummaryCsv:
dw.writerow(row_dict)
_sci_keys = {'lr'}
def _add_kwargs(text_update, name_map=None, **kwargs):
def _to_str(key, val):
if isinstance(val, float):
return f'{key}: {val:.4f}'
if key.lower() in _sci_keys:
return f'{key}: {val:.3e} '
else:
return f'{key}: {val:.4f}'
else:
return f'{key}: {val}'
@ -120,12 +128,13 @@ class Logger:
self,
experiment_name=None,
output_dir=None,
logger=None,
python_logger=None,
hparams=None,
log_wandb=False,
hparams=None):
self.output_dir = output_dir # for tensorboard, csv, console logging to file?
self.logger = logger or logging.getLogger('log')
output_enabled=True,
):
self.output_dir = output_dir # for tensorboard, csv, text file (TODO) logging
self.logger = python_logger or logging.getLogger('log')
hparams = hparams or {}
# Setup CSV writer(s)
@ -146,28 +155,32 @@ class Logger:
_logger.warning("You've requested to log metrics to wandb but package not found. "
"Metrics not being logged to wandb, try `pip install wandb`")
self.output_enabled = output_enabled
# FIXME image save
def log_step(
self,
phase: str,
step: int,
end_step: Optional[int] = None,
step_end: Optional[int] = None,
epoch: Optional[int] = None,
loss: Optional[float] = None,
rate: Optional[float] = None,
epoch: Optional[int] = None,
phase_suffix: str = '',
**kwargs,
):
""" log train/eval step
"""
phase_title = f'{phase.capitalize()} ({phase_suffix})' if phase_suffix else f'{phase.capitalize()}'
progress = 100. * step / end_step if end_step else 0.
if not self.output_enabled:
return
phase_title = f'{phase.capitalize()} ({phase_suffix})' if phase_suffix else f'{phase.capitalize()}:'
progress = 100. * step / step_end if step_end else 0.
text_update = [
phase_title,
f'Epoch: {epoch}' if epoch is not None else None,
f'Step: {step}' if end_step is None else None,
f'Step: [{step}/{end_step} ({progress:>3.0f}%)]' if end_step is not None else None,
f'{epoch}' if epoch is not None else None,
f'[{step}]' if step_end is None else None,
f'[{step}/{step_end} ({progress:>3.0f}%)]' if step_end is not None else None,
f'Rate: {rate:.2f}/s' if rate is not None else None,
f'Loss: {loss:.5f}' if loss is not None else None,
]
@ -187,6 +200,9 @@ class Logger:
):
"""log completion of evaluation or training phase
"""
if not self.output_enabled:
return
title = [
f'{phase.capitalize()}',
f'epoch: {epoch}' if epoch is not None else None,
@ -212,6 +228,9 @@ class Logger:
index: value for row index (typically epoch #)
index_name: name for row index header (typically 'epoch')
"""
if not self.output_enabled:
return
row_dict = summary_row_dict(index=index, index_name=index_name, results=results)
if self.csv_writer:
self.csv_writer.update(row_dict)

@ -0,0 +1,142 @@
import abc
from typing import Callable, Union, Optional, List, Tuple, Dict
from dataclasses import dataclass
import torch
from torch.distributed import ReduceOp
from .device_env import DeviceEnv
from .device_env_factory import get_device
from .distributed import all_gather_sequence, all_reduce_sequence
MetricValue = Union[float, torch.Tensor, List[float], List[torch.Tensor]]
@dataclass
class ValueInfo:
initial: Optional[MetricValue] = 0.
dtype: torch.dtype = torch.float32
dist_reduce: str = 'sum'
dist_average: bool = False
class Metric(abc.ABC):
def __init__(self, dev_env: DeviceEnv = None):
self._infos: Dict[str, ValueInfo] = {}
self._values: Dict[str, Optional[MetricValue]] = {}
self._values_dist: Dict[str, Optional[MetricValue]] = {}
if dev_env is None:
dev_env = get_device()
self._dev_env = dev_env
def _register_value(self, name: str, info: Optional[ValueInfo] = None):
info = info or ValueInfo()
self._infos[name] = info
# def get_value(self, name: str, use_dist=True):
# if use_dist:
# return self._values_dist.get(name, self._values.get(name))
# else:
# return self._values.get(name)
def __getattr__(self, item):
if item not in self._infos:
raise AttributeError
value = self._values_dist.get(item, self._values.get(item, None))
return value
def __setattr__(self, key, value):
if '_infos' in self.__dict__ and key in self._infos:
self._values[key] = value
else:
super().__setattr__(key, value)
def update(
self,
predictions: Union[torch.Tensor, Dict[str, torch.Tensor]],
target: Union[torch.Tensor, Dict[str, torch.Tensor]]):
self._update(predictions, target)
def _update(
self,
predictions: Union[torch.Tensor, Dict[str, torch.Tensor]],
target: Union[torch.Tensor, Dict[str, torch.Tensor]]):
pass
def reset(self):
self._values = {}
self._values_dist = {}
for name, info in self._infos.items():
# if info specifies an initial value, we reset here, otherwise set to None and leave it to child class
if info.initial is not None:
if isinstance(info.initial, torch.Tensor):
tensor = info.initial.detach().clone()
else:
tensor = torch.ones([], dtype=info.dtype) * info.initial # scalar
self._values[name] = tensor.to(device=self._dev_env.device, dtype=info.dtype)
else:
self._values[name] = None
self._reset()
def _reset(self):
pass
def compute(self) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[str, torch.Tensor]]:
if self._dev_env.distributed:
self._distribute_values()
results = self._compute()
self._values_dist = {}
return results
@abc.abstractmethod
def _compute(self) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[str, torch.Tensor]]:
pass
def _distribute_values(self):
if not self._infos or not self._values:
return
def _args(op: str):
if op == 'cat':
return True, dict(cat_dim=0)
else:
return False, dict(op=ReduceOp.SUM)
prev_dsr = None
same_dsr = True
names = []
values = []
reductions = []
for name, value in self._values.items():
if value is not None:
info = self._infos[name]
dsr = (value.dtype, value.shape, info.dist_reduce)
if prev_dsr is not None and prev_dsr != dsr:
same_dsr = False
prev_dsr = dsr
names.append(name)
values.append(value)
reductions.append(_args(info.dist_reduce))
same_dsr = False
if same_dsr:
do_gather, reduce_kwargs = reductions[0]
if do_gather:
reduced_values = all_gather_sequence(values, dev_env=self._dev_env, **reduce_kwargs)
else:
reduced_values = all_reduce_sequence(values, dev_env=self._dev_env, **reduce_kwargs)
for name, reduced_value in zip(names, reduced_values):
info = self._infos[name]
if info.dist_average:
reduced_value /= self._dev_env.world_size
self._values_dist[name] = reduced_value
else:
for n, v, r in zip(names, values, reductions):
info = self._infos[n]
do_gather, reduce_kwargs = r
if do_gather:
reduced_value = self._dev_env.all_gather(v, **reduce_kwargs)
else:
reduced_value = self._dev_env.all_reduce(v, **reduce_kwargs)
if info.dist_average:
reduced_value /= self._dev_env.world_size
self._values_dist[n] = reduced_value

@ -0,0 +1,98 @@
import torch
from typing import Optional, Tuple, Dict
from .device_env import DeviceEnv
from .metric import Metric, ValueInfo
class Accuracy(Metric):
def __init__(self, threshold=0.5, multi_label=False, dev_env=None):
super().__init__(dev_env=dev_env)
self.threshold = threshold
self.eps = 1e-8
self.multi_label = multi_label
# statistics / counts
self._register_value('correct')
self._register_value('total')
def _update(self, predictions, target):
raise NotImplemented()
def _compute(self):
raise NotImplemented()
# class AccuracyTopK(torch.nn.Module):
#
# def __init__(self, topk=(1, 5), device=None):
# super().__init__()
# self.eps = 1e-8
# self.device = device
# self.topk = topk
# self.maxk = max(topk)
# # FIXME handle distributed operation
#
# # statistics / counts
# self.reset()
#
# def update(self, predictions: torch.Tensor, target: torch.Tensor):
# sorted_indices = predictions.topk(self.maxk, dim=1)[1]
# sorted_indices.t_()
# correct = sorted_indices.eq(target.reshape(1, -1).expand_as(sorted_indices))
#
# batch_size = target.shape[0]
# correct_k = {k: correct[:k].reshape(-1).float().sum(0) for k in self.topk}
# for k, v in correct_k.items():
# attr = f'_correct_top{k}'
# old_v = getattr(self, attr)
# setattr(self, attr, old_v + v)
# self._total_sum += batch_size
#
# def reset(self):
# for k in self.topk:
# setattr(self, f'_correct_top{k}', torch.tensor(0, dtype=torch.float32))
# self._total_sum = torch.tensor(0, dtype=torch.float32)
#
# @property
# def counts(self):
# pass
#
# def compute(self) -> Dict[str, torch.Tensor]:
# # FIXME handle distributed reduction
# return {f'top{k}': 100 * getattr(self, f'_correct_top{k}') / self._total_sum for k in self.topk}
class AccuracyTopK(Metric):
def __init__(self, topk=(1, 5), dev_env: DeviceEnv = None):
super().__init__(dev_env=dev_env)
self.eps = 1e-8
self.topk = topk
self.maxk = max(topk)
# statistics / counts
for k in self.topk:
self._register_value(f'top{k}')
self._register_value('total')
self.reset()
def _update(self, predictions: torch.Tensor, target: torch.Tensor):
batch_size = predictions.shape[0]
sorted_indices = predictions.topk(self.maxk, dim=1)[1]
target_reshape = target.reshape(-1, 1).expand_as(sorted_indices)
correct = sorted_indices.eq(target_reshape).float().sum(0)
for k in self.topk:
attr_name = f'top{k}'
correct_at_k = correct[:k].sum()
setattr(self, attr_name, getattr(self, attr_name) + correct_at_k)
self.total += batch_size
def _compute(self) -> Dict[str, torch.Tensor]:
assert self.total is not None
output = {}
for k in self.topk:
attr_name = f'top{k}'
output[attr_name] = 100 * getattr(self, attr_name) / self.total
return output

@ -1,16 +1,16 @@
import time
from typing import Optional
from timm.metrics import ScalarAvgMinMax
from .avg_scalar import AvgMinMaxScalar
class Tracker:
def __init__(self):
self.data_time = ScalarAvgMinMax() # time for data loader to produce batch of samples
self.step_time = ScalarAvgMinMax() # time for model step
self.iter_time = ScalarAvgMinMax() # full iteration time incl. data, step, and book-keeping
self.epoch_time = ScalarAvgMinMax()
self.data_time = AvgMinMaxScalar() # time for data loader to produce batch of samples
self.step_time = AvgMinMaxScalar() # time for model step
self.iter_time = AvgMinMaxScalar() # full iteration time incl. data, step, and book-keeping
self.epoch_time = AvgMinMaxScalar()
self.iter_timestamp: Optional[float] = None
self.prev_timestamp: Optional[float] = None
@ -48,3 +48,12 @@ class Tracker:
self.epoch_time.update(epoch_time)
self.epoch_timestamp = timestamp
def get_avg_iter_rate(self, num_per_iter: int):
if num_per_iter == 0 or self.iter_time.avg == 0:
return 0
return num_per_iter / self.iter_time.avg
def get_last_iter_rate(self, num_per_iter: int):
if num_per_iter == 0 or self.iter_time.val == 0:
return 0
return num_per_iter / self.iter_time.val

@ -0,0 +1,12 @@
from dataclasses import dataclass
@dataclass
class TrainCfg:
""" Train Loop Configuration
Dataclass to propagate training configuration values
"""
num_epochs: int = 0
log_interval: int = 50
recovery_interval: int = 0
accumulate_steps: int = 0

@ -0,0 +1,13 @@
from dataclasses import dataclass
from .logger import Logger
from timm.utils.checkpoint_saver import CheckpointSaver
@dataclass
class TrainServices:
""" Train Loop Services
"""
logger: Logger = None
saver: CheckpointSaver = None

@ -0,0 +1,153 @@
import dataclasses
from typing import Callable, Union, Optional
import logging
import torch
import torch.nn as nn
from timm.optim import create_optimizer_v2
from timm.utils import ModelEmaV2
try:
import deepspeed as ds
except ImportError:
ds = None
from .checkpoint import resume_train_checkpoint
from .device_env import DeviceEnv
from .train_cfg import TrainCfg
from .train_state import TrainState
from .updater_factory import create_updater
_logger = logging.getLogger(__name__)
def setup_model_and_optimizer(
dev_env: DeviceEnv,
model: nn.Module,
optimizer: Union[Callable, str],
optimizer_cfg,
clip_fn: Optional[Union[Callable, str]] = None,
clip_value: Optional[float] = None,
model_ema: bool = False,
model_ema_decay: float = 0.9999,
use_syncbn: bool = False,
resume_path: str = '',
resume_opt: bool = True,
deepspeed: bool = False,
):
"""
Args:
dev_env:
model:
optimizer:
optimizer_cfg:
clip_value:
clip_fn:
model_ema:
model_ema_decay:
use_syncbn:
resume_path:
resume_opt:
Returns:
"""
if deepspeed:
return setup_model_and_optimizer_deepspeed(
dev_env=dev_env, model=model, optimizer=optimizer, optimizer_cfg=optimizer_cfg,
clip_fn=clip_fn, clip_value=clip_value, model_ema=model_ema, model_ema_decay=model_ema_decay,
resume_path=resume_path, resume_opt=resume_opt,
)
dev_env.to_device(model)
if use_syncbn and dev_env.distributed:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if dev_env.primary:
_logger.info(
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
if isinstance(optimizer, Callable):
optimizer = optimizer(model=model, **optimizer_cfg)
else:
optimizer = create_optimizer_v2(model=model, **optimizer_cfg)
updater = create_updater(
model=model,
optimizer=optimizer,
clip_fn=clip_fn,
clip_value=clip_value,
)
# ema model
model_ema = ModelEmaV2(model, decay=model_ema_decay) if model_ema else None
train_state = TrainState(model=model, updater=updater, model_ema=model_ema)
if resume_path:
resume_train_checkpoint(
train_state,
resume_path,
resume_opt=resume_opt,
log_info=dev_env.primary)
if dev_env.distributed:
train_state = dataclasses.replace(
train_state, model=dev_env.wrap_distributed(train_state.model))
return train_state
def setup_model_and_optimizer_deepspeed(
dev_env: DeviceEnv,
model: nn.Module,
optimizer: Union[Callable, str],
optimizer_cfg,
clip_fn: Optional[Union[Callable, str]] = None,
clip_value: Optional[float] = None,
model_ema: bool = False,
model_ema_decay: float = 0.9999,
use_syncbn: bool = False,
resume_path: str = '',
resume_opt: bool = True,
):
dev_env.to_device(model)
if isinstance(optimizer, Callable):
optimizer = optimizer(model=model, **optimizer_cfg)
else:
optimizer = create_optimizer_v2(model=model, **optimizer_cfg)
model = ds.initialize(model=model, optimizer=optimizer, dist_init_required=False)
updater = create_updater(
model=model,
optimizer=optimizer,
clip_fn=clip_fn,
clip_value=clip_value,
deepspeed=True,
)
# ema model
# FIXME how to do EMA w/ deepspeed?
model_ema = ModelEmaV2(model, decay=model_ema_decay) if model_ema else None
train_state = TrainState(model=model, updater=updater, model_ema=model_ema)
if resume_path:
# FIXME deepspeed resumes differently
resume_train_checkpoint(
train_state,
resume_path,
resume_opt=resume_opt,
log_info=dev_env.primary)
if dev_env.distributed:
train_state = dataclasses.replace(
train_state, model=dev_env.wrap_distributed(train_state.model))
return train_state

@ -0,0 +1,33 @@
from dataclasses import dataclass
from typing import Dict, Any
from torch import nn as nn
from timm.scheduler import Scheduler
from .updater import Updater
@dataclass
class TrainState:
model: nn.Module = None
train_loss: nn.Module = None
eval_loss: nn.Module = None
updater: Updater = None
lr_scheduler: Scheduler = None
model_ema: nn.Module = None
step_count_epoch: int = 0
step_count_global: int = 0
epoch: int = 0
def __post_init__(self):
assert self.model is not None
assert self.updater is not None
def serialize_train_state(train_state: TrainState):
pass
def deserialize_train_state(checkpoint: Dict[str, Any]):
pass

@ -1,54 +1,68 @@
from dataclasses import dataclass, field, InitVar
from functools import partial
from typing import Callable, Optional, Union
import torch
import torch.nn as nn
from .grad_clipper import GradClipper
from .grad_clip import get_clip_grad_fn, get_clip_parameters
@dataclass
class Updater:
def __init__(
self,
optimizer: torch.optim.Optimizer,
clip_value: Optional[Union[Callable, float]] = None,
clip_mode: str = 'norm'):
self.optimizer = optimizer
self.clipper: Optional[GradClipper] = None
if clip_value is not None:
if isinstance(clip_value, Callable):
self.clipper = clip_value
model: nn.Module = None
optimizer: torch.optim.Optimizer = None # FIXME handle multiple optimizers per-model
clip_fn: Optional[Union[Callable, str]] = None
clip_value: Optional[float] = None
clip_params_fn: Optional[Callable] = None
grad_scaler: Optional[Callable] = None
create_graph: Optional[bool] = None
after_step_closure: bool = False
def __post_init__(self):
assert self.model is not None
assert self.optimizer is not None
if self.clip_fn is not None:
if isinstance(self.clip_fn, Callable):
skip_last = 0
else:
GradClipper(clip_value, clip_mode)
self.scaler = None
self.create_graph = getattr(self.optimizer, 'second_order', False)
self.num_accumulated = 0
assert isinstance(self.clip_fn, str)
skip_last = 2 if 'agc' in self.clip_fn else 0
self.clip_fn = get_clip_grad_fn(self.clip_fn)
assert self.clip_value is not None
self.clip_params_fn = partial(get_clip_parameters, model=self.model, skip_last=skip_last)
if self.create_graph is None:
self.create_graph = getattr(self.optimizer, 'second_order', False)
self.after_step_closure = False
def reset(self):
self.optimizer.zero_grad()
def apply(self, loss: torch.Tensor, accumulate=False):
loss.backward(create_graph=self.create_graph)
if self.clipper is not None:
self.clipper()
if not accumulate:
self.optimizer.step()
self.reset()
else:
self.num_accumulated += 1
if accumulate:
return
if self.clip_fn is not None:
self.clip_fn(self.clip_params_fn(), self.clip_value)
self.optimizer.step()
self.reset()
def reset(self):
self.optimizer.zero_grad()
self.num_accumulated = 0
def get_average_lr(self):
lrl = [param_group['lr'] for param_group in self.optimizer.param_groups if param_group['lr'] > 0]
return sum(lrl) / len(lrl)
def state_dict(self):
state_dict = dict(optimizer=self.optimizer.state_dict())
if self.scaler is not None:
state_dict['scaler'] = self.scaler.state_dict()
if self.grad_scaler is not None:
state_dict['grad_scaler'] = self.grad_scaler.state_dict()
def load_state_dict(self, state_dict):
if 'optimizer' in state_dict:
self.optimizer.load_state_dict(state_dict['optimizer'])
if 'scaler' in state_dict and self.scaler is not None:
self.scaler.load_state_dict(state_dict['scaler'])
if 'grad_scaler' in state_dict and self.grad_scaler is not None:
self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
def after_step(self, after_step_fn, *args):
after_step_fn(*args)

@ -1,36 +1,30 @@
from typing import Callable, Optional, Union, Any
from dataclasses import dataclass, field, InitVar
from typing import Dict, Any
import torch
from .updater import Updater
class UpdaterCuda(Updater):
def __init__(
self,
optimizer: torch.optim.Optimizer,
clip_value: Optional[Union[Callable, float]] = None,
clip_mode: str = 'norm',
use_scaler: bool = False,
scaler_kwargs: Any = None,
):
super().__init__(optimizer=optimizer, clip_value=clip_value, clip_mode=clip_mode)
@dataclass
class UpdaterCudaWithScaler(Updater):
scaler_kwargs: InitVar[Dict[str, Any]] = None
def __post_init__(self, scaler_kwargs: Dict[str, Any]):
super().__post_init__()
scaler_kwargs = scaler_kwargs or {}
if use_scaler:
self.scaler = torch.cuda.amp.GradScaler(**scaler_kwargs)
self.grad_scaler = torch.cuda.amp.GradScaler(**scaler_kwargs)
def apply(self, loss: torch.Tensor, accumulate=False):
if self.scaler is not None:
self.scaler.scale(loss).backward(create_graph=self.create_graph)
if self.clipper is not None:
self.scaler.unscale_(self.optimizer) # unscale the gradients of optimizer's assigned params in-place
self.clipper()
if not accumulate:
self.scaler.step(self.optimizer)
self.reset()
else:
self.num_accumulated += 1
self.scaler.update()
else:
Updater.apply(self, loss, accumulate)
self.grad_scaler.scale(loss).backward(create_graph=self.create_graph)
if accumulate:
# unscale first?
return
if self.clip_fn is not None:
# unscale the gradients of optimizer's assigned params in-place
self.grad_scaler.unscale_(self.optimizer)
self.clip_fn(self.clip_params_fn(), self.clip_value)
self.grad_scaler.step(self.optimizer)
self.grad_scaler.update()
self.reset()

@ -0,0 +1,26 @@
from dataclasses import dataclass, field, InitVar
import torch
try:
import deepspeed as ds
except ImportError as e:
ds = None
from .updater import Updater
@dataclass
class UpdaterDeepSpeed(Updater):
def __post_init__(self):
super().__post_init__()
# FIXME not sure how to deal with model.module / grad clipping w/ DS engine interface
assert isinstance(self.model, ds.DeepSpeedEngine)
def reset(self):
self.model.zero_grad()
def apply(self, loss: torch.Tensor, accumulate=False):
self.model.backward(loss)
self.model.step()
self.reset()

@ -2,29 +2,38 @@ from typing import Callable, Optional, Union, Any
import torch
from .device_env import DeviceEnv
from .device_env import DeviceEnv, DeviceEnvType
from .device_env_factory import get_device
from .updater import Updater
from .updater_cuda import UpdaterCuda
from .updater_xla import UpdaterXla
from .updater_cuda import UpdaterCudaWithScaler
from .updater_deepspeed import UpdaterDeepSpeed
from .updater_xla import UpdaterXla, UpdaterXlaWithScaler
def create_updater(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
clip_fn: Optional[Union[Callable, str]] = None,
clip_value: Optional[float] = None,
scaler_kwargs: Any = None,
dev_env: Optional[DeviceEnv] = None,
clip_value: Optional[Union[Callable, float]] = None,
clip_mode: str = 'norm',
scaler_kwargs: Any = None) -> Updater:
deepspeed: bool = False,
) -> Updater:
if not dev_env:
dev_env = get_device()
updater_kwargs = dict(
optimizer=optimizer, clip_value=clip_value, clip_mode=clip_mode, scaler_kwargs=scaler_kwargs)
if dev_env.type == 'xla':
return UpdaterXla(**updater_kwargs, use_scaler=dev_env.amp)
elif dev_env.type == 'cuda':
return UpdaterCuda(**updater_kwargs, use_scaler=dev_env.amp)
else:
updater_kwargs.pop('scaler_kwargs', None)
return Updater(**updater_kwargs)
updater_kwargs = dict(model=model, optimizer=optimizer, clip_fn=clip_fn, clip_value=clip_value)
use_scaler = dev_env.amp
if use_scaler:
updater_kwargs['scaler_kwargs'] = scaler_kwargs
updater_cls = Updater
if dev_env.type == DeviceEnvType.XLA:
updater_cls = UpdaterXlaWithScaler if use_scaler else UpdaterXla
elif dev_env.type == DeviceEnvType.CUDA and use_scaler:
updater_cls = UpdaterCudaWithScaler
elif deepspeed:
del updater_kwargs['scaler_kwargs']
updater_cls = UpdaterDeepSpeed
return updater_cls(**updater_kwargs)

@ -1,6 +1,8 @@
from typing import Callable, Optional, Union, Any
from dataclasses import dataclass, field, InitVar
from typing import Any, Dict
import torch
import torch.nn as nn
try:
import torch_xla.core.xla_model as xm
@ -18,41 +20,49 @@ except ImportError as e:
from .updater import Updater
@dataclass
class UpdaterXla(Updater):
def __init__(
self,
optimizer: torch.optim.Optimizer,
clip_value: Optional[Union[Callable, float]] = None,
clip_mode: str = 'norm',
use_scaler: bool = False,
scaler_kwargs: Any = None,
):
super().__init__(optimizer=optimizer, clip_value=clip_value, clip_mode=clip_mode)
def __post_init__(self):
super().__post_init__()
self.after_step_closure = True
if use_scaler:
assert xa is not None, 'XLA AMP not present in this build'
self.scaler = xa.GradScaler(**scaler_kwargs)
def apply(self, loss: torch.Tensor, accumulate: bool = False):
if self.scaler is None:
loss.backward(create_graph=self.create_graph)
gradients = xm._fetch_gradients(self.optimizer)
xm.all_reduce('sum', gradients, scale=1.0 / xm.xrt_world_size())
if self.clipper is not None:
self.clipper()
if not accumulate:
xm.optimizer_step(self.optimizer)
else:
self.scaler.scale(loss).backward(create_graph=self.create_graph)
if self.clipper is not None:
self.scaler.unscale_(self.optimizer) # unscale the gradients of optimizer's assigned params in-place
self.clipper()
if not accumulate:
self.scaler.step(self.optimizer)
self.reset()
self.scaler.update()
loss.backward(create_graph=self.create_graph)
if accumulate:
return
xm.reduce_gradients(self.optimizer)
if self.clip_fn is not None:
self.clip_fn(self.clip_params_fn(), self.clip_value)
self.optimizer.step()
xm.mark_step()
self.reset()
def after_step(self, after_step_fn, *args):
xm.add_step_closure(after_step_fn, *args)
xm.add_step_closure(after_step_fn, args)
@dataclass
class UpdaterXlaWithScaler(UpdaterXla):
scaler_kwargs: InitVar[Dict[str, Any]] = None
def __post_init__(self, scaler_kwargs: Dict[str, Any]):
super().__post_init__()
scaler_kwargs = scaler_kwargs or {}
assert xa is not None, 'XLA AMP not present in this build'
self.scaler = xa.GradScaler(**scaler_kwargs)
def apply(self, loss: torch.Tensor, accumulate: bool = False):
self.scaler.scale(loss).backward(create_graph=self.create_graph)
if accumulate:
# unscale first?
return
xm.reduce_gradients(self.optimizer)
if self.clip_fn is not None:
self.scaler.unscale_(self.optimizer) # unscale the gradients of optimizer's assigned params in-place
self.clip_fn(self.clip_params_fn(), self.clip_value)
self.scaler.step(self.optimizer)
self.scaler.update()
xm.mark_step()
self.reset()

@ -23,24 +23,20 @@ class Fetcher:
re_count=1,
re_num_splits=0):
self.loader = loader
self.mean = torch.tensor([x * 255 for x in mean]).view(1, 3, 1, 1)
self.std = torch.tensor([x * 255 for x in std]).view(1, 3, 1, 1)
self.device = torch.device(device)
self.dtype = dtype or torch.float32
if device:
self.mean.to(device=device, dtype=self.dtype)
self.std.to(device=device, dtype=self.dtype)
self.mean = torch.tensor([x * 255 for x in mean], dtype=self.dtype, device=self.device).view(1, 3, 1, 1)
self.std = torch.tensor([x * 255 for x in std], dtype=self.dtype, device=self.device).view(1, 3, 1, 1)
if re_prob > 0.:
self.random_erasing = RandomErasing(
probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits)
probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits, device=device)
else:
self.random_erasing = None
def __iter__(self):
for sample, target in self.loader:
sample = sample.to(device=self.device)
sample = sample.to(device=self.device, dtype=self.dtype).sub_(self.mean).div_(self.std)
target = target.to(device=self.device)
sample = sample.to(dtype=self.dtype).sub_(self.mean).div_(self.std)
if self.random_erasing is not None:
sample = self.random_erasing(sample)
yield sample, target

@ -8,7 +8,7 @@ Hacked together by / Copyright 2020 Ross Wightman
import torch.utils.data
from timm.bits import get_device
from timm.bits import get_device, DeviceEnvType
from .fetcher import Fetcher
from .prefetcher_cuda import PrefetcherCuda
@ -78,7 +78,7 @@ def create_loader(
dev_env = get_device()
sampler = None
if dev_env.is_distributed and not isinstance(dataset, torch.utils.data.IterableDataset):
if dev_env.distributed and not isinstance(dataset, torch.utils.data.IterableDataset):
if is_training:
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, num_replicas=dev_env.world_size, rank=dev_env.global_rank)
@ -117,7 +117,7 @@ def create_loader(
re_count=re_count,
re_num_splits=re_num_splits
)
if dev_env.type == 'cuda':
if dev_env.type_cuda:
loader = PrefetcherCuda(loader, **fetcher_kwargs)
else:
loader = Fetcher(loader, device=dev_env.device, **fetcher_kwargs)

@ -82,7 +82,7 @@ class ParserTfds(Parser):
self.dist_num_replicas = 1
dev_env = get_device()
# FIXME allow to work without devenv usage?
if dev_env.is_distributed and dev_env.world_size > 1:
if dev_env.distributed and dev_env.world_size > 1:
self.dist_rank = dev_env.global_rank
self.dist_num_replicas = dev_env.world_size
# if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1:
@ -150,8 +150,10 @@ class ParserTfds(Parser):
ds = self.builder.as_dataset(
split=self.subsplit or self.split, shuffle_files=self.shuffle, read_config=read_config)
# avoid overloading threading w/ combo fo TF ds threads + PyTorch workers
ds.options().experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers)
ds.options().experimental_threading.max_intra_op_parallelism = 1
options = tf.data.Options()
options.experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers)
options.experimental_threading.max_intra_op_parallelism = 1
ds = ds.with_options(options)
if self.is_training or self.repeats > 1:
# to prevent excessive drop_last batch behaviour w/ IterableDatasets
# see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading

@ -1,4 +0,0 @@
from .accuracy import Accuracy, AccuracyTopK
from .precision_recall import PrecisionRecall
from .scalar_avg import ScalarAvgMinMax
from .tensor_avg import TensorAvg, TensorEma

@ -1,114 +0,0 @@
import torch
from typing import Optional, Tuple, Dict
class Accuracy(torch.nn.Module):
def __init__(self, threshold=0.5, multi_label=False):
self.threshold = threshold
self.eps = 1e-8
self.multi_label = multi_label
# statistics / counts
self._correct_sum = torch.tensor(0, dtype=torch.long)
self._total_sum = torch.tensor(0, dtype=torch.long)
def update(self, predictions, target):
raise NotImplemented()
def reset(self):
self._correct_sum = 0
self._total_sum = 0
@property
def counts(self):
pass
def compute(self):
raise NotImplemented()
class AccuracyTopK(torch.nn.Module):
def __init__(self, topk=(1, 5), device=None):
super().__init__()
self.eps = 1e-8
self.device = device
self.topk = topk
self.maxk = max(topk)
# FIXME handle distributed operation
# statistics / counts
self.reset()
def update(self, predictions: torch.Tensor, target: torch.Tensor):
sorted_indices = predictions.topk(self.maxk, dim=1)[1]
sorted_indices.t_()
correct = sorted_indices.eq(target.reshape(1, -1).expand_as(sorted_indices))
batch_size = target.shape[0]
correct_k = {k: correct[:k].reshape(-1).float().sum(0) for k in self.topk}
for k, v in correct_k.items():
attr = f'_correct_top{k}'
old_v = getattr(self, attr)
setattr(self, attr, old_v + v)
self._total_sum += batch_size
def reset(self):
for k in self.topk:
setattr(self, f'_correct_top{k}', torch.tensor(0, dtype=torch.float32))
self._total_sum = torch.tensor(0, dtype=torch.float32)
@property
def counts(self):
pass
def compute(self) -> Dict[str, torch.Tensor]:
# FIXME handle distributed reduction
return {f'top{k}': 100 * getattr(self, f'_correct_top{k}') / self._total_sum for k in self.topk}
#
# class AccuracyTopK:
#
# def __init__(self, topk=(1, 5), device=None):
# self.eps = 1e-8
# self.device = device
# self.topk = topk
# self.maxk = max(topk)
#
# # statistics / counts
# self._correct_sum = None
# self._total_sum = None
#
# def _check_init(self, device):
# to_device = self.device if self.device else device
# if self._correct_sum is None:
# self._correct_sum = {f'top{k}': torch.tensor(0., device=to_device) for k in self.topk}
# if self._total_sum is None:
# self._total_sum = torch.tensor(0, dtype=torch.long, device=to_device)
#
# def update(self, predictions: torch.Tensor, target: torch.Tensor):
# sorted_indices = predictions.topk(self.maxk, dim=1)[1]
# sorted_indices.t_()
# correct = sorted_indices.eq(target.reshape(1, -1).expand_as(sorted_indices))
#
# batch_size = target.shape[0]
# correct_k = {f'top{k}': correct[:k].reshape(-1).float().sum(0) for k in self.topk}
# self._check_init(device=predictions.device)
# for k, v in correct_k.items():
# old_v = self._correct_sum[k]
# self._correct_sum[k] = old_v + v
# self._total_sum += batch_size
#
# def reset(self):
# self._correct_sum = None
# self._total_sum = None
#
# @property
# def counts(self):
# pass
#
# def compute(self) -> Dict[str, torch.Tensor]:
# assert self._correct_sum is not None and self._total_sum is not None
# return {k: 100 * v / self._total_sum for k, v in self._correct_sum.items()}

@ -3,3 +3,4 @@ from .plateau_lr import PlateauLRScheduler
from .step_lr import StepLRScheduler
from .tanh_lr import TanhLRScheduler
from .scheduler_factory import create_scheduler
from .scheduler import Scheduler

@ -108,7 +108,8 @@ class CheckpointSaver:
save_state['arch'] = self.args.model
save_state['args'] = self.args
if self.amp_scaler is not None:
save_state[self.amp_scaler.state_dict_key] = self.amp_scaler.state_dict()
amp_key = getattr(self.amp_scaler, 'state_dict_key', 'amp_scaler')
save_state[amp_key] = self.amp_scaler.state_dict()
if self.model_ema is not None:
save_state['state_dict_ema'] = get_state_dict(self.model_ema, self.unwrap_fn)
if metric is not None:

@ -3,7 +3,11 @@ import torch
from timm.utils.agc import adaptive_clip_grad
def dispatch_clip_grad(parameters, value: float, mode: str = 'norm', norm_type: float = 2.0):
def dispatch_clip_grad(
parameters,
value: float,
mode: str = 'norm',
norm_type: float = 2.0):
""" Dispatch to gradient clipping method
Args:

@ -21,8 +21,8 @@ def distribute_bn(model, world_size, reduce=False):
if ('running_mean' in bn_name) or ('running_var' in bn_name):
if reduce:
# average bn stats across whole group
torch.distributed.all_reduce(bn_buf, op=dist.ReduceOp.SUM)
torch.distributed.all_reduce_recursive(bn_buf, op=dist.ReduceOp.SUM)
bn_buf /= float(world_size)
else:
# broadcast bn stats from rank 0 to whole group
torch.distributed.broadcast(bn_buf, 0)
torch.distributed.broadcast_recursive(bn_buf, 0)

@ -21,13 +21,15 @@ import os
import logging
from collections import OrderedDict
from datetime import datetime
from dataclasses import replace
from typing import Tuple
import torch
import torch.nn as nn
import torchvision.utils
from timm.bits import initialize_device, DeviceEnv, create_updater, Updater, Logger, Tracker
from timm.metrics import TensorAvg, AccuracyTopK
from timm.bits import initialize_device, setup_model_and_optimizer, DeviceEnv, Logger, Tracker,\
TrainState, TrainServices, TrainCfg, AccuracyTopK, AvgTensor
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint,\
convert_splitbn_model, model_parameters
@ -276,7 +278,7 @@ def main():
args, args_text = _parse_args()
dev_env = initialize_device(amp=args.amp)
if dev_env.is_distributed:
if dev_env.distributed:
_logger.info('Training in distributed mode with multiple processes, 1 device per process. Process %d, total %d.'
% (dev_env.global_rank, dev_env.world_size))
else:
@ -284,6 +286,111 @@ def main():
random_seed(args.seed, dev_env.global_rank)
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
train_state, train_cfg = setup_train_task(args, dev_env, mixup_active)
data_config, loader_eval, loader_train = setup_data(args, dev_env, mixup_active)
# setup checkpoint saver
eval_metric = args.eval_metric
best_metric = None
best_epoch = None
saver = None
output_dir = None
if dev_env.primary:
if args.experiment:
exp_name = args.experiment
else:
exp_name = '-'.join([
datetime.now().strftime("%Y%m%d-%H%M%S"),
safe_model_name(args.model),
str(data_config['input_size'][-1])
])
output_dir = get_outdir(args.output if args.output else './output/train', exp_name)
decreasing = True if eval_metric == 'loss' else False
saver = CheckpointSaver( # TODO CheckpointSaverV2
model=train_state.model,
optimizer=train_state.updater.optimizer,
args=args,
model_ema=train_state.model_ema,
amp_scaler=train_state.updater.grad_scaler,
checkpoint_dir=output_dir,
recovery_dir=output_dir,
decreasing=decreasing,
max_history=args.checkpoint_hist)
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
f.write(args_text)
services = TrainServices(
logger=Logger(
output_dir=output_dir, python_logger=_logger, hparams=vars(args), output_enabled=dev_env.primary),
saver=saver,
)
try:
for epoch in range(train_state.epoch, train_cfg.num_epochs):
if dev_env.distributed and hasattr(loader_train.sampler, 'set_epoch'):
loader_train.sampler.set_epoch(epoch)
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
if loader_train.mixup_enabled:
loader_train.mixup_enabled = False
train_metrics = train_one_epoch(
dev_env=dev_env,
state=train_state,
services=services,
cfg=train_cfg,
loader=loader_train
)
if dev_env.distributed and args.dist_bn in ('broadcast', 'reduce'):
if dev_env.primary:
_logger.info("Distributing BatchNorm running means and vars")
distribute_bn(model, dev_env.world_size, args.dist_bn == 'reduce')
eval_metrics = evaluate(
train_state.model,
train_state.eval_loss,
loader_eval,
dev_env,
logger=services.logger)
if train_state.model_ema is not None and not args.model_ema_force_cpu:
if dev_env.distributed and args.dist_bn in ('broadcast', 'reduce'):
distribute_bn(train_state.model_ema, dev_env.world_size, args.dist_bn == 'reduce')
ema_eval_metrics = evaluate(
train_state.model_ema.module,
train_state.eval_loss,
loader_eval,
dev_env,
logger=services.logger,
phase_suffix='EMA')
eval_metrics = ema_eval_metrics
if train_state.lr_scheduler is not None:
# step LR for next epoch
train_state.lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
if services.logger is not None:
services.logger.write_summary(index=epoch, results=dict(train=train_metrics, eval=eval_metrics))
if saver is not None:
# save proper checkpoint with eval metric
save_metric = eval_metrics[eval_metric]
best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)
train_state = replace(train_state, epoch=epoch + 1)
except KeyboardInterrupt:
pass
if best_metric is not None:
_logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
def setup_train_task(args, dev_env: DeviceEnv, mixup_active: bool):
model = create_model(
args.model,
pretrained=args.pretrained,
@ -302,82 +409,69 @@ def main():
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly
if dev_env.is_master:
if dev_env.primary:
_logger.info(
f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}')
data_config = resolve_data_config(vars(args), model=model, verbose=dev_env.is_master)
# setup augmentation batch splits for contrastive loss or split bn
num_aug_splits = 0
if args.aug_splits > 0:
assert args.aug_splits > 1, 'A split of 1 makes no sense'
num_aug_splits = args.aug_splits
assert args.aug_splits == 0 or args.aug_splits > 1, 'A split of 1 makes no sense'
# enable split bn (separate bn stats per batch-portion)
if args.split_bn:
assert num_aug_splits > 1 or args.resplit
model = convert_splitbn_model(model, max(num_aug_splits, 2))
# move model to GPU, enable channels last layout if set
dev_env.to_device(model)
# setup synchronized BatchNorm for distributed training
if dev_env.is_distributed and args.sync_bn:
assert not args.split_bn
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if dev_env.is_master:
_logger.info(
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
if args.torchscript:
assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
model = torch.jit.script(model)
updater = create_updater(
create_optimizer_v2(model, **optimizer_kwargs(cfg=args)),
clip_value=args.clip_grad, clip_mode=args.clip_mode)
# optionally resume from a checkpoint
resume_epoch = None
if args.resume:
resume_epoch = resume_checkpoint(
model, args.resume,
optimizer=None if args.no_resume_opt else updater.optimizer,
loss_scaler=None if args.no_resume_opt else updater.scaler,
log_info=dev_env.is_master)
# setup exponential moving average of model weights, SWA could be used here too
model_ema = None
if args.model_ema:
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
model_ema = ModelEmaV2(
model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None)
if args.resume:
load_checkpoint(model_ema.module, args.resume, use_ema=True)
# setup distributed training
if dev_env.is_distributed:
if dev_env.is_master:
_logger.info("Distributing model.")
model = dev_env.wrap_distributed(model)
# NOTE: EMA model does not need to be wrapped by DDP
assert args.aug_splits > 1 or args.resplit
model = convert_splitbn_model(model, max(args.aug_splits, 2))
train_state = setup_model_and_optimizer(
dev_env=dev_env,
model=model,
optimizer=args.opt,
optimizer_cfg=optimizer_kwargs(cfg=args),
clip_fn=args.clip_mode if args.clip_grad is not None else None,
clip_value=args.clip_grad,
model_ema=args.model_ema,
model_ema_decay=args.model_ema_decay,
use_syncbn=args.sync_bn,
)
# setup learning rate schedule and starting epoch
lr_scheduler, num_epochs = create_scheduler(args, updater.optimizer)
start_epoch = 0
if args.start_epoch is not None:
# a specified start_epoch will always override the resume epoch
start_epoch = args.start_epoch
elif resume_epoch is not None:
start_epoch = resume_epoch
if lr_scheduler is not None and start_epoch > 0:
lr_scheduler.step(start_epoch)
if dev_env.is_master:
# FIXME move into updater?
lr_scheduler, num_epochs = create_scheduler(args, train_state.updater.optimizer)
if lr_scheduler is not None and train_state.epoch > 0:
lr_scheduler.step(train_state.epoch)
# setup loss function
if args.jsd:
assert args.aug_splits > 1 # JSD only valid with aug splits set
train_loss_fn = JsdCrossEntropy(num_splits=args.aug_splits, smoothing=args.smoothing)
elif mixup_active:
# smoothing is handled with mixup target transform
train_loss_fn = SoftTargetCrossEntropy()
elif args.smoothing:
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
else:
train_loss_fn = nn.CrossEntropyLoss()
eval_loss_fn = nn.CrossEntropyLoss()
dev_env.to_device(train_loss_fn, eval_loss_fn)
if dev_env.primary:
_logger.info('Scheduled epochs: {}'.format(num_epochs))
train_state = replace(
train_state,
lr_scheduler=lr_scheduler,
train_loss=train_loss_fn,
eval_loss=eval_loss_fn)
train_cfg = TrainCfg(
num_epochs=num_epochs,
log_interval=args.log_interval,
recovery_interval=args.recovery_interval)
return train_state, train_cfg
def setup_data(args, dev_env, mixup_active):
data_config = resolve_data_config(vars(args), model=model, verbose=dev_env.primary)
# create the train and eval datasets
dataset_train = create_dataset(
args.dataset,
@ -388,18 +482,17 @@ def main():
# setup mixup / cutmix
collate_fn = None
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
if mixup_active:
mixup_args = dict(
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
label_smoothing=args.smoothing, num_classes=args.num_classes)
assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup)
assert not args.aug_splits # collate conflict (need to support deinterleaving in collate mixup)
collate_fn = FastCollateMixup(**mixup_args)
# wrap dataset in AugMix helper
if num_aug_splits > 1:
dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
if args.aug_splits > 1:
dataset_train = AugMixDataset(dataset_train, num_splits=args.aug_splits)
# create data loaders w/ augmentation pipeiine
train_interpolation = args.train_interpolation
@ -421,7 +514,7 @@ def main():
vflip=args.vflip,
color_jitter=args.color_jitter,
auto_augment=args.aa,
num_aug_splits=num_aug_splits,
num_aug_splits=args.aug_splits,
interpolation=train_interpolation,
mean=data_config['mean'],
std=data_config['std'],
@ -443,169 +536,107 @@ def main():
crop_pct=data_config['crop_pct'],
pin_memory=args.pin_mem,
)
return data_config, loader_eval, loader_train
# setup loss function
if args.jsd:
assert num_aug_splits > 1 # JSD only valid with aug splits set
train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing)
elif mixup_active:
# smoothing is handled with mixup target transform
train_loss_fn = SoftTargetCrossEntropy()
elif args.smoothing:
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
else:
train_loss_fn = nn.CrossEntropyLoss()
validate_loss_fn = nn.CrossEntropyLoss()
dev_env.to_device(train_loss_fn, validate_loss_fn)
# setup checkpoint saver and eval metric tracking
eval_metric = args.eval_metric
best_metric = None
best_epoch = None
saver = None
output_dir = None
if dev_env.is_master:
if args.experiment:
exp_name = args.experiment
else:
exp_name = '-'.join([
datetime.now().strftime("%Y%m%d-%H%M%S"),
safe_model_name(args.model),
str(data_config['input_size'][-1])
])
output_dir = get_outdir(args.output if args.output else './output/train', exp_name)
decreasing = True if eval_metric == 'loss' else False
saver = CheckpointSaver(
model=model, optimizer=updater.optimizer, args=args, model_ema=model_ema, amp_scaler=updater.scaler,
checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist)
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
f.write(args_text)
logger = Logger(output_dir=output_dir, logger=_logger, hparams=vars(args))
try:
for epoch in range(start_epoch, num_epochs):
if dev_env.is_distributed and hasattr(loader_train.sampler, 'set_epoch'):
loader_train.sampler.set_epoch(epoch)
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
if loader_train.mixup_enabled:
loader_train.mixup_enabled = False
def train_one_epoch(
dev_env: DeviceEnv,
state: TrainState,
cfg: TrainCfg,
services: TrainServices,
loader,
):
tracker = Tracker()
loss_meter = AvgTensor()
train_metrics = train_one_epoch(
epoch, model, loader_train, updater, train_loss_fn, dev_env,
lr_scheduler=lr_scheduler, saver=saver, logger=logger, model_ema=model_ema,
log_interval=args.log_interval, recovery_interval=args.recovery_interval)
state.model.train()
state.updater.reset() # zero-grad
if dev_env.is_distributed and args.dist_bn in ('broadcast', 'reduce'):
if dev_env.is_master:
_logger.info("Distributing BatchNorm running means and vars")
distribute_bn(model, dev_env.world_size, args.dist_bn == 'reduce')
step_end_idx = len(loader) - 1
tracker.mark_iter()
for step_idx, (sample, target) in enumerate(loader):
tracker.mark_iter_data_end()
eval_metrics = evaluate(model, loader_eval, validate_loss_fn, dev_env, logger=logger)
# FIXME move forward + loss into model 'task' wrapper
with dev_env.autocast():
output = state.model(sample)
loss = state.train_loss(output, target)
if model_ema is not None and not args.model_ema_force_cpu:
if dev_env.is_distributed and args.dist_bn in ('broadcast', 'reduce'):
distribute_bn(model_ema, dev_env.world_size, args.dist_bn == 'reduce')
state.updater.apply(loss)
ema_eval_metrics = evaluate(
model_ema.module, loader_eval, validate_loss_fn, dev_env,
logger=logger, phase_suffix='EMA')
eval_metrics = ema_eval_metrics
tracker.mark_iter_step_end()
if lr_scheduler is not None:
# step LR for next epoch
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
state.updater.after_step(
after_train_step,
dev_env,
state,
services,
cfg,
step_idx,
step_end_idx,
tracker,
loss_meter,
(output, target, loss),
)
if logger is not None:
logger.write_summary(index=epoch, results=dict(train=train_metrics, eval=eval_metrics))
tracker.mark_iter()
# end for
if saver is not None:
# save proper checkpoint with eval metric
save_metric = eval_metrics[eval_metric]
best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)
if hasattr(state.updater.optimizer, 'sync_lookahead'):
state.updater.optimizer.sync_lookahead()
except KeyboardInterrupt:
pass
if best_metric is not None:
_logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
return OrderedDict([('loss', loss_meter.compute().item())])
def train_one_epoch(
epoch: int,
model: nn.Module,
loader,
updater: Updater,
loss_fn: nn.Module,
def after_train_step(
dev_env: DeviceEnv,
lr_scheduler=None,
saver: CheckpointSaver = None,
logger: Logger = None,
model_ema: nn.Module = None,
log_interval: int = 50,
recovery_interval: int = 0,
state: TrainState,
services: TrainServices,
cfg: TrainCfg,
step_idx: int,
step_end_idx: int,
tracker: Tracker,
loss_meter: AvgTensor,
tensors: Tuple[torch.Tensor, ...],
):
tracker = Tracker()
losses_m = TensorAvg()
end_step = step_idx == step_end_idx
model.train()
end_idx = len(loader) - 1
num_updates = epoch * len(loader)
batch_size = 0
tracker.mark_iter()
for step_idx, (sample, target) in enumerate(loader):
tracker.mark_iter_data_end()
last_step = step_idx == end_idx
batch_size = max(batch_size, sample.shape[0])
with dev_env.autocast():
output = model(sample)
loss = loss_fn(output, target)
updater.reset()
updater.apply(loss)
with torch.no_grad():
output, target, loss = tensors
loss_meter.update(loss, output.shape[0])
dev_env.mark_step() # FIXME
tracker.mark_iter_step_end()
losses_m.update(loss, sample.size(0))
if model_ema is not None:
model_ema.update(model)
if state.model_ema is not None:
state.model_ema.update(model)
num_updates += 1
if last_step or (step_idx + 1) % log_interval == 0:
lrl = [param_group['lr'] for param_group in updater.optimizer.param_groups]
lr = sum(lrl) / len(lrl)
state = replace(state, step_count_global=state.step_count_global + 1)
if dev_env.is_master and logger is not None:
loss_avg = losses_m.compute()
logger.log_step(
if services.logger is not None and end_step or (step_idx + 1) % cfg.log_interval == 0:
global_batch_size = dev_env.world_size * output.shape[0]
loss_avg = loss_meter.compute()
if services.logger is not None:
lr_avg = state.updater.get_average_lr()
services.logger.log_step(
'Train',
step=step_idx,
end_step=end_idx,
step_end=step_end_idx,
epoch=state.epoch,
loss=loss_avg.item(),
rate=(dev_env.world_size * batch_size) / tracker.iter_time.avg,
lr=lr,
rate=tracker.get_avg_iter_rate(global_batch_size),
lr=lr_avg,
)
if saver is not None and recovery_interval and (last_step or (step_idx + 1) % recovery_interval == 0):
saver.save_recovery(epoch, batch_idx=step_idx)
if lr_scheduler is not None:
lr_scheduler.step_update(num_updates=num_updates)
if services.saver is not None and cfg.recovery_interval and (
end_step or (step_idx + 1) % cfg.recovery_interval == 0):
services.saver.save_recovery(state.epoch, batch_idx=step_idx)
tracker.mark_iter()
# end for
if hasattr(updater.optimizer, 'sync_lookahead'):
updater.optimizer.sync_lookahead()
return OrderedDict([('loss', losses_m.compute().item())])
if state.lr_scheduler is not None:
state.lr_scheduler.step_update(num_updates=state.step_count_global)
def evaluate(
model: nn.Module,
loader,
loss_fn: nn.Module,
loader,
dev_env: DeviceEnv,
logger: Logger,
phase_suffix: str = '',
@ -613,7 +644,7 @@ def evaluate(
):
tracker = Tracker()
losses_m = TensorAvg()
losses_m = AvgTensor()
accuracy_m = AccuracyTopK()
model.eval()
@ -636,13 +667,13 @@ def evaluate(
losses_m.update(loss, output.size(0))
accuracy_m.update(output, target)
if dev_env.is_master and (last_step or step_idx % log_interval == 0):
if last_step or step_idx % log_interval == 0:
top1, top5 = accuracy_m.compute().values()
loss_avg = losses_m.compute()
logger.log_step(
'Eval',
step=step_idx,
num_steps=end_idx,
step_end=end_idx,
loss=loss_avg.item(),
top1=top1.item(),
top5=top5.item(),

@ -18,8 +18,7 @@ import torch.nn as nn
import torch.nn.parallel
from collections import OrderedDict
from timm.bits import initialize_device, Tracker, Logger
from timm.metrics import AccuracyTopK, TensorAvg
from timm.bits import initialize_device, Tracker, Logger, AccuracyTopK, AvgTensor
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet
from timm.utils import natural_key, setup_default_logging
@ -155,10 +154,10 @@ def validate(args):
pin_memory=args.pin_mem,
tf_preprocessing=args.tf_preprocessing)
logger = Logger(logger=_logger)
logger = Logger(python_logger=_logger)
tracker = Tracker()
losses = TensorAvg()
accuracy = AccuracyTopK().to(dev_env.device)
losses = AvgTensor()
accuracy = AccuracyTopK(dev_env=dev_env)
model.eval()
num_steps = len(loader)
@ -175,10 +174,8 @@ def validate(args):
output = output[:, valid_labels]
loss = criterion(output, target)
if dev_env.type == 'cuda':
if dev_env.type_cuda:
torch.cuda.synchronize()
#elif dev_env.type == 'xla':
# dev_env.mark_step()
tracker.mark_iter_step_end()
losses.update(loss.detach(), sample.size(0))
@ -186,7 +183,7 @@ def validate(args):
real_labels.add_result(output)
accuracy.update(output.detach(), target)
if dev_env.type == 'xla':
if dev_env.type_xla:
dev_env.mark_step()
tracker.mark_iter()

Loading…
Cancel
Save