You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
38 lines
1.4 KiB
38 lines
1.4 KiB
5 years ago
|
import torch
|
||
|
import torch.nn as nn
|
||
|
|
||
|
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
|
||
|
from .norm_act import BatchNormAct2d
|
||
|
try:
|
||
|
from inplace_abn import InPlaceABN
|
||
|
has_iabn = True
|
||
|
except ImportError:
|
||
|
has_iabn = False
|
||
|
|
||
|
|
||
|
def create_norm_act(layer_type, num_features, jit=False, **kwargs):
|
||
|
layer_parts = layer_type.split('_')
|
||
|
assert len(layer_parts) in (1, 2)
|
||
|
layer_class = layer_parts[0].lower()
|
||
|
#activation_class = layer_parts[1].lower() if len(layer_parts) > 1 else '' # FIXME support string act selection
|
||
|
|
||
|
if layer_class == "batchnormact":
|
||
|
layer = BatchNormAct2d(num_features, **kwargs) # defaults to RELU of no kwargs override
|
||
|
elif layer_class == "batchnormrelu":
|
||
|
assert 'act_layer' not in kwargs
|
||
|
layer = BatchNormAct2d(num_features, act_layer=nn.ReLU, **kwargs)
|
||
|
elif layer_class == "evonormbatch":
|
||
|
layer = EvoNormBatch2d(num_features, **kwargs)
|
||
|
elif layer_class == "evonormsample":
|
||
|
layer = EvoNormSample2d(num_features, **kwargs)
|
||
|
elif layer_class == "iabn" or layer_class == "inplaceabn":
|
||
|
if not has_iabn:
|
||
|
raise ImportError(
|
||
|
"Pplease install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.11'")
|
||
|
layer = InPlaceABN(num_features, **kwargs)
|
||
|
else:
|
||
|
assert False, "Invalid norm_act layer (%s)" % layer_class
|
||
|
if jit:
|
||
|
layer = torch.jit.script(layer)
|
||
|
return layer
|