You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
27 lines
619 B
27 lines
619 B
from dataclasses import dataclass, field, InitVar
|
|
|
|
import torch
|
|
try:
|
|
import deepspeed as ds
|
|
except ImportError as e:
|
|
ds = None
|
|
|
|
from .updater import Updater
|
|
|
|
|
|
@dataclass
|
|
class UpdaterDeepSpeed(Updater):
|
|
|
|
def __post_init__(self):
|
|
super().__post_init__()
|
|
# FIXME not sure how to deal with model.module / grad clipping w/ DS engine interface
|
|
assert isinstance(self.model, ds.DeepSpeedEngine)
|
|
|
|
def reset(self):
|
|
self.model.zero_grad()
|
|
|
|
def apply(self, loss: torch.Tensor, accumulate=False):
|
|
self.model.backward(loss)
|
|
self.model.step()
|
|
self.reset()
|