Add proper TrainState checkpoint save/load. Some reorg/refactoring and other cleanup. More to go...

pull/1239/head
Ross Wightman 3 years ago
parent 5b9c69e80a
commit 91ab0b6ce5

@ -1,14 +1,15 @@
from .avg_scalar import AvgMinMaxScalar from .avg_scalar import AvgMinMaxScalar
from .avg_tensor import AvgTensor from .avg_tensor import AvgTensor
from .device_env import DeviceEnv, DeviceEnvType from .checkpoint_manager import CheckpointManager
from .device_env import DeviceEnv, DeviceEnvType, get_global_device, set_global_device, is_global_device
from .device_env_cuda import DeviceEnvCuda from .device_env_cuda import DeviceEnvCuda
from .device_env_factory import initialize_device, get_device from .device_env_factory import initialize_device
from .device_env_xla import DeviceEnvXla from .device_env_xla import DeviceEnvXla
from .distributed import distribute_bn, all_gather_recursive, all_reduce_recursive, broadcast_recursive,\ from .distributed import distribute_bn, all_gather_recursive, all_reduce_recursive, broadcast_recursive,\
all_reduce_sequence, all_gather_sequence all_reduce_sequence, all_gather_sequence
# from .evaluate import evaluate, eval_step # from .evaluate import evaluate, eval_step
from .logger import Logger from .monitor import Monitor
from .metric import Metric, MetricValue from .metric import Metric, MetricValueT
from .metric_accuracy import AccuracyTopK from .metric_accuracy import AccuracyTopK
from .tracker import Tracker from .tracker import Tracker
# from .task_metrics import TaskMetrics, TaskMetricsClassify # from .task_metrics import TaskMetrics, TaskMetricsClassify

@ -1,17 +1,73 @@
import logging import logging
import os import os
from collections import OrderedDict from collections import OrderedDict
from typing import Dict, Any, Callable
import torch import torch
from timm.utils import unwrap_model from timm.utils import unwrap_model
from .train_state import TrainState, serialize_train_state, deserialize_train_state from .device_env import DeviceEnv
from .train_state import TrainState
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
def _load_state_dict(checkpoint, state_dict_key='state_dict'): def save_train_state(
checkpoint_path: str, # FIXME pass base path + file pattern + epoch / step separately for DS?
train_state: TrainState,
extra_state: Dict[str, Any] = None,
unwrap_fn: Callable = unwrap_model,
dev_env: DeviceEnv = None,
log_info: bool = True):
assert not train_state.updater.deepspeed
# DeepSpeed has a fully custom checkpoint saving setup, it is not possible
# specify a filename, checkpoints needed to be saved from all ranks, etc
# if train_state.updater.deepspeed:
# save_train_state_deepspeed(train_state, checkpoint_path)
dev_env = dev_env or DeviceEnv.instance()
state_dict = train_state.state_dict(unwrap_fn=unwrap_fn)
if extra_state:
state_dict.update(extra_state)
if dev_env.type_xla:
# XLA state dict needs to be moved to CPU before save, this is normally done by xm.save
state_dict = dev_env.state_dict_to_cpu(state_dict)
torch.save(state_dict, checkpoint_path)
def load_train_state(
train_state: TrainState,
checkpoint_path: str, # FIXME pass base path + file pattern + epoch / step separately for DS
unwrap_fn: Callable = None,
load_opt: bool = True,
dev_env: DeviceEnv = None,
log_info: bool = True
):
unwrap_fn = unwrap_fn or unwrap_model
if not os.path.isfile(checkpoint_path):
_logger.error("No valid resume checkpoint found at '{}'".format(checkpoint_path))
raise FileNotFoundError()
if log_info:
_logger.info('Restoring training state from checkpoint...')
checkpoint = torch.load(checkpoint_path, map_location='cpu')
assert isinstance(checkpoint, dict)
if not checkpoint.get('version', 0) > 2:
load_legacy_checkpoint(train_state, checkpoint=checkpoint, load_opt=load_opt, log_info=log_info)
if log_info:
_logger.info("Loaded legacy checkpoint '{}' (epoch {})".format(checkpoint_path, train_state.epoch))
return
train_state.load_state_dict(checkpoint, unwrap_fn=unwrap_fn)
if log_info:
_logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, train_state.epoch))
def _get_state_dict(checkpoint, state_dict_key='state_dict'):
new_state_dict = OrderedDict() new_state_dict = OrderedDict()
for k, v in checkpoint[state_dict_key].items(): for k, v in checkpoint[state_dict_key].items():
name = k[7:] if k.startswith('module') else k name = k[7:] if k.startswith('module') else k
@ -19,48 +75,35 @@ def _load_state_dict(checkpoint, state_dict_key='state_dict'):
return new_state_dict return new_state_dict
def resume_train_checkpoint( def load_legacy_checkpoint(
train_state: TrainState, train_state: TrainState,
checkpoint_path, checkpoint,
resume_opt=True, load_opt=True,
deserialize_fn=deserialize_train_state,
log_info=True): log_info=True):
# FIXME this is a hacky adaptation of pre-bits resume to get up and running quickly assert isinstance(checkpoint, dict) and 'state_dict' in checkpoint
resume_epoch = None train_state.model.load_state_dict(_get_state_dict(checkpoint))
if os.path.isfile(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location='cpu') if train_state.model_ema is not None and 'state_dict_ema' in checkpoint:
assert isinstance(checkpoint, dict) and 'state_dict' in checkpoint
if log_info: if log_info:
_logger.info('Restoring model state from checkpoint...') _logger.info('Restoring model (EMA) state from checkpoint...')
unwrap_model(train_state.model_ema).load_state_dict(_get_state_dict(checkpoint, 'state_dict_ema'))
train_state.model.load_state_dict(_load_state_dict(checkpoint)) if load_opt:
if train_state.updater.optimizer is not None and 'optimizer' in checkpoint:
if log_info:
_logger.info('Restoring optimizer state from checkpoint...')
train_state.updater.optimizer.load_state_dict(checkpoint['optimizer'])
if train_state.model_ema is not None and 'state_dict_ema' in checkpoint: scaler_state_dict_key = 'amp_scaler'
if train_state.updater.grad_scaler is not None and scaler_state_dict_key in checkpoint:
if log_info: if log_info:
_logger.info('Restoring model (EMA) state from checkpoint...') _logger.info('Restoring AMP loss scaler state from checkpoint...')
unwrap_model(train_state.model_ema).load_state_dict(_load_state_dict(checkpoint, 'state_dict_ema')) train_state.updater.grad_scaler.load_state_dict(checkpoint[scaler_state_dict_key])
if resume_opt: if 'epoch' in checkpoint:
if train_state.updater.optimizer is not None and 'optimizer' in checkpoint: resume_epoch = checkpoint['epoch']
if log_info: if 'version' in checkpoint and checkpoint['version'] > 1:
_logger.info('Restoring optimizer state from checkpoint...') resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
train_state.updater.optimizer.load_state_dict(checkpoint['optimizer']) train_state.epoch = resume_epoch # FIXME use replace if we make train_state read-only
scaler_state_dict_key = 'amp_scaler'
if train_state.updater.grad_scaler is not None and scaler_state_dict_key in checkpoint:
if log_info:
_logger.info('Restoring AMP loss scaler state from checkpoint...')
train_state.updater.grad_scaler.load_state_dict(checkpoint[scaler_state_dict_key])
if 'epoch' in checkpoint:
resume_epoch = checkpoint['epoch']
if 'version' in checkpoint and checkpoint['version'] > 1:
resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
train_state.epoch = resume_epoch # FIXME use replace if we make train_state read-only
if log_info:
_logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
else:
_logger.error("No valid resume checkpoint found at '{}'".format(checkpoint_path))
raise FileNotFoundError()

@ -0,0 +1,219 @@
""" Checkpoint Manager
Track top-n training checkpoints and maintain recovery checkpoints on specified intervals.
Hacked together by / Copyright 2021 Ross Wightman
"""
import glob
import logging
import operator
import os
import shutil
from typing import Optional, Dict, Callable, List
from dataclasses import dataclass, replace
from .checkpoint import save_train_state
from .train_state import TrainState
_logger = logging.getLogger(__name__)
@dataclass
class CheckpointInfo:
path: str = ''
metrics: Dict[str, float] = None # all metrics at time of checkpoint save
metric_name: str = 'loss'
metric_decreasing: bool = True
epoch: int = 0
global_step: int = 0
@property
def valid_key(self):
return self.metric_name and self.metrics and self.metric_name in self.metrics
@property
def sort_key(self):
return self.metrics[self.metric_name] if self.valid_key else self.epoch
@property
def decreasing_key(self):
return self.metric_decreasing if self.valid_key else False
class CheckpointManager:
def __init__(
self,
hparams=None,
save_state_fn=None,
checkpoint_dir='',
recovery_dir='',
checkpoint_tmpl=None,
recovery_tmpl=None,
metric_name='loss',
metric_decreasing=True,
max_history=10):
# extra items to include in checkpoint
self.hparams = hparams # train arguments (config / hparams) # FIXME this will change with new config system
# state
self.checkpoint_files: List[CheckpointInfo] = [] # (filename, metric) tuples in order of decreasing betterness
self.best_checkpoint = None
self.curr_recovery_file = ''
self.prev_recovery_file = ''
self.can_hardlink = True
# util / helper fn
self.save_state_fn = save_state_fn or save_train_state
# file / folder config
self.extension = '.pth.tar'
self.checkpoint_dir = checkpoint_dir
self.recovery_dir = recovery_dir
self.checkpoint_tmpl = (checkpoint_tmpl or 'checkpoint-{index}') + self.extension
self.recovery_tmpl = (recovery_tmpl or 'recovery-{index}') + self.extension
# ordering / history config
self.metric_name = metric_name
self.metric_decreasing = metric_decreasing
self.metric_cmp_fn = operator.lt if metric_decreasing else operator.gt
self.max_history = max_history
assert self.max_history >= 1
def _replace(self, src, dst):
if self.can_hardlink:
try:
if os.path.exists(dst):
os.unlink(dst) # required for Windows support.
except Exception as e:
self.can_hardlink = False
os.replace(src, dst)
def _duplicate(self, src, dst):
if self.can_hardlink:
try:
if os.path.exists(dst):
# for Windows
os.unlink(dst)
os.link(src, dst)
return
except Exception as e:
self.can_hardlink = False
shutil.copy2(src, dst)
def _save(self, save_path, train_state: TrainState, metrics: Optional[Dict[str, float]] = None):
extra_state = dict(
# version < 2 increments epoch before save
# version < 3, pre timm bits
# version 3, first timm bits checkpoitns
version=3,
)
if self.hparams is not None:
extra_state.update(dict(arch=self.hparams['model'], hparams=self.hparams))
else:
arch = getattr(train_state.model, 'default_cfg', dict()).get('architecture', None)
if arch is None:
arch = type(train_state.model).__name__.lower()
extra_state.update(dict(arch=arch))
if metrics is not None:
# save the metrics and how we originally sorted them in the checkpoint for future comparisons
extra_state.update(dict(
metrics=metrics,
metric_name=self.metric_name,
metric_decreasing=self.metric_decreasing
))
self.save_state_fn(save_path, train_state, extra_state)
checkpoint_info = CheckpointInfo(
path=save_path,
metrics=metrics,
metric_name=self.metric_name,
metric_decreasing=self.metric_decreasing,
epoch=train_state.epoch,
global_step=train_state.step_count_global,
)
return checkpoint_info
def _udpate_checkpoints(self, info: CheckpointInfo):
self.checkpoint_files.append(info)
self.checkpoint_files = sorted(
self.checkpoint_files,
key=lambda x: x.sort_key,
reverse=not info.decreasing_key, # sort in descending order if a lower metric is not better
)
def _cleanup_checkpoints(self, trim=0):
trim = min(len(self.checkpoint_files), trim)
delete_index = self.max_history - trim
if delete_index < 0 or len(self.checkpoint_files) <= delete_index:
return
to_delete = self.checkpoint_files[delete_index:]
for d in to_delete:
try:
_logger.debug("Cleaning checkpoint: {}".format(d))
os.remove(d[0])
except Exception as e:
_logger.error("Exception '{}' while deleting checkpoint".format(e))
self.checkpoint_files = self.checkpoint_files[:delete_index]
def _compare_metric(self, lhs: CheckpointInfo, rhs: CheckpointInfo):
# compare metrics against an existing checkpoint
if not lhs or not lhs.valid_key or not rhs or not rhs.valid_key:
# always assume lhs metrics are better if there are no usable metrics to compare
return True
return self.metric_cmp_fn(lhs.sort_key, rhs.sort_key)
def save_checkpoint(self, train_state: TrainState, metrics: Optional[Dict[str, float]] = None):
assert train_state.epoch >= 0
tmp_save_path = os.path.join(self.checkpoint_dir, 'tmp' + self.extension)
last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension)
curr_checkpoint = self._save(tmp_save_path, train_state, metrics)
self._replace(tmp_save_path, last_save_path)
worst_checkpoint = self.checkpoint_files[-1] if self.checkpoint_files else None
if len(self.checkpoint_files) < self.max_history or self._compare_metric(curr_checkpoint, worst_checkpoint):
if len(self.checkpoint_files) >= self.max_history:
self._cleanup_checkpoints(1)
filename = self.checkpoint_tmpl.format(index=train_state.epoch)
save_path = os.path.join(self.checkpoint_dir, filename)
curr_checkpoint = replace(curr_checkpoint, path=save_path)
self._duplicate(last_save_path, save_path)
self._udpate_checkpoints(curr_checkpoint)
checkpoints_str = "Current checkpoints:\n"
for c in self.checkpoint_files:
checkpoints_str += f' {c.path}, {c.sort_key}\n'.format(c)
_logger.info(checkpoints_str)
if curr_checkpoint.valid_key and self._compare_metric(curr_checkpoint, self.best_checkpoint):
self.best_checkpoint = curr_checkpoint
best_save_path = os.path.join(self.checkpoint_dir, 'best' + self.extension)
self._duplicate(last_save_path, best_save_path)
return None if self.best_checkpoint is None else curr_checkpoint
def save_recovery(self, train_state: TrainState):
tmp_save_path = os.path.join(self.recovery_dir, 'recovery_tmp' + self.extension)
self._save(tmp_save_path, train_state)
filename = self.recovery_tmpl.format(index=train_state.step_count_global)
save_path = os.path.join(self.recovery_dir, filename)
self._replace(tmp_save_path, save_path)
if os.path.exists(self.prev_recovery_file):
try:
_logger.debug("Cleaning recovery: {}".format(self.prev_recovery_file))
os.remove(self.prev_recovery_file)
except Exception as e:
_logger.error("Exception '{}' while removing {}".format(e, self.prev_recovery_file))
self.prev_recovery_file = self.curr_recovery_file
self.curr_recovery_file = save_path
def find_recovery(self):
recovery_path = os.path.join(self.recovery_dir, self.recovery_prefix)
files = glob.glob(recovery_path + '*' + self.extension)
files = sorted(files)
return files[0] if len(files) else ''

@ -1,7 +1,7 @@
import abc import abc
from contextlib import suppress from contextlib import suppress
from enum import Enum from enum import Enum
from typing import Callable, Union, Optional, List, Tuple from typing import Callable, Union, Optional, List, Tuple, Dict, Any
from dataclasses import dataclass, field, InitVar from dataclasses import dataclass, field, InitVar
import torch import torch
@ -18,10 +18,21 @@ class DeviceEnvType(Enum):
XLA = "xla" XLA = "xla"
def state_dict_apply(state_dict: Dict[str, Any], apply_fn, select_fn=lambda x: x.isinstance(torch.Tensor)):
out_dict = {}
for k, v in state_dict.items():
if isinstance(v, dict):
out_dict[k] = state_dict_apply(v, apply_fn, select_fn)
else:
out_dict[k] = apply_fn(v) if select_fn(v) else v
return out_dict
@dataclass @dataclass
class DeviceEnv: class DeviceEnv:
device_type: InitVar[Optional[str]] = None device_type: InitVar[Optional[str]] = None
device_index: InitVar[Optional[int]] = None device_index: InitVar[Optional[int]] = None
channels_last: InitVar[bool] = False
device: torch.device = field(init=False) # set from device_type + device_index or post_init logic device: torch.device = field(init=False) # set from device_type + device_index or post_init logic
world_size: Optional[int] = None # set by post_init from env when None world_size: Optional[int] = None # set by post_init from env when None
@ -32,7 +43,12 @@ class DeviceEnv:
memory_format: Optional[torch.memory_format] = None memory_format: Optional[torch.memory_format] = None
dtype: Optional[torch.dtype] = None dtype: Optional[torch.dtype] = None
def __post_init__(self, device_type: Optional[str], device_index: Optional[int]): def __post_init__(
self,
device_type: Optional[str],
device_index: Optional[int],
channels_last: bool,
):
device_type = device_type or 'cpu' device_type = device_type or 'cpu'
self.device = torch.device(device_type) if device_index is None \ self.device = torch.device(device_type) if device_index is None \
else torch.device(device_type, device_index) else torch.device(device_type, device_index)
@ -41,6 +57,17 @@ class DeviceEnv:
self.global_rank = 0 if self.global_rank is None else self.global_rank self.global_rank = 0 if self.global_rank is None else self.global_rank
if self.autocast is None: if self.autocast is None:
self.autocast = suppress self.autocast = suppress
if channels_last:
self.memory_format = torch.channels_last
@staticmethod
def is_instance():
return is_global_device()
@staticmethod
def instance():
# throws if called before global device is set / initialized
return get_global_device()
@property @property
def type(self) -> DeviceEnvType: def type(self) -> DeviceEnvType:
@ -81,11 +108,23 @@ class DeviceEnv:
def wrap_parallel(self, *modules): def wrap_parallel(self, *modules):
pass pass
def to_cpu(self, *modules: torch.nn.Module):
moved = [m.cpu() for m in modules]
return moved[0] if len(moved) == 1 else moved
def to_device(self, *modules: torch.nn.Module): def to_device(self, *modules: torch.nn.Module):
# FIXME handling dtype / memformat... disable flags, enable flags, diff fn? # FIXME handling dtype? Do we want separate dtype for data vs model?
moved = [m.to(device=self.device, memory_format=self.memory_format) for m in modules] moved = [m.to(device=self.device, memory_format=self.memory_format) for m in modules]
return moved[0] if len(moved) == 1 else moved return moved[0] if len(moved) == 1 else moved
def state_dict_to_cpu(self, state: Dict[str, Any]):
cpu_state = state_dict_apply(state, apply_fn=lambda x: x.cpu())
return cpu_state
def state_dict_to_device(self, state: Dict[str, Any]):
cpu_state = state_dict_apply(state, apply_fn=lambda x: x.to(self.device))
return cpu_state
def mark_step(self): def mark_step(self):
pass # NO-OP for non-XLA devices pass # NO-OP for non-XLA devices
@ -126,3 +165,24 @@ class DeviceEnv:
def barrier(self): def barrier(self):
dist.barrier() dist.barrier()
# Global device environment singleton instance
_global_device_env: Optional[DeviceEnv] = None
def is_global_device():
return _global_device_env is not None
def get_global_device() -> DeviceEnv:
if not is_global_device():
raise RuntimeError('Please initialize device environment by calling initialize_device / set_global_device.')
return _global_device_env
def set_global_device(device: DeviceEnv):
global _global_device_env
if _global_device_env is not None:
raise RuntimeError('Global device is already set, it should NOT be set again.')
_global_device_env = device

@ -16,7 +16,12 @@ def is_cuda_available():
@dataclass @dataclass
class DeviceEnvCuda(DeviceEnv): class DeviceEnvCuda(DeviceEnv):
def __post_init__(self, device_type: str, device_index: Optional[int]): def __post_init__(
self,
device_type: Optional[str],
device_index: Optional[int],
channels_last: bool,
):
assert torch.cuda.device_count() assert torch.cuda.device_count()
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
setup_world_size = self.world_size or int(os.environ.get('WORLD_SIZE', 1)) setup_world_size = self.world_size or int(os.environ.get('WORLD_SIZE', 1))
@ -43,6 +48,8 @@ class DeviceEnvCuda(DeviceEnv):
self.global_rank = 0 self.global_rank = 0
if self.autocast is None: if self.autocast is None:
self.autocast = torch.cuda.amp.autocast if self.amp else suppress self.autocast = torch.cuda.amp.autocast if self.amp else suppress
if channels_last:
self.memory_format = torch.channels_last
@property @property
def type(self) -> DeviceEnvType: def type(self) -> DeviceEnvType:

@ -1,15 +1,15 @@
from .device_env import DeviceEnv import logging
from .device_env import DeviceEnv, is_global_device, get_global_device, set_global_device
from .device_env_cuda import DeviceEnvCuda, is_cuda_available from .device_env_cuda import DeviceEnvCuda, is_cuda_available
from .device_env_xla import DeviceEnvXla, is_xla_available from .device_env_xla import DeviceEnvXla, is_xla_available
_device_env = None _logger = logging.getLogger(__name__)
def initialize_device(force_cpu: bool = False, **kwargs) -> DeviceEnv: def initialize_device(force_cpu: bool = False, **kwargs) -> DeviceEnv:
global _device_env if is_global_device():
if _device_env is not None: return get_global_device()
# warning
return _device_env
denv = None denv = None
if not force_cpu: if not force_cpu:
@ -23,14 +23,10 @@ def initialize_device(force_cpu: bool = False, **kwargs) -> DeviceEnv:
if denv is None: if denv is None:
denv = DeviceEnv() denv = DeviceEnv()
print(denv) # FIXME DEBUG _logger.info(f'Initialized device {denv.device}. '
_device_env = denv f'Rank: {denv.global_rank} ({denv.local_rank}) of {denv.world_size}.')
return denv print(denv) # FIXME temporary print for debugging
def get_device() -> DeviceEnv:
if _device_env is None:
raise RuntimeError('Please initialize device environment by calling initialize_device first.')
return _device_env
set_global_device(denv)
return denv

@ -1,7 +1,7 @@
import os import os
from contextlib import suppress from contextlib import suppress
from dataclasses import dataclass, field, InitVar from dataclasses import dataclass, field, InitVar
from typing import Optional from typing import Optional, Dict
import torch import torch
from torch.distributed import ReduceOp from torch.distributed import ReduceOp
@ -42,7 +42,12 @@ def is_xla_available(xla_device_type=None):
@dataclass @dataclass
class DeviceEnvXla(DeviceEnv): class DeviceEnvXla(DeviceEnv):
def __post_init__(self, device_type: Optional[str], device_idx: Optional[int]): def __post_init__(
self,
device_type: Optional[str],
device_idx: Optional[int],
channels_last: bool,
):
if device_type is not None: if device_type is not None:
device_type = device_type.upper() device_type = device_type.upper()
assert device_type in ('TPU', 'GPU', 'CPU'), "XLA device type must be one of ('TPU', 'GPU', 'CPU')" assert device_type in ('TPU', 'GPU', 'CPU'), "XLA device type must be one of ('TPU', 'GPU', 'CPU')"
@ -59,6 +64,8 @@ class DeviceEnvXla(DeviceEnv):
assert xa is not None, 'XLA AMP is not present on this build' assert xa is not None, 'XLA AMP is not present on this build'
if self.autocast is None: if self.autocast is None:
self.autocast = xa.autocast if self.amp else suppress self.autocast = xa.autocast if self.amp else suppress
if channels_last:
self.memory_format = torch.channels_last
@property @property
def type(self) -> DeviceEnvType: def type(self) -> DeviceEnvType:
@ -114,3 +121,11 @@ class DeviceEnvXla(DeviceEnv):
def barrier(self): def barrier(self):
xm.rendezvous('timm.bits.dist_barrier') xm.rendezvous('timm.bits.dist_barrier')
def state_dict_to_cpu(self, state: Dict[str, torch.Tensor]):
cpu_state = xm._maybe_convert_to_cpu(state, convert=True)
return cpu_state
def state_dict_to_device(self, state: Dict[str, torch.Tensor]):
device_state = xm.send_cpu_data_to_device(state, device=self.device)
return device_state

@ -5,8 +5,7 @@ from torch.distributed import ReduceOp
from timm.utils import unwrap_model from timm.utils import unwrap_model
from .device_env import DeviceEnv, DeviceEnvType from .device_env import DeviceEnv
from .device_env_factory import get_device
TensorSeq = Union[torch.Tensor, Tuple[torch.Tensor, ...], List[torch.Tensor], Dict[Any, torch.Tensor]] TensorSeq = Union[torch.Tensor, Tuple[torch.Tensor, ...], List[torch.Tensor], Dict[Any, torch.Tensor]]
@ -22,7 +21,7 @@ def _validate_type(tensor: TensorSeq):
def distribute_bn(model: torch.nn.Module, reduce: bool = False, dev_env: DeviceEnv = None): def distribute_bn(model: torch.nn.Module, reduce: bool = False, dev_env: DeviceEnv = None):
if dev_env is None: if dev_env is None:
dev_env = get_device() dev_env = DeviceEnv.instance()
# ensure every node has the same running bn stats # ensure every node has the same running bn stats
for bn_name, bn_buf in unwrap_model(model).named_buffers(recurse=True): for bn_name, bn_buf in unwrap_model(model).named_buffers(recurse=True):
if ('running_mean' in bn_name) or ('running_var' in bn_name): if ('running_mean' in bn_name) or ('running_var' in bn_name):
@ -40,7 +39,7 @@ def all_gather_recursive(tensor: TensorSeq, cat_dim=0, dev_env: DeviceEnv = None
""" """
_validate_type(tensor) _validate_type(tensor)
if dev_env is None: if dev_env is None:
dev_env = get_device() dev_env = DeviceEnv.instance()
if isinstance(tensor, torch.Tensor): if isinstance(tensor, torch.Tensor):
return dev_env.all_gather(tensor, cat_dim=cat_dim) return dev_env.all_gather(tensor, cat_dim=cat_dim)
elif isinstance(tensor, dict): elif isinstance(tensor, dict):
@ -55,7 +54,7 @@ def all_reduce_recursive(tensor: TensorSeq, op=ReduceOp.SUM, average=False, dev_
""" """
_validate_type(tensor) _validate_type(tensor)
if dev_env is None: if dev_env is None:
dev_env = get_device() dev_env = DeviceEnv.instance()
if isinstance(tensor, torch.Tensor): if isinstance(tensor, torch.Tensor):
return dev_env.all_reduce_(tensor, op=op, average=average) return dev_env.all_reduce_(tensor, op=op, average=average)
elif isinstance(tensor, dict): elif isinstance(tensor, dict):
@ -70,7 +69,7 @@ def broadcast_recursive(tensor: TensorSeq, src_rank: int, dev_env: DeviceEnv = N
""" """
_validate_type(tensor) _validate_type(tensor)
if dev_env is None: if dev_env is None:
dev_env = get_device() dev_env = DeviceEnv.instance()
if isinstance(tensor, torch.Tensor): if isinstance(tensor, torch.Tensor):
return dev_env.broadcast_(tensor, src_rank=src_rank) return dev_env.broadcast_(tensor, src_rank=src_rank)
elif isinstance(tensor, dict): elif isinstance(tensor, dict):
@ -85,7 +84,7 @@ def all_gather_sequence(tensor: TensorSeq, cat_dim: int = 0, dev_env: DeviceEnv
""" """
_validate_type(tensor) _validate_type(tensor)
if dev_env is None: if dev_env is None:
dev_env = get_device() dev_env = DeviceEnv.instance()
with torch.no_grad(): with torch.no_grad():
names = None names = None
@ -124,7 +123,7 @@ def all_reduce_sequence(tensor: TensorSeq, op=ReduceOp.SUM, average=False, dev_e
""" """
_validate_type(tensor) _validate_type(tensor)
if dev_env is None: if dev_env is None:
dev_env = get_device() dev_env = DeviceEnv.instance()
with torch.no_grad(): with torch.no_grad():
names = None names = None

@ -6,14 +6,13 @@ import torch
from torch.distributed import ReduceOp from torch.distributed import ReduceOp
from .device_env import DeviceEnv from .device_env import DeviceEnv
from .device_env_factory import get_device
from .distributed import all_gather_sequence, all_reduce_sequence from .distributed import all_gather_sequence, all_reduce_sequence
MetricValue = Union[float, torch.Tensor, List[float], List[torch.Tensor]] MetricValueT = Union[float, torch.Tensor, List[float], List[torch.Tensor]]
@dataclass @dataclass
class ValueInfo: class ValueInfo:
initial: Optional[MetricValue] = 0. initial: Optional[MetricValueT] = 0.
dtype: torch.dtype = torch.float32 dtype: torch.dtype = torch.float32
dist_reduce: str = 'sum' dist_reduce: str = 'sum'
dist_average: bool = False dist_average: bool = False
@ -23,10 +22,10 @@ class Metric(abc.ABC):
def __init__(self, dev_env: DeviceEnv = None): def __init__(self, dev_env: DeviceEnv = None):
self._infos: Dict[str, ValueInfo] = {} self._infos: Dict[str, ValueInfo] = {}
self._values: Dict[str, Optional[MetricValue]] = {} self._values: Dict[str, Optional[MetricValueT]] = {}
self._values_dist: Dict[str, Optional[MetricValue]] = {} self._values_dist: Dict[str, Optional[MetricValueT]] = {}
if dev_env is None: if dev_env is None:
dev_env = get_device() dev_env = DeviceEnv.instance()
self._dev_env = dev_env self._dev_env = dev_env
def _register_value(self, name: str, info: Optional[ValueInfo] = None): def _register_value(self, name: str, info: Optional[ValueInfo] = None):
@ -117,7 +116,7 @@ class Metric(abc.ABC):
names.append(name) names.append(name)
values.append(value) values.append(value)
reductions.append(_args(info.dist_reduce)) reductions.append(_args(info.dist_reduce))
same_dsr = False
if same_dsr: if same_dsr:
do_gather, reduce_kwargs = reductions[0] do_gather, reduce_kwargs = reductions[0]
if do_gather: if do_gather:

@ -21,8 +21,6 @@ except ImportError:
HAS_WANDB = False HAS_WANDB = False
from .device_env_factory import get_device
# FIXME old formatting for reference, to remove # FIXME old formatting for reference, to remove
# #
# def log_eval(batch_idx, last_idx, batch_time, loss, top1, top5, log_suffix=''): # def log_eval(batch_idx, last_idx, batch_time, loss, top1, top5, log_suffix=''):
@ -122,19 +120,19 @@ def _add_kwargs(text_update, name_map=None, **kwargs):
text_update += [_to_str(name, v)] text_update += [_to_str(name, v)]
class Logger: class Monitor:
def __init__( def __init__(
self, self,
experiment_name=None, experiment_name=None,
output_dir=None, output_dir=None,
python_logger=None, logger=None,
hparams=None, hparams=None,
log_wandb=False, log_wandb=False,
output_enabled=True, output_enabled=True,
): ):
self.output_dir = output_dir # for tensorboard, csv, text file (TODO) logging self.output_dir = output_dir # for tensorboard, csv, text file (TODO) logging
self.logger = python_logger or logging.getLogger('log') self.logger = logger or logging.getLogger('log')
hparams = hparams or {} hparams = hparams or {}
# Setup CSV writer(s) # Setup CSV writer(s)

@ -1,13 +1,13 @@
from dataclasses import dataclass from dataclasses import dataclass
from .logger import Logger from .monitor import Monitor
from timm.utils.checkpoint_saver import CheckpointSaver from .checkpoint_manager import CheckpointManager
@dataclass @dataclass
class TrainServices: class TrainServices:
""" Train Loop Services """ Train Loop Services
""" """
logger: Logger = None logger: Monitor = None
saver: CheckpointSaver = None checkpoint_manager: CheckpointManager = None

@ -13,7 +13,7 @@ try:
except ImportError: except ImportError:
ds = None ds = None
from .checkpoint import resume_train_checkpoint from .checkpoint import load_train_state
from .device_env import DeviceEnv from .device_env import DeviceEnv
from .train_cfg import TrainCfg from .train_cfg import TrainCfg
from .train_state import TrainState from .train_state import TrainState
@ -90,10 +90,10 @@ def setup_model_and_optimizer(
if resume_path: if resume_path:
# FIXME this is not implemented yet, do a hack job before proper TrainState serialization? # FIXME this is not implemented yet, do a hack job before proper TrainState serialization?
resume_train_checkpoint( load_train_state(
train_state, train_state,
resume_path, resume_path,
resume_opt=resume_opt, load_opt=resume_opt,
log_info=dev_env.primary) log_info=dev_env.primary)
if dev_env.distributed: if dev_env.distributed:
@ -141,10 +141,10 @@ def setup_model_and_optimizer_deepspeed(
if resume_path: if resume_path:
# FIXME deepspeed resumes differently # FIXME deepspeed resumes differently
resume_train_checkpoint( load_legacy_checkpoint(
train_state, train_state,
resume_path, resume_path,
resume_opt=resume_opt, load_opt=resume_opt,
log_info=dev_env.primary) log_info=dev_env.primary)
if dev_env.distributed: if dev_env.distributed:

@ -4,6 +4,8 @@ from typing import Dict, Any
from torch import nn as nn from torch import nn as nn
from timm.scheduler import Scheduler from timm.scheduler import Scheduler
from timm.utils import get_state_dict, unwrap_model
from .updater import Updater from .updater import Updater
@ -16,18 +18,33 @@ class TrainState:
lr_scheduler: Scheduler = None lr_scheduler: Scheduler = None
model_ema: nn.Module = None model_ema: nn.Module = None
step_count_epoch: int = 0
step_count_global: int = 0
epoch: int = 0 epoch: int = 0
step_count: int = 0
step_count_global: int = 0
def __post_init__(self): def __post_init__(self):
assert self.model is not None assert self.model is not None
assert self.updater is not None assert self.updater is not None
def state_dict(self, unwrap_fn=unwrap_model):
def serialize_train_state(train_state: TrainState): state = dict(
pass epoch=self.epoch,
step_count=self.step_count,
step_count_global=self.step_count_global,
def deserialize_train_state(checkpoint: Dict[str, Any]): model=get_state_dict(self.model, unwrap_fn),
pass model_ema=None if self.model_ema is None else get_state_dict(self.model_ema, unwrap_fn),
)
# FIXME lr_scheduler state save?
state.update(self.updater.state_dict())
return state
def load_state_dict(self, state_dict, unwrap_fn=unwrap_model):
self.epoch = state_dict['epoch']
self.step_count = state_dict['step_count']
self.step_count_global = state_dict['step_count_global']
unwrap_fn(self.model).load_state_dict(state_dict.get('model'))
if 'model_ema' in state_dict and self.model_ema is not None:
unwrap_fn(self.model_ema).load_state_dict(state_dict.get('model_ema'))
self.updater.load_state_dict(state_dict)

@ -56,6 +56,7 @@ class Updater:
state_dict = dict(optimizer=self.optimizer.state_dict()) state_dict = dict(optimizer=self.optimizer.state_dict())
if self.grad_scaler is not None: if self.grad_scaler is not None:
state_dict['grad_scaler'] = self.grad_scaler.state_dict() state_dict['grad_scaler'] = self.grad_scaler.state_dict()
return state_dict
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
if 'optimizer' in state_dict: if 'optimizer' in state_dict:
@ -66,3 +67,6 @@ class Updater:
def after_step(self, after_step_fn, *args): def after_step(self, after_step_fn, *args):
after_step_fn(*args) after_step_fn(*args)
@property
def deepspeed(self):
return False

@ -24,3 +24,7 @@ class UpdaterDeepSpeed(Updater):
self.model.backward(loss) self.model.backward(loss)
self.model.step() self.model.step()
self.reset() self.reset()
@property
def deepspeed(self):
return True

@ -3,7 +3,6 @@ from typing import Callable, Optional, Union, Any
import torch import torch
from .device_env import DeviceEnv, DeviceEnvType from .device_env import DeviceEnv, DeviceEnvType
from .device_env_factory import get_device
from .updater import Updater from .updater import Updater
from .updater_cuda import UpdaterCudaWithScaler from .updater_cuda import UpdaterCudaWithScaler
from .updater_deepspeed import UpdaterDeepSpeed from .updater_deepspeed import UpdaterDeepSpeed
@ -21,7 +20,7 @@ def create_updater(
) -> Updater: ) -> Updater:
if not dev_env: if not dev_env:
dev_env = get_device() dev_env = DeviceEnv.instance()
updater_kwargs = dict(model=model, optimizer=optimizer, clip_fn=clip_fn, clip_value=clip_value) updater_kwargs = dict(model=model, optimizer=optimizer, clip_fn=clip_fn, clip_value=clip_value)
use_scaler = dev_env.amp use_scaler = dev_env.amp

@ -8,7 +8,7 @@ Hacked together by / Copyright 2020 Ross Wightman
import torch.utils.data import torch.utils.data
from timm.bits import get_device, DeviceEnvType from timm.bits import DeviceEnv
from .fetcher import Fetcher from .fetcher import Fetcher
from .prefetcher_cuda import PrefetcherCuda from .prefetcher_cuda import PrefetcherCuda
@ -75,7 +75,7 @@ def create_loader(
) )
if dev_env is None: if dev_env is None:
dev_env = get_device() dev_env = DeviceEnv.instance()
sampler = None sampler = None
if dev_env.distributed and not isinstance(dataset, torch.utils.data.IterableDataset): if dev_env.distributed and not isinstance(dataset, torch.utils.data.IterableDataset):

@ -23,7 +23,7 @@ except ImportError as e:
exit(1) exit(1)
from .parser import Parser from .parser import Parser
from timm.bits import get_device from timm.bits import get_global_device
MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities
SHUFFLE_SIZE = 16834 # samples to shuffle in DS queue SHUFFLE_SIZE = 16834 # samples to shuffle in DS queue
@ -80,7 +80,7 @@ class ParserTfds(Parser):
self.worker_info = None self.worker_info = None
self.dist_rank = 0 self.dist_rank = 0
self.dist_num_replicas = 1 self.dist_num_replicas = 1
dev_env = get_device() dev_env = get_global_device()
# FIXME allow to work without devenv usage? # FIXME allow to work without devenv usage?
if dev_env.distributed and dev_env.world_size > 1: if dev_env.distributed and dev_env.world_size > 1:
self.dist_rank = dev_env.global_rank self.dist_rank = dev_env.global_rank

@ -3,33 +3,38 @@
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2020 Ross Wightman
""" """
from .model_ema import ModelEma from .model_ema import ModelEma
import torch import torch
import fnmatch import fnmatch
def unwrap_model(model): _SUB_MODULE_ATTR = ('module', 'model')
if isinstance(model, ModelEma):
return unwrap_model(model.ema)
else: def unwrap_model(model, recursive=True):
return model.module if hasattr(model, 'module') else model for attr in _SUB_MODULE_ATTR:
sub_module = getattr(model, attr, None)
if sub_module is not None:
return unwrap_model(sub_module) if recursive else sub_module
return model
def get_state_dict(model, unwrap_fn=unwrap_model): def get_state_dict(model, unwrap_fn=unwrap_model):
return unwrap_fn(model).state_dict() return unwrap_fn(model).state_dict()
def avg_sq_ch_mean(model, input, output): def avg_sq_ch_mean(model, input, output):
"calculate average channel square mean of output activations" """calculate average channel square mean of output activations
return torch.mean(output.mean(axis=[0,2,3])**2).item() """
return torch.mean(output.mean(axis=[0, 2, 3]) ** 2).item()
def avg_ch_var(model, input, output): def avg_ch_var(model, input, output):
"calculate average channel variance of output activations" """calculate average channel variance of output activations"""
return torch.mean(output.var(axis=[0,2,3])).item()\ return torch.mean(output.var(axis=[0, 2, 3])).item()
def avg_ch_var_residual(model, input, output): def avg_ch_var_residual(model, input, output):
"calculate average channel variance of output activations" """calculate average channel variance of output activations"""
return torch.mean(output.var(axis=[0,2,3])).item() return torch.mean(output.var(axis=[0, 2, 3])).item()
class ActivationStatsHook: class ActivationStatsHook:
@ -58,15 +63,16 @@ class ActivationStatsHook:
raise ValueError("Please provide `hook_fns` for each `hook_fn_locs`, \ raise ValueError("Please provide `hook_fns` for each `hook_fn_locs`, \
their lengths are different.") their lengths are different.")
self.stats = dict((hook_fn.__name__, []) for hook_fn in hook_fns) self.stats = dict((hook_fn.__name__, []) for hook_fn in hook_fns)
for hook_fn_loc, hook_fn in zip(hook_fn_locs, hook_fns): for hook_fn_loc, hook_fn in zip(hook_fn_locs, hook_fns):
self.register_hook(hook_fn_loc, hook_fn) self.register_hook(hook_fn_loc, hook_fn)
def _create_hook(self, hook_fn): def _create_hook(self, hook_fn):
def append_activation_stats(module, input, output): def append_activation_stats(module, input, output):
out = hook_fn(module, input, output) out = hook_fn(module, input, output)
self.stats[hook_fn.__name__].append(out) self.stats[hook_fn.__name__].append(out)
return append_activation_stats return append_activation_stats
def register_hook(self, hook_fn_loc, hook_fn): def register_hook(self, hook_fn_loc, hook_fn):
for name, module in self.model.named_modules(): for name, module in self.model.named_modules():
if not fnmatch.fnmatch(name, hook_fn_loc): if not fnmatch.fnmatch(name, hook_fn_loc):
@ -74,9 +80,9 @@ class ActivationStatsHook:
module.register_forward_hook(self._create_hook(hook_fn)) module.register_forward_hook(self._create_hook(hook_fn))
def extract_spp_stats(model, def extract_spp_stats(model,
hook_fn_locs, hook_fn_locs,
hook_fns, hook_fns,
input_shape=[8, 3, 224, 224]): input_shape=[8, 3, 224, 224]):
"""Extract average square channel mean and variance of activations during """Extract average square channel mean and variance of activations during
forward pass to plot Signal Propogation Plots (SPP). forward pass to plot Signal Propogation Plots (SPP).
@ -84,9 +90,8 @@ def extract_spp_stats(model,
Paper: https://arxiv.org/abs/2101.08692 Paper: https://arxiv.org/abs/2101.08692
Example Usage: https://gist.github.com/amaarora/6e56942fcb46e67ba203f3009b30d950 Example Usage: https://gist.github.com/amaarora/6e56942fcb46e67ba203f3009b30d950
""" """
x = torch.normal(0., 1., input_shape) x = torch.normal(0., 1., input_shape)
hook = ActivationStatsHook(model, hook_fn_locs=hook_fn_locs, hook_fns=hook_fns) hook = ActivationStatsHook(model, hook_fn_locs=hook_fn_locs, hook_fns=hook_fns)
_ = model(x) _ = model(x)
return hook.stats return hook.stats

@ -28,14 +28,14 @@ import torch
import torch.nn as nn import torch.nn as nn
import torchvision.utils import torchvision.utils
from timm.bits import initialize_device, setup_model_and_optimizer, DeviceEnv, Logger, Tracker,\ from timm.bits import initialize_device, setup_model_and_optimizer, DeviceEnv, Monitor, Tracker,\
TrainState, TrainServices, TrainCfg, AccuracyTopK, AvgTensor, distribute_bn TrainState, TrainServices, TrainCfg, CheckpointManager, AccuracyTopK, AvgTensor, distribute_bn
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.models import create_model, safe_model_name, convert_splitbn_model from timm.models import create_model, safe_model_name, convert_splitbn_model
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
from timm.optim import create_optimizer_v2, optimizer_kwargs from timm.optim import optimizer_kwargs
from timm.scheduler import create_scheduler from timm.scheduler import create_scheduler
from timm.utils import setup_default_logging, random_seed, get_outdir, CheckpointSaver from timm.utils import setup_default_logging, random_seed, get_outdir, unwrap_model
_logger = logging.getLogger('train') _logger = logging.getLogger('train')
@ -276,7 +276,7 @@ def main():
setup_default_logging() setup_default_logging()
args, args_text = _parse_args() args, args_text = _parse_args()
dev_env = initialize_device(amp=args.amp) dev_env = initialize_device(amp=args.amp, channels_last=args.channels_last)
if dev_env.distributed: if dev_env.distributed:
_logger.info('Training in distributed mode with multiple processes, 1 device per process. Process %d, total %d.' _logger.info('Training in distributed mode with multiple processes, 1 device per process. Process %d, total %d.'
% (dev_env.global_rank, dev_env.world_size)) % (dev_env.global_rank, dev_env.world_size))
@ -293,13 +293,17 @@ def main():
# FIXME perhaps keep the same and just set diff seeds for dataloader worker process? what about TFDS? # FIXME perhaps keep the same and just set diff seeds for dataloader worker process? what about TFDS?
random_seed(args.seed, dev_env.global_rank) random_seed(args.seed, dev_env.global_rank)
data_config, loader_eval, loader_train = setup_data(args, train_state.model.default_cfg, dev_env, mixup_active) data_config, loader_eval, loader_train = setup_data(
args,
unwrap_model(train_state.model).default_cfg,
dev_env,
mixup_active)
# setup checkpoint saver # setup checkpoint manager
eval_metric = args.eval_metric eval_metric = args.eval_metric
best_metric = None best_metric = None
best_epoch = None best_epoch = None
saver = None checkpoint_manager = None
output_dir = None output_dir = None
if dev_env.primary: if dev_env.primary:
if args.experiment: if args.experiment:
@ -311,24 +315,20 @@ def main():
str(data_config['input_size'][-1]) str(data_config['input_size'][-1])
]) ])
output_dir = get_outdir(args.output if args.output else './output/train', exp_name) output_dir = get_outdir(args.output if args.output else './output/train', exp_name)
decreasing = True if eval_metric == 'loss' else False checkpoint_manager = CheckpointManager(
saver = CheckpointSaver( # TODO CheckpointSaverV2 hparams=vars(args),
model=train_state.model,
optimizer=train_state.updater.optimizer,
args=args,
model_ema=train_state.model_ema,
amp_scaler=train_state.updater.grad_scaler,
checkpoint_dir=output_dir, checkpoint_dir=output_dir,
recovery_dir=output_dir, recovery_dir=output_dir,
decreasing=decreasing, metric_name=eval_metric,
metric_decreasing=True if eval_metric == 'loss' else False,
max_history=args.checkpoint_hist) max_history=args.checkpoint_hist)
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
f.write(args_text) f.write(args_text)
services = TrainServices( services = TrainServices(
logger=Logger( logger=Monitor(
output_dir=output_dir, python_logger=_logger, hparams=vars(args), output_enabled=dev_env.primary), output_dir=output_dir, logger=_logger, hparams=vars(args), output_enabled=dev_env.primary),
saver=saver, checkpoint_manager=checkpoint_manager,
) )
try: try:
@ -379,10 +379,10 @@ def main():
if services.logger is not None: if services.logger is not None:
services.logger.write_summary(index=epoch, results=dict(train=train_metrics, eval=eval_metrics)) services.logger.write_summary(index=epoch, results=dict(train=train_metrics, eval=eval_metrics))
if saver is not None: if checkpoint_manager is not None:
# save proper checkpoint with eval metric # save proper checkpoint with eval metric
save_metric = eval_metrics[eval_metric] best_checkpoint = checkpoint_manager.save_checkpoint(train_state, eval_metrics)
best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric) best_metric, best_epoch = best_checkpoint.sort_key, best_checkpoint.epoch
train_state = replace(train_state, epoch=epoch + 1) train_state = replace(train_state, epoch=epoch + 1)
@ -629,9 +629,9 @@ def after_train_step(
lr=lr_avg, lr=lr_avg,
) )
if services.saver is not None and cfg.recovery_interval and ( if services.checkpoint_manager is not None and cfg.recovery_interval and (
end_step or (step_idx + 1) % cfg.recovery_interval == 0): end_step or (step_idx + 1) % cfg.recovery_interval == 0):
services.saver.save_recovery(state.epoch, batch_idx=step_idx) services.checkpoint_manager.save_recovery(state.epoch, batch_idx=step_idx)
if state.lr_scheduler is not None: if state.lr_scheduler is not None:
state.lr_scheduler.step_update(num_updates=state.step_count_global) state.lr_scheduler.step_update(num_updates=state.step_count_global)
@ -641,7 +641,7 @@ def evaluate(
model: nn.Module, model: nn.Module,
loss_fn: nn.Module, loss_fn: nn.Module,
loader, loader,
logger: Logger, logger: Monitor,
dev_env: DeviceEnv, dev_env: DeviceEnv,
phase_suffix: str = '', phase_suffix: str = '',
log_interval: int = 10, log_interval: int = 10,

@ -18,7 +18,7 @@ import torch.nn as nn
import torch.nn.parallel import torch.nn.parallel
from collections import OrderedDict from collections import OrderedDict
from timm.bits import initialize_device, Tracker, Logger, AccuracyTopK, AvgTensor from timm.bits import initialize_device, Tracker, Monitor, AccuracyTopK, AvgTensor
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet
from timm.utils import natural_key, setup_default_logging from timm.utils import natural_key, setup_default_logging
@ -154,7 +154,7 @@ def validate(args):
pin_memory=args.pin_mem, pin_memory=args.pin_mem,
tf_preprocessing=args.tf_preprocessing) tf_preprocessing=args.tf_preprocessing)
logger = Logger(python_logger=_logger) logger = Monitor(logger=_logger)
tracker = Tracker() tracker = Tracker()
losses = AvgTensor() losses = AvgTensor()
accuracy = AccuracyTopK(dev_env=dev_env) accuracy = AccuracyTopK(dev_env=dev_env)

Loading…
Cancel
Save