From 88129b2569dec4725a84c8a072c7613327ee25cb Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 2 Jun 2020 21:06:10 -0700 Subject: [PATCH] Add set_layer_config contextmgr to adjust all layer configs at once, use in create_module with new args. Remove a few old warning causing constant annotations for jit. --- timm/models/dpn.py | 2 +- timm/models/factory.py | 20 +++++++++---- timm/models/inception_resnet_v2.py | 1 - timm/models/layers/__init__.py | 3 +- timm/models/layers/cond_conv2d.py | 2 +- timm/models/layers/config.py | 47 ++++++++++++++++++++++++++++-- timm/models/layers/pool2d_same.py | 2 +- validate.py | 6 ++-- 8 files changed, 65 insertions(+), 18 deletions(-) diff --git a/timm/models/dpn.py b/timm/models/dpn.py index 1f45095d..fa4e39fb 100644 --- a/timm/models/dpn.py +++ b/timm/models/dpn.py @@ -10,7 +10,7 @@ from __future__ import division from __future__ import print_function from collections import OrderedDict -from typing import Union, Optional, List, Tuple +from typing import Tuple import torch import torch.nn as nn diff --git a/timm/models/factory.py b/timm/models/factory.py index fbcd004d..03d8cc1f 100644 --- a/timm/models/factory.py +++ b/timm/models/factory.py @@ -1,5 +1,6 @@ from .registry import is_model, is_model_in_modules, model_entrypoint from .helpers import load_checkpoint +from .layers import set_layer_config def create_model( @@ -8,6 +9,9 @@ def create_model( num_classes=1000, in_chans=3, checkpoint_path='', + scriptable=None, + exportable=None, + no_jit=None, **kwargs): """Create a model @@ -17,13 +21,16 @@ def create_model( num_classes (int): number of classes for final fully connected layer (default: 1000) in_chans (int): number of input channels / colors (default: 3) checkpoint_path (str): path of checkpoint to load after model is initialized + scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet) + exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet) + no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only) Keyword Args: drop_rate (float): dropout rate for training (default: 0.0) global_pool (str): global pool type (default: 'avg') **: other kwargs are model specific """ - margs = dict(pretrained=pretrained, num_classes=num_classes, in_chans=in_chans) + model_args = dict(pretrained=pretrained, num_classes=num_classes, in_chans=in_chans) # Only EfficientNet and MobileNetV3 models have support for batchnorm params or drop_connect_rate passed as args is_efficientnet = is_model_in_modules(model_name, ['efficientnet', 'mobilenetv3']) @@ -47,11 +54,12 @@ def create_model( if kwargs.get('drop_path_rate', None) is None: kwargs.pop('drop_path_rate', None) - if is_model(model_name): - create_fn = model_entrypoint(model_name) - model = create_fn(**margs, **kwargs) - else: - raise RuntimeError('Unknown model (%s)' % model_name) + with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit): + if is_model(model_name): + create_fn = model_entrypoint(model_name) + model = create_fn(**model_args, **kwargs) + else: + raise RuntimeError('Unknown model (%s)' % model_name) if checkpoint_path: load_checkpoint(model, checkpoint_path) diff --git a/timm/models/inception_resnet_v2.py b/timm/models/inception_resnet_v2.py index 34b14570..f8772cc8 100644 --- a/timm/models/inception_resnet_v2.py +++ b/timm/models/inception_resnet_v2.py @@ -193,7 +193,6 @@ class Mixed_7a(nn.Module): class Block8(nn.Module): - __constants__ = ['relu'] # for pre 1.4 torchscript compat def __init__(self, scale=1.0, no_relu=False): super(Block8, self).__init__() diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index b9c26fea..1ebc4be0 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -4,7 +4,8 @@ from .adaptive_avgmax_pool import \ from .anti_aliasing import AntiAliasDownsampleLayer from .blur_pool import BlurPool2d from .cond_conv2d import CondConv2d, get_condconv_initializer -from .config import is_exportable, is_scriptable, set_exportable, set_scriptable, is_no_jit, set_no_jit +from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\ + set_layer_config from .conv2d_same import Conv2dSame from .conv_bn_act import ConvBnAct from .create_act import create_act_layer, get_act_layer, get_act_fn diff --git a/timm/models/layers/cond_conv2d.py b/timm/models/layers/cond_conv2d.py index b1759d99..df98f71a 100644 --- a/timm/models/layers/cond_conv2d.py +++ b/timm/models/layers/cond_conv2d.py @@ -38,7 +38,7 @@ class CondConv2d(nn.Module): Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion: https://github.com/pytorch/pytorch/issues/17983 """ - __constants__ = ['bias', 'in_channels', 'out_channels', 'dynamic_padding'] + __constants__ = ['in_channels', 'out_channels', 'dynamic_padding'] def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4): diff --git a/timm/models/layers/config.py b/timm/models/layers/config.py index 2c0faf23..f07b9d78 100644 --- a/timm/models/layers/config.py +++ b/timm/models/layers/config.py @@ -1,13 +1,18 @@ -""" Model / Layer Config Singleton +""" Model / Layer Config singleton state """ -from typing import Any +from typing import Any, Optional -__all__ = ['is_exportable', 'is_scriptable', 'set_exportable', 'set_scriptable', 'is_no_jit', 'set_no_jit'] +__all__ = [ + 'is_exportable', 'is_scriptable', 'is_no_jit', + 'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config' +] # Set to True if prefer to have layers with no jit optimization (includes activations) _NO_JIT = False # Set to True if prefer to have activation layers with no jit optimization +# NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying +# the jit flags so far are activations. This will change as more layers are updated and/or added. _NO_ACTIVATION_JIT = False # Set to True if exporting a model with Same padding via ONNX @@ -72,3 +77,39 @@ class set_scriptable: global _SCRIPTABLE _SCRIPTABLE = self.prev return False + + +class set_layer_config: + """ Layer config context manager that allows setting all layer config flags at once. + If a flag arg is None, it will not change the current value. + """ + def __init__( + self, + scriptable: Optional[bool] = None, + exportable: Optional[bool] = None, + no_jit: Optional[bool] = None, + no_activation_jit: Optional[bool] = None): + global _SCRIPTABLE + global _EXPORTABLE + global _NO_JIT + global _NO_ACTIVATION_JIT + self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT + if scriptable is not None: + _SCRIPTABLE = scriptable + if exportable is not None: + _EXPORTABLE = exportable + if no_jit is not None: + _NO_JIT = no_jit + if no_activation_jit is not None: + _NO_ACTIVATION_JIT = no_activation_jit + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _SCRIPTABLE + global _EXPORTABLE + global _NO_JIT + global _NO_ACTIVATION_JIT + _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev + return False diff --git a/timm/models/layers/pool2d_same.py b/timm/models/layers/pool2d_same.py index 7135f831..51242619 100644 --- a/timm/models/layers/pool2d_same.py +++ b/timm/models/layers/pool2d_same.py @@ -5,7 +5,7 @@ Hacked together by Ross Wightman import torch import torch.nn as nn import torch.nn.functional as F -from typing import Union, List, Tuple, Optional +from typing import List, Tuple, Optional from .helpers import tup_pair from .padding import pad_same, get_padding_value diff --git a/validate.py b/validate.py index ca031263..50010cce 100755 --- a/validate.py +++ b/validate.py @@ -85,15 +85,13 @@ def validate(args): args.pretrained = args.pretrained or not args.checkpoint args.prefetcher = not args.no_prefetcher - if args.torchscript: - set_scriptable(True) - # create model model = create_model( args.model, + pretrained=args.pretrained, num_classes=args.num_classes, in_chans=3, - pretrained=args.pretrained) + scriptable=args.torchscript) if args.checkpoint: load_checkpoint(model, args.checkpoint, args.use_ema)