Add set_layer_config contextmgr to adjust all layer configs at once, use in create_module with new args. Remove a few old warning causing constant annotations for jit.

pull/155/head
Ross Wightman 4 years ago
parent f28170df3f
commit 88129b2569

@ -10,7 +10,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from collections import OrderedDict from collections import OrderedDict
from typing import Union, Optional, List, Tuple from typing import Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn

@ -1,5 +1,6 @@
from .registry import is_model, is_model_in_modules, model_entrypoint from .registry import is_model, is_model_in_modules, model_entrypoint
from .helpers import load_checkpoint from .helpers import load_checkpoint
from .layers import set_layer_config
def create_model( def create_model(
@ -8,6 +9,9 @@ def create_model(
num_classes=1000, num_classes=1000,
in_chans=3, in_chans=3,
checkpoint_path='', checkpoint_path='',
scriptable=None,
exportable=None,
no_jit=None,
**kwargs): **kwargs):
"""Create a model """Create a model
@ -17,13 +21,16 @@ def create_model(
num_classes (int): number of classes for final fully connected layer (default: 1000) num_classes (int): number of classes for final fully connected layer (default: 1000)
in_chans (int): number of input channels / colors (default: 3) in_chans (int): number of input channels / colors (default: 3)
checkpoint_path (str): path of checkpoint to load after model is initialized checkpoint_path (str): path of checkpoint to load after model is initialized
scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet)
exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet)
no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only)
Keyword Args: Keyword Args:
drop_rate (float): dropout rate for training (default: 0.0) drop_rate (float): dropout rate for training (default: 0.0)
global_pool (str): global pool type (default: 'avg') global_pool (str): global pool type (default: 'avg')
**: other kwargs are model specific **: other kwargs are model specific
""" """
margs = dict(pretrained=pretrained, num_classes=num_classes, in_chans=in_chans) model_args = dict(pretrained=pretrained, num_classes=num_classes, in_chans=in_chans)
# Only EfficientNet and MobileNetV3 models have support for batchnorm params or drop_connect_rate passed as args # Only EfficientNet and MobileNetV3 models have support for batchnorm params or drop_connect_rate passed as args
is_efficientnet = is_model_in_modules(model_name, ['efficientnet', 'mobilenetv3']) is_efficientnet = is_model_in_modules(model_name, ['efficientnet', 'mobilenetv3'])
@ -47,9 +54,10 @@ def create_model(
if kwargs.get('drop_path_rate', None) is None: if kwargs.get('drop_path_rate', None) is None:
kwargs.pop('drop_path_rate', None) kwargs.pop('drop_path_rate', None)
with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
if is_model(model_name): if is_model(model_name):
create_fn = model_entrypoint(model_name) create_fn = model_entrypoint(model_name)
model = create_fn(**margs, **kwargs) model = create_fn(**model_args, **kwargs)
else: else:
raise RuntimeError('Unknown model (%s)' % model_name) raise RuntimeError('Unknown model (%s)' % model_name)

@ -193,7 +193,6 @@ class Mixed_7a(nn.Module):
class Block8(nn.Module): class Block8(nn.Module):
__constants__ = ['relu'] # for pre 1.4 torchscript compat
def __init__(self, scale=1.0, no_relu=False): def __init__(self, scale=1.0, no_relu=False):
super(Block8, self).__init__() super(Block8, self).__init__()

@ -4,7 +4,8 @@ from .adaptive_avgmax_pool import \
from .anti_aliasing import AntiAliasDownsampleLayer from .anti_aliasing import AntiAliasDownsampleLayer
from .blur_pool import BlurPool2d from .blur_pool import BlurPool2d
from .cond_conv2d import CondConv2d, get_condconv_initializer from .cond_conv2d import CondConv2d, get_condconv_initializer
from .config import is_exportable, is_scriptable, set_exportable, set_scriptable, is_no_jit, set_no_jit from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\
set_layer_config
from .conv2d_same import Conv2dSame from .conv2d_same import Conv2dSame
from .conv_bn_act import ConvBnAct from .conv_bn_act import ConvBnAct
from .create_act import create_act_layer, get_act_layer, get_act_fn from .create_act import create_act_layer, get_act_layer, get_act_fn

@ -38,7 +38,7 @@ class CondConv2d(nn.Module):
Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion: Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion:
https://github.com/pytorch/pytorch/issues/17983 https://github.com/pytorch/pytorch/issues/17983
""" """
__constants__ = ['bias', 'in_channels', 'out_channels', 'dynamic_padding'] __constants__ = ['in_channels', 'out_channels', 'dynamic_padding']
def __init__(self, in_channels, out_channels, kernel_size=3, def __init__(self, in_channels, out_channels, kernel_size=3,
stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4): stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4):

@ -1,13 +1,18 @@
""" Model / Layer Config Singleton """ Model / Layer Config singleton state
""" """
from typing import Any from typing import Any, Optional
__all__ = ['is_exportable', 'is_scriptable', 'set_exportable', 'set_scriptable', 'is_no_jit', 'set_no_jit'] __all__ = [
'is_exportable', 'is_scriptable', 'is_no_jit',
'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config'
]
# Set to True if prefer to have layers with no jit optimization (includes activations) # Set to True if prefer to have layers with no jit optimization (includes activations)
_NO_JIT = False _NO_JIT = False
# Set to True if prefer to have activation layers with no jit optimization # Set to True if prefer to have activation layers with no jit optimization
# NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying
# the jit flags so far are activations. This will change as more layers are updated and/or added.
_NO_ACTIVATION_JIT = False _NO_ACTIVATION_JIT = False
# Set to True if exporting a model with Same padding via ONNX # Set to True if exporting a model with Same padding via ONNX
@ -72,3 +77,39 @@ class set_scriptable:
global _SCRIPTABLE global _SCRIPTABLE
_SCRIPTABLE = self.prev _SCRIPTABLE = self.prev
return False return False
class set_layer_config:
""" Layer config context manager that allows setting all layer config flags at once.
If a flag arg is None, it will not change the current value.
"""
def __init__(
self,
scriptable: Optional[bool] = None,
exportable: Optional[bool] = None,
no_jit: Optional[bool] = None,
no_activation_jit: Optional[bool] = None):
global _SCRIPTABLE
global _EXPORTABLE
global _NO_JIT
global _NO_ACTIVATION_JIT
self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT
if scriptable is not None:
_SCRIPTABLE = scriptable
if exportable is not None:
_EXPORTABLE = exportable
if no_jit is not None:
_NO_JIT = no_jit
if no_activation_jit is not None:
_NO_ACTIVATION_JIT = no_activation_jit
def __enter__(self) -> None:
pass
def __exit__(self, *args: Any) -> bool:
global _SCRIPTABLE
global _EXPORTABLE
global _NO_JIT
global _NO_ACTIVATION_JIT
_SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev
return False

@ -5,7 +5,7 @@ Hacked together by Ross Wightman
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from typing import Union, List, Tuple, Optional from typing import List, Tuple, Optional
from .helpers import tup_pair from .helpers import tup_pair
from .padding import pad_same, get_padding_value from .padding import pad_same, get_padding_value

@ -85,15 +85,13 @@ def validate(args):
args.pretrained = args.pretrained or not args.checkpoint args.pretrained = args.pretrained or not args.checkpoint
args.prefetcher = not args.no_prefetcher args.prefetcher = not args.no_prefetcher
if args.torchscript:
set_scriptable(True)
# create model # create model
model = create_model( model = create_model(
args.model, args.model,
pretrained=args.pretrained,
num_classes=args.num_classes, num_classes=args.num_classes,
in_chans=3, in_chans=3,
pretrained=args.pretrained) scriptable=args.torchscript)
if args.checkpoint: if args.checkpoint:
load_checkpoint(model, args.checkpoint, args.use_ema) load_checkpoint(model, args.checkpoint, args.use_ema)

Loading…
Cancel
Save