From 6441e9cc1b6545fd68b35d1d7eecebd96d9a5266 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 22 May 2020 16:16:45 -0700 Subject: [PATCH] Fix memory_efficient mode for DenseNets. Add AntiAliasing (Blur) support for DenseNets and create one test model. Add lr cycle/mul params to train args. --- timm/models/densenet.py | 24 +++++++++++++++++++----- timm/scheduler/scheduler_factory.py | 8 ++++---- train.py | 4 ++++ 3 files changed, 27 insertions(+), 9 deletions(-) diff --git a/timm/models/densenet.py b/timm/models/densenet.py index b9f9853c..539d5012 100644 --- a/timm/models/densenet.py +++ b/timm/models/densenet.py @@ -14,7 +14,7 @@ from torch.jit.annotations import List from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import load_pretrained -from .layers import SelectAdaptivePool2d, BatchNormAct2d, create_norm_act +from .layers import SelectAdaptivePool2d, BatchNormAct2d, create_norm_act, BlurPool2d from .registry import register_model __all__ = ['DenseNet'] @@ -71,9 +71,9 @@ class DenseLayer(nn.Module): def call_checkpoint_bottleneck(self, x): # type: (List[torch.Tensor]) -> torch.Tensor def closure(*xs): - return self.bottleneck_fn(*xs) + return self.bottleneck_fn(xs) - return cp.checkpoint(closure, x) + return cp.checkpoint(closure, *x) @torch.jit._overload_method # noqa: F811 def forward(self, x): @@ -132,12 +132,15 @@ class DenseBlock(nn.ModuleDict): class DenseTransition(nn.Sequential): - def __init__(self, num_input_features, num_output_features, norm_act_layer=nn.BatchNorm2d): + def __init__(self, num_input_features, num_output_features, norm_act_layer=nn.BatchNorm2d, aa_layer=None): super(DenseTransition, self).__init__() self.add_module('norm', norm_act_layer(num_input_features)) self.add_module('conv', nn.Conv2d( num_input_features, num_output_features, kernel_size=1, stride=1, bias=False)) - self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) + if aa_layer is not None: + self.add_module('pool', aa_layer(num_output_features, stride=2)) + else: + self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) class DenseNet(nn.Module): @@ -301,6 +304,17 @@ def densenet121(pretrained=False, **kwargs): return model +@register_model +def densenetblur121d(pretrained=False, **kwargs): + r"""Densenet-121 model from + `"Densely Connected Convolutional Networks" ` + """ + model = _densenet( + 'densenet121', growth_rate=32, block_config=(6, 12, 24, 16), pretrained=pretrained, stem_type='deep', + aa_layer=BlurPool2d, **kwargs) + return model + + @register_model def densenet121d(pretrained=False, **kwargs): r"""Densenet-121 model from diff --git a/timm/scheduler/scheduler_factory.py b/timm/scheduler/scheduler_factory.py index 2320c96b..ee4220ec 100644 --- a/timm/scheduler/scheduler_factory.py +++ b/timm/scheduler/scheduler_factory.py @@ -23,12 +23,12 @@ def create_scheduler(args, optimizer): lr_scheduler = CosineLRScheduler( optimizer, t_initial=num_epochs, - t_mul=1.0, + t_mul=args.lr_cycle_mul, lr_min=args.min_lr, decay_rate=args.decay_rate, warmup_lr_init=args.warmup_lr, warmup_t=args.warmup_epochs, - cycle_limit=1, + cycle_limit=args.lr_cycle_limit, t_in_epochs=True, noise_range_t=noise_range, noise_pct=args.lr_noise_pct, @@ -40,11 +40,11 @@ def create_scheduler(args, optimizer): lr_scheduler = TanhLRScheduler( optimizer, t_initial=num_epochs, - t_mul=1.0, + t_mul=args.lr_cycle_mul, lr_min=args.min_lr, warmup_lr_init=args.warmup_lr, warmup_t=args.warmup_epochs, - cycle_limit=1, + cycle_limit=args.lr_cycle_limit, t_in_epochs=True, noise_range_t=noise_range, noise_pct=args.lr_noise_pct, diff --git a/train.py b/train.py index 899c6984..7f8d4a26 100755 --- a/train.py +++ b/train.py @@ -111,6 +111,10 @@ parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT help='learning rate noise limit percent (default: 0.67)') parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', help='learning rate noise std-dev (default: 1.0)') +parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT', + help='learning rate cycle len multiplier (default: 1.0)') +parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N', + help='learning rate cycle limit') parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR', help='warmup learning rate (default: 0.0001)') parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',