You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/bits/train_state.py

65 lines
2.2 KiB

from dataclasses import dataclass
from typing import Dict, Any
from torch import nn as nn
from timm.scheduler import Scheduler
from timm.utils import get_state_dict, unwrap_model
from .train_cfg import TrainCfg
from .updater import Updater
@dataclass
class TrainState:
model: nn.Module = None
train_loss: nn.Module = None
eval_loss: nn.Module = None
updater: Updater = None
lr_scheduler: Scheduler = None
model_ema: nn.Module = None
train_cfg: TrainCfg = TrainCfg()
# FIXME collect & include other cfg like data & model so it's in one spot for checkpoints / logging / debugging?
epoch: int = 0
step_count: int = 0
step_count_global: int = 0
def __post_init__(self):
assert self.model is not None
assert self.updater is not None
def state_dict(self, unwrap_fn=unwrap_model):
state = dict(
# train loop state (counters, etc), saved and restored
epoch=self.epoch,
step_count=self.step_count,
step_count_global=self.step_count_global,
# model params / state, saved and restored
model=get_state_dict(self.model, unwrap_fn),
model_ema=None if self.model_ema is None else get_state_dict(self.model_ema, unwrap_fn),
# configuration, saved but currently not restored, determined by args / config file for each run
train_cfg=vars(self.train_cfg)
)
# FIXME include lr_scheduler state?
state.update(self.updater.state_dict()) # updater (optimizer, scaler, etc.) state added to state
return state
def load_state_dict(self, state_dict, unwrap_fn=unwrap_model, load_opt=True):
# restore train loop state
self.epoch = state_dict['epoch'] + 1
self.step_count = 0 # FIXME need more logic to restore part way through epoch
self.step_count_global = state_dict['step_count_global']
# restore model params / state
unwrap_fn(self.model).load_state_dict(state_dict.get('model'))
if 'model_ema' in state_dict and self.model_ema is not None:
unwrap_fn(self.model_ema).load_state_dict(state_dict.get('model_ema'))
# restore optimizer state
if load_opt:
self.updater.load_state_dict(state_dict)