Add set_layer_config contextmgr to adjust all layer configs at once, use in create_module with new args. Remove a few old warning causing constant annotations for jit.

pull/155/head
Ross Wightman 4 years ago
parent f28170df3f
commit 88129b2569

@ -10,7 +10,7 @@ from __future__ import division
from __future__ import print_function
from collections import OrderedDict
from typing import Union, Optional, List, Tuple
from typing import Tuple
import torch
import torch.nn as nn

@ -1,5 +1,6 @@
from .registry import is_model, is_model_in_modules, model_entrypoint
from .helpers import load_checkpoint
from .layers import set_layer_config
def create_model(
@ -8,6 +9,9 @@ def create_model(
num_classes=1000,
in_chans=3,
checkpoint_path='',
scriptable=None,
exportable=None,
no_jit=None,
**kwargs):
"""Create a model
@ -17,13 +21,16 @@ def create_model(
num_classes (int): number of classes for final fully connected layer (default: 1000)
in_chans (int): number of input channels / colors (default: 3)
checkpoint_path (str): path of checkpoint to load after model is initialized
scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet)
exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet)
no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only)
Keyword Args:
drop_rate (float): dropout rate for training (default: 0.0)
global_pool (str): global pool type (default: 'avg')
**: other kwargs are model specific
"""
margs = dict(pretrained=pretrained, num_classes=num_classes, in_chans=in_chans)
model_args = dict(pretrained=pretrained, num_classes=num_classes, in_chans=in_chans)
# Only EfficientNet and MobileNetV3 models have support for batchnorm params or drop_connect_rate passed as args
is_efficientnet = is_model_in_modules(model_name, ['efficientnet', 'mobilenetv3'])
@ -47,11 +54,12 @@ def create_model(
if kwargs.get('drop_path_rate', None) is None:
kwargs.pop('drop_path_rate', None)
if is_model(model_name):
create_fn = model_entrypoint(model_name)
model = create_fn(**margs, **kwargs)
else:
raise RuntimeError('Unknown model (%s)' % model_name)
with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
if is_model(model_name):
create_fn = model_entrypoint(model_name)
model = create_fn(**model_args, **kwargs)
else:
raise RuntimeError('Unknown model (%s)' % model_name)
if checkpoint_path:
load_checkpoint(model, checkpoint_path)

@ -193,7 +193,6 @@ class Mixed_7a(nn.Module):
class Block8(nn.Module):
__constants__ = ['relu'] # for pre 1.4 torchscript compat
def __init__(self, scale=1.0, no_relu=False):
super(Block8, self).__init__()

@ -4,7 +4,8 @@ from .adaptive_avgmax_pool import \
from .anti_aliasing import AntiAliasDownsampleLayer
from .blur_pool import BlurPool2d
from .cond_conv2d import CondConv2d, get_condconv_initializer
from .config import is_exportable, is_scriptable, set_exportable, set_scriptable, is_no_jit, set_no_jit
from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\
set_layer_config
from .conv2d_same import Conv2dSame
from .conv_bn_act import ConvBnAct
from .create_act import create_act_layer, get_act_layer, get_act_fn

@ -38,7 +38,7 @@ class CondConv2d(nn.Module):
Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion:
https://github.com/pytorch/pytorch/issues/17983
"""
__constants__ = ['bias', 'in_channels', 'out_channels', 'dynamic_padding']
__constants__ = ['in_channels', 'out_channels', 'dynamic_padding']
def __init__(self, in_channels, out_channels, kernel_size=3,
stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4):

@ -1,13 +1,18 @@
""" Model / Layer Config Singleton
""" Model / Layer Config singleton state
"""
from typing import Any
from typing import Any, Optional
__all__ = ['is_exportable', 'is_scriptable', 'set_exportable', 'set_scriptable', 'is_no_jit', 'set_no_jit']
__all__ = [
'is_exportable', 'is_scriptable', 'is_no_jit',
'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config'
]
# Set to True if prefer to have layers with no jit optimization (includes activations)
_NO_JIT = False
# Set to True if prefer to have activation layers with no jit optimization
# NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying
# the jit flags so far are activations. This will change as more layers are updated and/or added.
_NO_ACTIVATION_JIT = False
# Set to True if exporting a model with Same padding via ONNX
@ -72,3 +77,39 @@ class set_scriptable:
global _SCRIPTABLE
_SCRIPTABLE = self.prev
return False
class set_layer_config:
""" Layer config context manager that allows setting all layer config flags at once.
If a flag arg is None, it will not change the current value.
"""
def __init__(
self,
scriptable: Optional[bool] = None,
exportable: Optional[bool] = None,
no_jit: Optional[bool] = None,
no_activation_jit: Optional[bool] = None):
global _SCRIPTABLE
global _EXPORTABLE
global _NO_JIT
global _NO_ACTIVATION_JIT
self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT
if scriptable is not None:
_SCRIPTABLE = scriptable
if exportable is not None:
_EXPORTABLE = exportable
if no_jit is not None:
_NO_JIT = no_jit
if no_activation_jit is not None:
_NO_ACTIVATION_JIT = no_activation_jit
def __enter__(self) -> None:
pass
def __exit__(self, *args: Any) -> bool:
global _SCRIPTABLE
global _EXPORTABLE
global _NO_JIT
global _NO_ACTIVATION_JIT
_SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev
return False

@ -5,7 +5,7 @@ Hacked together by Ross Wightman
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Union, List, Tuple, Optional
from typing import List, Tuple, Optional
from .helpers import tup_pair
from .padding import pad_same, get_padding_value

@ -85,15 +85,13 @@ def validate(args):
args.pretrained = args.pretrained or not args.checkpoint
args.prefetcher = not args.no_prefetcher
if args.torchscript:
set_scriptable(True)
# create model
model = create_model(
args.model,
pretrained=args.pretrained,
num_classes=args.num_classes,
in_chans=3,
pretrained=args.pretrained)
scriptable=args.torchscript)
if args.checkpoint:
load_checkpoint(model, args.checkpoint, args.use_ema)

Loading…
Cancel
Save