From b5255960d9baefb884f2dedc0551f1e02c9717ca Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 13 Feb 2019 23:11:09 -0800 Subject: [PATCH] Tweaking tanh scheduler, senet weight init (for BN), transform defaults --- models/senet.py | 33 ++++++++++++++++++++++++++ models/transforms.py | 2 +- scheduler/tanh_lr.py | 55 +++++++++++++++++++++++++++++--------------- train.py | 33 +++++++++++++------------- 4 files changed, 87 insertions(+), 36 deletions(-) diff --git a/models/senet.py b/models/senet.py index dc40c1e8..e0907ebf 100644 --- a/models/senet.py +++ b/models/senet.py @@ -104,6 +104,18 @@ pretrained_config = { } +def _weight_init(m, n='', ll=''): + print(m, n, ll) + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + if ll and n == ll: + nn.init.constant_(m.weight, 0.) + else: + nn.init.constant_(m.weight, 1.) + nn.init.constant_(m.bias, 0.) + + class SEModule(nn.Module): def __init__(self, channels, reduction): @@ -116,6 +128,9 @@ class SEModule(nn.Module): channels // reduction, channels, kernel_size=1, padding=0) self.sigmoid = nn.Sigmoid() + for m in self.modules(): + _weight_init(m) + def forward(self, x): module_input = x x = self.avg_pool(x) @@ -176,6 +191,9 @@ class SEBottleneck(Bottleneck): self.downsample = downsample self.stride = stride + for n, m in self.named_modules(): + _weight_init(m, n, ll='bn3') + class SEResNetBottleneck(Bottleneck): """ @@ -201,6 +219,9 @@ class SEResNetBottleneck(Bottleneck): self.downsample = downsample self.stride = stride + for n, m in self.named_modules(): + _weight_init(m, n, ll='bn3') + class SEResNeXtBottleneck(Bottleneck): """ @@ -225,6 +246,9 @@ class SEResNeXtBottleneck(Bottleneck): self.downsample = downsample self.stride = stride + for n, m in self.named_modules(): + _weight_init(m, n, ll='bn3') + class SEResNetBlock(nn.Module): expansion = 1 @@ -242,6 +266,9 @@ class SEResNetBlock(nn.Module): self.downsample = downsample self.stride = stride + for n, m in self.named_modules(): + _weight_init(m, n, ll='bn2') + def forward(self, x): residual = x @@ -378,6 +405,12 @@ class SENet(nn.Module): self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None self.last_linear = nn.Linear(512 * block.expansion, num_classes) + for n, m in self.named_children(): + if n == 'layer0': + m.apply(_weight_init) + else: + _weight_init(m) + def _make_layer(self, block, planes, blocks, groups, reduction, stride=1, downsample_kernel_size=1, downsample_padding=0): downsample = None diff --git a/models/transforms.py b/models/transforms.py index cdb84456..6d54e891 100644 --- a/models/transforms.py +++ b/models/transforms.py @@ -21,7 +21,7 @@ class LeNormalize(object): return tensor -def transforms_imagenet_train(model_name, img_size=224, scale=(0.08, 1.0), color_jitter=(0.3, 0.3, 0.3)): +def transforms_imagenet_train(model_name, img_size=224, scale=(0.1, 1.0), color_jitter=(0.333, 0.333, 0.333)): if 'dpn' in model_name: normalize = transforms.Normalize( mean=IMAGENET_DPN_MEAN, diff --git a/scheduler/tanh_lr.py b/scheduler/tanh_lr.py index fbb6ccf4..af47412a 100644 --- a/scheduler/tanh_lr.py +++ b/scheduler/tanh_lr.py @@ -23,14 +23,20 @@ class TanhLRScheduler(Scheduler): t_mul: float = 1., lr_min: float = 0., decay_rate: float = 1., - warmup_updates=0, + warmup_t=0, warmup_lr_init=0, + warmup_prefix=False, cycle_limit=0, + t_in_epochs=False, initialize=True) -> None: super().__init__(optimizer, param_group_field="lr", initialize=initialize) assert t_initial > 0 assert lr_min >= 0 + assert lb < ub + assert cycle_limit >= 0 + assert warmup_t >= 0 + assert warmup_lr_init >= 0 self.lb = lb self.ub = ub self.t_initial = t_initial @@ -38,33 +44,33 @@ class TanhLRScheduler(Scheduler): self.lr_min = lr_min self.decay_rate = decay_rate self.cycle_limit = cycle_limit - self.warmup_updates = warmup_updates + self.warmup_t = warmup_t self.warmup_lr_init = warmup_lr_init - if self.warmup_updates: - self.warmup_steps = [(v - warmup_lr_init) / self.warmup_updates for v in self.base_values] + self.warmup_prefix = warmup_prefix + self.t_in_epochs = t_in_epochs + if self.warmup_t: + t_v = self.base_values if self.warmup_prefix else self._get_lr(self.warmup_t) + print(t_v) + self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in t_v] + super().update_groups(self.warmup_lr_init) else: self.warmup_steps = [1 for _ in self.base_values] - if self.warmup_lr_init: - super().update_groups(self.warmup_lr_init) - - def get_epoch_values(self, epoch: int): - # this scheduler doesn't update on epoch - return None - def get_update_values(self, num_updates: int): - if num_updates < self.warmup_updates: - lrs = [self.warmup_lr_init + num_updates * s for s in self.warmup_steps] + def _get_lr(self, t): + if t < self.warmup_t: + lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] else: - curr_updates = num_updates - self.warmup_updates + if self.warmup_prefix: + t = t - self.warmup_t if self.t_mul != 1: - i = math.floor(math.log(1 - curr_updates / self.t_initial * (1 - self.t_mul), self.t_mul)) + i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul)) t_i = self.t_mul ** i * self.t_initial - t_curr = curr_updates - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial + t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial else: - i = curr_updates // self.t_initial + i = t // self.t_initial t_i = self.t_initial - t_curr = curr_updates - (self.t_initial * i) + t_curr = t - (self.t_initial * i) if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit): gamma = self.decay_rate ** i @@ -78,5 +84,16 @@ class TanhLRScheduler(Scheduler): ] else: lrs = [self.lr_min * (self.decay_rate ** self.cycle_limit) for _ in self.base_values] - return lrs + + def get_epoch_values(self, epoch: int): + if self.t_in_epochs: + return self._get_lr(epoch) + else: + return None + + def get_update_values(self, num_updates: int): + if not self.t_in_epochs: + return self._get_lr(num_updates) + else: + return None diff --git a/train.py b/train.py index 2adc54b3..96d5cc9e 100644 --- a/train.py +++ b/train.py @@ -162,7 +162,7 @@ def main(): if args.opt.lower() == 'sgd': optimizer = optim.SGD( model.parameters(), lr=args.lr, - momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) + momentum=args.momentum, weight_decay=args.weight_decay, nesterov=False) elif args.opt.lower() == 'adam': optimizer = optim.Adam( model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps) @@ -183,15 +183,27 @@ def main(): if optimizer_state is not None: optimizer.load_state_dict(optimizer_state) + updates_per_epoch = len(loader_train) if args.sched == 'cosine': lr_scheduler = scheduler.CosineLRScheduler( optimizer, - t_initial=13 * len(loader_train), - t_mul=2.0, + t_initial=100 * updates_per_epoch, + t_mul=1.0, lr_min=0, decay_rate=0.5, warmup_lr_init=1e-4, - warmup_updates=len(loader_train) + warmup_updates=1 * updates_per_epoch + ) + elif args.sched == 'tanh': + lr_scheduler = scheduler.TanhLRScheduler( + optimizer, + t_initial=80 * updates_per_epoch, + t_mul=1.0, + lr_min=1e-5, + decay_rate=0.5, + warmup_lr_init=.001, + warmup_t=5 * updates_per_epoch, + cycle_limit=1 ) else: lr_scheduler = scheduler.StepLRScheduler( @@ -354,7 +366,7 @@ def validate(model, loader, loss_fn, args): losses_m.update(loss.item(), input.size(0)) # metrics - prec1, prec5 = accuracy(output, target, topk=(1, 3)) + prec1, prec5 = accuracy(output, target, topk=(1, 5)) prec1_m.update(prec1.item(), output.size(0)) prec5_m.update(prec5.item(), output.size(0)) @@ -375,16 +387,5 @@ def validate(model, loader, loss_fn, args): return metrics -def update_summary(epoch, train_metrics, eval_metrics, output_dir, write_header=False): - rowd = OrderedDict(epoch=epoch) - rowd.update(train_metrics) - rowd.update(eval_metrics) - with open(os.path.join(output_dir, 'summary.csv'), mode='a') as cf: - dw = csv.DictWriter(cf, fieldnames=rowd.keys()) - if write_header: # first iteration (epoch == 1 can't be used) - dw.writeheader() - dw.writerow(rowd) - - if __name__ == '__main__': main()