You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/bits/train_setup.py

152 lines
4.2 KiB

import dataclasses
from typing import Callable, Union, Optional
import logging
import torch
import torch.nn as nn
from timm.models.layers import convert_sync_batchnorm
from timm.optim import create_optimizer_v2
from timm.utils import ModelEmaV2
try:
import deepspeed as ds
except ImportError:
ds = None
from .checkpoint import load_train_state
from .device_env import DeviceEnv
from .train_cfg import TrainCfg
from .train_state import TrainState
from .updater_factory import create_updater
_logger = logging.getLogger(__name__)
def setup_model_and_optimizer(
dev_env: DeviceEnv,
model: nn.Module,
optimizer: Union[Callable, str],
optimizer_cfg,
clip_fn: Optional[Union[Callable, str]] = None,
clip_value: Optional[float] = None,
model_ema: bool = False,
model_ema_decay: float = 0.9999,
use_syncbn: bool = False,
resume_path: str = '',
resume_opt: bool = True,
deepspeed: bool = False,
):
"""
Args:
dev_env:
model:
optimizer:
optimizer_cfg:
clip_value:
clip_fn:
model_ema:
model_ema_decay:
use_syncbn:
resume_path:
resume_opt:
Returns:
"""
if deepspeed:
return setup_model_and_optimizer_deepspeed(
dev_env=dev_env, model=model, optimizer=optimizer, optimizer_cfg=optimizer_cfg,
clip_fn=clip_fn, clip_value=clip_value, model_ema=model_ema, model_ema_decay=model_ema_decay,
resume_path=resume_path, resume_opt=resume_opt,
)
dev_env.to_device(model)
if use_syncbn and dev_env.distributed:
model = convert_sync_batchnorm(model)
if dev_env.primary:
_logger.info(
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
if isinstance(optimizer, Callable):
# FIXME this interface needs to be figured out, model, model and/or parameters, or just parameters?
optimizer = optimizer(model, **optimizer_cfg)
else:
optimizer = create_optimizer_v2(model, **optimizer_cfg)
updater = create_updater(
model=model,
optimizer=optimizer,
clip_fn=clip_fn,
clip_value=clip_value,
)
# ema model
model_ema = ModelEmaV2(model, decay=model_ema_decay) if model_ema else None
train_state = TrainState(model=model, updater=updater, model_ema=model_ema)
if resume_path:
load_train_state(
train_state,
resume_path,
load_opt=resume_opt,
log_info=dev_env.primary)
if dev_env.distributed:
train_state = dataclasses.replace(
train_state, model=dev_env.wrap_distributed(train_state.model))
return train_state
def setup_model_and_optimizer_deepspeed(
dev_env: DeviceEnv,
model: nn.Module,
optimizer: Union[Callable, str],
optimizer_cfg,
clip_fn: Optional[Union[Callable, str]] = None,
clip_value: Optional[float] = None,
model_ema: bool = False,
model_ema_decay: float = 0.9999,
use_syncbn: bool = False,
resume_path: str = '',
resume_opt: bool = True,
):
dev_env.to_device(model)
if isinstance(optimizer, Callable):
optimizer = optimizer(model=model, **optimizer_cfg)
else:
optimizer = create_optimizer_v2(model=model, **optimizer_cfg)
model = ds.initialize(model=model, optimizer=optimizer, dist_init_required=False)
updater = create_updater(
model=model,
optimizer=optimizer,
clip_fn=clip_fn,
clip_value=clip_value,
deepspeed=True,
)
# ema model
# FIXME how to do EMA w/ deepspeed?
model_ema = ModelEmaV2(model, decay=model_ema_decay) if model_ema else None
train_state = TrainState(model=model, updater=updater, model_ema=model_ema)
if resume_path:
# FIXME deepspeed resumes differently
assert False
if dev_env.distributed:
train_state = dataclasses.replace(
train_state, model=dev_env.wrap_distributed(train_state.model))
return train_state