Add proper TrainState checkpoint save/load. Some reorg/refactoring and other cleanup. More to go...

pull/1239/head
Ross Wightman 3 years ago
parent 5b9c69e80a
commit 91ab0b6ce5

@ -1,14 +1,15 @@
from .avg_scalar import AvgMinMaxScalar
from .avg_tensor import AvgTensor
from .device_env import DeviceEnv, DeviceEnvType
from .checkpoint_manager import CheckpointManager
from .device_env import DeviceEnv, DeviceEnvType, get_global_device, set_global_device, is_global_device
from .device_env_cuda import DeviceEnvCuda
from .device_env_factory import initialize_device, get_device
from .device_env_factory import initialize_device
from .device_env_xla import DeviceEnvXla
from .distributed import distribute_bn, all_gather_recursive, all_reduce_recursive, broadcast_recursive,\
all_reduce_sequence, all_gather_sequence
# from .evaluate import evaluate, eval_step
from .logger import Logger
from .metric import Metric, MetricValue
from .monitor import Monitor
from .metric import Metric, MetricValueT
from .metric_accuracy import AccuracyTopK
from .tracker import Tracker
# from .task_metrics import TaskMetrics, TaskMetricsClassify

@ -1,17 +1,73 @@
import logging
import os
from collections import OrderedDict
from typing import Dict, Any, Callable
import torch
from timm.utils import unwrap_model
from .train_state import TrainState, serialize_train_state, deserialize_train_state
from .device_env import DeviceEnv
from .train_state import TrainState
_logger = logging.getLogger(__name__)
def _load_state_dict(checkpoint, state_dict_key='state_dict'):
def save_train_state(
checkpoint_path: str, # FIXME pass base path + file pattern + epoch / step separately for DS?
train_state: TrainState,
extra_state: Dict[str, Any] = None,
unwrap_fn: Callable = unwrap_model,
dev_env: DeviceEnv = None,
log_info: bool = True):
assert not train_state.updater.deepspeed
# DeepSpeed has a fully custom checkpoint saving setup, it is not possible
# specify a filename, checkpoints needed to be saved from all ranks, etc
# if train_state.updater.deepspeed:
# save_train_state_deepspeed(train_state, checkpoint_path)
dev_env = dev_env or DeviceEnv.instance()
state_dict = train_state.state_dict(unwrap_fn=unwrap_fn)
if extra_state:
state_dict.update(extra_state)
if dev_env.type_xla:
# XLA state dict needs to be moved to CPU before save, this is normally done by xm.save
state_dict = dev_env.state_dict_to_cpu(state_dict)
torch.save(state_dict, checkpoint_path)
def load_train_state(
train_state: TrainState,
checkpoint_path: str, # FIXME pass base path + file pattern + epoch / step separately for DS
unwrap_fn: Callable = None,
load_opt: bool = True,
dev_env: DeviceEnv = None,
log_info: bool = True
):
unwrap_fn = unwrap_fn or unwrap_model
if not os.path.isfile(checkpoint_path):
_logger.error("No valid resume checkpoint found at '{}'".format(checkpoint_path))
raise FileNotFoundError()
if log_info:
_logger.info('Restoring training state from checkpoint...')
checkpoint = torch.load(checkpoint_path, map_location='cpu')
assert isinstance(checkpoint, dict)
if not checkpoint.get('version', 0) > 2:
load_legacy_checkpoint(train_state, checkpoint=checkpoint, load_opt=load_opt, log_info=log_info)
if log_info:
_logger.info("Loaded legacy checkpoint '{}' (epoch {})".format(checkpoint_path, train_state.epoch))
return
train_state.load_state_dict(checkpoint, unwrap_fn=unwrap_fn)
if log_info:
_logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, train_state.epoch))
def _get_state_dict(checkpoint, state_dict_key='state_dict'):
new_state_dict = OrderedDict()
for k, v in checkpoint[state_dict_key].items():
name = k[7:] if k.startswith('module') else k
@ -19,48 +75,35 @@ def _load_state_dict(checkpoint, state_dict_key='state_dict'):
return new_state_dict
def resume_train_checkpoint(
def load_legacy_checkpoint(
train_state: TrainState,
checkpoint_path,
resume_opt=True,
deserialize_fn=deserialize_train_state,
checkpoint,
load_opt=True,
log_info=True):
# FIXME this is a hacky adaptation of pre-bits resume to get up and running quickly
resume_epoch = None
if os.path.isfile(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location='cpu')
assert isinstance(checkpoint, dict) and 'state_dict' in checkpoint
assert isinstance(checkpoint, dict) and 'state_dict' in checkpoint
train_state.model.load_state_dict(_get_state_dict(checkpoint))
if train_state.model_ema is not None and 'state_dict_ema' in checkpoint:
if log_info:
_logger.info('Restoring model state from checkpoint...')
_logger.info('Restoring model (EMA) state from checkpoint...')
unwrap_model(train_state.model_ema).load_state_dict(_get_state_dict(checkpoint, 'state_dict_ema'))
train_state.model.load_state_dict(_load_state_dict(checkpoint))
if load_opt:
if train_state.updater.optimizer is not None and 'optimizer' in checkpoint:
if log_info:
_logger.info('Restoring optimizer state from checkpoint...')
train_state.updater.optimizer.load_state_dict(checkpoint['optimizer'])
if train_state.model_ema is not None and 'state_dict_ema' in checkpoint:
scaler_state_dict_key = 'amp_scaler'
if train_state.updater.grad_scaler is not None and scaler_state_dict_key in checkpoint:
if log_info:
_logger.info('Restoring model (EMA) state from checkpoint...')
unwrap_model(train_state.model_ema).load_state_dict(_load_state_dict(checkpoint, 'state_dict_ema'))
if resume_opt:
if train_state.updater.optimizer is not None and 'optimizer' in checkpoint:
if log_info:
_logger.info('Restoring optimizer state from checkpoint...')
train_state.updater.optimizer.load_state_dict(checkpoint['optimizer'])
scaler_state_dict_key = 'amp_scaler'
if train_state.updater.grad_scaler is not None and scaler_state_dict_key in checkpoint:
if log_info:
_logger.info('Restoring AMP loss scaler state from checkpoint...')
train_state.updater.grad_scaler.load_state_dict(checkpoint[scaler_state_dict_key])
if 'epoch' in checkpoint:
resume_epoch = checkpoint['epoch']
if 'version' in checkpoint and checkpoint['version'] > 1:
resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
train_state.epoch = resume_epoch # FIXME use replace if we make train_state read-only
_logger.info('Restoring AMP loss scaler state from checkpoint...')
train_state.updater.grad_scaler.load_state_dict(checkpoint[scaler_state_dict_key])
if 'epoch' in checkpoint:
resume_epoch = checkpoint['epoch']
if 'version' in checkpoint and checkpoint['version'] > 1:
resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
train_state.epoch = resume_epoch # FIXME use replace if we make train_state read-only
if log_info:
_logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
else:
_logger.error("No valid resume checkpoint found at '{}'".format(checkpoint_path))
raise FileNotFoundError()

@ -0,0 +1,219 @@
""" Checkpoint Manager
Track top-n training checkpoints and maintain recovery checkpoints on specified intervals.
Hacked together by / Copyright 2021 Ross Wightman
"""
import glob
import logging
import operator
import os
import shutil
from typing import Optional, Dict, Callable, List
from dataclasses import dataclass, replace
from .checkpoint import save_train_state
from .train_state import TrainState
_logger = logging.getLogger(__name__)
@dataclass
class CheckpointInfo:
path: str = ''
metrics: Dict[str, float] = None # all metrics at time of checkpoint save
metric_name: str = 'loss'
metric_decreasing: bool = True
epoch: int = 0
global_step: int = 0
@property
def valid_key(self):
return self.metric_name and self.metrics and self.metric_name in self.metrics
@property
def sort_key(self):
return self.metrics[self.metric_name] if self.valid_key else self.epoch
@property
def decreasing_key(self):
return self.metric_decreasing if self.valid_key else False
class CheckpointManager:
def __init__(
self,
hparams=None,
save_state_fn=None,
checkpoint_dir='',
recovery_dir='',
checkpoint_tmpl=None,
recovery_tmpl=None,
metric_name='loss',
metric_decreasing=True,
max_history=10):
# extra items to include in checkpoint
self.hparams = hparams # train arguments (config / hparams) # FIXME this will change with new config system
# state
self.checkpoint_files: List[CheckpointInfo] = [] # (filename, metric) tuples in order of decreasing betterness
self.best_checkpoint = None
self.curr_recovery_file = ''
self.prev_recovery_file = ''
self.can_hardlink = True
# util / helper fn
self.save_state_fn = save_state_fn or save_train_state
# file / folder config
self.extension = '.pth.tar'
self.checkpoint_dir = checkpoint_dir
self.recovery_dir = recovery_dir
self.checkpoint_tmpl = (checkpoint_tmpl or 'checkpoint-{index}') + self.extension
self.recovery_tmpl = (recovery_tmpl or 'recovery-{index}') + self.extension
# ordering / history config
self.metric_name = metric_name
self.metric_decreasing = metric_decreasing
self.metric_cmp_fn = operator.lt if metric_decreasing else operator.gt
self.max_history = max_history
assert self.max_history >= 1
def _replace(self, src, dst):
if self.can_hardlink:
try:
if os.path.exists(dst):
os.unlink(dst) # required for Windows support.
except Exception as e:
self.can_hardlink = False
os.replace(src, dst)
def _duplicate(self, src, dst):
if self.can_hardlink:
try:
if os.path.exists(dst):
# for Windows
os.unlink(dst)
os.link(src, dst)
return
except Exception as e:
self.can_hardlink = False
shutil.copy2(src, dst)
def _save(self, save_path, train_state: TrainState, metrics: Optional[Dict[str, float]] = None):
extra_state = dict(
# version < 2 increments epoch before save
# version < 3, pre timm bits
# version 3, first timm bits checkpoitns
version=3,
)
if self.hparams is not None:
extra_state.update(dict(arch=self.hparams['model'], hparams=self.hparams))
else:
arch = getattr(train_state.model, 'default_cfg', dict()).get('architecture', None)
if arch is None:
arch = type(train_state.model).__name__.lower()
extra_state.update(dict(arch=arch))
if metrics is not None:
# save the metrics and how we originally sorted them in the checkpoint for future comparisons
extra_state.update(dict(
metrics=metrics,
metric_name=self.metric_name,
metric_decreasing=self.metric_decreasing
))
self.save_state_fn(save_path, train_state, extra_state)
checkpoint_info = CheckpointInfo(
path=save_path,
metrics=metrics,
metric_name=self.metric_name,
metric_decreasing=self.metric_decreasing,
epoch=train_state.epoch,
global_step=train_state.step_count_global,
)
return checkpoint_info
def _udpate_checkpoints(self, info: CheckpointInfo):
self.checkpoint_files.append(info)
self.checkpoint_files = sorted(
self.checkpoint_files,
key=lambda x: x.sort_key,
reverse=not info.decreasing_key, # sort in descending order if a lower metric is not better
)
def _cleanup_checkpoints(self, trim=0):
trim = min(len(self.checkpoint_files), trim)
delete_index = self.max_history - trim
if delete_index < 0 or len(self.checkpoint_files) <= delete_index:
return
to_delete = self.checkpoint_files[delete_index:]
for d in to_delete:
try:
_logger.debug("Cleaning checkpoint: {}".format(d))
os.remove(d[0])
except Exception as e:
_logger.error("Exception '{}' while deleting checkpoint".format(e))
self.checkpoint_files = self.checkpoint_files[:delete_index]
def _compare_metric(self, lhs: CheckpointInfo, rhs: CheckpointInfo):
# compare metrics against an existing checkpoint
if not lhs or not lhs.valid_key or not rhs or not rhs.valid_key:
# always assume lhs metrics are better if there are no usable metrics to compare
return True
return self.metric_cmp_fn(lhs.sort_key, rhs.sort_key)
def save_checkpoint(self, train_state: TrainState, metrics: Optional[Dict[str, float]] = None):
assert train_state.epoch >= 0
tmp_save_path = os.path.join(self.checkpoint_dir, 'tmp' + self.extension)
last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension)
curr_checkpoint = self._save(tmp_save_path, train_state, metrics)
self._replace(tmp_save_path, last_save_path)
worst_checkpoint = self.checkpoint_files[-1] if self.checkpoint_files else None
if len(self.checkpoint_files) < self.max_history or self._compare_metric(curr_checkpoint, worst_checkpoint):
if len(self.checkpoint_files) >= self.max_history:
self._cleanup_checkpoints(1)
filename = self.checkpoint_tmpl.format(index=train_state.epoch)
save_path = os.path.join(self.checkpoint_dir, filename)
curr_checkpoint = replace(curr_checkpoint, path=save_path)
self._duplicate(last_save_path, save_path)
self._udpate_checkpoints(curr_checkpoint)
checkpoints_str = "Current checkpoints:\n"
for c in self.checkpoint_files:
checkpoints_str += f' {c.path}, {c.sort_key}\n'.format(c)
_logger.info(checkpoints_str)
if curr_checkpoint.valid_key and self._compare_metric(curr_checkpoint, self.best_checkpoint):
self.best_checkpoint = curr_checkpoint
best_save_path = os.path.join(self.checkpoint_dir, 'best' + self.extension)
self._duplicate(last_save_path, best_save_path)
return None if self.best_checkpoint is None else curr_checkpoint
def save_recovery(self, train_state: TrainState):
tmp_save_path = os.path.join(self.recovery_dir, 'recovery_tmp' + self.extension)
self._save(tmp_save_path, train_state)
filename = self.recovery_tmpl.format(index=train_state.step_count_global)
save_path = os.path.join(self.recovery_dir, filename)
self._replace(tmp_save_path, save_path)
if os.path.exists(self.prev_recovery_file):
try:
_logger.debug("Cleaning recovery: {}".format(self.prev_recovery_file))
os.remove(self.prev_recovery_file)
except Exception as e:
_logger.error("Exception '{}' while removing {}".format(e, self.prev_recovery_file))
self.prev_recovery_file = self.curr_recovery_file
self.curr_recovery_file = save_path
def find_recovery(self):
recovery_path = os.path.join(self.recovery_dir, self.recovery_prefix)
files = glob.glob(recovery_path + '*' + self.extension)
files = sorted(files)
return files[0] if len(files) else ''

@ -1,7 +1,7 @@
import abc
from contextlib import suppress
from enum import Enum
from typing import Callable, Union, Optional, List, Tuple
from typing import Callable, Union, Optional, List, Tuple, Dict, Any
from dataclasses import dataclass, field, InitVar
import torch
@ -18,10 +18,21 @@ class DeviceEnvType(Enum):
XLA = "xla"
def state_dict_apply(state_dict: Dict[str, Any], apply_fn, select_fn=lambda x: x.isinstance(torch.Tensor)):
out_dict = {}
for k, v in state_dict.items():
if isinstance(v, dict):
out_dict[k] = state_dict_apply(v, apply_fn, select_fn)
else:
out_dict[k] = apply_fn(v) if select_fn(v) else v
return out_dict
@dataclass
class DeviceEnv:
device_type: InitVar[Optional[str]] = None
device_index: InitVar[Optional[int]] = None
channels_last: InitVar[bool] = False
device: torch.device = field(init=False) # set from device_type + device_index or post_init logic
world_size: Optional[int] = None # set by post_init from env when None
@ -32,7 +43,12 @@ class DeviceEnv:
memory_format: Optional[torch.memory_format] = None
dtype: Optional[torch.dtype] = None
def __post_init__(self, device_type: Optional[str], device_index: Optional[int]):
def __post_init__(
self,
device_type: Optional[str],
device_index: Optional[int],
channels_last: bool,
):
device_type = device_type or 'cpu'
self.device = torch.device(device_type) if device_index is None \
else torch.device(device_type, device_index)
@ -41,6 +57,17 @@ class DeviceEnv:
self.global_rank = 0 if self.global_rank is None else self.global_rank
if self.autocast is None:
self.autocast = suppress
if channels_last:
self.memory_format = torch.channels_last
@staticmethod
def is_instance():
return is_global_device()
@staticmethod
def instance():
# throws if called before global device is set / initialized
return get_global_device()
@property
def type(self) -> DeviceEnvType:
@ -81,11 +108,23 @@ class DeviceEnv:
def wrap_parallel(self, *modules):
pass
def to_cpu(self, *modules: torch.nn.Module):
moved = [m.cpu() for m in modules]
return moved[0] if len(moved) == 1 else moved
def to_device(self, *modules: torch.nn.Module):
# FIXME handling dtype / memformat... disable flags, enable flags, diff fn?
# FIXME handling dtype? Do we want separate dtype for data vs model?
moved = [m.to(device=self.device, memory_format=self.memory_format) for m in modules]
return moved[0] if len(moved) == 1 else moved
def state_dict_to_cpu(self, state: Dict[str, Any]):
cpu_state = state_dict_apply(state, apply_fn=lambda x: x.cpu())
return cpu_state
def state_dict_to_device(self, state: Dict[str, Any]):
cpu_state = state_dict_apply(state, apply_fn=lambda x: x.to(self.device))
return cpu_state
def mark_step(self):
pass # NO-OP for non-XLA devices
@ -126,3 +165,24 @@ class DeviceEnv:
def barrier(self):
dist.barrier()
# Global device environment singleton instance
_global_device_env: Optional[DeviceEnv] = None
def is_global_device():
return _global_device_env is not None
def get_global_device() -> DeviceEnv:
if not is_global_device():
raise RuntimeError('Please initialize device environment by calling initialize_device / set_global_device.')
return _global_device_env
def set_global_device(device: DeviceEnv):
global _global_device_env
if _global_device_env is not None:
raise RuntimeError('Global device is already set, it should NOT be set again.')
_global_device_env = device

@ -16,7 +16,12 @@ def is_cuda_available():
@dataclass
class DeviceEnvCuda(DeviceEnv):
def __post_init__(self, device_type: str, device_index: Optional[int]):
def __post_init__(
self,
device_type: Optional[str],
device_index: Optional[int],
channels_last: bool,
):
assert torch.cuda.device_count()
torch.backends.cudnn.benchmark = True
setup_world_size = self.world_size or int(os.environ.get('WORLD_SIZE', 1))
@ -43,6 +48,8 @@ class DeviceEnvCuda(DeviceEnv):
self.global_rank = 0
if self.autocast is None:
self.autocast = torch.cuda.amp.autocast if self.amp else suppress
if channels_last:
self.memory_format = torch.channels_last
@property
def type(self) -> DeviceEnvType:

@ -1,15 +1,15 @@
from .device_env import DeviceEnv
import logging
from .device_env import DeviceEnv, is_global_device, get_global_device, set_global_device
from .device_env_cuda import DeviceEnvCuda, is_cuda_available
from .device_env_xla import DeviceEnvXla, is_xla_available
_device_env = None
_logger = logging.getLogger(__name__)
def initialize_device(force_cpu: bool = False, **kwargs) -> DeviceEnv:
global _device_env
if _device_env is not None:
# warning
return _device_env
if is_global_device():
return get_global_device()
denv = None
if not force_cpu:
@ -23,14 +23,10 @@ def initialize_device(force_cpu: bool = False, **kwargs) -> DeviceEnv:
if denv is None:
denv = DeviceEnv()
print(denv) # FIXME DEBUG
_device_env = denv
return denv
def get_device() -> DeviceEnv:
if _device_env is None:
raise RuntimeError('Please initialize device environment by calling initialize_device first.')
return _device_env
_logger.info(f'Initialized device {denv.device}. '
f'Rank: {denv.global_rank} ({denv.local_rank}) of {denv.world_size}.')
print(denv) # FIXME temporary print for debugging
set_global_device(denv)
return denv

@ -1,7 +1,7 @@
import os
from contextlib import suppress
from dataclasses import dataclass, field, InitVar
from typing import Optional
from typing import Optional, Dict
import torch
from torch.distributed import ReduceOp
@ -42,7 +42,12 @@ def is_xla_available(xla_device_type=None):
@dataclass
class DeviceEnvXla(DeviceEnv):
def __post_init__(self, device_type: Optional[str], device_idx: Optional[int]):
def __post_init__(
self,
device_type: Optional[str],
device_idx: Optional[int],
channels_last: bool,
):
if device_type is not None:
device_type = device_type.upper()
assert device_type in ('TPU', 'GPU', 'CPU'), "XLA device type must be one of ('TPU', 'GPU', 'CPU')"
@ -59,6 +64,8 @@ class DeviceEnvXla(DeviceEnv):
assert xa is not None, 'XLA AMP is not present on this build'
if self.autocast is None:
self.autocast = xa.autocast if self.amp else suppress
if channels_last:
self.memory_format = torch.channels_last
@property
def type(self) -> DeviceEnvType:
@ -114,3 +121,11 @@ class DeviceEnvXla(DeviceEnv):
def barrier(self):
xm.rendezvous('timm.bits.dist_barrier')
def state_dict_to_cpu(self, state: Dict[str, torch.Tensor]):
cpu_state = xm._maybe_convert_to_cpu(state, convert=True)
return cpu_state
def state_dict_to_device(self, state: Dict[str, torch.Tensor]):
device_state = xm.send_cpu_data_to_device(state, device=self.device)
return device_state

@ -5,8 +5,7 @@ from torch.distributed import ReduceOp
from timm.utils import unwrap_model
from .device_env import DeviceEnv, DeviceEnvType
from .device_env_factory import get_device
from .device_env import DeviceEnv
TensorSeq = Union[torch.Tensor, Tuple[torch.Tensor, ...], List[torch.Tensor], Dict[Any, torch.Tensor]]
@ -22,7 +21,7 @@ def _validate_type(tensor: TensorSeq):
def distribute_bn(model: torch.nn.Module, reduce: bool = False, dev_env: DeviceEnv = None):
if dev_env is None:
dev_env = get_device()
dev_env = DeviceEnv.instance()
# ensure every node has the same running bn stats
for bn_name, bn_buf in unwrap_model(model).named_buffers(recurse=True):
if ('running_mean' in bn_name) or ('running_var' in bn_name):
@ -40,7 +39,7 @@ def all_gather_recursive(tensor: TensorSeq, cat_dim=0, dev_env: DeviceEnv = None
"""
_validate_type(tensor)
if dev_env is None:
dev_env = get_device()
dev_env = DeviceEnv.instance()
if isinstance(tensor, torch.Tensor):
return dev_env.all_gather(tensor, cat_dim=cat_dim)
elif isinstance(tensor, dict):
@ -55,7 +54,7 @@ def all_reduce_recursive(tensor: TensorSeq, op=ReduceOp.SUM, average=False, dev_
"""
_validate_type(tensor)
if dev_env is None:
dev_env = get_device()
dev_env = DeviceEnv.instance()
if isinstance(tensor, torch.Tensor):
return dev_env.all_reduce_(tensor, op=op, average=average)
elif isinstance(tensor, dict):
@ -70,7 +69,7 @@ def broadcast_recursive(tensor: TensorSeq, src_rank: int, dev_env: DeviceEnv = N
"""
_validate_type(tensor)
if dev_env is None:
dev_env = get_device()
dev_env = DeviceEnv.instance()
if isinstance(tensor, torch.Tensor):
return dev_env.broadcast_(tensor, src_rank=src_rank)
elif isinstance(tensor, dict):
@ -85,7 +84,7 @@ def all_gather_sequence(tensor: TensorSeq, cat_dim: int = 0, dev_env: DeviceEnv
"""
_validate_type(tensor)
if dev_env is None:
dev_env = get_device()
dev_env = DeviceEnv.instance()
with torch.no_grad():
names = None
@ -124,7 +123,7 @@ def all_reduce_sequence(tensor: TensorSeq, op=ReduceOp.SUM, average=False, dev_e
"""
_validate_type(tensor)
if dev_env is None:
dev_env = get_device()
dev_env = DeviceEnv.instance()
with torch.no_grad():
names = None

@ -6,14 +6,13 @@ import torch
from torch.distributed import ReduceOp
from .device_env import DeviceEnv
from .device_env_factory import get_device
from .distributed import all_gather_sequence, all_reduce_sequence
MetricValue = Union[float, torch.Tensor, List[float], List[torch.Tensor]]
MetricValueT = Union[float, torch.Tensor, List[float], List[torch.Tensor]]
@dataclass
class ValueInfo:
initial: Optional[MetricValue] = 0.
initial: Optional[MetricValueT] = 0.
dtype: torch.dtype = torch.float32
dist_reduce: str = 'sum'
dist_average: bool = False
@ -23,10 +22,10 @@ class Metric(abc.ABC):
def __init__(self, dev_env: DeviceEnv = None):
self._infos: Dict[str, ValueInfo] = {}
self._values: Dict[str, Optional[MetricValue]] = {}
self._values_dist: Dict[str, Optional[MetricValue]] = {}
self._values: Dict[str, Optional[MetricValueT]] = {}
self._values_dist: Dict[str, Optional[MetricValueT]] = {}
if dev_env is None:
dev_env = get_device()
dev_env = DeviceEnv.instance()
self._dev_env = dev_env
def _register_value(self, name: str, info: Optional[ValueInfo] = None):
@ -117,7 +116,7 @@ class Metric(abc.ABC):
names.append(name)
values.append(value)
reductions.append(_args(info.dist_reduce))
same_dsr = False
if same_dsr:
do_gather, reduce_kwargs = reductions[0]
if do_gather:

@ -21,8 +21,6 @@ except ImportError:
HAS_WANDB = False
from .device_env_factory import get_device
# FIXME old formatting for reference, to remove
#
# def log_eval(batch_idx, last_idx, batch_time, loss, top1, top5, log_suffix=''):
@ -122,19 +120,19 @@ def _add_kwargs(text_update, name_map=None, **kwargs):
text_update += [_to_str(name, v)]
class Logger:
class Monitor:
def __init__(
self,
experiment_name=None,
output_dir=None,
python_logger=None,
logger=None,
hparams=None,
log_wandb=False,
output_enabled=True,
):
self.output_dir = output_dir # for tensorboard, csv, text file (TODO) logging
self.logger = python_logger or logging.getLogger('log')
self.logger = logger or logging.getLogger('log')
hparams = hparams or {}
# Setup CSV writer(s)

@ -1,13 +1,13 @@
from dataclasses import dataclass
from .logger import Logger
from timm.utils.checkpoint_saver import CheckpointSaver
from .monitor import Monitor
from .checkpoint_manager import CheckpointManager
@dataclass
class TrainServices:
""" Train Loop Services
"""
logger: Logger = None
saver: CheckpointSaver = None
logger: Monitor = None
checkpoint_manager: CheckpointManager = None

@ -13,7 +13,7 @@ try:
except ImportError:
ds = None
from .checkpoint import resume_train_checkpoint
from .checkpoint import load_train_state
from .device_env import DeviceEnv
from .train_cfg import TrainCfg
from .train_state import TrainState
@ -90,10 +90,10 @@ def setup_model_and_optimizer(
if resume_path:
# FIXME this is not implemented yet, do a hack job before proper TrainState serialization?
resume_train_checkpoint(
load_train_state(
train_state,
resume_path,
resume_opt=resume_opt,
load_opt=resume_opt,
log_info=dev_env.primary)
if dev_env.distributed:
@ -141,10 +141,10 @@ def setup_model_and_optimizer_deepspeed(
if resume_path:
# FIXME deepspeed resumes differently
resume_train_checkpoint(
load_legacy_checkpoint(
train_state,
resume_path,
resume_opt=resume_opt,
load_opt=resume_opt,
log_info=dev_env.primary)
if dev_env.distributed:

@ -4,6 +4,8 @@ from typing import Dict, Any
from torch import nn as nn
from timm.scheduler import Scheduler
from timm.utils import get_state_dict, unwrap_model
from .updater import Updater
@ -16,18 +18,33 @@ class TrainState:
lr_scheduler: Scheduler = None
model_ema: nn.Module = None
step_count_epoch: int = 0
step_count_global: int = 0
epoch: int = 0
step_count: int = 0
step_count_global: int = 0
def __post_init__(self):
assert self.model is not None
assert self.updater is not None
def serialize_train_state(train_state: TrainState):
pass
def deserialize_train_state(checkpoint: Dict[str, Any]):
pass
def state_dict(self, unwrap_fn=unwrap_model):
state = dict(
epoch=self.epoch,
step_count=self.step_count,
step_count_global=self.step_count_global,
model=get_state_dict(self.model, unwrap_fn),
model_ema=None if self.model_ema is None else get_state_dict(self.model_ema, unwrap_fn),
)
# FIXME lr_scheduler state save?
state.update(self.updater.state_dict())
return state
def load_state_dict(self, state_dict, unwrap_fn=unwrap_model):
self.epoch = state_dict['epoch']
self.step_count = state_dict['step_count']
self.step_count_global = state_dict['step_count_global']
unwrap_fn(self.model).load_state_dict(state_dict.get('model'))
if 'model_ema' in state_dict and self.model_ema is not None:
unwrap_fn(self.model_ema).load_state_dict(state_dict.get('model_ema'))
self.updater.load_state_dict(state_dict)

@ -56,6 +56,7 @@ class Updater:
state_dict = dict(optimizer=self.optimizer.state_dict())
if self.grad_scaler is not None:
state_dict['grad_scaler'] = self.grad_scaler.state_dict()
return state_dict
def load_state_dict(self, state_dict):
if 'optimizer' in state_dict:
@ -66,3 +67,6 @@ class Updater:
def after_step(self, after_step_fn, *args):
after_step_fn(*args)
@property
def deepspeed(self):
return False

@ -24,3 +24,7 @@ class UpdaterDeepSpeed(Updater):
self.model.backward(loss)
self.model.step()
self.reset()
@property
def deepspeed(self):
return True

@ -3,7 +3,6 @@ from typing import Callable, Optional, Union, Any
import torch
from .device_env import DeviceEnv, DeviceEnvType
from .device_env_factory import get_device
from .updater import Updater
from .updater_cuda import UpdaterCudaWithScaler
from .updater_deepspeed import UpdaterDeepSpeed
@ -21,7 +20,7 @@ def create_updater(
) -> Updater:
if not dev_env:
dev_env = get_device()
dev_env = DeviceEnv.instance()
updater_kwargs = dict(model=model, optimizer=optimizer, clip_fn=clip_fn, clip_value=clip_value)
use_scaler = dev_env.amp

@ -8,7 +8,7 @@ Hacked together by / Copyright 2020 Ross Wightman
import torch.utils.data
from timm.bits import get_device, DeviceEnvType
from timm.bits import DeviceEnv
from .fetcher import Fetcher
from .prefetcher_cuda import PrefetcherCuda
@ -75,7 +75,7 @@ def create_loader(
)
if dev_env is None:
dev_env = get_device()
dev_env = DeviceEnv.instance()
sampler = None
if dev_env.distributed and not isinstance(dataset, torch.utils.data.IterableDataset):

@ -23,7 +23,7 @@ except ImportError as e:
exit(1)
from .parser import Parser
from timm.bits import get_device
from timm.bits import get_global_device
MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities
SHUFFLE_SIZE = 16834 # samples to shuffle in DS queue
@ -80,7 +80,7 @@ class ParserTfds(Parser):
self.worker_info = None
self.dist_rank = 0
self.dist_num_replicas = 1
dev_env = get_device()
dev_env = get_global_device()
# FIXME allow to work without devenv usage?
if dev_env.distributed and dev_env.world_size > 1:
self.dist_rank = dev_env.global_rank

@ -3,33 +3,38 @@
Hacked together by / Copyright 2020 Ross Wightman
"""
from .model_ema import ModelEma
import torch
import torch
import fnmatch
def unwrap_model(model):
if isinstance(model, ModelEma):
return unwrap_model(model.ema)
else:
return model.module if hasattr(model, 'module') else model
_SUB_MODULE_ATTR = ('module', 'model')
def unwrap_model(model, recursive=True):
for attr in _SUB_MODULE_ATTR:
sub_module = getattr(model, attr, None)
if sub_module is not None:
return unwrap_model(sub_module) if recursive else sub_module
return model
def get_state_dict(model, unwrap_fn=unwrap_model):
return unwrap_fn(model).state_dict()
def avg_sq_ch_mean(model, input, output):
"calculate average channel square mean of output activations"
return torch.mean(output.mean(axis=[0,2,3])**2).item()
def avg_sq_ch_mean(model, input, output):
"""calculate average channel square mean of output activations
"""
return torch.mean(output.mean(axis=[0, 2, 3]) ** 2).item()
def avg_ch_var(model, input, output):
"calculate average channel variance of output activations"
return torch.mean(output.var(axis=[0,2,3])).item()\
def avg_ch_var(model, input, output):
"""calculate average channel variance of output activations"""
return torch.mean(output.var(axis=[0, 2, 3])).item()
def avg_ch_var_residual(model, input, output):
"calculate average channel variance of output activations"
return torch.mean(output.var(axis=[0,2,3])).item()
def avg_ch_var_residual(model, input, output):
"""calculate average channel variance of output activations"""
return torch.mean(output.var(axis=[0, 2, 3])).item()
class ActivationStatsHook:
@ -58,15 +63,16 @@ class ActivationStatsHook:
raise ValueError("Please provide `hook_fns` for each `hook_fn_locs`, \
their lengths are different.")
self.stats = dict((hook_fn.__name__, []) for hook_fn in hook_fns)
for hook_fn_loc, hook_fn in zip(hook_fn_locs, hook_fns):
for hook_fn_loc, hook_fn in zip(hook_fn_locs, hook_fns):
self.register_hook(hook_fn_loc, hook_fn)
def _create_hook(self, hook_fn):
def append_activation_stats(module, input, output):
out = hook_fn(module, input, output)
self.stats[hook_fn.__name__].append(out)
return append_activation_stats
def register_hook(self, hook_fn_loc, hook_fn):
for name, module in self.model.named_modules():
if not fnmatch.fnmatch(name, hook_fn_loc):
@ -74,9 +80,9 @@ class ActivationStatsHook:
module.register_forward_hook(self._create_hook(hook_fn))
def extract_spp_stats(model,
def extract_spp_stats(model,
hook_fn_locs,
hook_fns,
hook_fns,
input_shape=[8, 3, 224, 224]):
"""Extract average square channel mean and variance of activations during
forward pass to plot Signal Propogation Plots (SPP).
@ -84,9 +90,8 @@ def extract_spp_stats(model,
Paper: https://arxiv.org/abs/2101.08692
Example Usage: https://gist.github.com/amaarora/6e56942fcb46e67ba203f3009b30d950
"""
"""
x = torch.normal(0., 1., input_shape)
hook = ActivationStatsHook(model, hook_fn_locs=hook_fn_locs, hook_fns=hook_fns)
_ = model(x)
return hook.stats

@ -28,14 +28,14 @@ import torch
import torch.nn as nn
import torchvision.utils
from timm.bits import initialize_device, setup_model_and_optimizer, DeviceEnv, Logger, Tracker,\
TrainState, TrainServices, TrainCfg, AccuracyTopK, AvgTensor, distribute_bn
from timm.bits import initialize_device, setup_model_and_optimizer, DeviceEnv, Monitor, Tracker,\
TrainState, TrainServices, TrainCfg, CheckpointManager, AccuracyTopK, AvgTensor, distribute_bn
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.models import create_model, safe_model_name, convert_splitbn_model
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
from timm.optim import create_optimizer_v2, optimizer_kwargs
from timm.optim import optimizer_kwargs
from timm.scheduler import create_scheduler
from timm.utils import setup_default_logging, random_seed, get_outdir, CheckpointSaver
from timm.utils import setup_default_logging, random_seed, get_outdir, unwrap_model
_logger = logging.getLogger('train')
@ -276,7 +276,7 @@ def main():
setup_default_logging()
args, args_text = _parse_args()
dev_env = initialize_device(amp=args.amp)
dev_env = initialize_device(amp=args.amp, channels_last=args.channels_last)
if dev_env.distributed:
_logger.info('Training in distributed mode with multiple processes, 1 device per process. Process %d, total %d.'
% (dev_env.global_rank, dev_env.world_size))
@ -293,13 +293,17 @@ def main():
# FIXME perhaps keep the same and just set diff seeds for dataloader worker process? what about TFDS?
random_seed(args.seed, dev_env.global_rank)
data_config, loader_eval, loader_train = setup_data(args, train_state.model.default_cfg, dev_env, mixup_active)
data_config, loader_eval, loader_train = setup_data(
args,
unwrap_model(train_state.model).default_cfg,
dev_env,
mixup_active)
# setup checkpoint saver
# setup checkpoint manager
eval_metric = args.eval_metric
best_metric = None
best_epoch = None
saver = None
checkpoint_manager = None
output_dir = None
if dev_env.primary:
if args.experiment:
@ -311,24 +315,20 @@ def main():
str(data_config['input_size'][-1])
])
output_dir = get_outdir(args.output if args.output else './output/train', exp_name)
decreasing = True if eval_metric == 'loss' else False
saver = CheckpointSaver( # TODO CheckpointSaverV2
model=train_state.model,
optimizer=train_state.updater.optimizer,
args=args,
model_ema=train_state.model_ema,
amp_scaler=train_state.updater.grad_scaler,
checkpoint_manager = CheckpointManager(
hparams=vars(args),
checkpoint_dir=output_dir,
recovery_dir=output_dir,
decreasing=decreasing,
metric_name=eval_metric,
metric_decreasing=True if eval_metric == 'loss' else False,
max_history=args.checkpoint_hist)
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
f.write(args_text)
services = TrainServices(
logger=Logger(
output_dir=output_dir, python_logger=_logger, hparams=vars(args), output_enabled=dev_env.primary),
saver=saver,
logger=Monitor(
output_dir=output_dir, logger=_logger, hparams=vars(args), output_enabled=dev_env.primary),
checkpoint_manager=checkpoint_manager,
)
try:
@ -379,10 +379,10 @@ def main():
if services.logger is not None:
services.logger.write_summary(index=epoch, results=dict(train=train_metrics, eval=eval_metrics))
if saver is not None:
if checkpoint_manager is not None:
# save proper checkpoint with eval metric
save_metric = eval_metrics[eval_metric]
best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)
best_checkpoint = checkpoint_manager.save_checkpoint(train_state, eval_metrics)
best_metric, best_epoch = best_checkpoint.sort_key, best_checkpoint.epoch
train_state = replace(train_state, epoch=epoch + 1)
@ -629,9 +629,9 @@ def after_train_step(
lr=lr_avg,
)
if services.saver is not None and cfg.recovery_interval and (
if services.checkpoint_manager is not None and cfg.recovery_interval and (
end_step or (step_idx + 1) % cfg.recovery_interval == 0):
services.saver.save_recovery(state.epoch, batch_idx=step_idx)
services.checkpoint_manager.save_recovery(state.epoch, batch_idx=step_idx)
if state.lr_scheduler is not None:
state.lr_scheduler.step_update(num_updates=state.step_count_global)
@ -641,7 +641,7 @@ def evaluate(
model: nn.Module,
loss_fn: nn.Module,
loader,
logger: Logger,
logger: Monitor,
dev_env: DeviceEnv,
phase_suffix: str = '',
log_interval: int = 10,

@ -18,7 +18,7 @@ import torch.nn as nn
import torch.nn.parallel
from collections import OrderedDict
from timm.bits import initialize_device, Tracker, Logger, AccuracyTopK, AvgTensor
from timm.bits import initialize_device, Tracker, Monitor, AccuracyTopK, AvgTensor
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet
from timm.utils import natural_key, setup_default_logging
@ -154,7 +154,7 @@ def validate(args):
pin_memory=args.pin_mem,
tf_preprocessing=args.tf_preprocessing)
logger = Logger(python_logger=_logger)
logger = Monitor(logger=_logger)
tracker = Tracker()
losses = AvgTensor()
accuracy = AccuracyTopK(dev_env=dev_env)

Loading…
Cancel
Save