You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
59 lines
2.3 KiB
59 lines
2.3 KiB
4 years ago
|
import logging
|
||
|
import os
|
||
|
from collections import OrderedDict
|
||
|
|
||
|
import torch
|
||
|
|
||
|
from .train_state import TrainState, serialize_train_state, deserialize_train_state
|
||
|
|
||
|
_logger = logging.getLogger(__name__)
|
||
|
|
||
|
|
||
|
def resume_train_checkpoint(
|
||
|
train_state,
|
||
|
checkpoint_path,
|
||
|
resume_opt=True,
|
||
|
deserialize_fn=deserialize_train_state,
|
||
|
log_info=True):
|
||
|
|
||
|
raise NotImplementedError
|
||
|
|
||
|
# resume_epoch = None
|
||
|
# if os.path.isfile(checkpoint_path):
|
||
|
# checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||
|
#
|
||
|
# if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
||
|
# if log_info:
|
||
|
# _logger.info('Restoring model state from checkpoint...')
|
||
|
# new_state_dict = OrderedDict()
|
||
|
# for k, v in checkpoint['state_dict'].items():
|
||
|
# name = k[7:] if k.startswith('module') else k
|
||
|
# new_state_dict[name] = v
|
||
|
# model.load_state_dict(new_state_dict)
|
||
|
#
|
||
|
# if optimizer is not None and 'optimizer' in checkpoint:
|
||
|
# if log_info:
|
||
|
# _logger.info('Restoring optimizer state from checkpoint...')
|
||
|
# optimizer.load_state_dict(checkpoint['optimizer'])
|
||
|
#
|
||
|
# if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint:
|
||
|
# if log_info:
|
||
|
# _logger.info('Restoring AMP loss scaler state from checkpoint...')
|
||
|
# loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key])
|
||
|
#
|
||
|
# if 'epoch' in checkpoint:
|
||
|
# resume_epoch = checkpoint['epoch']
|
||
|
# if 'version' in checkpoint and checkpoint['version'] > 1:
|
||
|
# resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
|
||
|
#
|
||
|
# if log_info:
|
||
|
# _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
|
||
|
# else:
|
||
|
# model.load_state_dict(checkpoint)
|
||
|
# if log_info:
|
||
|
# _logger.info("Loaded checkpoint '{}'".format(checkpoint_path))
|
||
|
# return resume_epoch
|
||
|
# else:
|
||
|
# _logger.error("No checkpoint found at '{}'".format(checkpoint_path))
|
||
|
# raise FileNotFoundError()
|