Major timm.bits update. Updater and DeviceEnv now dataclasses, after_step closure used, metrics base impl w/ distributed reduce, many tweaks/fixes.
parent
938716c753
commit
aa92d7b1c5
@ -1,10 +1,25 @@
|
||||
from .avg_scalar import AvgMinMaxScalar
|
||||
from .avg_tensor import AvgTensor
|
||||
from .device_env import DeviceEnv, DeviceEnvType
|
||||
from .device_env_cuda import DeviceEnvCuda
|
||||
from .device_env_factory import initialize_device, get_device
|
||||
from .device_env import DeviceEnv
|
||||
#from .evaluate import evaluate, eval_step
|
||||
from .device_env_xla import DeviceEnvXla
|
||||
from .distributed import distribute_bn, all_gather_recursive, all_reduce_recursive, broadcast_recursive,\
|
||||
all_reduce_sequence, all_gather_sequence
|
||||
# from .evaluate import evaluate, eval_step
|
||||
from .logger import Logger
|
||||
#from .task import TaskClassify
|
||||
from .metric import Metric, MetricValue
|
||||
from .metric_accuracy import AccuracyTopK
|
||||
from .tracker import Tracker
|
||||
# from .task_metrics import TaskMetrics, TaskMetricsClassify
|
||||
from .train_cfg import TrainCfg
|
||||
from .train_services import TrainServices
|
||||
from .train_setup import setup_model_and_optimizer
|
||||
from .train_state import TrainState
|
||||
# from .task import TaskClassify
|
||||
from .updater import Updater
|
||||
from .updater_cuda import UpdaterCudaWithScaler
|
||||
from .updater_deepspeed import UpdaterDeepSpeed
|
||||
from .updater_factory import create_updater
|
||||
from .tracker import Tracker
|
||||
#from .task_metrics import TaskMetrics, TaskMetricsClassify
|
||||
#from .train import train_one_epoch, TrainServices, TrainState, TrainCfg, Experiment
|
||||
from .updater_xla import UpdaterXla, UpdaterXlaWithScaler
|
||||
# from .train import train_one_epoch, Experiment
|
||||
|
@ -1,4 +1,4 @@
|
||||
class ScalarAvgMinMax:
|
||||
class AvgMinMaxScalar:
|
||||
|
||||
"""Computes and stores the average and current value"""
|
||||
def __init__(self):
|
@ -1,7 +1,7 @@
|
||||
import torch
|
||||
|
||||
|
||||
class TensorAvg:
|
||||
class AvgTensor:
|
||||
|
||||
"""Computes and stores the average and current value"""
|
||||
def __init__(self):
|
@ -0,0 +1,58 @@
|
||||
import logging
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
|
||||
from .train_state import TrainState, serialize_train_state, deserialize_train_state
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def resume_train_checkpoint(
|
||||
train_state,
|
||||
checkpoint_path,
|
||||
resume_opt=True,
|
||||
deserialize_fn=deserialize_train_state,
|
||||
log_info=True):
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
# resume_epoch = None
|
||||
# if os.path.isfile(checkpoint_path):
|
||||
# checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||
#
|
||||
# if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
||||
# if log_info:
|
||||
# _logger.info('Restoring model state from checkpoint...')
|
||||
# new_state_dict = OrderedDict()
|
||||
# for k, v in checkpoint['state_dict'].items():
|
||||
# name = k[7:] if k.startswith('module') else k
|
||||
# new_state_dict[name] = v
|
||||
# model.load_state_dict(new_state_dict)
|
||||
#
|
||||
# if optimizer is not None and 'optimizer' in checkpoint:
|
||||
# if log_info:
|
||||
# _logger.info('Restoring optimizer state from checkpoint...')
|
||||
# optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
#
|
||||
# if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint:
|
||||
# if log_info:
|
||||
# _logger.info('Restoring AMP loss scaler state from checkpoint...')
|
||||
# loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key])
|
||||
#
|
||||
# if 'epoch' in checkpoint:
|
||||
# resume_epoch = checkpoint['epoch']
|
||||
# if 'version' in checkpoint and checkpoint['version'] > 1:
|
||||
# resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
|
||||
#
|
||||
# if log_info:
|
||||
# _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
|
||||
# else:
|
||||
# model.load_state_dict(checkpoint)
|
||||
# if log_info:
|
||||
# _logger.info("Loaded checkpoint '{}'".format(checkpoint_path))
|
||||
# return resume_epoch
|
||||
# else:
|
||||
# _logger.error("No checkpoint found at '{}'".format(checkpoint_path))
|
||||
# raise FileNotFoundError()
|
@ -1,58 +1,130 @@
|
||||
import torch
|
||||
import abc
|
||||
from contextlib import suppress
|
||||
from enum import Enum
|
||||
from typing import Callable, Union, Optional, List, Tuple
|
||||
from dataclasses import dataclass, field, InitVar
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
class DeviceEnv(abc.ABC):
|
||||
TensorList = Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def device(self) -> torch.device:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def local_rank(self) -> int:
|
||||
pass
|
||||
class DeviceEnvType(Enum):
|
||||
""" Device Environment Types
|
||||
"""
|
||||
CPU = "cpu"
|
||||
CUDA = "cuda"
|
||||
XLA = "xla"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeviceEnv:
|
||||
device_type: InitVar[Optional[str]] = None
|
||||
device_index: InitVar[Optional[int]] = None
|
||||
|
||||
device: torch.device = field(init=False) # set from device_type + device_index or post_init logic
|
||||
world_size: Optional[int] = None # set by post_init from env when None
|
||||
local_rank: Optional[int] = None # set by post_init from env when None
|
||||
global_rank: Optional[int] = None # set by post_init from env when None
|
||||
amp: bool = False
|
||||
autocast: Optional[Callable] = None # set by post_init from env when None
|
||||
memory_format: Optional[torch.memory_format] = None
|
||||
dtype: Optional[torch.dtype] = None
|
||||
|
||||
def __post_init__(self, device_type: Optional[str], device_index: Optional[int]):
|
||||
device_type = device_type or 'cpu'
|
||||
self.device = torch.device(device_type) if device_index is None \
|
||||
else torch.device(device_type, device_index)
|
||||
self.world_size = 1 if self.world_size is None else self.world_size
|
||||
self.local_rank = 0 if self.local_rank is None else self.local_rank
|
||||
self.global_rank = 0 if self.global_rank is None else self.global_rank
|
||||
if self.autocast is None:
|
||||
self.autocast = suppress
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def global_rank(self) -> int:
|
||||
pass
|
||||
def type(self) -> DeviceEnvType:
|
||||
if self.device.type == 'cpu':
|
||||
return DeviceEnvType.CPU
|
||||
elif self.device.type == 'cuda':
|
||||
return DeviceEnvType.CUDA
|
||||
elif self.device.type == 'xla':
|
||||
return DeviceEnvType.XLA
|
||||
else:
|
||||
assert False, "Unexpected device type for base DevEnv impl."
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def is_distributed(self) -> bool:
|
||||
pass
|
||||
def type_cuda(self):
|
||||
# shortcut for common cuda device type
|
||||
return self.type == DeviceEnvType.CUDA
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def world_size(self) -> int:
|
||||
pass
|
||||
def type_xla(self):
|
||||
# shortcut for common xla device type
|
||||
return self.type == DeviceEnvType.XLA
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def is_master(self) -> bool:
|
||||
pass
|
||||
def distributed(self):
|
||||
return self.world_size > 1
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def type(self) -> str:
|
||||
pass
|
||||
def primary(self):
|
||||
return self.local_rank == 0
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def autocast(self):
|
||||
pass
|
||||
def global_primary(self):
|
||||
return self.global_rank == 0
|
||||
|
||||
@abc.abstractmethod
|
||||
def wrap_distributed(self, *modules):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def to_device(self, *modules: torch.nn.Module):
|
||||
def wrap_parallel(self, *modules):
|
||||
pass
|
||||
|
||||
#@abc.abstractmethod
|
||||
def to_device(self, *modules: torch.nn.Module):
|
||||
# FIXME handling dtype / memformat... disable flags, enable flags, diff fn?
|
||||
moved = [m.to(device=self.device, memory_format=self.memory_format) for m in modules]
|
||||
return moved[0] if len(moved) == 1 else moved
|
||||
|
||||
def mark_step(self):
|
||||
# FIXME this is for XLA only, make it common to all devices w/ appropriate no-ops?
|
||||
pass
|
||||
pass # NO-OP for non-XLA devices
|
||||
|
||||
def all_reduce_(self, tensor: TensorList, op=dist.ReduceOp.SUM, average=False):
|
||||
print(len(tensor), type(tensor))
|
||||
print(tensor.shape)
|
||||
dist.all_reduce(tensor, op=op)
|
||||
if average:
|
||||
tensor.div_(self.world_size)
|
||||
return tensor
|
||||
|
||||
def all_reduce(self, tensor: torch.Tensor, op=dist.ReduceOp.SUM, average=False):
|
||||
reduce_tensor = tensor.clone()
|
||||
dist.all_reduce(reduce_tensor, op=op)
|
||||
if average:
|
||||
reduce_tensor = reduce_tensor / self.world_size
|
||||
return reduce_tensor
|
||||
|
||||
def all_gather(self, tensor: torch.Tensor, cat_dim=0):
|
||||
output_tensors = [torch.empty_like(tensor) for _ in range(self.world_size)]
|
||||
dist.all_gather(output_tensors, tensor)
|
||||
return torch.cat(output_tensors, cat_dim)
|
||||
|
||||
def all_to_all(self, tensor: torch.Tensor, num_splits, split_dim, cat_dim=0):
|
||||
input_tensors = torch.chunk(tensor, num_splits, split_dim)
|
||||
output_tensors = [torch.empty_like(input_tensors[0]) for _ in range(self.world_size)]
|
||||
dist.all_to_all(output_tensors, input_tensors)
|
||||
return torch.cat(output_tensors, cat_dim)
|
||||
|
||||
def broadcast_(self, tensor: torch.Tensor, src_rank=0):
|
||||
dist.broadcast(tensor, src=src_rank)
|
||||
return tensor
|
||||
|
||||
def broadcast(self, tensor: Optional[torch.Tensor] = None, src_rank=0):
|
||||
if self.global_rank != src_rank:
|
||||
tensor = torch.empty_like(tensor)
|
||||
assert tensor is not None
|
||||
dist.broadcast(tensor, src=src_rank)
|
||||
return tensor
|
||||
|
||||
def barrier(self):
|
||||
dist.barrier()
|
||||
|
@ -1,92 +1,58 @@
|
||||
import os
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass, field, InitVar
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.nn.parallel import DistributedDataParallel, DataParallel
|
||||
|
||||
from .device_env import DeviceEnv
|
||||
from .device_env import DeviceEnv, DeviceEnvType
|
||||
|
||||
|
||||
def is_cuda_available():
|
||||
return torch.cuda.is_available()
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeviceEnvCuda(DeviceEnv):
|
||||
|
||||
def __init__(self, device_idx=None, local_rank=None, amp=False, memory_format=None):
|
||||
def __post_init__(self, device_type: str, device_index: Optional[int]):
|
||||
assert torch.cuda.device_count()
|
||||
torch.backends.cudnn.benchmark = True
|
||||
self._local_rank = 0
|
||||
self._distributed = False
|
||||
self._world_size = 1
|
||||
self._global_rank = 0
|
||||
if 'WORLD_SIZE' in os.environ:
|
||||
self._distributed = int(os.environ['WORLD_SIZE']) > 1
|
||||
if self._distributed:
|
||||
if local_rank is None:
|
||||
setup_world_size = self.world_size or int(os.environ.get('WORLD_SIZE', 1))
|
||||
assert setup_world_size
|
||||
if setup_world_size > 1:
|
||||
# setup distributed
|
||||
assert device_index is None
|
||||
if self.local_rank is None:
|
||||
lr = os.environ.get('LOCAL_RANK', None)
|
||||
if lr is None:
|
||||
raise RuntimeError(
|
||||
'At least one of LOCAL_RANK env variable or local_rank arg must be set to valid integer.')
|
||||
self._local_rank = lr
|
||||
else:
|
||||
self._local_rank = int(local_rank)
|
||||
self._device = torch.device('cuda:%d' % self._local_rank)
|
||||
torch.cuda.set_device(self._local_rank)
|
||||
self.local_rank = int(lr)
|
||||
self.device = torch.device('cuda:%d' % self.local_rank)
|
||||
torch.cuda.set_device(self.local_rank)
|
||||
torch.distributed.init_process_group(backend='nccl', init_method='env://')
|
||||
self._world_size = torch.distributed.get_world_size()
|
||||
self._global_rank = torch.distributed.get_rank()
|
||||
self.world_size = torch.distributed.get_world_size()
|
||||
assert self.world_size == setup_world_size
|
||||
self.global_rank = torch.distributed.get_rank()
|
||||
else:
|
||||
self._device = torch.device('cuda' if device_idx is None else f'cuda:{device_idx}')
|
||||
self._memory_format = memory_format
|
||||
if amp:
|
||||
self._amp = amp
|
||||
self._autocast = torch.cuda.amp.autocast
|
||||
else:
|
||||
self._amp = amp
|
||||
self._autocast = suppress
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self._device
|
||||
|
||||
@property
|
||||
def local_rank(self):
|
||||
return self._local_rank
|
||||
|
||||
@property
|
||||
def global_rank(self):
|
||||
return self._global_rank
|
||||
|
||||
@property
|
||||
def is_distributed(self):
|
||||
return self._distributed
|
||||
self.device = torch.device('cuda' if device_index is None else f'cuda:{device_index}')
|
||||
self.local_rank = 0
|
||||
self.world_size = 1
|
||||
self.global_rank = 0
|
||||
if self.autocast is None:
|
||||
self.autocast = torch.cuda.amp.autocast if self.amp else suppress
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
return self._world_size
|
||||
|
||||
@property
|
||||
def is_master(self):
|
||||
return self._local_rank == 0
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return 'cuda'
|
||||
|
||||
@property
|
||||
def amp(self) -> bool:
|
||||
return self._amp
|
||||
|
||||
@property
|
||||
def autocast(self):
|
||||
return self._autocast
|
||||
def type(self) -> DeviceEnvType:
|
||||
return DeviceEnvType.CUDA
|
||||
|
||||
def wrap_distributed(self, *modules, **kwargs):
|
||||
wrapped = [DistributedDataParallel(m, device_ids=[self._local_rank], **kwargs) for m in modules]
|
||||
wrapped = [DistributedDataParallel(m, device_ids=[self.local_rank], **kwargs) for m in modules]
|
||||
return wrapped[0] if len(wrapped) == 1 else wrapped
|
||||
|
||||
def to_device(self, *modules: torch.nn.Module):
|
||||
# FIXME handling dtype / memformat... disable flags, enable flags, diff fn?
|
||||
moved = [m.to(device=self._device, memory_format=self._memory_format) for m in modules]
|
||||
return moved[0] if len(moved) == 1 else moved
|
||||
def wrap_parallel(self, *modules, **kwargs):
|
||||
assert not self.distributed
|
||||
wrapped = [DataParallel(m, **kwargs) for m in modules]
|
||||
return wrapped[0] if len(wrapped) == 1 else wrapped
|
||||
|
@ -0,0 +1,151 @@
|
||||
from typing import Dict, Tuple, List, Union, Any, Callable
|
||||
|
||||
import torch
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
from timm.utils import unwrap_model
|
||||
|
||||
from .device_env import DeviceEnv, DeviceEnvType
|
||||
from .device_env_factory import get_device
|
||||
|
||||
|
||||
TensorSeq = Union[torch.Tensor, Tuple[torch.Tensor, ...], List[torch.Tensor], Dict[Any, torch.Tensor]]
|
||||
|
||||
|
||||
def _validate_type(tensor: TensorSeq):
|
||||
if isinstance(tensor, (dict, list, tuple)):
|
||||
if not tensor:
|
||||
return
|
||||
else:
|
||||
assert isinstance(tensor, torch.Tensor)
|
||||
|
||||
|
||||
def distribute_bn(model: torch.nn.Module, reduce: bool = False, dev_env: DeviceEnv = None):
|
||||
if dev_env is None:
|
||||
dev_env = get_device()
|
||||
# ensure every node has the same running bn stats
|
||||
for bn_name, bn_buf in unwrap_model(model).named_buffers(recurse=True):
|
||||
if ('running_mean' in bn_name) or ('running_var' in bn_name):
|
||||
if reduce:
|
||||
# average bn stats across whole group
|
||||
dev_env.all_reduce_(bn_buf, average=True)
|
||||
else:
|
||||
# broadcast bn stats from rank 0 to whole group
|
||||
dev_env.broadcast_(bn_buf, 0)
|
||||
|
||||
|
||||
def all_gather_recursive(tensor: TensorSeq, cat_dim=0, dev_env: DeviceEnv = None):
|
||||
""" Recursive all gather via DeviceEnv distributed primitives
|
||||
FIXME add group support
|
||||
"""
|
||||
_validate_type(tensor)
|
||||
if dev_env is None:
|
||||
dev_env = get_device()
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
return dev_env.all_gather(tensor, cat_dim=cat_dim)
|
||||
elif isinstance(tensor, dict):
|
||||
return {k: all_gather_recursive(v, dev_env=dev_env) for k, v in tensor.items()}
|
||||
elif isinstance(tensor, (tuple, list)):
|
||||
return type(tensor)(all_gather_recursive(v, dev_env=dev_env) for v in tensor)
|
||||
|
||||
|
||||
def all_reduce_recursive(tensor: TensorSeq, op=ReduceOp.SUM, average=False, dev_env: DeviceEnv = None):
|
||||
""" Recursive all reduce via DeviceEnv distributed primitives
|
||||
FIXME add group support
|
||||
"""
|
||||
_validate_type(tensor)
|
||||
if dev_env is None:
|
||||
dev_env = get_device()
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
return dev_env.all_reduce_(tensor, op=op, average=average)
|
||||
elif isinstance(tensor, dict):
|
||||
return {k: all_reduce_recursive(v, op=op, average=average, dev_env=dev_env) for k, v in tensor.items()}
|
||||
elif isinstance(tensor, (tuple, list)):
|
||||
return type(tensor)(all_reduce_recursive(v, op=op, average=average, dev_env=dev_env) for v in tensor)
|
||||
|
||||
|
||||
def broadcast_recursive(tensor: TensorSeq, src_rank: int, dev_env: DeviceEnv = None):
|
||||
""" Recursive broadcast via DeviceEnv distributed primitives
|
||||
FIXME add group support
|
||||
"""
|
||||
_validate_type(tensor)
|
||||
if dev_env is None:
|
||||
dev_env = get_device()
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
return dev_env.broadcast_(tensor, src_rank=src_rank)
|
||||
elif isinstance(tensor, dict):
|
||||
return {k: broadcast_recursive(v, src_rank=src_rank, dev_env=dev_env) for k, v in tensor.items()}
|
||||
elif isinstance(tensor, (tuple, list)):
|
||||
return type(tensor)(broadcast_recursive(v, src_rank=src_rank, dev_env=dev_env) for v in tensor)
|
||||
|
||||
|
||||
def all_gather_sequence(tensor: TensorSeq, cat_dim: int = 0, dev_env: DeviceEnv = None):
|
||||
""" All gather a flat Tensor sequence (dict, list, tuple) of same shape
|
||||
|
||||
"""
|
||||
_validate_type(tensor)
|
||||
if dev_env is None:
|
||||
dev_env = get_device()
|
||||
|
||||
with torch.no_grad():
|
||||
names = None
|
||||
# merge values into one tensor for reduction
|
||||
if isinstance(tensor, dict):
|
||||
names = tensor.keys()
|
||||
gather_values = tuple(tensor.values())
|
||||
elif isinstance(tensor, (tuple, list)):
|
||||
gather_values = tensor
|
||||
else:
|
||||
gather_values = (tensor,)
|
||||
|
||||
gather_values = torch.stack(gather_values, dim=0)
|
||||
gather_values = dev_env.all_gather(gather_values, cat_dim=cat_dim + 1).unbind(dim=0)
|
||||
|
||||
# separate reduced values into original structure
|
||||
if isinstance(tensor, dict):
|
||||
gather_values = {k: v for k, v in zip(names, gather_values)}
|
||||
elif isinstance(tensor, (tuple, list)):
|
||||
gather_values = type(tensor)(v for v in gather_values)
|
||||
else:
|
||||
gather_values = gather_values[0]
|
||||
|
||||
return gather_values
|
||||
|
||||
|
||||
def all_reduce_sequence(tensor: TensorSeq, op=ReduceOp.SUM, average=False, dev_env: DeviceEnv = None):
|
||||
"""
|
||||
All reduce the tensors in a flat Tensor sequence (dict, list, tuple) of same tensor shape
|
||||
|
||||
Args:
|
||||
tensor (dict): inputs to be reduced. All the values must be scalar Tensor.
|
||||
average (bool): whether to do average or sum
|
||||
Returns:
|
||||
a sequence with the same type as input (dict, list, tuple)
|
||||
"""
|
||||
_validate_type(tensor)
|
||||
if dev_env is None:
|
||||
dev_env = get_device()
|
||||
|
||||
with torch.no_grad():
|
||||
names = None
|
||||
# merge values into one tensor for reduction
|
||||
if isinstance(tensor, dict):
|
||||
names = tensor.keys()
|
||||
reduce_values = tuple(tensor.values())
|
||||
elif isinstance(tensor, (tuple, list)):
|
||||
reduce_values = tensor
|
||||
else:
|
||||
reduce_values = (tensor,)
|
||||
|
||||
reduce_values = torch.stack(reduce_values, dim=0)
|
||||
dev_env.all_reduce_(reduce_values, op=op, average=average)
|
||||
reduce_values = reduce_values.unbind(dim=0)
|
||||
# separate reduced values into original structure
|
||||
if isinstance(tensor, dict):
|
||||
reduce_values = {k: v for k, v in zip(names, reduce_values)}
|
||||
elif isinstance(tensor, (tuple, list)):
|
||||
reduce_values = type(tensor)(v for v in reduce_values)
|
||||
else:
|
||||
reduce_values = reduce_values[0]
|
||||
|
||||
return reduce_values
|
@ -0,0 +1,190 @@
|
||||
""" PyTorch distributed helpers
|
||||
|
||||
Some of this lifted from Detectron2 with other fns added by myself.
|
||||
|
||||
FIXME many functions remain unfinished/untested
|
||||
"""
|
||||
from typing import Dict, Tuple, List, Union, Any, Callable
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
TensorSeq = Union[torch.Tensor, Tuple[torch.Tensor, ...], List[torch.Tensor], Dict[Any, torch.Tensor]]
|
||||
|
||||
|
||||
def synchronize_torch():
|
||||
"""
|
||||
Helper function to synchronize (barrier) among all processes when
|
||||
using distributed training
|
||||
"""
|
||||
if not dist.is_available():
|
||||
return
|
||||
if not dist.is_initialized():
|
||||
return
|
||||
world_size = dist.get_world_size()
|
||||
if world_size == 1:
|
||||
return
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def all_reduce_sequence_torch(values: TensorSeq, op=ReduceOp.SUM, average=False, group=None):
|
||||
"""
|
||||
All reduce the tensors in a sequence (dict, list, tuple)
|
||||
|
||||
Args:
|
||||
values (dict): inputs to be reduced. All the values must be scalar Tensor.
|
||||
average (bool): whether to do average or sum
|
||||
Returns:
|
||||
a sequence with the same type as input (dict, list, tuple)
|
||||
"""
|
||||
world_size = dist.get_world_size(group)
|
||||
if world_size <= 1:
|
||||
return values
|
||||
|
||||
with torch.no_grad():
|
||||
names = None
|
||||
if isinstance(values, dict):
|
||||
names = values.keys()
|
||||
reduce_values = torch.stack(tuple(values.values()), dim=0)
|
||||
elif isinstance(values, (tuple, list)):
|
||||
reduce_values = torch.stack(values, dim=0)
|
||||
else:
|
||||
reduce_values = values
|
||||
dist.all_reduce(reduce_values, op=op, group=group)
|
||||
if average:
|
||||
reduce_values /= world_size
|
||||
if isinstance(values, dict):
|
||||
reduce_values = {k: v for k, v in zip(names, reduce_values)}
|
||||
elif isinstance(values, (tuple, list)):
|
||||
reduce_values = type(values)(v for v in reduce_values)
|
||||
return reduce_values
|
||||
|
||||
|
||||
def reduce_sequence_torch(values: TensorSeq, dst_rank=0, op=ReduceOp.SUM, average=False, group=None):
|
||||
"""
|
||||
All reduce the tensors in a sequence (dict, list, tuple)
|
||||
|
||||
Args:
|
||||
values (dict): inputs to be reduced. All the values must be scalar Tensor.
|
||||
average (bool): whether to do average or sum
|
||||
Returns:
|
||||
a sequence with the same type as input (dict, list, tuple)
|
||||
"""
|
||||
world_size = dist.get_world_size(group)
|
||||
this_rank = dist.get_rank()
|
||||
if world_size <= 1:
|
||||
return values
|
||||
|
||||
with torch.no_grad():
|
||||
names = None
|
||||
if isinstance(values, dict):
|
||||
names = values.keys()
|
||||
reduce_values = torch.stack(tuple(values.values()), dim=0)
|
||||
elif isinstance(values, (tuple, list)):
|
||||
reduce_values = torch.stack(values, dim=0)
|
||||
else:
|
||||
reduce_values = values
|
||||
reduce_values = torch.stack(reduce_values, dim=0)
|
||||
dist.reduce(reduce_values, dst=dst_rank, op=op, group=group)
|
||||
if average and this_rank == dst_rank:
|
||||
reduce_values /= world_size
|
||||
if isinstance(values, dict):
|
||||
reduce_values = {k: v for k, v in zip(names, reduce_values)}
|
||||
elif isinstance(values, (tuple, list)):
|
||||
reduce_values = type(values)(v for v in reduce_values)
|
||||
return reduce_values
|
||||
|
||||
|
||||
def all_gather_sequence_torch(values: TensorSeq, group=None, join_fn=torch.cat, join_dim=0):
|
||||
world_size = dist.get_world_size(group)
|
||||
|
||||
def _do_gather(tensor):
|
||||
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
|
||||
dist.all_gather(tensor_list, tensor, group=group)
|
||||
return join_fn(tensor_list, dim=join_dim)
|
||||
|
||||
if isinstance(values, dict):
|
||||
gathered = {k: _do_gather(v) for k, v in values.items()}
|
||||
return gathered
|
||||
elif isinstance(values, (list, tuple)):
|
||||
gathered = type(values)(_do_gather(v) for v in values)
|
||||
return gathered
|
||||
else:
|
||||
# if not a dict, list, tuple, expect a singular tensor
|
||||
assert isinstance(values, torch.Tensor)
|
||||
return _do_gather(values)
|
||||
|
||||
|
||||
def gather_sequence_torch(values: TensorSeq, dst_rank, group=None, join_fn=torch.cat, join_dim=0):
|
||||
world_size = dist.get_world_size(group)
|
||||
this_rank = dist.get_rank(group)
|
||||
|
||||
def _do_gather(tensor):
|
||||
tensor_list = None
|
||||
if this_rank == dst_rank:
|
||||
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
|
||||
dist.gather(tensor, tensor_list, dst=dst_rank, group=group)
|
||||
return join_fn(tensor_list, dim=join_dim)
|
||||
|
||||
if isinstance(values, dict):
|
||||
gathered = {k: _do_gather(v) for k, v in values.items()}
|
||||
return gathered
|
||||
elif isinstance(values, (list, tuple)):
|
||||
gathered = type(values)(_do_gather(v) for v in values)
|
||||
return gathered
|
||||
else:
|
||||
# if not a dict, list, tuple, expect a singular tensor
|
||||
assert isinstance(values, torch.Tensor)
|
||||
return _do_gather(values)
|
||||
|
||||
|
||||
def all_gather_torch(value: TensorSeq, group=None, join_fn: Callable = None, join_dim=0):
|
||||
if isinstance(value, torch.Tensor):
|
||||
world_size = dist.get_world_size(group)
|
||||
out_tensors = [torch.empty_like(value) for _ in range(world_size)]
|
||||
dist.all_gather(out_tensors, value, group=group)
|
||||
if join_fn is not None:
|
||||
out_tensors = join_fn(out_tensors, dim=join_dim)
|
||||
return out_tensors
|
||||
elif isinstance(value, dict):
|
||||
return {k: all_gather_torch(v, group, join_fn, join_dim) for k, v in value.items()}
|
||||
elif isinstance(value, (tuple, list)):
|
||||
return type(value)(all_gather_torch(v, group, join_fn, join_dim) for v in value)
|
||||
|
||||
|
||||
def gather_torch(value: TensorSeq, dst_rank=0, group=None, join_fn: Callable = None, join_dim=0):
|
||||
if isinstance(value, torch.Tensor):
|
||||
world_size = dist.get_world_size(group)
|
||||
this_rank = dist.get_rank()
|
||||
out_tensors = None
|
||||
if this_rank == dst_rank:
|
||||
out_tensors = [torch.empty_like(value) for _ in range(world_size)]
|
||||
dist.gather(value, out_tensors, dst=dst_rank, group=group)
|
||||
if join_fn is not None:
|
||||
out_tensors = join_fn(out_tensors, dim=join_dim)
|
||||
return out_tensors
|
||||
elif isinstance(value, dict):
|
||||
return {k: gather_torch(v, dst_rank, group, join_fn, join_dim) for k, v in value.items()}
|
||||
elif isinstance(value, (tuple, list)):
|
||||
return type(value)(gather_torch(v, dst_rank, group, join_fn, join_dim) for v in value)
|
||||
|
||||
|
||||
def all_reduce_torch(value: TensorSeq, op=ReduceOp.SUM, average=False, group=None):
|
||||
if isinstance(value, torch.Tensor):
|
||||
dist.all_reduce(value, op=op, group=group)
|
||||
if average:
|
||||
value /= dist.get_world_size(group)
|
||||
elif isinstance(value, dict):
|
||||
return {k: all_reduce_torch(v, op=op, average=average, group=group) for k, v in value.items()}
|
||||
elif isinstance(value, (tuple, list)):
|
||||
return type(value)(all_reduce_torch(v, op=op, average=average, group=group) for v in value)
|
||||
|
||||
|
||||
def broadcast_torch(value: TensorSeq, src_rank: int = 0, group=None):
|
||||
if isinstance(value, torch.Tensor):
|
||||
return dist.broadcast(value, src=src_rank, group=group)
|
||||
elif isinstance(value, dict):
|
||||
return {k: broadcast_torch(v, src_rank=src_rank, group=group) for k, v in value.items()}
|
||||
elif isinstance(value, (tuple, list)):
|
||||
return type(value)(broadcast_torch(v, src_rank=src_rank, group=group) for v in value)
|
@ -0,0 +1,142 @@
|
||||
import abc
|
||||
from typing import Callable, Union, Optional, List, Tuple, Dict
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
from .device_env import DeviceEnv
|
||||
from .device_env_factory import get_device
|
||||
from .distributed import all_gather_sequence, all_reduce_sequence
|
||||
|
||||
MetricValue = Union[float, torch.Tensor, List[float], List[torch.Tensor]]
|
||||
|
||||
@dataclass
|
||||
class ValueInfo:
|
||||
initial: Optional[MetricValue] = 0.
|
||||
dtype: torch.dtype = torch.float32
|
||||
dist_reduce: str = 'sum'
|
||||
dist_average: bool = False
|
||||
|
||||
|
||||
class Metric(abc.ABC):
|
||||
|
||||
def __init__(self, dev_env: DeviceEnv = None):
|
||||
self._infos: Dict[str, ValueInfo] = {}
|
||||
self._values: Dict[str, Optional[MetricValue]] = {}
|
||||
self._values_dist: Dict[str, Optional[MetricValue]] = {}
|
||||
if dev_env is None:
|
||||
dev_env = get_device()
|
||||
self._dev_env = dev_env
|
||||
|
||||
def _register_value(self, name: str, info: Optional[ValueInfo] = None):
|
||||
info = info or ValueInfo()
|
||||
self._infos[name] = info
|
||||
|
||||
# def get_value(self, name: str, use_dist=True):
|
||||
# if use_dist:
|
||||
# return self._values_dist.get(name, self._values.get(name))
|
||||
# else:
|
||||
# return self._values.get(name)
|
||||
|
||||
def __getattr__(self, item):
|
||||
if item not in self._infos:
|
||||
raise AttributeError
|
||||
value = self._values_dist.get(item, self._values.get(item, None))
|
||||
return value
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if '_infos' in self.__dict__ and key in self._infos:
|
||||
self._values[key] = value
|
||||
else:
|
||||
super().__setattr__(key, value)
|
||||
|
||||
def update(
|
||||
self,
|
||||
predictions: Union[torch.Tensor, Dict[str, torch.Tensor]],
|
||||
target: Union[torch.Tensor, Dict[str, torch.Tensor]]):
|
||||
self._update(predictions, target)
|
||||
|
||||
def _update(
|
||||
self,
|
||||
predictions: Union[torch.Tensor, Dict[str, torch.Tensor]],
|
||||
target: Union[torch.Tensor, Dict[str, torch.Tensor]]):
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
self._values = {}
|
||||
self._values_dist = {}
|
||||
for name, info in self._infos.items():
|
||||
# if info specifies an initial value, we reset here, otherwise set to None and leave it to child class
|
||||
if info.initial is not None:
|
||||
if isinstance(info.initial, torch.Tensor):
|
||||
tensor = info.initial.detach().clone()
|
||||
else:
|
||||
tensor = torch.ones([], dtype=info.dtype) * info.initial # scalar
|
||||
self._values[name] = tensor.to(device=self._dev_env.device, dtype=info.dtype)
|
||||
else:
|
||||
self._values[name] = None
|
||||
self._reset()
|
||||
|
||||
def _reset(self):
|
||||
pass
|
||||
|
||||
def compute(self) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[str, torch.Tensor]]:
|
||||
if self._dev_env.distributed:
|
||||
self._distribute_values()
|
||||
results = self._compute()
|
||||
self._values_dist = {}
|
||||
return results
|
||||
|
||||
@abc.abstractmethod
|
||||
def _compute(self) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[str, torch.Tensor]]:
|
||||
pass
|
||||
|
||||
def _distribute_values(self):
|
||||
if not self._infos or not self._values:
|
||||
return
|
||||
|
||||
def _args(op: str):
|
||||
if op == 'cat':
|
||||
return True, dict(cat_dim=0)
|
||||
else:
|
||||
return False, dict(op=ReduceOp.SUM)
|
||||
|
||||
prev_dsr = None
|
||||
same_dsr = True
|
||||
names = []
|
||||
values = []
|
||||
reductions = []
|
||||
for name, value in self._values.items():
|
||||
if value is not None:
|
||||
info = self._infos[name]
|
||||
dsr = (value.dtype, value.shape, info.dist_reduce)
|
||||
if prev_dsr is not None and prev_dsr != dsr:
|
||||
same_dsr = False
|
||||
prev_dsr = dsr
|
||||
names.append(name)
|
||||
values.append(value)
|
||||
reductions.append(_args(info.dist_reduce))
|
||||
same_dsr = False
|
||||
if same_dsr:
|
||||
do_gather, reduce_kwargs = reductions[0]
|
||||
if do_gather:
|
||||
reduced_values = all_gather_sequence(values, dev_env=self._dev_env, **reduce_kwargs)
|
||||
else:
|
||||
reduced_values = all_reduce_sequence(values, dev_env=self._dev_env, **reduce_kwargs)
|
||||
for name, reduced_value in zip(names, reduced_values):
|
||||
info = self._infos[name]
|
||||
if info.dist_average:
|
||||
reduced_value /= self._dev_env.world_size
|
||||
self._values_dist[name] = reduced_value
|
||||
else:
|
||||
for n, v, r in zip(names, values, reductions):
|
||||
info = self._infos[n]
|
||||
do_gather, reduce_kwargs = r
|
||||
if do_gather:
|
||||
reduced_value = self._dev_env.all_gather(v, **reduce_kwargs)
|
||||
else:
|
||||
reduced_value = self._dev_env.all_reduce(v, **reduce_kwargs)
|
||||
if info.dist_average:
|
||||
reduced_value /= self._dev_env.world_size
|
||||
self._values_dist[n] = reduced_value
|
@ -0,0 +1,98 @@
|
||||
import torch
|
||||
from typing import Optional, Tuple, Dict
|
||||
|
||||
from .device_env import DeviceEnv
|
||||
from .metric import Metric, ValueInfo
|
||||
|
||||
|
||||
class Accuracy(Metric):
|
||||
|
||||
def __init__(self, threshold=0.5, multi_label=False, dev_env=None):
|
||||
super().__init__(dev_env=dev_env)
|
||||
self.threshold = threshold
|
||||
self.eps = 1e-8
|
||||
self.multi_label = multi_label
|
||||
|
||||
# statistics / counts
|
||||
self._register_value('correct')
|
||||
self._register_value('total')
|
||||
|
||||
def _update(self, predictions, target):
|
||||
raise NotImplemented()
|
||||
|
||||
def _compute(self):
|
||||
raise NotImplemented()
|
||||
|
||||
|
||||
# class AccuracyTopK(torch.nn.Module):
|
||||
#
|
||||
# def __init__(self, topk=(1, 5), device=None):
|
||||
# super().__init__()
|
||||
# self.eps = 1e-8
|
||||
# self.device = device
|
||||
# self.topk = topk
|
||||
# self.maxk = max(topk)
|
||||
# # FIXME handle distributed operation
|
||||
#
|
||||
# # statistics / counts
|
||||
# self.reset()
|
||||
#
|
||||
# def update(self, predictions: torch.Tensor, target: torch.Tensor):
|
||||
# sorted_indices = predictions.topk(self.maxk, dim=1)[1]
|
||||
# sorted_indices.t_()
|
||||
# correct = sorted_indices.eq(target.reshape(1, -1).expand_as(sorted_indices))
|
||||
#
|
||||
# batch_size = target.shape[0]
|
||||
# correct_k = {k: correct[:k].reshape(-1).float().sum(0) for k in self.topk}
|
||||
# for k, v in correct_k.items():
|
||||
# attr = f'_correct_top{k}'
|
||||
# old_v = getattr(self, attr)
|
||||
# setattr(self, attr, old_v + v)
|
||||
# self._total_sum += batch_size
|
||||
#
|
||||
# def reset(self):
|
||||
# for k in self.topk:
|
||||
# setattr(self, f'_correct_top{k}', torch.tensor(0, dtype=torch.float32))
|
||||
# self._total_sum = torch.tensor(0, dtype=torch.float32)
|
||||
#
|
||||
# @property
|
||||
# def counts(self):
|
||||
# pass
|
||||
#
|
||||
# def compute(self) -> Dict[str, torch.Tensor]:
|
||||
# # FIXME handle distributed reduction
|
||||
# return {f'top{k}': 100 * getattr(self, f'_correct_top{k}') / self._total_sum for k in self.topk}
|
||||
|
||||
|
||||
class AccuracyTopK(Metric):
|
||||
|
||||
def __init__(self, topk=(1, 5), dev_env: DeviceEnv = None):
|
||||
super().__init__(dev_env=dev_env)
|
||||
self.eps = 1e-8
|
||||
self.topk = topk
|
||||
self.maxk = max(topk)
|
||||
|
||||
# statistics / counts
|
||||
for k in self.topk:
|
||||
self._register_value(f'top{k}')
|
||||
self._register_value('total')
|
||||
self.reset()
|
||||
|
||||
def _update(self, predictions: torch.Tensor, target: torch.Tensor):
|
||||
batch_size = predictions.shape[0]
|
||||
sorted_indices = predictions.topk(self.maxk, dim=1)[1]
|
||||
target_reshape = target.reshape(-1, 1).expand_as(sorted_indices)
|
||||
correct = sorted_indices.eq(target_reshape).float().sum(0)
|
||||
for k in self.topk:
|
||||
attr_name = f'top{k}'
|
||||
correct_at_k = correct[:k].sum()
|
||||
setattr(self, attr_name, getattr(self, attr_name) + correct_at_k)
|
||||
self.total += batch_size
|
||||
|
||||
def _compute(self) -> Dict[str, torch.Tensor]:
|
||||
assert self.total is not None
|
||||
output = {}
|
||||
for k in self.topk:
|
||||
attr_name = f'top{k}'
|
||||
output[attr_name] = 100 * getattr(self, attr_name) / self.total
|
||||
return output
|
@ -0,0 +1,12 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainCfg:
|
||||
""" Train Loop Configuration
|
||||
Dataclass to propagate training configuration values
|
||||
"""
|
||||
num_epochs: int = 0
|
||||
log_interval: int = 50
|
||||
recovery_interval: int = 0
|
||||
accumulate_steps: int = 0
|
@ -0,0 +1,13 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .logger import Logger
|
||||
from timm.utils.checkpoint_saver import CheckpointSaver
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainServices:
|
||||
""" Train Loop Services
|
||||
"""
|
||||
logger: Logger = None
|
||||
saver: CheckpointSaver = None
|
||||
|
@ -0,0 +1,153 @@
|
||||
import dataclasses
|
||||
from typing import Callable, Union, Optional
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from timm.optim import create_optimizer_v2
|
||||
from timm.utils import ModelEmaV2
|
||||
|
||||
try:
|
||||
import deepspeed as ds
|
||||
except ImportError:
|
||||
ds = None
|
||||
|
||||
from .checkpoint import resume_train_checkpoint
|
||||
from .device_env import DeviceEnv
|
||||
from .train_cfg import TrainCfg
|
||||
from .train_state import TrainState
|
||||
from .updater_factory import create_updater
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def setup_model_and_optimizer(
|
||||
dev_env: DeviceEnv,
|
||||
model: nn.Module,
|
||||
optimizer: Union[Callable, str],
|
||||
optimizer_cfg,
|
||||
clip_fn: Optional[Union[Callable, str]] = None,
|
||||
clip_value: Optional[float] = None,
|
||||
model_ema: bool = False,
|
||||
model_ema_decay: float = 0.9999,
|
||||
use_syncbn: bool = False,
|
||||
resume_path: str = '',
|
||||
resume_opt: bool = True,
|
||||
deepspeed: bool = False,
|
||||
):
|
||||
"""
|
||||
|
||||
Args:
|
||||
dev_env:
|
||||
model:
|
||||
optimizer:
|
||||
optimizer_cfg:
|
||||
clip_value:
|
||||
clip_fn:
|
||||
model_ema:
|
||||
model_ema_decay:
|
||||
use_syncbn:
|
||||
resume_path:
|
||||
resume_opt:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
if deepspeed:
|
||||
return setup_model_and_optimizer_deepspeed(
|
||||
dev_env=dev_env, model=model, optimizer=optimizer, optimizer_cfg=optimizer_cfg,
|
||||
clip_fn=clip_fn, clip_value=clip_value, model_ema=model_ema, model_ema_decay=model_ema_decay,
|
||||
resume_path=resume_path, resume_opt=resume_opt,
|
||||
)
|
||||
|
||||
dev_env.to_device(model)
|
||||
|
||||
if use_syncbn and dev_env.distributed:
|
||||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||
if dev_env.primary:
|
||||
_logger.info(
|
||||
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
|
||||
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
|
||||
|
||||
if isinstance(optimizer, Callable):
|
||||
optimizer = optimizer(model=model, **optimizer_cfg)
|
||||
else:
|
||||
optimizer = create_optimizer_v2(model=model, **optimizer_cfg)
|
||||
|
||||
updater = create_updater(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
clip_fn=clip_fn,
|
||||
clip_value=clip_value,
|
||||
)
|
||||
|
||||
# ema model
|
||||
model_ema = ModelEmaV2(model, decay=model_ema_decay) if model_ema else None
|
||||
|
||||
train_state = TrainState(model=model, updater=updater, model_ema=model_ema)
|
||||
|
||||
if resume_path:
|
||||
resume_train_checkpoint(
|
||||
train_state,
|
||||
resume_path,
|
||||
resume_opt=resume_opt,
|
||||
log_info=dev_env.primary)
|
||||
|
||||
if dev_env.distributed:
|
||||
train_state = dataclasses.replace(
|
||||
train_state, model=dev_env.wrap_distributed(train_state.model))
|
||||
|
||||
return train_state
|
||||
|
||||
|
||||
def setup_model_and_optimizer_deepspeed(
|
||||
dev_env: DeviceEnv,
|
||||
model: nn.Module,
|
||||
optimizer: Union[Callable, str],
|
||||
optimizer_cfg,
|
||||
clip_fn: Optional[Union[Callable, str]] = None,
|
||||
clip_value: Optional[float] = None,
|
||||
model_ema: bool = False,
|
||||
model_ema_decay: float = 0.9999,
|
||||
use_syncbn: bool = False,
|
||||
resume_path: str = '',
|
||||
resume_opt: bool = True,
|
||||
):
|
||||
dev_env.to_device(model)
|
||||
|
||||
if isinstance(optimizer, Callable):
|
||||
optimizer = optimizer(model=model, **optimizer_cfg)
|
||||
else:
|
||||
optimizer = create_optimizer_v2(model=model, **optimizer_cfg)
|
||||
|
||||
model = ds.initialize(model=model, optimizer=optimizer, dist_init_required=False)
|
||||
|
||||
updater = create_updater(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
clip_fn=clip_fn,
|
||||
clip_value=clip_value,
|
||||
deepspeed=True,
|
||||
)
|
||||
|
||||
# ema model
|
||||
# FIXME how to do EMA w/ deepspeed?
|
||||
model_ema = ModelEmaV2(model, decay=model_ema_decay) if model_ema else None
|
||||
|
||||
train_state = TrainState(model=model, updater=updater, model_ema=model_ema)
|
||||
|
||||
if resume_path:
|
||||
# FIXME deepspeed resumes differently
|
||||
resume_train_checkpoint(
|
||||
train_state,
|
||||
resume_path,
|
||||
resume_opt=resume_opt,
|
||||
log_info=dev_env.primary)
|
||||
|
||||
if dev_env.distributed:
|
||||
train_state = dataclasses.replace(
|
||||
train_state, model=dev_env.wrap_distributed(train_state.model))
|
||||
|
||||
return train_state
|
@ -0,0 +1,33 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Any
|
||||
|
||||
from torch import nn as nn
|
||||
|
||||
from timm.scheduler import Scheduler
|
||||
from .updater import Updater
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainState:
|
||||
model: nn.Module = None
|
||||
train_loss: nn.Module = None
|
||||
eval_loss: nn.Module = None
|
||||
updater: Updater = None
|
||||
lr_scheduler: Scheduler = None
|
||||
model_ema: nn.Module = None
|
||||
|
||||
step_count_epoch: int = 0
|
||||
step_count_global: int = 0
|
||||
epoch: int = 0
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.model is not None
|
||||
assert self.updater is not None
|
||||
|
||||
|
||||
def serialize_train_state(train_state: TrainState):
|
||||
pass
|
||||
|
||||
|
||||
def deserialize_train_state(checkpoint: Dict[str, Any]):
|
||||
pass
|
@ -1,54 +1,68 @@
|
||||
from dataclasses import dataclass, field, InitVar
|
||||
from functools import partial
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .grad_clipper import GradClipper
|
||||
from .grad_clip import get_clip_grad_fn, get_clip_parameters
|
||||
|
||||
|
||||
@dataclass
|
||||
class Updater:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
clip_value: Optional[Union[Callable, float]] = None,
|
||||
clip_mode: str = 'norm'):
|
||||
|
||||
self.optimizer = optimizer
|
||||
self.clipper: Optional[GradClipper] = None
|
||||
if clip_value is not None:
|
||||
if isinstance(clip_value, Callable):
|
||||
self.clipper = clip_value
|
||||
model: nn.Module = None
|
||||
optimizer: torch.optim.Optimizer = None # FIXME handle multiple optimizers per-model
|
||||
clip_fn: Optional[Union[Callable, str]] = None
|
||||
clip_value: Optional[float] = None
|
||||
clip_params_fn: Optional[Callable] = None
|
||||
grad_scaler: Optional[Callable] = None
|
||||
create_graph: Optional[bool] = None
|
||||
after_step_closure: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.model is not None
|
||||
assert self.optimizer is not None
|
||||
if self.clip_fn is not None:
|
||||
if isinstance(self.clip_fn, Callable):
|
||||
skip_last = 0
|
||||
else:
|
||||
GradClipper(clip_value, clip_mode)
|
||||
self.scaler = None
|
||||
self.create_graph = getattr(self.optimizer, 'second_order', False)
|
||||
self.num_accumulated = 0
|
||||
assert isinstance(self.clip_fn, str)
|
||||
skip_last = 2 if 'agc' in self.clip_fn else 0
|
||||
self.clip_fn = get_clip_grad_fn(self.clip_fn)
|
||||
assert self.clip_value is not None
|
||||
self.clip_params_fn = partial(get_clip_parameters, model=self.model, skip_last=skip_last)
|
||||
if self.create_graph is None:
|
||||
self.create_graph = getattr(self.optimizer, 'second_order', False)
|
||||
self.after_step_closure = False
|
||||
|
||||
def reset(self):
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
def apply(self, loss: torch.Tensor, accumulate=False):
|
||||
loss.backward(create_graph=self.create_graph)
|
||||
if self.clipper is not None:
|
||||
self.clipper()
|
||||
if not accumulate:
|
||||
self.optimizer.step()
|
||||
self.reset()
|
||||
else:
|
||||
self.num_accumulated += 1
|
||||
if accumulate:
|
||||
return
|
||||
if self.clip_fn is not None:
|
||||
self.clip_fn(self.clip_params_fn(), self.clip_value)
|
||||
self.optimizer.step()
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.optimizer.zero_grad()
|
||||
self.num_accumulated = 0
|
||||
def get_average_lr(self):
|
||||
lrl = [param_group['lr'] for param_group in self.optimizer.param_groups if param_group['lr'] > 0]
|
||||
return sum(lrl) / len(lrl)
|
||||
|
||||
def state_dict(self):
|
||||
state_dict = dict(optimizer=self.optimizer.state_dict())
|
||||
if self.scaler is not None:
|
||||
state_dict['scaler'] = self.scaler.state_dict()
|
||||
if self.grad_scaler is not None:
|
||||
state_dict['grad_scaler'] = self.grad_scaler.state_dict()
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
if 'optimizer' in state_dict:
|
||||
self.optimizer.load_state_dict(state_dict['optimizer'])
|
||||
if 'scaler' in state_dict and self.scaler is not None:
|
||||
self.scaler.load_state_dict(state_dict['scaler'])
|
||||
|
||||
if 'grad_scaler' in state_dict and self.grad_scaler is not None:
|
||||
self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
|
||||
|
||||
def after_step(self, after_step_fn, *args):
|
||||
after_step_fn(*args)
|
||||
|
||||
|
@ -1,36 +1,30 @@
|
||||
from typing import Callable, Optional, Union, Any
|
||||
from dataclasses import dataclass, field, InitVar
|
||||
from typing import Dict, Any
|
||||
|
||||
import torch
|
||||
|
||||
from .updater import Updater
|
||||
|
||||
|
||||
class UpdaterCuda(Updater):
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
clip_value: Optional[Union[Callable, float]] = None,
|
||||
clip_mode: str = 'norm',
|
||||
use_scaler: bool = False,
|
||||
scaler_kwargs: Any = None,
|
||||
):
|
||||
super().__init__(optimizer=optimizer, clip_value=clip_value, clip_mode=clip_mode)
|
||||
@dataclass
|
||||
class UpdaterCudaWithScaler(Updater):
|
||||
|
||||
scaler_kwargs: InitVar[Dict[str, Any]] = None
|
||||
|
||||
def __post_init__(self, scaler_kwargs: Dict[str, Any]):
|
||||
super().__post_init__()
|
||||
scaler_kwargs = scaler_kwargs or {}
|
||||
if use_scaler:
|
||||
self.scaler = torch.cuda.amp.GradScaler(**scaler_kwargs)
|
||||
self.grad_scaler = torch.cuda.amp.GradScaler(**scaler_kwargs)
|
||||
|
||||
def apply(self, loss: torch.Tensor, accumulate=False):
|
||||
if self.scaler is not None:
|
||||
self.scaler.scale(loss).backward(create_graph=self.create_graph)
|
||||
if self.clipper is not None:
|
||||
self.scaler.unscale_(self.optimizer) # unscale the gradients of optimizer's assigned params in-place
|
||||
self.clipper()
|
||||
if not accumulate:
|
||||
self.scaler.step(self.optimizer)
|
||||
self.reset()
|
||||
else:
|
||||
self.num_accumulated += 1
|
||||
self.scaler.update()
|
||||
else:
|
||||
Updater.apply(self, loss, accumulate)
|
||||
|
||||
self.grad_scaler.scale(loss).backward(create_graph=self.create_graph)
|
||||
if accumulate:
|
||||
# unscale first?
|
||||
return
|
||||
if self.clip_fn is not None:
|
||||
# unscale the gradients of optimizer's assigned params in-place
|
||||
self.grad_scaler.unscale_(self.optimizer)
|
||||
self.clip_fn(self.clip_params_fn(), self.clip_value)
|
||||
self.grad_scaler.step(self.optimizer)
|
||||
self.grad_scaler.update()
|
||||
self.reset()
|
||||
|
@ -0,0 +1,26 @@
|
||||
from dataclasses import dataclass, field, InitVar
|
||||
|
||||
import torch
|
||||
try:
|
||||
import deepspeed as ds
|
||||
except ImportError as e:
|
||||
ds = None
|
||||
|
||||
from .updater import Updater
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdaterDeepSpeed(Updater):
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
# FIXME not sure how to deal with model.module / grad clipping w/ DS engine interface
|
||||
assert isinstance(self.model, ds.DeepSpeedEngine)
|
||||
|
||||
def reset(self):
|
||||
self.model.zero_grad()
|
||||
|
||||
def apply(self, loss: torch.Tensor, accumulate=False):
|
||||
self.model.backward(loss)
|
||||
self.model.step()
|
||||
self.reset()
|
@ -1,4 +0,0 @@
|
||||
from .accuracy import Accuracy, AccuracyTopK
|
||||
from .precision_recall import PrecisionRecall
|
||||
from .scalar_avg import ScalarAvgMinMax
|
||||
from .tensor_avg import TensorAvg, TensorEma
|
@ -1,114 +0,0 @@
|
||||
import torch
|
||||
from typing import Optional, Tuple, Dict
|
||||
|
||||
|
||||
class Accuracy(torch.nn.Module):
|
||||
|
||||
def __init__(self, threshold=0.5, multi_label=False):
|
||||
self.threshold = threshold
|
||||
self.eps = 1e-8
|
||||
self.multi_label = multi_label
|
||||
|
||||
# statistics / counts
|
||||
self._correct_sum = torch.tensor(0, dtype=torch.long)
|
||||
self._total_sum = torch.tensor(0, dtype=torch.long)
|
||||
|
||||
def update(self, predictions, target):
|
||||
raise NotImplemented()
|
||||
|
||||
def reset(self):
|
||||
self._correct_sum = 0
|
||||
self._total_sum = 0
|
||||
|
||||
@property
|
||||
def counts(self):
|
||||
pass
|
||||
|
||||
def compute(self):
|
||||
raise NotImplemented()
|
||||
|
||||
|
||||
class AccuracyTopK(torch.nn.Module):
|
||||
|
||||
def __init__(self, topk=(1, 5), device=None):
|
||||
super().__init__()
|
||||
self.eps = 1e-8
|
||||
self.device = device
|
||||
self.topk = topk
|
||||
self.maxk = max(topk)
|
||||
# FIXME handle distributed operation
|
||||
|
||||
# statistics / counts
|
||||
self.reset()
|
||||
|
||||
def update(self, predictions: torch.Tensor, target: torch.Tensor):
|
||||
sorted_indices = predictions.topk(self.maxk, dim=1)[1]
|
||||
sorted_indices.t_()
|
||||
correct = sorted_indices.eq(target.reshape(1, -1).expand_as(sorted_indices))
|
||||
|
||||
batch_size = target.shape[0]
|
||||
correct_k = {k: correct[:k].reshape(-1).float().sum(0) for k in self.topk}
|
||||
for k, v in correct_k.items():
|
||||
attr = f'_correct_top{k}'
|
||||
old_v = getattr(self, attr)
|
||||
setattr(self, attr, old_v + v)
|
||||
self._total_sum += batch_size
|
||||
|
||||
def reset(self):
|
||||
for k in self.topk:
|
||||
setattr(self, f'_correct_top{k}', torch.tensor(0, dtype=torch.float32))
|
||||
self._total_sum = torch.tensor(0, dtype=torch.float32)
|
||||
|
||||
@property
|
||||
def counts(self):
|
||||
pass
|
||||
|
||||
def compute(self) -> Dict[str, torch.Tensor]:
|
||||
# FIXME handle distributed reduction
|
||||
return {f'top{k}': 100 * getattr(self, f'_correct_top{k}') / self._total_sum for k in self.topk}
|
||||
|
||||
|
||||
#
|
||||
# class AccuracyTopK:
|
||||
#
|
||||
# def __init__(self, topk=(1, 5), device=None):
|
||||
# self.eps = 1e-8
|
||||
# self.device = device
|
||||
# self.topk = topk
|
||||
# self.maxk = max(topk)
|
||||
#
|
||||
# # statistics / counts
|
||||
# self._correct_sum = None
|
||||
# self._total_sum = None
|
||||
#
|
||||
# def _check_init(self, device):
|
||||
# to_device = self.device if self.device else device
|
||||
# if self._correct_sum is None:
|
||||
# self._correct_sum = {f'top{k}': torch.tensor(0., device=to_device) for k in self.topk}
|
||||
# if self._total_sum is None:
|
||||
# self._total_sum = torch.tensor(0, dtype=torch.long, device=to_device)
|
||||
#
|
||||
# def update(self, predictions: torch.Tensor, target: torch.Tensor):
|
||||
# sorted_indices = predictions.topk(self.maxk, dim=1)[1]
|
||||
# sorted_indices.t_()
|
||||
# correct = sorted_indices.eq(target.reshape(1, -1).expand_as(sorted_indices))
|
||||
#
|
||||
# batch_size = target.shape[0]
|
||||
# correct_k = {f'top{k}': correct[:k].reshape(-1).float().sum(0) for k in self.topk}
|
||||
# self._check_init(device=predictions.device)
|
||||
# for k, v in correct_k.items():
|
||||
# old_v = self._correct_sum[k]
|
||||
# self._correct_sum[k] = old_v + v
|
||||
# self._total_sum += batch_size
|
||||
#
|
||||
# def reset(self):
|
||||
# self._correct_sum = None
|
||||
# self._total_sum = None
|
||||
#
|
||||
# @property
|
||||
# def counts(self):
|
||||
# pass
|
||||
#
|
||||
# def compute(self) -> Dict[str, torch.Tensor]:
|
||||
# assert self._correct_sum is not None and self._total_sum is not None
|
||||
# return {k: 100 * v / self._total_sum for k, v in self._correct_sum.items()}
|
Loading…
Reference in new issue