diff --git a/timm/optim/__init__.py b/timm/optim/__init__.py index 3432baf0..2df9fb12 100644 --- a/timm/optim/__init__.py +++ b/timm/optim/__init__.py @@ -3,6 +3,7 @@ from .adamw import AdamW from .adafactor import Adafactor from .adahessian import Adahessian from .lookahead import Lookahead +from .madgrad import MADGRAD from .nadam import Nadam from .nvnovograd import NvNovoGrad from .radam import RAdam diff --git a/timm/optim/madgrad.py b/timm/optim/madgrad.py index 941dbf7b..f9ab24e3 100644 --- a/timm/optim/madgrad.py +++ b/timm/optim/madgrad.py @@ -1,3 +1,9 @@ +""" PyTorch MADGRAD optimizer + +MADGRAD: https://arxiv.org/abs/2101.11075 + +Code from: https://github.com/facebookresearch/madgrad +""" # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the diff --git a/timm/optim/optim_factory.py b/timm/optim/optim_factory.py index 315c02b3..dd7f02b8 100644 --- a/timm/optim/optim_factory.py +++ b/timm/optim/optim_factory.py @@ -1,12 +1,11 @@ """ Optimizer Factory w/ Custom Weight Decay -Hacked together by / Copyright 2020 Ross Wightman +Hacked together by / Copyright 2021 Ross Wightman """ from typing import Optional import torch import torch.nn as nn import torch.optim as optim -from torch.optim.optimizer import required from .adabelief import AdaBelief from .adafactor import Adafactor