cosine_lr docstring, type hints

pull/1113/head
ayasyrev 4 years ago
parent 07379c6d5d
commit 6734cf56ed

@ -6,7 +6,7 @@ Hacked together by / Copyright 2021 Ross Wightman
"""
import logging
import math
import numpy as np
from typing import List, Union
import torch
from .scheduler import Scheduler
@ -24,6 +24,25 @@ class CosineLRScheduler(Scheduler):
https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py
k-decay option based on `k-decay: A New Method For Learning Rate Schedule` - https://arxiv.org/abs/2004.05909
Args:
optimizer (torch.optim.Optimizer): torch optimizer to schedule
t_initial (int): Number of epochs it initial (first) cycle.
lr_min (float, optional): Minimum learning rate to use during the scheduling. Defaults to 0..
cycle_mul (float, optional): Multiplyer for cycle length. Defaults to 1..
cycle_decay (float, optional): Factor to decay lr at next cycle. Defaults to 1..
cycle_limit (int, optional): Number of cycles. Defaults to 1.
warmup_t (int, optional): Number of epochs to warmup. Defaults to 0.
warmup_lr_init (float, optional): Initial learning rate during warmup . Defaults to 0.
warmup_prefix (bool, optional): If True, after warmup annealing starts from initial LR. Defaults to False.
t_in_epochs (bool, optional): If set to False, returned lr are None. Defaults to True.
noise_range_t (Union[int, float, List[int, float]], optional): Epoch when noise starts.\
If list or tuple - epoch range, when noise applied. Defaults to None.
noise_pct (float, optional): Percentage of noise to add. Defaults to 0.67.
noise_std (float, optional): Noise standard deviation. Defaults to 1.0.
noise_seed (int, optional): Seed to use to add random noise. Defaults to 42.
k_decay (float, optional): Power for k_decay. Defaults to 1.0.
initialize (bool, optional): Add initial_{field_name} to optimizer param group. Defaults to True.
"""
def __init__(self,
@ -33,16 +52,17 @@ class CosineLRScheduler(Scheduler):
cycle_mul: float = 1.,
cycle_decay: float = 1.,
cycle_limit: int = 1,
warmup_t=0,
warmup_lr_init=0,
warmup_prefix=False,
t_in_epochs=True,
noise_range_t=None,
noise_pct=0.67,
noise_std=1.0,
noise_seed=42,
k_decay=1.0,
initialize=True) -> None:
warmup_t: int = 0,
warmup_lr_init: float = 0,
warmup_prefix: bool = False,
t_in_epochs: bool = True,
noise_range_t: Union[int, float, List[int, float]] = None,
noise_pct: float = 0.67,
noise_std: float = 1.0,
noise_seed: int = 42,
k_decay: float = 1.0,
initialize: bool = True) -> None:
super().__init__(
optimizer, param_group_field="lr",
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
@ -111,7 +131,15 @@ class CosineLRScheduler(Scheduler):
else:
return None
def get_cycle_length(self, cycles=0):
def get_cycle_length(self, cycles: int = 0) -> int:
"""Return total number of epochs.
Args:
cycles (int, optional): Number of cycles. If 0, takes cycle_limit from sched. Defaults to 0.
Returns:
int: Total number of epochs
"""
cycles = max(1, cycles or self.cycle_limit)
if self.cycle_mul == 1.0:
return self.t_initial * cycles

Loading…
Cancel
Save