You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/utils/cuda.py

48 lines
1.1 KiB

""" CUDA / AMP utils
Hacked together by / Copyright 2020 Ross Wightman
"""
import torch
try:
from apex import amp
has_apex = True
except ImportError:
amp = None
has_apex = False
class ApexScaler:
state_dict_key = "amp"
def __call__(self, loss, optimizer):
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
def state_dict(self):
if 'state_dict' in amp.__dict__:
return amp.state_dict()
def load_state_dict(self, state_dict):
if 'load_state_dict' in amp.__dict__:
amp.load_state_dict(state_dict)
class NativeScaler:
state_dict_key = "amp_scaler"
def __init__(self):
self._scaler = torch.cuda.amp.GradScaler()
def __call__(self, loss, optimizer):
self._scaler.scale(loss).backward()
self._scaler.step(optimizer)
self._scaler.update()
def state_dict(self):
return self._scaler.state_dict()
def load_state_dict(self, state_dict):
self._scaler.load_state_dict(state_dict)