pull/240/merge
Kim, Taehoon 4 years ago committed by GitHub
commit 00dfd2cb6d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,456 @@
#!/usr/bin/env python3
""" ImageNet Validation Script
This is intended to be a lean and easily modifiable ImageNet validation script for evaluating pretrained
models or training checkpoints against ImageNet or similarly organized image datasets. It prioritizes
canonical PyTorch, standard Python style, and good performance. Repurpose as you see fit.
Hacked together by Ross Wightman (https://github.com/rwightman)
"""
import argparse
import os
import csv
import glob
import time
import logging
import torch
import torch.nn as nn
import torch.nn.parallel
from collections import OrderedDict
from contextlib import suppress
import torch.quantization
import torch.quantization.quantize_fx as quantize_fx
import copy
#currently, quantization only runs on CPUs
os.environ['CUDA_VISIBLE_DEVICES'] = ""
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_legacy
#has_apex = False
#try:
# from apex import amp
# has_apex = True
#except ImportError:
# pass
#as_native_amp = False
#try:
# if getattr(torch.cuda.amp, 'autocast') is not None:
# has_native_amp = True
#except AttributeError:
# pass
torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('validate')
parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')
parser.add_argument('data', metavar='DIR',
help='path to dataset')
parser.add_argument('--dataset', '-d', metavar='NAME', default='',
help='dataset type (default: ImageFolder/ImageTar if empty)')
#argument for calibration dataset
parser.add_argument('--calib-data', metavar='DIR',
help='path to calibration dataset')
# quantization option(weight only, dynamic, static)
parser.add_argument('--quant_option', metavar='NAME', default='static',
help='quantization option (weight_only, dynamic, static) (default: static)')
parser.add_argument('--split', metavar='NAME', default='validation',
help='dataset split (default: validation)')
parser.add_argument('--model', '-m', metavar='NAME', default='dpn92',
help='model architecture (default: dpn92)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 2)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--img-size', default=None, type=int,
metavar='N', help='Input image dimension, uses model default if empty')
parser.add_argument('--input-size', default=None, nargs=3, type=int,
metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
parser.add_argument('--crop-pct', default=None, type=float,
metavar='N', help='Input image center crop pct')
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
help='Override std deviation of of dataset')
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)')
parser.add_argument('--num-classes', type=int, default=None,
help='Number classes in dataset')
parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',
help='path to class to idx mapping file (default: "")')
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
parser.add_argument('--log-freq', default=10, type=int,
metavar='N', help='batch logging frequency (default: 10)')
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='use pre-trained model')
#parser.add_argument('--num-gpu', type=int, default=1,
# help='Number of GPUS to use')
#num-gpu is set to zero(no gpu usage)
parser.add_argument('--num-gpu', type=int, default=0,
help='Number of GPUS to use')
parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true',
help='disable test time pool')
parser.add_argument('--no-prefetcher', action='store_true', default=False,
help='disable fast prefetcher')
parser.add_argument('--pin-mem', action='store_true', default=False,
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--channels-last', action='store_true', default=False,
help='Use channels_last memory layout')
#parser.add_argument('--amp', action='store_true', default=False,
# help='Use AMP mixed precision. Defaults to Apex, fallback to native Torch AMP.')
#parser.add_argument('--apex-amp', action='store_true', default=False,
# help='Use NVIDIA Apex AMP mixed precision')
#parser.add_argument('--native-amp', action='store_true', default=False,
# help='Use Native Torch AMP mixed precision')
parser.add_argument('--tf-preprocessing', action='store_true', default=False,
help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
help='use ema version of weights if present')
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
help='convert model torchscript for inference')
parser.add_argument('--legacy-jit', dest='legacy_jit', action='store_true',
help='use legacy jit mode for pytorch 1.5/1.5.1/1.6 to get back fusion performance')
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
help='Output csv file for validation results (summary)')
parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME',
help='Real labels JSON file for imagenet evaluation')
parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME',
help='Valid label indices txt file for validation of partial label space')
def validate(args):
# might as well try to validate something
args.pretrained = args.pretrained or not args.checkpoint
args.prefetcher = not args.no_prefetcher
# amp_autocast = suppress # do nothing
# if args.amp:
# if has_native_amp:
# args.native_amp = True
# elif has_apex:
# args.apex_amp = True
# else:
# _logger.warning("Neither APEX or Native Torch AMP is available.")
# assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set."
# if args.native_amp:
# amp_autocast = torch.cuda.amp.autocast
# _logger.info('Validating in mixed precision with native PyTorch AMP.')
# elif args.apex_amp:
# _logger.info('Validating in mixed precision with NVIDIA APEX AMP.')
# else:
# _logger.info('Validating in float32. AMP not enabled.')
if args.legacy_jit:
set_jit_legacy()
# create model
model = create_model(
args.model,
pretrained=args.pretrained,
num_classes=args.num_classes,
in_chans=3,
global_pool=args.gp,
scriptable=args.torchscript)
if args.num_classes is None:
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
args.num_classes = model.num_classes
if args.checkpoint:
load_checkpoint(model, args.checkpoint, args.use_ema)
param_count = sum([m.numel() for m in model.parameters()])
_logger.info('Model %s created, param count: %d' % (args.model, param_count))
data_config = resolve_data_config(vars(args), model=model, use_test_size=True)
test_time_pool = False
if not args.no_test_pool:
model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True)
if args.torchscript:
torch.jit.optimized_execution(True)
model = torch.jit.script(model)
# model = model.cuda()
# if args.apex_amp:
# model = amp.initialize(model, opt_level='O1')
if args.channels_last:
model = model.to(memory_format=torch.channels_last)
# if args.num_gpu > 1:
# model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))
# criterion = nn.CrossEntropyLoss().cuda()
criterion = nn.CrossEntropyLoss()
dataset = create_dataset(
root=args.data, name=args.dataset, split=args.split,
load_bytes=args.tf_preprocessing, class_map=args.class_map)
# added for post quantization calibration
calib_dataset = create_dataset(
root=args.data, name=args.dataset, split=args.split,
load_bytes=args.tf_preprocessing, class_map=args.class_map)
if args.valid_labels:
with open(args.valid_labels, 'r') as f:
valid_labels = {int(line.rstrip()) for line in f}
valid_labels = [i in valid_labels for i in range(args.num_classes)]
else:
valid_labels = None
if args.real_labels:
real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels)
else:
real_labels = None
crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
loader = create_loader(
dataset,
input_size=data_config['input_size'],
batch_size=args.batch_size,
use_prefetcher=args.prefetcher,
interpolation=data_config['interpolation'],
mean=data_config['mean'],
std=data_config['std'],
num_workers=args.workers,
crop_pct=crop_pct,
pin_memory=args.pin_mem,
tf_preprocessing=args.tf_preprocessing)
#Also create loader for calibration dataset
calib_loader = create_loader(
calib_dataset,
input_size=data_config['input_size'],
batch_size=args.batch_size,
use_prefetcher=args.prefetcher,
interpolation=data_config['interpolation'],
mean=data_config['mean'],
std=data_config['std'],
num_workers=args.workers,
crop_pct=crop_pct,
pin_memory=args.pin_mem,
tf_preprocessing=args.tf_preprocessing)
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
print('Start calibration of quantization observers before post-quantization')
model_to_quantize = copy.deepcopy(model)
model_to_quantize.eval()
#post training static quantization
if args.quant_option == 'static':
qconfig_dict = {"": torch.quantization.default_static_qconfig}
model_to_quantize = copy.deepcopy(model_fp)
qconfig_dict = {"": torch.quantization.get_default_qconfig('qnnpack')}
model_to_quantize.eval()
# prepare
model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_dict)
# calibrate
with torch.no_grad():
# warmup, reduce variability of first batch time, especially for comparing torchscript vs non
input = torch.randn((args.batch_size,) + tuple(data_config['input_size']))
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
model(input)
end = time.time()
for batch_idx, (input, target) in enumerate(loader):
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
if valid_labels is not None:
output = output[:, valid_labels]
loss = criterion(output, target)
if real_labels is not None:
real_labels.add_result(output)
# measure accuracy and record loss
acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5))
losses.update(loss.item(), input.size(0))
top1.update(acc1.item(), input.size(0))
top5.update(acc5.item(), input.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if batch_idx % args.log_freq == 0:
_logger.info(
'Test: [{0:>4d}/{1}] '
'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) '
'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
batch_idx, len(loader), batch_time=batch_time,
rate_avg=input.size(0) / batch_time.avg,
loss=losses, top1=top1, top5=top5))
# quantize
model_quantized = quantize_fx.convert_fx(model_prepared)
#post training dynamic/weight only quantization
elif args.quant_option == 'dynamic':
qconfig_dict = {"": torch.quantization.default_dynamic_qconfig}
# prepare
model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_dict)
# no calibration needed when we only have dynamici/weight_only quantization
# quantize
model_quantized = quantize_fx.convert_fx(model_prepared)
else:
_logger.warning("Invalid quantization option. Set option to default(static)")
#
# fusion
#
model_to_quantize = copy.deepcopy(model_fp)
model_fused = quantize_fx.fuse_fx(model_to_quantize)
model = model_fused
with torch.no_grad():
# warmup, reduce variability of first batch time, especially for comparing torchscript vs non
# input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).cuda()
input = torch.randn((args.batch_size,) + tuple(data_config['input_size']))
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
model(input)
end = time.time()
for batch_idx, (input, target) in enumerate(loader):
# if args.no_prefetcher:
# target = target.cuda()
# input = input.cuda()
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
# compute output
# with amp_autocast():
# output = model(input)
if valid_labels is not None:
output = output[:, valid_labels]
loss = criterion(output, target)
if real_labels is not None:
real_labels.add_result(output)
# measure accuracy and record loss
acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5))
losses.update(loss.item(), input.size(0))
top1.update(acc1.item(), input.size(0))
top5.update(acc5.item(), input.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if batch_idx % args.log_freq == 0:
_logger.info(
'Test: [{0:>4d}/{1}] '
'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) '
'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
batch_idx, len(loader), batch_time=batch_time,
rate_avg=input.size(0) / batch_time.avg,
loss=losses, top1=top1, top5=top5))
if real_labels is not None:
# real labels mode replaces topk values at the end
top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5)
else:
top1a, top5a = top1.avg, top5.avg
results = OrderedDict(
top1=round(top1a, 4), top1_err=round(100 - top1a, 4),
top5=round(top5a, 4), top5_err=round(100 - top5a, 4),
param_count=round(param_count / 1e6, 2),
img_size=data_config['input_size'][-1],
cropt_pct=crop_pct,
interpolation=data_config['interpolation'])
_logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
results['top1'], results['top1_err'], results['top5'], results['top5_err']))
return results
def main():
setup_default_logging()
args = parser.parse_args()
model_cfgs = []
model_names = []
if os.path.isdir(args.checkpoint):
# validate all checkpoints in a path with same model
checkpoints = glob.glob(args.checkpoint + '/*.pth.tar')
checkpoints += glob.glob(args.checkpoint + '/*.pth')
model_names = list_models(args.model)
model_cfgs = [(args.model, c) for c in sorted(checkpoints, key=natural_key)]
else:
if args.model == 'all':
# validate all models in a list of names with pretrained checkpoints
args.pretrained = True
model_names = list_models(pretrained=True, exclude_filters=['*in21k'])
model_cfgs = [(n, '') for n in model_names]
elif not is_model(args.model):
# model name doesn't exist, try as wildcard filter
model_names = list_models(args.model)
model_cfgs = [(n, '') for n in model_names]
if len(model_cfgs):
results_file = args.results_file or './results-all.csv'
_logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names)))
results = []
try:
start_batch_size = args.batch_size
for m, c in model_cfgs:
batch_size = start_batch_size
args.model = m
args.checkpoint = c
result = OrderedDict(model=args.model)
r = {}
while not r and batch_size >= args.num_gpu:
# torch.cuda.empty_cache()
torch.empty_cache()
try:
args.batch_size = batch_size
print('Validating with batch size: %d' % args.batch_size)
r = validate(args)
except RuntimeError as e:
if batch_size <= args.num_gpu:
print("Validation failed with no ability to reduce batch size. Exiting.")
raise e
batch_size = max(batch_size // 2, args.num_gpu)
print("Validation failed, reducing batch size by 50%")
result.update(r)
if args.checkpoint:
result['checkpoint'] = args.checkpoint
results.append(result)
except KeyboardInterrupt as e:
pass
results = sorted(results, key=lambda x: x['top1'], reverse=True)
if len(results):
write_results(results_file, results)
else:
validate(args)
def write_results(results_file, results):
with open(results_file, mode='w') as cf:
dw = csv.DictWriter(cf, fieldnames=results[0].keys())
dw.writeheader()
for r in results:
dw.writerow(r)
cf.flush()
if __name__ == '__main__':
main()

@ -0,0 +1,252 @@
""" Loader Factory, Fast Collate, CUDA Prefetcher
Prefetcher and Fast Collate inspired by NVIDIA APEX example at
https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#diff-cf86c282ff7fba81fad27a559379d5bf
Hacked together by / Copyright 2020 Ross Wightman
"""
import torch.utils.data
import numpy as np
from .transforms_factory import create_transform
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .distributed_sampler import OrderedDistributedSampler
from .random_erasing import RandomErasing
from .mixup import FastCollateMixup
def fast_collate(batch):
""" A fast collation function optimized for uint8 images (np array or torch) and int64 targets (labels)"""
assert isinstance(batch[0], tuple)
batch_size = len(batch)
if isinstance(batch[0][0], tuple):
# This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position
# such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position
inner_tuple_size = len(batch[0][0])
flattened_batch_size = batch_size * inner_tuple_size
targets = torch.zeros(flattened_batch_size, dtype=torch.int64)
tensor = torch.zeros((flattened_batch_size, *batch[0][0][0].shape), dtype=torch.uint8)
for i in range(batch_size):
assert len(batch[i][0]) == inner_tuple_size # all input tensor tuples must be same length
for j in range(inner_tuple_size):
targets[i + j * batch_size] = batch[i][1]
tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j])
return tensor, targets
elif isinstance(batch[0][0], np.ndarray):
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
assert len(targets) == batch_size
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
for i in range(batch_size):
tensor[i] += torch.from_numpy(batch[i][0])
return tensor, targets
elif isinstance(batch[0][0], torch.Tensor):
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
assert len(targets) == batch_size
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
for i in range(batch_size):
tensor[i].copy_(batch[i][0])
return tensor, targets
else:
assert False
class PrefetchLoader:
def __init__(self,
loader,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
fp16=False,
re_prob=0.,
re_mode='const',
re_count=1,
re_num_splits=0):
self.loader = loader
self.mean = torch.tensor([x * 255 for x in mean]).view(1, 3, 1, 1)
self.std = torch.tensor([x * 255 for x in std]).view(1, 3, 1, 1)
self.fp16 = fp16
if fp16:
self.mean = self.mean.half()
self.std = self.std.half()
if re_prob > 0.:
self.random_erasing = RandomErasing(
probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits)
else:
self.random_erasing = None
def __iter__(self):
first = True
for next_input, next_target in self.loader:
if self.fp16:
next_input = next_input.half().sub_(self.mean).div_(self.std)
else:
next_input = next_input.float().sub_(self.mean).div_(self.std)
if self.random_erasing is not None:
next_input = self.random_erasing(next_input)
if not first:
yield input, target
else:
first = False
input = next_input
target = next_target
yield input, target
def __len__(self):
return len(self.loader)
@property
def sampler(self):
return self.loader.sampler
@property
def dataset(self):
return self.loader.dataset
@property
def mixup_enabled(self):
if isinstance(self.loader.collate_fn, FastCollateMixup):
return self.loader.collate_fn.mixup_enabled
else:
return False
@mixup_enabled.setter
def mixup_enabled(self, x):
if isinstance(self.loader.collate_fn, FastCollateMixup):
self.loader.collate_fn.mixup_enabled = x
def create_loader(
dataset,
input_size,
batch_size,
is_training=False,
use_prefetcher=True,
no_aug=False,
re_prob=0.,
re_mode='const',
re_count=1,
re_split=False,
scale=None,
ratio=None,
hflip=0.5,
vflip=0.,
color_jitter=0.4,
auto_augment=None,
num_aug_splits=0,
interpolation='bilinear',
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
num_workers=1,
distributed=False,
crop_pct=None,
collate_fn=None,
pin_memory=False,
fp16=False,
tf_preprocessing=False,
use_multi_epochs_loader=False
):
re_num_splits = 0
if re_split:
# apply RE to second half of batch if no aug split otherwise line up with aug split
re_num_splits = num_aug_splits or 2
dataset.transform = create_transform(
input_size,
is_training=is_training,
use_prefetcher=use_prefetcher,
no_aug=no_aug,
scale=scale,
ratio=ratio,
hflip=hflip,
vflip=vflip,
color_jitter=color_jitter,
auto_augment=auto_augment,
interpolation=interpolation,
mean=mean,
std=std,
crop_pct=crop_pct,
tf_preprocessing=tf_preprocessing,
re_prob=re_prob,
re_mode=re_mode,
re_count=re_count,
re_num_splits=re_num_splits,
separate=num_aug_splits > 0,
)
sampler = None
if distributed:
if is_training:
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
else:
# This will add extra duplicate entries to result in equal num
# of samples per-process, will slightly alter validation results
sampler = OrderedDistributedSampler(dataset)
if collate_fn is None:
collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate
loader_class = torch.utils.data.DataLoader
if use_multi_epochs_loader:
loader_class = MultiEpochsDataLoader
loader = loader_class(
dataset,
batch_size=batch_size,
shuffle=sampler is None and is_training,
num_workers=num_workers,
sampler=sampler,
collate_fn=collate_fn,
pin_memory=pin_memory,
drop_last=is_training,
)
if use_prefetcher:
prefetch_re_prob = re_prob if is_training and not no_aug else 0.
loader = PrefetchLoader(
loader,
mean=mean,
std=std,
fp16=fp16,
re_prob=prefetch_re_prob,
re_mode=re_mode,
re_count=re_count,
re_num_splits=re_num_splits
)
return loader
class MultiEpochsDataLoader(torch.utils.data.DataLoader):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._DataLoader__initialized = False
self.batch_sampler = _RepeatSampler(self.batch_sampler)
self._DataLoader__initialized = True
self.iterator = super().__iter__()
def __len__(self):
return len(self.batch_sampler.sampler)
def __iter__(self):
for i in range(len(self)):
yield next(self.iterator)
class _RepeatSampler(object):
""" Sampler that repeats forever.
Args:
sampler (Sampler)
"""
def __init__(self, sampler):
self.sampler = sampler
def __iter__(self):
while True:
yield from iter(self.sampler)

@ -29,7 +29,13 @@ from .vision_transformer import *
from .vovnet import * from .vovnet import *
from .xception import * from .xception import *
from .xception_aligned import * from .xception_aligned import *
<<<<<<< HEAD
from .quantization.efficientnet import *
from .quantization.mobilenetv3 import *
from .quantization.rexnet import *
=======
from .hardcorenas import * from .hardcorenas import *
>>>>>>> upstream/master
from .factory import create_model, split_model_name, safe_model_name from .factory import create_model, split_model_name, safe_model_name
from .helpers import load_checkpoint, resume_checkpoint, model_parameters from .helpers import load_checkpoint, resume_checkpoint, model_parameters

@ -0,0 +1,10 @@
from .efficientnet import *
from .mobilenetv3 import *
from .rexnet import *
from timm.models.factory import create_model
from timm.models.helpers import load_checkpoint, resume_checkpoint
from .layers import TestTimePoolHead, apply_test_time_pool
from .layers import convert_splitbn_model
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit
from timm.models.registry import *

File diff suppressed because it is too large Load Diff

@ -0,0 +1,423 @@
""" EfficientNet, MobileNetV3, etc Blocks
Hacked together by / Copyright 2020 Ross Wightman
"""
import torch
import torch.nn as nn
from torch.nn import functional as F
from .layers import create_conv2d, drop_path, get_act_layer
from .layers.activations import sigmoid, HardSigmoid
# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per
# papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay)
# NOTE: momentum varies btw .99 and .9997 depending on source
# .99 in official TF TPU impl
# .9997 (/w .999 in search space) for paper
BN_MOMENTUM_TF_DEFAULT = 1 - 0.99
BN_EPS_TF_DEFAULT = 1e-3
_BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT)
def get_bn_args_tf():
return _BN_ARGS_TF.copy()
def resolve_bn_args(kwargs):
bn_args = get_bn_args_tf() if kwargs.pop('bn_tf', False) else {}
bn_momentum = kwargs.pop('bn_momentum', None)
if bn_momentum is not None:
bn_args['momentum'] = bn_momentum
bn_eps = kwargs.pop('bn_eps', None)
if bn_eps is not None:
bn_args['eps'] = bn_eps
return bn_args
_SE_ARGS_DEFAULT = dict(
gate_fn=sigmoid,
act_layer=None,
reduce_mid=False,
divisor=1)
def resolve_se_args(kwargs, in_chs, act_layer=None):
se_kwargs = kwargs.copy() if kwargs is not None else {}
# fill in args that aren't specified with the defaults
for k, v in _SE_ARGS_DEFAULT.items():
se_kwargs.setdefault(k, v)
# some models, like MobilNetV3, calculate SE reduction chs from the containing block's mid_ch instead of in_ch
if not se_kwargs.pop('reduce_mid'):
se_kwargs['reduced_base_chs'] = in_chs
# act_layer override, if it remains None, the containing block's act_layer will be used
if se_kwargs['act_layer'] is None:
assert act_layer is not None
se_kwargs['act_layer'] = act_layer
return se_kwargs
def resolve_act_layer(kwargs, default='relu'):
act_layer = kwargs.pop('act_layer', default)
if isinstance(act_layer, str):
act_layer = get_act_layer(act_layer)
return act_layer
def make_divisible(v, divisor=8, min_value=None):
min_value = min_value or divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None):
"""Round number of filters based on depth multiplier."""
if not multiplier:
return channels
channels *= multiplier
return make_divisible(channels, divisor, channel_min)
class ChannelShuffle(nn.Module):
# FIXME haven't used yet
def __init__(self, groups):
super(ChannelShuffle, self).__init__()
self.groups = groups
def forward(self, x):
"""Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]"""
N, C, H, W = x.size()
g = self.groups
assert C % g == 0, "Incompatible group size {} for input channel {}".format(
g, C
)
return (
x.view(N, g, int(C / g), H, W)
.permute(0, 2, 1, 3, 4)
.contiguous()
.view(N, C, H, W)
)
class SqueezeExcite(nn.Module):
def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None,
act_layer=nn.ReLU, gate_fn=sigmoid, divisor=1, **_):
super(SqueezeExcite, self).__init__()
self.gate_fn = gate_fn
reduced_chs = make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
self.act1 = act_layer(inplace=True)
self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)
self.quant_mul = nn.quantized.FloatFunctional()
if gate_fn == HardSigmoid:
self.gate_fn = HardSigmoid()
def forward(self, x):
x_se = self.avg_pool(x)
x_se = self.conv_reduce(x_se)
x_se = self.act1(x_se)
x_se = self.conv_expand(x_se)
x = self.quant_mul.mul(x, self.gate_fn(x_se))
return x
def fuse_module(self):
if type(self.act1) == nn.ReLU:
modules_to_fuse = ['conv_reduce','act1']
torch.quantization.fuse_modules(self, modules_to_fuse, inplace=True)
class ConvBnAct(nn.Module):
def __init__(self, in_chs, out_chs, kernel_size,
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d, norm_kwargs=None):
super(ConvBnAct, self).__init__()
norm_kwargs = norm_kwargs or {}
self.conv = create_conv2d(in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, padding=pad_type)
self.bn1 = norm_layer(out_chs, **norm_kwargs)
self.act1 = act_layer(inplace=True)
def feature_info(self, location):
if location == 'expansion': # output of conv after act, same as block coutput
info = dict(module='act1', hook_type='forward', num_chs=self.conv.out_channels)
else: # location == 'bottleneck', block output
info = dict(module='', hook_type='', num_chs=self.conv.out_channels)
return info
def forward(self, x):
x = self.conv(x)
x = self.bn1(x)
x = self.act1(x)
return x
def fuse_module(self):
modules_to_fuse = ['conv','bn1']
if type(self.act1) == nn.ReLU:
modules_to_fuse.append('act1')
torch.quantization.fuse_modules(self, modules_to_fuse, inplace=True)
class DepthwiseSeparableConv(nn.Module):
""" DepthwiseSeparable block
Used for DS convs in MobileNet-V1 and in the place of IR blocks that have no expansion
(factor of 1.0). This is an alternative to having a IR with an optional first pw conv.
"""
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False,
pw_kernel_size=1, pw_act=False, se_ratio=0., se_kwargs=None,
norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_path_rate=0.):
super(DepthwiseSeparableConv, self).__init__()
norm_kwargs = norm_kwargs or {}
has_se = se_ratio is not None and se_ratio > 0.
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
self.has_pw_act = pw_act # activation after point-wise conv
self.drop_path_rate = drop_path_rate
self.conv_dw = create_conv2d(
in_chs, in_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=pad_type, depthwise=True)
self.bn1 = norm_layer(in_chs, **norm_kwargs)
self.act1 = act_layer(inplace=True)
# Squeeze-and-excitation
if has_se:
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
self.se = SqueezeExcite(in_chs, se_ratio=se_ratio, **se_kwargs)
else:
self.se = None
self.conv_pw = create_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type)
self.bn2 = norm_layer(out_chs, **norm_kwargs)
self.act2 = act_layer(inplace=True) if self.has_pw_act else nn.Identity()
self.skip_add = nn.quantized.FloatFunctional()
def feature_info(self, location):
if location == 'expansion': # after SE, input to PW
info = dict(module='conv_pw', hook_type='forward_pre', num_chs=self.conv_pw.in_channels)
else: # location == 'bottleneck', block output
info = dict(module='', hook_type='', num_chs=self.conv_pw.out_channels)
return info
def forward(self, x):
residual = x
x = self.conv_dw(x)
x = self.bn1(x)
x = self.act1(x)
if self.se is not None:
x = self.se(x)
x = self.conv_pw(x)
x = self.bn2(x)
x = self.act2(x)
if self.has_residual:
if self.drop_path_rate > 0.:
x = drop_path(x, self.drop_path_rate, self.training)
x = self.skip_add.add(x, residual)
return x
def fuse_module(self):
modules_to_fuse = [['conv_dw','bn1'],['conv_pw','bn2']]
if type(self.act1) == nn.ReLU:
modules_to_fuse[0].append('act1')
if type(self.act2) == nn.ReLU:
modules_to_fuse[1].append('act2')
torch.quantization.fuse_modules(self, modules_to_fuse, inplace=True)
class InvertedResidual(nn.Module):
""" Inverted residual block w/ optional SE and CondConv routing"""
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False,
exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1,
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
conv_kwargs=None, drop_path_rate=0.):
super(InvertedResidual, self).__init__()
norm_kwargs = norm_kwargs or {}
conv_kwargs = conv_kwargs or {}
mid_chs = make_divisible(in_chs * exp_ratio)
has_se = se_ratio is not None and se_ratio > 0.
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
self.drop_path_rate = drop_path_rate
# Point-wise expansion
self.conv_pw = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs)
self.bn1 = norm_layer(mid_chs, **norm_kwargs)
self.act1 = act_layer(inplace=True)
# Depth-wise convolution
self.conv_dw = create_conv2d(
mid_chs, mid_chs, dw_kernel_size, stride=stride, dilation=dilation,
padding=pad_type, depthwise=True, **conv_kwargs)
self.bn2 = norm_layer(mid_chs, **norm_kwargs)
self.act2 = act_layer(inplace=True)
# Squeeze-and-excitation
if has_se:
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs)
else:
self.se = None
# Point-wise linear projection
self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs)
self.bn3 = norm_layer(out_chs, **norm_kwargs)
self.skip_add = nn.quantized.FloatFunctional()
def feature_info(self, location):
if location == 'expansion': # after SE, input to PWL
info = dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels)
else: # location == 'bottleneck', block output
info = dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels)
return info
def forward(self, x):
residual = x
# Point-wise expansion
x = self.conv_pw(x)
x = self.bn1(x)
x = self.act1(x)
# Depth-wise convolution
x = self.conv_dw(x)
x = self.bn2(x)
x = self.act2(x)
# Squeeze-and-excitation
if self.se is not None:
x = self.se(x)
# Point-wise linear projection
x = self.conv_pwl(x)
x = self.bn3(x)
if self.has_residual:
if self.drop_path_rate > 0.:
x = drop_path(x, self.drop_path_rate, self.training)
x = self.skip_add.add(x, residual)
return x
def fuse_module(self):
modules_to_fuse = [['conv_pw','bn1'],['conv_dw','bn2'],['conv_pwl','bn3']]
if type(self.act1) == nn.ReLU:
modules_to_fuse[0].append('act1')
if type(self.act2) == nn.ReLU:
modules_to_fuse[1].append('act2')
torch.quantization.fuse_modules(self, modules_to_fuse, inplace=True)
class CondConvResidual(InvertedResidual):
""" Inverted residual block w/ CondConv routing"""
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False,
exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1,
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
num_experts=0, drop_path_rate=0.):
self.num_experts = num_experts
conv_kwargs = dict(num_experts=self.num_experts)
super(CondConvResidual, self).__init__(
in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, dilation=dilation, pad_type=pad_type,
act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size,
pw_kernel_size=pw_kernel_size, se_ratio=se_ratio, se_kwargs=se_kwargs,
norm_layer=norm_layer, norm_kwargs=norm_kwargs, conv_kwargs=conv_kwargs,
drop_path_rate=drop_path_rate)
self.routing_fn = nn.Linear(in_chs, self.num_experts)
def forward(self, x):
residual = x
# CondConv routing
pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1)
routing_weights = torch.sigmoid(self.routing_fn(pooled_inputs))
# Point-wise expansion
x = self.conv_pw(x, routing_weights)
x = self.bn1(x)
x = self.act1(x)
# Depth-wise convolution
x = self.conv_dw(x, routing_weights)
x = self.bn2(x)
x = self.act2(x)
# Squeeze-and-excitation
if self.se is not None:
x = self.se(x)
# Point-wise linear projection
x = self.conv_pwl(x, routing_weights)
x = self.bn3(x)
if self.has_residual:
if self.drop_path_rate > 0.:
x = drop_path(x, self.drop_path_rate, self.training)
x += residual
return x
class EdgeResidual(nn.Module):
""" Residual block with expansion convolution followed by pointwise-linear w/ stride"""
def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_chs=0,
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, pw_kernel_size=1,
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
drop_path_rate=0.):
super(EdgeResidual, self).__init__()
norm_kwargs = norm_kwargs or {}
if fake_in_chs > 0:
mid_chs = make_divisible(fake_in_chs * exp_ratio)
else:
mid_chs = make_divisible(in_chs * exp_ratio)
has_se = se_ratio is not None and se_ratio > 0.
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
self.drop_path_rate = drop_path_rate
# Expansion convolution
self.conv_exp = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type)
self.bn1 = norm_layer(mid_chs, **norm_kwargs)
self.act1 = act_layer(inplace=True)
# Squeeze-and-excitation
if has_se:
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs)
else:
self.se = None
# Point-wise linear projection
self.conv_pwl = create_conv2d(
mid_chs, out_chs, pw_kernel_size, stride=stride, dilation=dilation, padding=pad_type)
self.bn2 = norm_layer(out_chs, **norm_kwargs)
def feature_info(self, location):
if location == 'expansion': # after SE, before PWL
info = dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels)
else: # location == 'bottleneck', block output
info = dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels)
return info
def forward(self, x):
residual = x
# Expansion convolution
x = self.conv_exp(x)
x = self.bn1(x)
x = self.act1(x)
# Squeeze-and-excitation
if self.se is not None:
x = self.se(x)
# Point-wise linear projection
x = self.conv_pwl(x)
x = self.bn2(x)
if self.has_residual:
if self.drop_path_rate > 0.:
x = drop_path(x, self.drop_path_rate, self.training)
x += residual
return x

@ -0,0 +1,414 @@
""" EfficientNet, MobileNetV3, etc Builder
Assembles EfficieNet and related network feature blocks from string definitions.
Handles stride, dilation calculations, and selects feature extraction points.
Hacked together by / Copyright 2020 Ross Wightman
"""
import logging
import math
import re
from copy import deepcopy
import torch.nn as nn
from .efficientnet_blocks import *
from .layers import CondConv2d, get_condconv_initializer
__all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights"]
_logger = logging.getLogger(__name__)
def _log_info_if(msg, condition):
if condition:
_logger.info(msg)
def _parse_ksize(ss):
if ss.isdigit():
return int(ss)
else:
return [int(k) for k in ss.split('.')]
def _decode_block_str(block_str):
""" Decode block definition string
Gets a list of block arg (dicts) through a string notation of arguments.
E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip
All args can exist in any order with the exception of the leading string which
is assumed to indicate the block type.
leading string - block type (
ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct)
r - number of repeat blocks,
k - kernel size,
s - strides (1-9),
e - expansion ratio,
c - output channels,
se - squeeze/excitation ratio
n - activation fn ('re', 'r6', 'hs', or 'sw')
Args:
block_str: a string representation of block arguments.
Returns:
A list of block args (dicts)
Raises:
ValueError: if the string def not properly specified (TODO)
"""
assert isinstance(block_str, str)
ops = block_str.split('_')
block_type = ops[0] # take the block type off the front
ops = ops[1:]
options = {}
noskip = False
for op in ops:
# string options being checked on individual basis, combine if they grow
if op == 'noskip':
noskip = True
elif op.startswith('n'):
# activation fn
key = op[0]
v = op[1:]
if v == 're':
value = get_act_layer('relu')
elif v == 'r6':
value = get_act_layer('relu6')
elif v == 'hs':
value = get_act_layer('hard_swish')
elif v == 'sw':
value = get_act_layer('swish')
else:
continue
options[key] = value
else:
# all numeric options
splits = re.split(r'(\d.*)', op)
if len(splits) >= 2:
key, value = splits[:2]
options[key] = value
# if act_layer is None, the model default (passed to model init) will be used
act_layer = options['n'] if 'n' in options else None
exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1
pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1
fake_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def
num_repeat = int(options['r'])
# each type of block has different valid arguments, fill accordingly
if block_type == 'ir':
block_args = dict(
block_type=block_type,
dw_kernel_size=_parse_ksize(options['k']),
exp_kernel_size=exp_kernel_size,
pw_kernel_size=pw_kernel_size,
out_chs=int(options['c']),
exp_ratio=float(options['e']),
se_ratio=float(options['se']) if 'se' in options else None,
stride=int(options['s']),
act_layer=act_layer,
noskip=noskip,
)
if 'cc' in options:
block_args['num_experts'] = int(options['cc'])
elif block_type == 'ds' or block_type == 'dsa':
block_args = dict(
block_type=block_type,
dw_kernel_size=_parse_ksize(options['k']),
pw_kernel_size=pw_kernel_size,
out_chs=int(options['c']),
se_ratio=float(options['se']) if 'se' in options else None,
stride=int(options['s']),
act_layer=act_layer,
pw_act=block_type == 'dsa',
noskip=block_type == 'dsa' or noskip,
)
elif block_type == 'er':
block_args = dict(
block_type=block_type,
exp_kernel_size=_parse_ksize(options['k']),
pw_kernel_size=pw_kernel_size,
out_chs=int(options['c']),
exp_ratio=float(options['e']),
fake_in_chs=fake_in_chs,
se_ratio=float(options['se']) if 'se' in options else None,
stride=int(options['s']),
act_layer=act_layer,
noskip=noskip,
)
elif block_type == 'cn':
block_args = dict(
block_type=block_type,
kernel_size=int(options['k']),
out_chs=int(options['c']),
stride=int(options['s']),
act_layer=act_layer,
)
else:
assert False, 'Unknown block type (%s)' % block_type
return block_args, num_repeat
def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'):
""" Per-stage depth scaling
Scales the block repeats in each stage. This depth scaling impl maintains
compatibility with the EfficientNet scaling method, while allowing sensible
scaling for other models that may have multiple block arg definitions in each stage.
"""
# We scale the total repeat count for each stage, there may be multiple
# block arg defs per stage so we need to sum.
num_repeat = sum(repeats)
if depth_trunc == 'round':
# Truncating to int by rounding allows stages with few repeats to remain
# proportionally smaller for longer. This is a good choice when stage definitions
# include single repeat stages that we'd prefer to keep that way as long as possible
num_repeat_scaled = max(1, round(num_repeat * depth_multiplier))
else:
# The default for EfficientNet truncates repeats to int via 'ceil'.
# Any multiplier > 1.0 will result in an increased depth for every stage.
num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier))
# Proportionally distribute repeat count scaling to each block definition in the stage.
# Allocation is done in reverse as it results in the first block being less likely to be scaled.
# The first block makes less sense to repeat in most of the arch definitions.
repeats_scaled = []
for r in repeats[::-1]:
rs = max(1, round((r / num_repeat * num_repeat_scaled)))
repeats_scaled.append(rs)
num_repeat -= r
num_repeat_scaled -= rs
repeats_scaled = repeats_scaled[::-1]
# Apply the calculated scaling to each block arg in the stage
sa_scaled = []
for ba, rep in zip(stack_args, repeats_scaled):
sa_scaled.extend([deepcopy(ba) for _ in range(rep)])
return sa_scaled
def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_multiplier=1, fix_first_last=False):
arch_args = []
for stack_idx, block_strings in enumerate(arch_def):
assert isinstance(block_strings, list)
stack_args = []
repeats = []
for block_str in block_strings:
assert isinstance(block_str, str)
ba, rep = _decode_block_str(block_str)
if ba.get('num_experts', 0) > 0 and experts_multiplier > 1:
ba['num_experts'] *= experts_multiplier
stack_args.append(ba)
repeats.append(rep)
if fix_first_last and (stack_idx == 0 or stack_idx == len(arch_def) - 1):
arch_args.append(_scale_stage_depth(stack_args, repeats, 1.0, depth_trunc))
else:
arch_args.append(_scale_stage_depth(stack_args, repeats, depth_multiplier, depth_trunc))
return arch_args
class EfficientNetBuilder:
""" Build Trunk Blocks
This ended up being somewhat of a cross between
https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py
and
https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py
"""
def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
output_stride=32, pad_type='', act_layer=None, se_kwargs=None,
norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_path_rate=0., feature_location='',
verbose=False):
self.channel_multiplier = channel_multiplier
self.channel_divisor = channel_divisor
self.channel_min = channel_min
self.output_stride = output_stride
self.pad_type = pad_type
self.act_layer = act_layer
self.se_kwargs = se_kwargs
self.norm_layer = norm_layer
self.norm_kwargs = norm_kwargs
self.drop_path_rate = drop_path_rate
if feature_location == 'depthwise':
# old 'depthwise' mode renamed 'expansion' to match TF impl, old expansion mode didn't make sense
_logger.warning("feature_location=='depthwise' is deprecated, using 'expansion'")
feature_location = 'expansion'
self.feature_location = feature_location
assert feature_location in ('bottleneck', 'expansion', '')
self.verbose = verbose
# state updated during build, consumed by model
self.in_chs = None
self.features = []
def _round_channels(self, chs):
return round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min)
def _make_block(self, ba, block_idx, block_count):
drop_path_rate = self.drop_path_rate * block_idx / block_count
bt = ba.pop('block_type')
ba['in_chs'] = self.in_chs
ba['out_chs'] = self._round_channels(ba['out_chs'])
if 'fake_in_chs' in ba and ba['fake_in_chs']:
# FIXME this is a hack to work around mismatch in origin impl input filters
ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs'])
ba['norm_layer'] = self.norm_layer
ba['norm_kwargs'] = self.norm_kwargs
ba['pad_type'] = self.pad_type
# block act fn overrides the model default
ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer
assert ba['act_layer'] is not None
if bt == 'ir':
ba['drop_path_rate'] = drop_path_rate
ba['se_kwargs'] = self.se_kwargs
_log_info_if(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
if ba.get('num_experts', 0) > 0:
block = CondConvResidual(**ba)
else:
block = InvertedResidual(**ba)
elif bt == 'ds' or bt == 'dsa':
ba['drop_path_rate'] = drop_path_rate
ba['se_kwargs'] = self.se_kwargs
_log_info_if(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
block = DepthwiseSeparableConv(**ba)
elif bt == 'er':
ba['drop_path_rate'] = drop_path_rate
ba['se_kwargs'] = self.se_kwargs
_log_info_if(' EdgeResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
block = EdgeResidual(**ba)
elif bt == 'cn':
_log_info_if(' ConvBnAct {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
block = ConvBnAct(**ba)
else:
assert False, 'Uknkown block type (%s) while building model.' % bt
self.in_chs = ba['out_chs'] # update in_chs for arg of next block
return block
def __call__(self, in_chs, model_block_args):
""" Build the blocks
Args:
in_chs: Number of input-channels passed to first block
model_block_args: A list of lists, outer list defines stages, inner
list contains strings defining block configuration(s)
Return:
List of block stacks (each stack wrapped in nn.Sequential)
"""
_log_info_if('Building model trunk with %d stages...' % len(model_block_args), self.verbose)
self.in_chs = in_chs
total_block_count = sum([len(x) for x in model_block_args])
total_block_idx = 0
current_stride = 2
current_dilation = 1
stages = []
if model_block_args[0][0]['stride'] > 1:
# if the first block starts with a stride, we need to extract first level feat from stem
feature_info = dict(
module='act1', num_chs=in_chs, stage=0, reduction=current_stride,
hook_type='forward' if self.feature_location != 'bottleneck' else '')
self.features.append(feature_info)
# outer list of block_args defines the stacks
for stack_idx, stack_args in enumerate(model_block_args):
last_stack = stack_idx + 1 == len(model_block_args)
_log_info_if('Stack: {}'.format(stack_idx), self.verbose)
assert isinstance(stack_args, list)
blocks = []
# each stack (stage of blocks) contains a list of block arguments
for block_idx, block_args in enumerate(stack_args):
last_block = block_idx + 1 == len(stack_args)
_log_info_if(' Block: {}'.format(block_idx), self.verbose)
assert block_args['stride'] in (1, 2)
if block_idx >= 1: # only the first block in any stack can have a stride > 1
block_args['stride'] = 1
extract_features = False
if last_block:
next_stack_idx = stack_idx + 1
extract_features = next_stack_idx >= len(model_block_args) or \
model_block_args[next_stack_idx][0]['stride'] > 1
next_dilation = current_dilation
if block_args['stride'] > 1:
next_output_stride = current_stride * block_args['stride']
if next_output_stride > self.output_stride:
next_dilation = current_dilation * block_args['stride']
block_args['stride'] = 1
_log_info_if(' Converting stride to dilation to maintain output_stride=={}'.format(
self.output_stride), self.verbose)
else:
current_stride = next_output_stride
block_args['dilation'] = current_dilation
if next_dilation != current_dilation:
current_dilation = next_dilation
# create the block
block = self._make_block(block_args, total_block_idx, total_block_count)
blocks.append(block)
# stash feature module name and channel info for model feature extraction
if extract_features:
feature_info = dict(
stage=stack_idx + 1, reduction=current_stride, **block.feature_info(self.feature_location))
module_name = f'blocks.{stack_idx}.{block_idx}'
leaf_name = feature_info.get('module', '')
feature_info['module'] = '.'.join([module_name, leaf_name]) if leaf_name else module_name
self.features.append(feature_info)
total_block_idx += 1 # incr global block idx (across all stacks)
stages.append(nn.Sequential(*blocks))
return stages
def _init_weight_goog(m, n='', fix_group_fanout=True):
""" Weight initialization as per Tensorflow official implementations.
Args:
m (nn.Module): module to init
n (str): module name
fix_group_fanout (bool): enable correct (matching Tensorflow TPU impl) fanout calculation w/ group convs
Handles layers in EfficientNet, EfficientNet-CondConv, MixNet, MnasNet, MobileNetV3, etc:
* https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py
* https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
"""
if isinstance(m, CondConv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
if fix_group_fanout:
fan_out //= m.groups
init_weight_fn = get_condconv_initializer(
lambda w: w.data.normal_(0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape)
init_weight_fn(m.weight)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
if fix_group_fanout:
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1.0)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
fan_out = m.weight.size(0) # fan-out
fan_in = 0
if 'routing_fn' in n:
fan_in = m.weight.size(1)
init_range = 1.0 / math.sqrt(fan_in + fan_out)
m.weight.data.uniform_(-init_range, init_range)
m.bias.data.zero_()
def efficientnet_init_weights(model: nn.Module, init_fn=None):
init_fn = init_fn or _init_weight_goog
for n, m in model.named_modules():
init_fn(m, n)

@ -0,0 +1,31 @@
from .activations import *
from .adaptive_avgmax_pool import \
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
from .anti_aliasing import AntiAliasDownsampleLayer
from .blur_pool import BlurPool2d
from .classifier import ClassifierHead, create_classifier
from .cond_conv2d import CondConv2d, get_condconv_initializer
from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\
set_layer_config
from .conv2d_same import Conv2dSame
from .conv_bn_act import ConvBnAct
from .create_act import create_act_layer, get_act_layer
from .create_conv2d import create_conv2d
from .create_norm_act import create_norm_act, get_norm_act_layer
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
from .inplace_abn import InplaceAbn
from .mixed_conv2d import MixedConv2d
from .norm_act import BatchNormAct2d
from .padding import get_padding
from .pool2d_same import AvgPool2dSame, create_pool2d
from .selective_kernel import SelectiveKernelConv
from .separable_conv import SeparableConv2d, SeparableConvBnAct
from .space_to_depth import SpaceToDepthModule
from .split_attn import SplitAttnConv2d
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
from .weight_init import trunc_normal_

@ -0,0 +1,66 @@
""" Activations
A collection of activations fn and modules with a common interface so that they can
easily be swapped. All have an `inplace` arg even if not used.
Hacked together by / Copyright 2020 Ross Wightman
"""
import torch
from torch import nn as nn
from torch.nn import functional as F
class Swish(nn.Module):
def __init__(self, inplace: bool = False):
super(Swish, self).__init__()
nn.sig = nn.Sigmoid()
self.skip = nn.quantized.FloatFunctional()
def forward(self, x):
out = nn.sig(x)
return self.skip.mul(x,out)
def sigmoid(x, inplace: bool = False):
return x.sigmoid_() if inplace else x.sigmoid()
# PyTorch has this, but not with a consistent inplace argmument interface
class Sigmoid(nn.Module):
def __init__(self, inplace: bool = False):
super(Sigmoid, self).__init__()
nn.sig = nn.Sigmoid()
def forward(self, x):
out = nn.sig(x)
return out
class HardSwish(nn.Module):
def __init__(self, inplace: bool = False):
super(HardSwish, self).__init__()
self.relu6 = nn.ReLU6(inplace)
self.quant_mul1 = nn.quantized.FloatFunctional()
self.quant_mul2 = nn.quantized.FloatFunctional()
self.quant_add = nn.quantized.FloatFunctional()
def forward(self, x):
out = self.quant_add.add_scalar(x, 3.0)
out = self.relu6(out)
out = self.quant_mul1.mul(x,out)
out = self.quant_mul2.mul_scalar(out, 1/6)
return out
class HardSigmoid(nn.Module):
def __init__(self, inplace: bool = False):
super(HardSigmoid, self).__init__()
self.relu6 = nn.ReLU6(inplace)
self.quant_add = nn.quantized.FloatFunctional()
self.quant_mul = nn.quantized.FloatFunctional()
def forward(self, x):
out = self.quant_add.add_scalar(x, 3.0)
out = self.relu6(out)
out = self.quant_mul.mul_scalar(out,1/6)
return out

@ -0,0 +1,90 @@
""" Activations
A collection of jit-scripted activations fn and modules with a common interface so that they can
easily be swapped. All have an `inplace` arg even if not used.
All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not
currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted
versions if they contain in-place ops.
Hacked together by / Copyright 2020 Ross Wightman
"""
import torch
from torch import nn as nn
from torch.nn import functional as F
@torch.jit.script
def swish_jit(x, inplace: bool = False):
"""Swish - Described in: https://arxiv.org/abs/1710.05941
"""
return x.mul(x.sigmoid())
@torch.jit.script
def mish_jit(x, _inplace: bool = False):
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
"""
return x.mul(F.softplus(x).tanh())
class SwishJit(nn.Module):
def __init__(self, inplace: bool = False):
super(SwishJit, self).__init__()
def forward(self, x):
return swish_jit(x)
class MishJit(nn.Module):
def __init__(self, inplace: bool = False):
super(MishJit, self).__init__()
def forward(self, x):
return mish_jit(x)
@torch.jit.script
def hard_sigmoid_jit(x, inplace: bool = False):
# return F.relu6(x + 3.) / 6.
return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster?
class HardSigmoidJit(nn.Module):
def __init__(self, inplace: bool = False):
super(HardSigmoidJit, self).__init__()
def forward(self, x):
return hard_sigmoid_jit(x)
@torch.jit.script
def hard_swish_jit(x, inplace: bool = False):
# return x * (F.relu6(x + 3.) / 6)
return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster?
class HardSwishJit(nn.Module):
def __init__(self, inplace: bool = False):
super(HardSwishJit, self).__init__()
def forward(self, x):
return hard_swish_jit(x)
@torch.jit.script
def hard_mish_jit(x, inplace: bool = False):
""" Hard Mish
Experimental, based on notes by Mish author Diganta Misra at
https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
"""
return 0.5 * x * (x + 2).clamp(min=0, max=2)
class HardMishJit(nn.Module):
def __init__(self, inplace: bool = False):
super(HardMishJit, self).__init__()
def forward(self, x):
return hard_mish_jit(x)

@ -0,0 +1,208 @@
""" Activations (memory-efficient w/ custom autograd)
A collection of activations fn and modules with a common interface so that they can
easily be swapped. All have an `inplace` arg even if not used.
These activations are not compatible with jit scripting or ONNX export of the model, please use either
the JIT or basic versions of the activations.
Hacked together by / Copyright 2020 Ross Wightman
"""
import torch
from torch import nn as nn
from torch.nn import functional as F
@torch.jit.script
def swish_jit_fwd(x):
return x.mul(torch.sigmoid(x))
@torch.jit.script
def swish_jit_bwd(x, grad_output):
x_sigmoid = torch.sigmoid(x)
return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))
class SwishJitAutoFn(torch.autograd.Function):
""" torch.jit.script optimised Swish w/ memory-efficient checkpoint
Inspired by conversation btw Jeremy Howard & Adam Pazske
https://twitter.com/jeremyphoward/status/1188251041835315200
"""
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return swish_jit_fwd(x)
@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_tensors[0]
return swish_jit_bwd(x, grad_output)
def swish_me(x, inplace=False):
return SwishJitAutoFn.apply(x)
class SwishMe(nn.Module):
def __init__(self, inplace: bool = False):
super(SwishMe, self).__init__()
def forward(self, x):
return SwishJitAutoFn.apply(x)
@torch.jit.script
def mish_jit_fwd(x):
return x.mul(torch.tanh(F.softplus(x)))
@torch.jit.script
def mish_jit_bwd(x, grad_output):
x_sigmoid = torch.sigmoid(x)
x_tanh_sp = F.softplus(x).tanh()
return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))
class MishJitAutoFn(torch.autograd.Function):
""" Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
A memory efficient, jit scripted variant of Mish
"""
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return mish_jit_fwd(x)
@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_tensors[0]
return mish_jit_bwd(x, grad_output)
def mish_me(x, inplace=False):
return MishJitAutoFn.apply(x)
class MishMe(nn.Module):
def __init__(self, inplace: bool = False):
super(MishMe, self).__init__()
def forward(self, x):
return MishJitAutoFn.apply(x)
@torch.jit.script
def hard_sigmoid_jit_fwd(x, inplace: bool = False):
return (x + 3).clamp(min=0, max=6).div(6.)
@torch.jit.script
def hard_sigmoid_jit_bwd(x, grad_output):
m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6.
return grad_output * m
class HardSigmoidJitAutoFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return hard_sigmoid_jit_fwd(x)
@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_tensors[0]
return hard_sigmoid_jit_bwd(x, grad_output)
def hard_sigmoid_me(x, inplace: bool = False):
return HardSigmoidJitAutoFn.apply(x)
class HardSigmoidMe(nn.Module):
def __init__(self, inplace: bool = False):
super(HardSigmoidMe, self).__init__()
def forward(self, x):
return HardSigmoidJitAutoFn.apply(x)
@torch.jit.script
def hard_swish_jit_fwd(x):
return x * (x + 3).clamp(min=0, max=6).div(6.)
@torch.jit.script
def hard_swish_jit_bwd(x, grad_output):
m = torch.ones_like(x) * (x >= 3.)
m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m)
return grad_output * m
class HardSwishJitAutoFn(torch.autograd.Function):
"""A memory efficient, jit-scripted HardSwish activation"""
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return hard_swish_jit_fwd(x)
@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_tensors[0]
return hard_swish_jit_bwd(x, grad_output)
def hard_swish_me(x, inplace=False):
return HardSwishJitAutoFn.apply(x)
class HardSwishMe(nn.Module):
def __init__(self, inplace: bool = False):
super(HardSwishMe, self).__init__()
def forward(self, x):
return HardSwishJitAutoFn.apply(x)
@torch.jit.script
def hard_mish_jit_fwd(x):
return 0.5 * x * (x + 2).clamp(min=0, max=2)
@torch.jit.script
def hard_mish_jit_bwd(x, grad_output):
m = torch.ones_like(x) * (x >= -2.)
m = torch.where((x >= -2.) & (x <= 0.), x + 1., m)
return grad_output * m
class HardMishJitAutoFn(torch.autograd.Function):
""" A memory efficient, jit scripted variant of Hard Mish
Experimental, based on notes by Mish author Diganta Misra at
https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
"""
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return hard_mish_jit_fwd(x)
@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_tensors[0]
return hard_mish_jit_bwd(x, grad_output)
def hard_mish_me(x, inplace: bool = False):
return HardMishJitAutoFn.apply(x)
class HardMishMe(nn.Module):
def __init__(self, inplace: bool = False):
super(HardMishMe, self).__init__()
def forward(self, x):
return HardMishJitAutoFn.apply(x)

@ -0,0 +1,106 @@
""" PyTorch selectable adaptive pooling
Adaptive pooling with the ability to select the type of pooling from:
* 'avg' - Average pooling
* 'max' - Max pooling
* 'avgmax' - Sum of average and max pooling re-scaled by 0.5
* 'avgmaxc' - Concatenation of average and max pooling along feature dim, doubles feature dim
Both a functional and a nn.Module version of the pooling is provided.
Hacked together by / Copyright 2020 Ross Wightman
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
def adaptive_pool_feat_mult(pool_type='avg'):
if pool_type == 'catavgmax':
return 2
else:
return 1
def adaptive_avgmax_pool2d(x, output_size=1):
x_avg = F.adaptive_avg_pool2d(x, output_size)
x_max = F.adaptive_max_pool2d(x, output_size)
return 0.5 * (x_avg + x_max)
def adaptive_catavgmax_pool2d(x, output_size=1):
x_avg = F.adaptive_avg_pool2d(x, output_size)
x_max = F.adaptive_max_pool2d(x, output_size)
return torch.cat((x_avg, x_max), 1)
def select_adaptive_pool2d(x, pool_type='avg', output_size=1):
"""Selectable global pooling function with dynamic input kernel size
"""
if pool_type == 'avg':
x = F.adaptive_avg_pool2d(x, output_size)
elif pool_type == 'avgmax':
x = adaptive_avgmax_pool2d(x, output_size)
elif pool_type == 'catavgmax':
x = adaptive_catavgmax_pool2d(x, output_size)
elif pool_type == 'max':
x = F.adaptive_max_pool2d(x, output_size)
else:
assert False, 'Invalid pool type: %s' % pool_type
return x
class AdaptiveAvgMaxPool2d(nn.Module):
def __init__(self, output_size=1):
super(AdaptiveAvgMaxPool2d, self).__init__()
self.output_size = output_size
def forward(self, x):
return adaptive_avgmax_pool2d(x, self.output_size)
class AdaptiveCatAvgMaxPool2d(nn.Module):
def __init__(self, output_size=1):
super(AdaptiveCatAvgMaxPool2d, self).__init__()
self.output_size = output_size
def forward(self, x):
return adaptive_catavgmax_pool2d(x, self.output_size)
class SelectAdaptivePool2d(nn.Module):
"""Selectable global pooling layer with dynamic input kernel size
"""
def __init__(self, output_size=1, pool_type='avg', flatten=False):
super(SelectAdaptivePool2d, self).__init__()
self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing
self.flatten = flatten
if pool_type == '':
self.pool = nn.Identity() # pass through
elif pool_type == 'avg':
self.pool = nn.AdaptiveAvgPool2d(output_size)
elif pool_type == 'avgmax':
self.pool = AdaptiveAvgMaxPool2d(output_size)
elif pool_type == 'catavgmax':
self.pool = AdaptiveCatAvgMaxPool2d(output_size)
elif pool_type == 'max':
self.pool = nn.AdaptiveMaxPool2d(output_size)
else:
assert False, 'Invalid pool type: %s' % pool_type
def is_identity(self):
return self.pool_type == ''
def forward(self, x):
x = self.pool(x)
if self.flatten:
x = x.flatten(1)
return x
def feat_mult(self):
return adaptive_pool_feat_mult(self.pool_type)
def __repr__(self):
return self.__class__.__name__ + ' (' \
+ 'pool_type=' + self.pool_type \
+ ', flatten=' + str(self.flatten) + ')'

@ -0,0 +1,60 @@
import torch
import torch.nn.parallel
import torch.nn as nn
import torch.nn.functional as F
class AntiAliasDownsampleLayer(nn.Module):
def __init__(self, channels: int = 0, filt_size: int = 3, stride: int = 2, no_jit: bool = False):
super(AntiAliasDownsampleLayer, self).__init__()
if no_jit:
self.op = Downsample(channels, filt_size, stride)
else:
self.op = DownsampleJIT(channels, filt_size, stride)
# FIXME I should probably override _apply and clear DownsampleJIT filter cache for .cuda(), .half(), etc calls
def forward(self, x):
return self.op(x)
@torch.jit.script
class DownsampleJIT(object):
def __init__(self, channels: int = 0, filt_size: int = 3, stride: int = 2):
self.channels = channels
self.stride = stride
self.filt_size = filt_size
assert self.filt_size == 3
assert stride == 2
self.filt = {} # lazy init by device for DataParallel compat
def _create_filter(self, like: torch.Tensor):
filt = torch.tensor([1., 2., 1.], dtype=like.dtype, device=like.device)
filt = filt[:, None] * filt[None, :]
filt = filt / torch.sum(filt)
return filt[None, None, :, :].repeat((self.channels, 1, 1, 1))
def __call__(self, input: torch.Tensor):
input_pad = F.pad(input, (1, 1, 1, 1), 'reflect')
filt = self.filt.get(str(input.device), self._create_filter(input))
return F.conv2d(input_pad, filt, stride=2, padding=0, groups=input.shape[1])
class Downsample(nn.Module):
def __init__(self, channels=None, filt_size=3, stride=2):
super(Downsample, self).__init__()
self.channels = channels
self.filt_size = filt_size
self.stride = stride
assert self.filt_size == 3
filt = torch.tensor([1., 2., 1.])
filt = filt[:, None] * filt[None, :]
filt = filt / torch.sum(filt)
# self.filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1))
self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1)))
def forward(self, input):
input_pad = F.pad(input, (1, 1, 1, 1), 'reflect')
return F.conv2d(input_pad, self.filt, stride=self.stride, padding=0, groups=input.shape[1])

@ -0,0 +1,58 @@
"""
BlurPool layer inspired by
- Kornia's Max_BlurPool2d
- Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar`
FIXME merge this impl with those in `anti_aliasing.py`
Hacked together by Chris Ha and Ross Wightman
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Dict
from .padding import get_padding
class BlurPool2d(nn.Module):
r"""Creates a module that computes blurs and downsample a given feature map.
See :cite:`zhang2019shiftinvar` for more details.
Corresponds to the Downsample class, which does blurring and subsampling
Args:
channels = Number of input channels
filt_size (int): binomial filter size for blurring. currently supports 3 (default) and 5.
stride (int): downsampling filter stride
Returns:
torch.Tensor: the transformed tensor.
"""
filt: Dict[str, torch.Tensor]
def __init__(self, channels, filt_size=3, stride=2) -> None:
super(BlurPool2d, self).__init__()
assert filt_size > 1
self.channels = channels
self.filt_size = filt_size
self.stride = stride
pad_size = [get_padding(filt_size, stride, dilation=1)] * 4
self.padding = nn.ReflectionPad2d(pad_size)
self._coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs) # for torchscript compat
self.filt = {} # lazy init by device for DataParallel compat
def _create_filter(self, like: torch.Tensor):
blur_filter = (self._coeffs[:, None] * self._coeffs[None, :]).to(dtype=like.dtype, device=like.device)
return blur_filter[None, None, :, :].repeat(self.channels, 1, 1, 1)
def _apply(self, fn):
# override nn.Module _apply, reset filter cache if used
self.filt = {}
super(BlurPool2d, self)._apply(fn)
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
C = input_tensor.shape[1]
blur_filt = self.filt.get(str(input_tensor.device), self._create_filter(input_tensor))
return F.conv2d(
self.padding(input_tensor), blur_filt, stride=self.stride, groups=C)

@ -0,0 +1,41 @@
""" Classifier head and layer factory
Hacked together by / Copyright 2020 Ross Wightman
"""
from torch import nn as nn
from torch.nn import functional as F
from .adaptive_avgmax_pool import SelectAdaptivePool2d
def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False):
flatten = not use_conv # flatten when we use a Linear layer after pooling
if not pool_type:
assert num_classes == 0 or use_conv,\
'Pooling can only be disabled if classifier is also removed or conv classifier is used'
flatten = False # disable flattening if pooling is pass-through (no pooling)
global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten)
num_pooled_features = num_features * global_pool.feat_mult()
if num_classes <= 0:
fc = nn.Identity() # pass-through (no classifier)
elif use_conv:
fc = nn.Conv2d(num_pooled_features, num_classes, 1, bias=True)
else:
fc = nn.Linear(num_pooled_features, num_classes, bias=True)
return global_pool, fc
class ClassifierHead(nn.Module):
"""Classifier head w/ configurable global pooling and dropout."""
def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0.):
super(ClassifierHead, self).__init__()
self.drop_rate = drop_rate
self.global_pool, self.fc = create_classifier(in_chs, num_classes, pool_type=pool_type)
def forward(self, x):
x = self.global_pool(x)
if self.drop_rate:
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
x = self.fc(x)
return x

@ -0,0 +1,122 @@
""" PyTorch Conditionally Parameterized Convolution (CondConv)
Paper: CondConv: Conditionally Parameterized Convolutions for Efficient Inference
(https://arxiv.org/abs/1904.04971)
Hacked together by / Copyright 2020 Ross Wightman
"""
import math
from functools import partial
import numpy as np
import torch
from torch import nn as nn
from torch.nn import functional as F
from .helpers import tup_pair
from .conv2d_same import conv2d_same
from .padding import get_padding_value
def get_condconv_initializer(initializer, num_experts, expert_shape):
def condconv_initializer(weight):
"""CondConv initializer function."""
num_params = np.prod(expert_shape)
if (len(weight.shape) != 2 or weight.shape[0] != num_experts or
weight.shape[1] != num_params):
raise (ValueError(
'CondConv variables must have shape [num_experts, num_params]'))
for i in range(num_experts):
initializer(weight[i].view(expert_shape))
return condconv_initializer
class CondConv2d(nn.Module):
""" Conditionally Parameterized Convolution
Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py
Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion:
https://github.com/pytorch/pytorch/issues/17983
"""
__constants__ = ['in_channels', 'out_channels', 'dynamic_padding']
def __init__(self, in_channels, out_channels, kernel_size=3,
stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4):
super(CondConv2d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = tup_pair(kernel_size)
self.stride = tup_pair(stride)
padding_val, is_padding_dynamic = get_padding_value(
padding, kernel_size, stride=stride, dilation=dilation)
self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript
self.padding = tup_pair(padding_val)
self.dilation = tup_pair(dilation)
self.groups = groups
self.num_experts = num_experts
self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size
weight_num_param = 1
for wd in self.weight_shape:
weight_num_param *= wd
self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param))
if bias:
self.bias_shape = (self.out_channels,)
self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
init_weight = get_condconv_initializer(
partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape)
init_weight(self.weight)
if self.bias is not None:
fan_in = np.prod(self.weight_shape[1:])
bound = 1 / math.sqrt(fan_in)
init_bias = get_condconv_initializer(
partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape)
init_bias(self.bias)
def forward(self, x, routing_weights):
B, C, H, W = x.shape
weight = torch.matmul(routing_weights, self.weight)
new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size
weight = weight.view(new_weight_shape)
bias = None
if self.bias is not None:
bias = torch.matmul(routing_weights, self.bias)
bias = bias.view(B * self.out_channels)
# move batch elements with channels so each batch element can be efficiently convolved with separate kernel
x = x.view(1, B * C, H, W)
if self.dynamic_padding:
out = conv2d_same(
x, weight, bias, stride=self.stride, padding=self.padding,
dilation=self.dilation, groups=self.groups * B)
else:
out = F.conv2d(
x, weight, bias, stride=self.stride, padding=self.padding,
dilation=self.dilation, groups=self.groups * B)
out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1])
# Literal port (from TF definition)
# x = torch.split(x, 1, 0)
# weight = torch.split(weight, 1, 0)
# if self.bias is not None:
# bias = torch.matmul(routing_weights, self.bias)
# bias = torch.split(bias, 1, 0)
# else:
# bias = [None] * B
# out = []
# for xi, wi, bi in zip(x, weight, bias):
# wi = wi.view(*self.weight_shape)
# if bi is not None:
# bi = bi.view(*self.bias_shape)
# out.append(self.conv_fn(
# xi, wi, bi, stride=self.stride, padding=self.padding,
# dilation=self.dilation, groups=self.groups))
# out = torch.cat(out, 0)
return out

@ -0,0 +1,115 @@
""" Model / Layer Config singleton state
"""
from typing import Any, Optional
__all__ = [
'is_exportable', 'is_scriptable', 'is_no_jit',
'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config'
]
# Set to True if prefer to have layers with no jit optimization (includes activations)
_NO_JIT = False
# Set to True if prefer to have activation layers with no jit optimization
# NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying
# the jit flags so far are activations. This will change as more layers are updated and/or added.
_NO_ACTIVATION_JIT = False
# Set to True if exporting a model with Same padding via ONNX
_EXPORTABLE = False
# Set to True if wanting to use torch.jit.script on a model
_SCRIPTABLE = False
def is_no_jit():
return _NO_JIT
class set_no_jit:
def __init__(self, mode: bool) -> None:
global _NO_JIT
self.prev = _NO_JIT
_NO_JIT = mode
def __enter__(self) -> None:
pass
def __exit__(self, *args: Any) -> bool:
global _NO_JIT
_NO_JIT = self.prev
return False
def is_exportable():
return _EXPORTABLE
class set_exportable:
def __init__(self, mode: bool) -> None:
global _EXPORTABLE
self.prev = _EXPORTABLE
_EXPORTABLE = mode
def __enter__(self) -> None:
pass
def __exit__(self, *args: Any) -> bool:
global _EXPORTABLE
_EXPORTABLE = self.prev
return False
def is_scriptable():
return _SCRIPTABLE
class set_scriptable:
def __init__(self, mode: bool) -> None:
global _SCRIPTABLE
self.prev = _SCRIPTABLE
_SCRIPTABLE = mode
def __enter__(self) -> None:
pass
def __exit__(self, *args: Any) -> bool:
global _SCRIPTABLE
_SCRIPTABLE = self.prev
return False
class set_layer_config:
""" Layer config context manager that allows setting all layer config flags at once.
If a flag arg is None, it will not change the current value.
"""
def __init__(
self,
scriptable: Optional[bool] = None,
exportable: Optional[bool] = None,
no_jit: Optional[bool] = None,
no_activation_jit: Optional[bool] = None):
global _SCRIPTABLE
global _EXPORTABLE
global _NO_JIT
global _NO_ACTIVATION_JIT
self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT
if scriptable is not None:
_SCRIPTABLE = scriptable
if exportable is not None:
_EXPORTABLE = exportable
if no_jit is not None:
_NO_JIT = no_jit
if no_activation_jit is not None:
_NO_ACTIVATION_JIT = no_activation_jit
def __enter__(self) -> None:
pass
def __exit__(self, *args: Any) -> bool:
global _SCRIPTABLE
global _EXPORTABLE
global _NO_JIT
global _NO_ACTIVATION_JIT
_SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev
return False

@ -0,0 +1,42 @@
""" Conv2d w/ Same Padding
Hacked together by / Copyright 2020 Ross Wightman
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional
from .padding import pad_same, get_padding_value
def conv2d_same(
x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1),
padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1):
x = pad_same(x, weight.shape[-2:], stride, dilation)
return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)
class Conv2dSame(nn.Conv2d):
""" Tensorflow like 'SAME' convolution wrapper for 2D convolutions
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
super(Conv2dSame, self).__init__(
in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
def forward(self, x):
return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
padding = kwargs.pop('padding', '')
kwargs.setdefault('bias', False)
padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs)
if is_dynamic:
return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)
else:
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)

@ -0,0 +1,40 @@
""" Conv2d + BN + Act
Hacked together by / Copyright 2020 Ross Wightman
"""
from torch import nn as nn
from .create_conv2d import create_conv2d
from .create_norm_act import convert_norm_act_type
class ConvBnAct(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1,
norm_layer=nn.BatchNorm2d, norm_kwargs=None, act_layer=nn.ReLU, apply_act=True,
drop_block=None, aa_layer=None):
super(ConvBnAct, self).__init__()
use_aa = aa_layer is not None
self.conv = create_conv2d(
in_channels, out_channels, kernel_size, stride=1 if use_aa else stride,
padding=padding, dilation=dilation, groups=groups, bias=False)
# NOTE for backwards compatibility with models that use separate norm and act layer definitions
norm_act_layer, norm_act_args = convert_norm_act_type(norm_layer, act_layer, norm_kwargs)
self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block, **norm_act_args)
self.aa = aa_layer(channels=out_channels) if stride == 2 and use_aa else None
@property
def in_channels(self):
return self.conv.in_channels
@property
def out_channels(self):
return self.conv.out_channels
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
if self.aa is not None:
x = self.aa(x)
return x

@ -0,0 +1,36 @@
""" Activation Factory
Hacked together by / Copyright 2020 Ross Wightman
"""
from .activations import *
from .activations_jit import *
from .activations_me import *
from .config import is_exportable, is_scriptable, is_no_jit
_ACT_LAYER_DEFAULT = dict(
swish=Swish,
relu=nn.ReLU,
relu6=nn.ReLU6,
sigmoid=Sigmoid,
hard_sigmoid=HardSigmoid,
hard_swish=HardSwish,
)
def get_act_layer(name='relu'):
""" Activation Layer Factory
Fetching activation layers by name with this function allows export or torch script friendly
functions to be returned dynamically based on current config.
"""
if not name:
return None
return _ACT_LAYER_DEFAULT[name]
def create_act_layer(name, inplace=False, **kwargs):
act_layer = get_act_layer(name)
if act_layer is not None:
return act_layer(inplace=inplace, **kwargs)
else:
return None

@ -0,0 +1,30 @@
""" Create Conv2d Factory Method
Hacked together by / Copyright 2020 Ross Wightman
"""
from .mixed_conv2d import MixedConv2d
from .cond_conv2d import CondConv2d
from .conv2d_same import create_conv2d_pad
def create_conv2d(in_channels, out_channels, kernel_size, **kwargs):
""" Select a 2d convolution implementation based on arguments
Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d.
Used extensively by EfficientNet, MobileNetv3 and related networks.
"""
if isinstance(kernel_size, list):
assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently
assert 'groups' not in kwargs # MixedConv groups are defined by kernel list
# We're going to use only lists for defining the MixedConv2d kernel groups,
# ints, tuples, other iterables will continue to pass to normal conv and specify h, w.
m = MixedConv2d(in_channels, out_channels, kernel_size, **kwargs)
else:
depthwise = kwargs.pop('depthwise', False)
groups = out_channels if depthwise else kwargs.pop('groups', 1)
if 'num_experts' in kwargs and kwargs['num_experts'] > 0:
m = CondConv2d(in_channels, out_channels, kernel_size, groups=groups, **kwargs)
else:
m = create_conv2d_pad(in_channels, out_channels, kernel_size, groups=groups, **kwargs)
return m

@ -0,0 +1,74 @@
""" NormAct (Normalizaiton + Activation Layer) Factory
Create norm + act combo modules that attempt to be backwards compatible with separate norm + act
isntances in models. Where these are used it will be possible to swap separate BN + act layers with
combined modules like IABN or EvoNorms.
Hacked together by / Copyright 2020 Ross Wightman
"""
import types
import functools
import torch
import torch.nn as nn
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
from .norm_act import BatchNormAct2d, GroupNormAct
from .inplace_abn import InplaceAbn
_NORM_ACT_TYPES = {BatchNormAct2d, GroupNormAct, EvoNormBatch2d, EvoNormSample2d, InplaceAbn}
_NORM_ACT_REQUIRES_ARG = {BatchNormAct2d, GroupNormAct, InplaceAbn} # requires act_layer arg to define act type
def get_norm_act_layer(layer_class):
layer_class = layer_class.replace('_', '').lower()
if layer_class.startswith("batchnorm"):
layer = BatchNormAct2d
elif layer_class.startswith("groupnorm"):
layer = GroupNormAct
elif layer_class == "evonormbatch":
layer = EvoNormBatch2d
elif layer_class == "evonormsample":
layer = EvoNormSample2d
elif layer_class == "iabn" or layer_class == "inplaceabn":
layer = InplaceAbn
else:
assert False, "Invalid norm_act layer (%s)" % layer_class
return layer
def create_norm_act(layer_type, num_features, apply_act=True, jit=False, **kwargs):
layer_parts = layer_type.split('-') # e.g. batchnorm-leaky_relu
assert len(layer_parts) in (1, 2)
layer = get_norm_act_layer(layer_parts[0])
#activation_class = layer_parts[1].lower() if len(layer_parts) > 1 else '' # FIXME support string act selection?
layer_instance = layer(num_features, apply_act=apply_act, **kwargs)
if jit:
layer_instance = torch.jit.script(layer_instance)
return layer_instance
def convert_norm_act_type(norm_layer, act_layer, norm_kwargs=None):
assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial))
assert act_layer is None or isinstance(act_layer, (type, str, types.FunctionType, functools.partial))
norm_act_args = norm_kwargs.copy() if norm_kwargs else {}
if isinstance(norm_layer, str):
norm_act_layer = get_norm_act_layer(norm_layer)
elif norm_layer in _NORM_ACT_TYPES:
norm_act_layer = norm_layer
elif isinstance(norm_layer, (types.FunctionType, functools.partial)):
# assuming this is a lambda/fn/bound partial that creates norm_act layer
norm_act_layer = norm_layer
else:
type_name = norm_layer.__name__.lower()
if type_name.startswith('batchnorm'):
norm_act_layer = BatchNormAct2d
elif type_name.startswith('groupnorm'):
norm_act_layer = GroupNormAct
else:
assert False, f"No equivalent norm_act layer for {type_name}"
if norm_act_layer in _NORM_ACT_REQUIRES_ARG:
# Must pass `act_layer` through for backwards compat where `act_layer=None` implies no activation.
# In the future, may force use of `apply_act` with `act_layer` arg bound to relevant NormAct types
# It is intended that functions/partial does not trigger this, they should define act.
norm_act_args.update(dict(act_layer=act_layer))
return norm_act_layer, norm_act_args

@ -0,0 +1,167 @@
""" DropBlock, DropPath
PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers.
Papers:
DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890)
Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382)
Code:
DropBlock impl inspired by two Tensorflow impl that I liked:
- https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74
- https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py
Hacked together by / Copyright 2020 Ross Wightman
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
def drop_block_2d(
x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0,
with_noise: bool = False, inplace: bool = False, batchwise: bool = False):
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
DropBlock with an experimental gaussian noise option. This layer has been tested on a few training
runs with success, but needs further validation and possibly optimization for lower runtime impact.
"""
B, C, H, W = x.shape
total_size = W * H
clipped_block_size = min(block_size, min(W, H))
# seed_drop_rate, the gamma parameter
gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
(W - block_size + 1) * (H - block_size + 1))
# Forces the block to be inside the feature map.
w_i, h_i = torch.meshgrid(torch.arange(W).to(x.device), torch.arange(H).to(x.device))
valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \
((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2))
valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype)
if batchwise:
# one mask for whole batch, quite a bit faster
uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device)
else:
uniform_noise = torch.rand_like(x)
block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype)
block_mask = -F.max_pool2d(
-block_mask,
kernel_size=clipped_block_size, # block_size,
stride=1,
padding=clipped_block_size // 2)
if with_noise:
normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x)
if inplace:
x.mul_(block_mask).add_(normal_noise * (1 - block_mask))
else:
x = x * block_mask + normal_noise * (1 - block_mask)
else:
normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype)
if inplace:
x.mul_(block_mask * normalize_scale)
else:
x = x * block_mask * normalize_scale
return x
def drop_block_fast_2d(
x: torch.Tensor, drop_prob: float = 0.1, block_size: int = 7,
gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False, batchwise: bool = False):
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid
block mask at edges.
"""
B, C, H, W = x.shape
total_size = W * H
clipped_block_size = min(block_size, min(W, H))
gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
(W - block_size + 1) * (H - block_size + 1))
if batchwise:
# one mask for whole batch, quite a bit faster
block_mask = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) < gamma
else:
# mask per batch element
block_mask = torch.rand_like(x) < gamma
block_mask = F.max_pool2d(
block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2)
if with_noise:
normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x)
if inplace:
x.mul_(1. - block_mask).add_(normal_noise * block_mask)
else:
x = x * (1. - block_mask) + normal_noise * block_mask
else:
block_mask = 1 - block_mask
normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(dtype=x.dtype)
if inplace:
x.mul_(block_mask * normalize_scale)
else:
x = x * block_mask * normalize_scale
return x
class DropBlock2d(nn.Module):
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
"""
def __init__(self,
drop_prob=0.1,
block_size=7,
gamma_scale=1.0,
with_noise=False,
inplace=False,
batchwise=False,
fast=True):
super(DropBlock2d, self).__init__()
self.drop_prob = drop_prob
self.gamma_scale = gamma_scale
self.block_size = block_size
self.with_noise = with_noise
self.inplace = inplace
self.batchwise = batchwise
self.fast = fast # FIXME finish comparisons of fast vs not
def forward(self, x):
if not self.training or not self.drop_prob:
return x
if self.fast:
return drop_block_fast_2d(
x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise)
else:
return drop_block_2d(
x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise)
def drop_path(x, drop_prob: float = 0., training: bool = False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
random_tensor = keep_prob + torch.rand((x.size()[0], 1, 1, 1), dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)

@ -0,0 +1,83 @@
"""EvoNormB0 (Batched) and EvoNormS0 (Sample) in PyTorch
An attempt at getting decent performing EvoNorms running in PyTorch.
While currently faster than other impl, still quite a ways off the built-in BN
in terms of memory usage and throughput (roughly 5x mem, 1/2 - 1/3x speed).
Still very much a WIP, fiddling with buffer usage, in-place/jit optimizations, and layouts.
Hacked together by / Copyright 2020 Ross Wightman
"""
import torch
import torch.nn as nn
class EvoNormBatch2d(nn.Module):
def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, drop_block=None):
super(EvoNormBatch2d, self).__init__()
self.apply_act = apply_act # apply activation (non-linearity)
self.momentum = momentum
self.eps = eps
param_shape = (1, num_features, 1, 1)
self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True)
self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True)
if apply_act:
self.v = nn.Parameter(torch.ones(param_shape), requires_grad=True)
self.register_buffer('running_var', torch.ones(1, num_features, 1, 1))
self.reset_parameters()
def reset_parameters(self):
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
if self.apply_act:
nn.init.ones_(self.v)
def forward(self, x):
assert x.dim() == 4, 'expected 4D input'
x_type = x.dtype
if self.training:
var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True)
n = x.numel() / x.shape[1]
self.running_var.copy_(
var.detach() * self.momentum * (n / (n - 1)) + self.running_var * (1 - self.momentum))
else:
var = self.running_var
if self.apply_act:
v = self.v.to(dtype=x_type)
d = x * v + (x.var(dim=(2, 3), unbiased=False, keepdim=True) + self.eps).sqrt().to(dtype=x_type)
d = d.max((var + self.eps).sqrt().to(dtype=x_type))
x = x / d
return x * self.weight + self.bias
class EvoNormSample2d(nn.Module):
def __init__(self, num_features, apply_act=True, groups=8, eps=1e-5, drop_block=None):
super(EvoNormSample2d, self).__init__()
self.apply_act = apply_act # apply activation (non-linearity)
self.groups = groups
self.eps = eps
param_shape = (1, num_features, 1, 1)
self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True)
self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True)
if apply_act:
self.v = nn.Parameter(torch.ones(param_shape), requires_grad=True)
self.reset_parameters()
def reset_parameters(self):
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
if self.apply_act:
nn.init.ones_(self.v)
def forward(self, x):
assert x.dim() == 4, 'expected 4D input'
B, C, H, W = x.shape
assert C % self.groups == 0
if self.apply_act:
n = x * (x * self.v).sigmoid()
x = x.reshape(B, self.groups, -1)
x = n.reshape(B, self.groups, -1) / (x.var(dim=-1, unbiased=False, keepdim=True) + self.eps).sqrt()
x = x.reshape(B, C, H, W)
return x * self.weight + self.bias

@ -0,0 +1,27 @@
""" Layer/Module Helpers
Hacked together by / Copyright 2020 Ross Wightman
"""
from itertools import repeat
from torch._six import container_abcs
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, container_abcs.Iterable):
return x
return tuple(repeat(x, n))
return parse
tup_single = _ntuple(1)
tup_pair = _ntuple(2)
tup_triple = _ntuple(3)
tup_quadruple = _ntuple(4)
ntup = _ntuple

@ -0,0 +1,87 @@
import torch
from torch import nn as nn
try:
from inplace_abn.functions import inplace_abn, inplace_abn_sync
has_iabn = True
except ImportError:
has_iabn = False
def inplace_abn(x, weight, bias, running_mean, running_var,
training=True, momentum=0.1, eps=1e-05, activation="leaky_relu", activation_param=0.01):
raise ImportError(
"Please install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.11'")
def inplace_abn_sync(**kwargs):
inplace_abn(**kwargs)
class InplaceAbn(nn.Module):
"""Activated Batch Normalization
This gathers a BatchNorm and an activation function in a single module
Parameters
----------
num_features : int
Number of feature channels in the input and output.
eps : float
Small constant to prevent numerical issues.
momentum : float
Momentum factor applied to compute running statistics.
affine : bool
If `True` apply learned scale and shift transformation after normalization.
act_layer : str or nn.Module type
Name or type of the activation functions, one of: `leaky_relu`, `elu`
act_param : float
Negative slope for the `leaky_relu` activation.
"""
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, apply_act=True,
act_layer="leaky_relu", act_param=0.01, drop_block=None):
super(InplaceAbn, self).__init__()
self.num_features = num_features
self.affine = affine
self.eps = eps
self.momentum = momentum
if apply_act:
if isinstance(act_layer, str):
assert act_layer in ('leaky_relu', 'elu', 'identity', '')
self.act_name = act_layer if act_layer else 'identity'
else:
# convert act layer passed as type to string
if act_layer == nn.ELU:
self.act_name = 'elu'
elif act_layer == nn.LeakyReLU:
self.act_name = 'leaky_relu'
elif act_layer == nn.Identity:
self.act_name = 'identity'
else:
assert False, f'Invalid act layer {act_layer.__name__} for IABN'
else:
self.act_name = 'identity'
self.act_param = act_param
if self.affine:
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.reset_parameters()
def reset_parameters(self):
nn.init.constant_(self.running_mean, 0)
nn.init.constant_(self.running_var, 1)
if self.affine:
nn.init.constant_(self.weight, 1)
nn.init.constant_(self.bias, 0)
def forward(self, x):
output = inplace_abn(
x, self.weight, self.bias, self.running_mean, self.running_var,
self.training, self.momentum, self.eps, self.act_name, self.act_param)
if isinstance(output, tuple):
output = output[0]
return output

@ -0,0 +1,49 @@
""" Median Pool
Hacked together by / Copyright 2020 Ross Wightman
"""
import torch.nn as nn
import torch.nn.functional as F
from .helpers import tup_pair, tup_quadruple
class MedianPool2d(nn.Module):
""" Median pool (usable as median filter when stride=1) module.
Args:
kernel_size: size of pooling kernel, int or 2-tuple
stride: pool stride, int or 2-tuple
padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad
same: override padding and enforce same padding, boolean
"""
def __init__(self, kernel_size=3, stride=1, padding=0, same=False):
super(MedianPool2d, self).__init__()
self.k = tup_pair(kernel_size)
self.stride = tup_pair(stride)
self.padding = tup_quadruple(padding) # convert to l, r, t, b
self.same = same
def _padding(self, x):
if self.same:
ih, iw = x.size()[2:]
if ih % self.stride[0] == 0:
ph = max(self.k[0] - self.stride[0], 0)
else:
ph = max(self.k[0] - (ih % self.stride[0]), 0)
if iw % self.stride[1] == 0:
pw = max(self.k[1] - self.stride[1], 0)
else:
pw = max(self.k[1] - (iw % self.stride[1]), 0)
pl = pw // 2
pr = pw - pl
pt = ph // 2
pb = ph - pt
padding = (pl, pr, pt, pb)
else:
padding = self.padding
return padding
def forward(self, x):
x = F.pad(x, self._padding(x), mode='reflect')
x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1])
x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0]
return x

@ -0,0 +1,51 @@
""" PyTorch Mixed Convolution
Paper: MixConv: Mixed Depthwise Convolutional Kernels (https://arxiv.org/abs/1907.09595)
Hacked together by / Copyright 2020 Ross Wightman
"""
import torch
from torch import nn as nn
from .conv2d_same import create_conv2d_pad
def _split_channels(num_chan, num_groups):
split = [num_chan // num_groups for _ in range(num_groups)]
split[0] += num_chan - sum(split)
return split
class MixedConv2d(nn.ModuleDict):
""" Mixed Grouped Convolution
Based on MDConv and GroupedConv in MixNet impl:
https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py
"""
def __init__(self, in_channels, out_channels, kernel_size=3,
stride=1, padding='', dilation=1, depthwise=False, **kwargs):
super(MixedConv2d, self).__init__()
kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size]
num_groups = len(kernel_size)
in_splits = _split_channels(in_channels, num_groups)
out_splits = _split_channels(out_channels, num_groups)
self.in_channels = sum(in_splits)
self.out_channels = sum(out_splits)
for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)):
conv_groups = out_ch if depthwise else 1
# use add_module to keep key space clean
self.add_module(
str(idx),
create_conv2d_pad(
in_ch, out_ch, k, stride=stride,
padding=padding, dilation=dilation, groups=conv_groups, **kwargs)
)
self.splits = in_splits
def forward(self, x):
x_split = torch.split(x, self.splits, 1)
x_out = [c(x_split[i]) for i, c in enumerate(self.values())]
x = torch.cat(x_out, 1)
return x

@ -0,0 +1,86 @@
""" Normalization + Activation Layers
"""
import torch
from torch import nn as nn
from torch.nn import functional as F
from .create_act import get_act_layer
class BatchNormAct2d(nn.BatchNorm2d):
"""BatchNorm + Activation
This module performs BatchNorm + Activation in a manner that will remain backwards
compatible with weights trained with separate bn, act. This is why we inherit from BN
instead of composing it as a .bn member.
"""
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True,
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None):
super(BatchNormAct2d, self).__init__(
num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
if isinstance(act_layer, str):
act_layer = get_act_layer(act_layer)
if act_layer is not None and apply_act:
act_args = dict(inplace=True) if inplace else {}
self.act = act_layer(**act_args)
else:
self.act = None
def _forward_jit(self, x):
""" A cut & paste of the contents of the PyTorch BatchNorm2d forward function
"""
# exponential_average_factor is self.momentum set to
# (when it is available) only so that if gets updated
# in ONNX graph when this node is exported to ONNX.
if self.momentum is None:
exponential_average_factor = 0.0
else:
exponential_average_factor = self.momentum
if self.training and self.track_running_stats:
# TODO: if statement only here to tell the jit to skip emitting this when it is None
if self.num_batches_tracked is not None:
self.num_batches_tracked += 1
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: # use exponential moving average
exponential_average_factor = self.momentum
x = F.batch_norm(
x, self.running_mean, self.running_var, self.weight, self.bias,
self.training or not self.track_running_stats,
exponential_average_factor, self.eps)
return x
@torch.jit.ignore
def _forward_python(self, x):
return super(BatchNormAct2d, self).forward(x)
def forward(self, x):
# FIXME cannot call parent forward() and maintain jit.script compatibility?
if torch.jit.is_scripting():
x = self._forward_jit(x)
else:
x = self._forward_python(x)
if self.act is not None:
x = self.act(x)
return x
class GroupNormAct(nn.GroupNorm):
def __init__(self, num_groups, num_channels, eps=1e-5, affine=True,
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None):
super(GroupNormAct, self).__init__(num_groups, num_channels, eps=eps, affine=affine)
if isinstance(act_layer, str):
act_layer = get_act_layer(act_layer)
if act_layer is not None and apply_act:
self.act = act_layer(inplace=inplace)
else:
self.act = None
def forward(self, x):
x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
if self.act is not None:
x = self.act(x)
return x

@ -0,0 +1,56 @@
""" Padding Helpers
Hacked together by / Copyright 2020 Ross Wightman
"""
import math
from typing import List, Tuple
import torch.nn.functional as F
# Calculate symmetric padding for a convolution
def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int:
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
return padding
# Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution
def get_same_padding(x: int, k: int, s: int, d: int):
return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0)
# Can SAME padding for given args be done statically?
def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_):
return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
# Dynamically pad input x with 'SAME' padding for conv with specified args
def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0):
ih, iw = x.size()[-2:]
pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1])
if pad_h > 0 or pad_w > 0:
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value)
return x
def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:
dynamic = False
if isinstance(padding, str):
# for any string padding, the padding will be calculated for you, one of three ways
padding = padding.lower()
if padding == 'same':
# TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
if is_static_pad(kernel_size, **kwargs):
# static case, no extra overhead
padding = get_padding(kernel_size, **kwargs)
else:
# dynamic 'SAME' padding, has runtime/GPU memory overhead
padding = 0
dynamic = True
elif padding == 'valid':
# 'VALID' padding, same as padding=0
padding = 0
else:
# Default to PyTorch style 'same'-ish symmetric padding
padding = get_padding(kernel_size, **kwargs)
return padding, dynamic

@ -0,0 +1,71 @@
""" AvgPool2d w/ Same Padding
Hacked together by / Copyright 2020 Ross Wightman
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Optional
from .helpers import tup_pair
from .padding import pad_same, get_padding_value
def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0),
ceil_mode: bool = False, count_include_pad: bool = True):
# FIXME how to deal with count_include_pad vs not for external padding?
x = pad_same(x, kernel_size, stride)
return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
class AvgPool2dSame(nn.AvgPool2d):
""" Tensorflow like 'SAME' wrapper for 2D average pooling
"""
def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True):
kernel_size = tup_pair(kernel_size)
stride = tup_pair(stride)
super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
def forward(self, x):
return avg_pool2d_same(
x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad)
def max_pool2d_same(
x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0),
dilation: List[int] = (1, 1), ceil_mode: bool = False):
x = pad_same(x, kernel_size, stride, value=-float('inf'))
return F.max_pool2d(x, kernel_size, stride, (0, 0), dilation, ceil_mode)
class MaxPool2dSame(nn.MaxPool2d):
""" Tensorflow like 'SAME' wrapper for 2D max pooling
"""
def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False, count_include_pad=True):
kernel_size = tup_pair(kernel_size)
stride = tup_pair(stride)
dilation = tup_pair(dilation)
super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode, count_include_pad)
def forward(self, x):
return max_pool2d_same(x, self.kernel_size, self.stride, self.padding, self.dilation, self.ceil_mode)
def create_pool2d(pool_type, kernel_size, stride=None, **kwargs):
stride = stride or kernel_size
padding = kwargs.pop('padding', '')
padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, **kwargs)
if is_dynamic:
if pool_type == 'avg':
return AvgPool2dSame(kernel_size, stride=stride, **kwargs)
elif pool_type == 'max':
return MaxPool2dSame(kernel_size, stride=stride, **kwargs)
else:
assert False, f'Unsupported pool type {pool_type}'
else:
if pool_type == 'avg':
return nn.AvgPool2d(kernel_size, stride=stride, padding=padding, **kwargs)
elif pool_type == 'max':
return nn.MaxPool2d(kernel_size, stride=stride, padding=padding, **kwargs)
else:
assert False, f'Unsupported pool type {pool_type}'

@ -0,0 +1,120 @@
""" Selective Kernel Convolution/Attention
Paper: Selective Kernel Networks (https://arxiv.org/abs/1903.06586)
Hacked together by / Copyright 2020 Ross Wightman
"""
import torch
from torch import nn as nn
from .conv_bn_act import ConvBnAct
def _kernel_valid(k):
if isinstance(k, (list, tuple)):
for ki in k:
return _kernel_valid(ki)
assert k >= 3 and k % 2
class SelectiveKernelAttn(nn.Module):
def __init__(self, channels, num_paths=2, attn_channels=32,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
""" Selective Kernel Attention Module
Selective Kernel attention mechanism factored out into its own module.
"""
super(SelectiveKernelAttn, self).__init__()
self.num_paths = num_paths
self.pool = nn.AdaptiveAvgPool2d(1)
self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False)
self.bn = norm_layer(attn_channels)
self.act = act_layer(inplace=True)
self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False)
def forward(self, x):
assert x.shape[1] == self.num_paths
x = torch.sum(x, dim=1)
x = self.pool(x)
x = self.fc_reduce(x)
x = self.bn(x)
x = self.act(x)
x = self.fc_select(x)
B, C, H, W = x.shape
x = x.view(B, self.num_paths, C // self.num_paths, H, W)
x = torch.softmax(x, dim=1)
return x
class SelectiveKernelConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=None, stride=1, dilation=1, groups=1,
attn_reduction=16, min_attn_channels=32, keep_3x3=True, split_input=False,
drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None):
""" Selective Kernel Convolution Module
As described in Selective Kernel Networks (https://arxiv.org/abs/1903.06586) with some modifications.
Largest change is the input split, which divides the input channels across each convolution path, this can
be viewed as a grouping of sorts, but the output channel counts expand to the module level value. This keeps
the parameter count from ballooning when the convolutions themselves don't have groups, but still provides
a noteworthy increase in performance over similar param count models without this attention layer. -Ross W
Args:
in_channels (int): module input (feature) channel count
out_channels (int): module output (feature) channel count
kernel_size (int, list): kernel size for each convolution branch
stride (int): stride for convolutions
dilation (int): dilation for module as a whole, impacts dilation of each branch
groups (int): number of groups for each branch
attn_reduction (int, float): reduction factor for attention features
min_attn_channels (int): minimum attention feature channels
keep_3x3 (bool): keep all branch convolution kernels as 3x3, changing larger kernels for dilations
split_input (bool): split input channels evenly across each convolution branch, keeps param count lower,
can be viewed as grouping by path, output expands to module out_channels count
drop_block (nn.Module): drop block module
act_layer (nn.Module): activation layer to use
norm_layer (nn.Module): batchnorm/norm layer to use
"""
super(SelectiveKernelConv, self).__init__()
kernel_size = kernel_size or [3, 5] # default to one 3x3 and one 5x5 branch. 5x5 -> 3x3 + dilation
_kernel_valid(kernel_size)
if not isinstance(kernel_size, list):
kernel_size = [kernel_size] * 2
if keep_3x3:
dilation = [dilation * (k - 1) // 2 for k in kernel_size]
kernel_size = [3] * len(kernel_size)
else:
dilation = [dilation] * len(kernel_size)
self.num_paths = len(kernel_size)
self.in_channels = in_channels
self.out_channels = out_channels
self.split_input = split_input
if self.split_input:
assert in_channels % self.num_paths == 0
in_channels = in_channels // self.num_paths
groups = min(out_channels, groups)
conv_kwargs = dict(
stride=stride, groups=groups, drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer,
aa_layer=aa_layer)
self.paths = nn.ModuleList([
ConvBnAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs)
for k, d in zip(kernel_size, dilation)])
attn_channels = max(int(out_channels / attn_reduction), min_attn_channels)
self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels)
self.drop_block = drop_block
def forward(self, x):
if self.split_input:
x_split = torch.split(x, self.in_channels // self.num_paths, 1)
x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)]
else:
x_paths = [op(x) for op in self.paths]
x = torch.stack(x_paths, dim=1)
x_attn = self.attn(x)
x = x * x_attn
x = torch.sum(x, dim=1)
return x

@ -0,0 +1,74 @@
""" Depthwise Separable Conv Modules
Basic DWS convs. Other variations of DWS exist with batch norm or activations between the
DW and PW convs such as the Depthwise modules in MobileNetV2 / EfficientNet and Xception.
Hacked together by / Copyright 2020 Ross Wightman
"""
from torch import nn as nn
from .create_conv2d import create_conv2d
from .create_norm_act import convert_norm_act_type
class SeparableConvBnAct(nn.Module):
""" Separable Conv w/ trailing Norm and Activation
"""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False,
channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
act_layer=nn.ReLU, apply_act=True, drop_block=None):
super(SeparableConvBnAct, self).__init__()
norm_kwargs = norm_kwargs or {}
self.conv_dw = create_conv2d(
in_channels, int(in_channels * channel_multiplier), kernel_size,
stride=stride, dilation=dilation, padding=padding, depthwise=True)
self.conv_pw = create_conv2d(
int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias)
norm_act_layer, norm_act_args = convert_norm_act_type(norm_layer, act_layer, norm_kwargs)
self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block, **norm_act_args)
@property
def in_channels(self):
return self.conv_dw.in_channels
@property
def out_channels(self):
return self.conv_pw.out_channels
def forward(self, x):
x = self.conv_dw(x)
x = self.conv_pw(x)
if self.bn is not None:
x = self.bn(x)
return x
class SeparableConv2d(nn.Module):
""" Separable Conv
"""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False,
channel_multiplier=1.0, pw_kernel_size=1):
super(SeparableConv2d, self).__init__()
self.conv_dw = create_conv2d(
in_channels, int(in_channels * channel_multiplier), kernel_size,
stride=stride, dilation=dilation, padding=padding, depthwise=True)
self.conv_pw = create_conv2d(
int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias)
@property
def in_channels(self):
return self.conv_dw.in_channels
@property
def out_channels(self):
return self.conv_pw.out_channels
def forward(self, x):
x = self.conv_dw(x)
x = self.conv_pw(x)
return x

@ -0,0 +1,53 @@
import torch
import torch.nn as nn
class SpaceToDepth(nn.Module):
def __init__(self, block_size=4):
super().__init__()
assert block_size == 4
self.bs = block_size
def forward(self, x):
N, C, H, W = x.size()
x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs)
x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs)
x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs)
return x
@torch.jit.script
class SpaceToDepthJit(object):
def __call__(self, x: torch.Tensor):
# assuming hard-coded that block_size==4 for acceleration
N, C, H, W = x.size()
x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs)
x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs)
x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs)
return x
class SpaceToDepthModule(nn.Module):
def __init__(self, no_jit=False):
super().__init__()
if not no_jit:
self.op = SpaceToDepthJit()
else:
self.op = SpaceToDepth()
def forward(self, x):
return self.op(x)
class DepthToSpace(nn.Module):
def __init__(self, block_size):
super().__init__()
self.bs = block_size
def forward(self, x):
N, C, H, W = x.size()
x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W)
x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs)
x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs)
return x

@ -0,0 +1,88 @@
""" Split Attention Conv2d (for ResNeSt Models)
Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955
Adapted from original PyTorch impl at https://github.com/zhanghang1989/ResNeSt
Modified for torchscript compat, performance, and consistency with timm by Ross Wightman
"""
import torch
import torch.nn.functional as F
from torch import nn
class RadixSoftmax(nn.Module):
def __init__(self, radix, cardinality):
super(RadixSoftmax, self).__init__()
self.radix = radix
self.cardinality = cardinality
def forward(self, x):
batch = x.size(0)
if self.radix > 1:
x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)
x = F.softmax(x, dim=1)
x = x.reshape(batch, -1)
else:
x = torch.sigmoid(x)
return x
class SplitAttnConv2d(nn.Module):
"""Split-Attention Conv2d
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
dilation=1, groups=1, bias=False, radix=2, reduction_factor=4,
act_layer=nn.ReLU, norm_layer=None, drop_block=None, **kwargs):
super(SplitAttnConv2d, self).__init__()
self.radix = radix
self.drop_block = drop_block
mid_chs = out_channels * radix
attn_chs = max(in_channels * radix // reduction_factor, 32)
self.conv = nn.Conv2d(
in_channels, mid_chs, kernel_size, stride, padding, dilation,
groups=groups * radix, bias=bias, **kwargs)
self.bn0 = norm_layer(mid_chs) if norm_layer is not None else None
self.act0 = act_layer(inplace=True)
self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups)
self.bn1 = norm_layer(attn_chs) if norm_layer is not None else None
self.act1 = act_layer(inplace=True)
self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups)
self.rsoftmax = RadixSoftmax(radix, groups)
@property
def in_channels(self):
return self.conv.in_channels
@property
def out_channels(self):
return self.fc1.out_channels
def forward(self, x):
x = self.conv(x)
if self.bn0 is not None:
x = self.bn0(x)
if self.drop_block is not None:
x = self.drop_block(x)
x = self.act0(x)
B, RC, H, W = x.shape
if self.radix > 1:
x = x.reshape((B, self.radix, RC // self.radix, H, W))
x_gap = x.sum(dim=1)
else:
x_gap = x
x_gap = F.adaptive_avg_pool2d(x_gap, 1)
x_gap = self.fc1(x_gap)
if self.bn1 is not None:
x_gap = self.bn1(x_gap)
x_gap = self.act1(x_gap)
x_attn = self.fc2(x_gap)
x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1)
if self.radix > 1:
out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(dim=1)
else:
out = x * x_attn
return out.contiguous()

@ -0,0 +1,75 @@
""" Split BatchNorm
A PyTorch BatchNorm layer that splits input batch into N equal parts and passes each through
a separate BN layer. The first split is passed through the parent BN layers with weight/bias
keys the same as the original BN. All other splits pass through BN sub-layers under the '.aux_bn'
namespace.
This allows easily removing the auxiliary BN layers after training to efficiently
achieve the 'Auxiliary BatchNorm' as described in the AdvProp Paper, section 4.2,
'Disentangled Learning via An Auxiliary BN'
Hacked together by / Copyright 2020 Ross Wightman
"""
import torch
import torch.nn as nn
class SplitBatchNorm2d(torch.nn.BatchNorm2d):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
track_running_stats=True, num_splits=2):
super().__init__(num_features, eps, momentum, affine, track_running_stats)
assert num_splits > 1, 'Should have at least one aux BN layer (num_splits at least 2)'
self.num_splits = num_splits
self.aux_bn = nn.ModuleList([
nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_splits - 1)])
def forward(self, input: torch.Tensor):
if self.training: # aux BN only relevant while training
split_size = input.shape[0] // self.num_splits
assert input.shape[0] == split_size * self.num_splits, "batch size must be evenly divisible by num_splits"
split_input = input.split(split_size)
x = [super().forward(split_input[0])]
for i, a in enumerate(self.aux_bn):
x.append(a(split_input[i + 1]))
return torch.cat(x, dim=0)
else:
return super().forward(input)
def convert_splitbn_model(module, num_splits=2):
"""
Recursively traverse module and its children to replace all instances of
``torch.nn.modules.batchnorm._BatchNorm`` with `SplitBatchnorm2d`.
Args:
module (torch.nn.Module): input module
num_splits: number of separate batchnorm layers to split input across
Example::
>>> # model is an instance of torch.nn.Module
>>> model = timm.models.convert_splitbn_model(model, num_splits=2)
"""
mod = module
if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm):
return module
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
mod = SplitBatchNorm2d(
module.num_features, module.eps, module.momentum, module.affine,
module.track_running_stats, num_splits=num_splits)
mod.running_mean = module.running_mean
mod.running_var = module.running_var
mod.num_batches_tracked = module.num_batches_tracked
if module.affine:
mod.weight.data = module.weight.data.clone().detach()
mod.bias.data = module.bias.data.clone().detach()
for aux in mod.aux_bn:
aux.running_mean = module.running_mean.clone()
aux.running_var = module.running_var.clone()
aux.num_batches_tracked = module.num_batches_tracked.clone()
if module.affine:
aux.weight.data = module.weight.data.clone().detach()
aux.bias.data = module.bias.data.clone().detach()
for name, child in module.named_children():
mod.add_module(name, convert_splitbn_model(child, num_splits=num_splits))
del module
return mod

@ -0,0 +1,50 @@
""" Test Time Pooling (Average-Max Pool)
Hacked together by / Copyright 2020 Ross Wightman
"""
import logging
from torch import nn
import torch.nn.functional as F
from .adaptive_avgmax_pool import adaptive_avgmax_pool2d
_logger = logging.getLogger(__name__)
class TestTimePoolHead(nn.Module):
def __init__(self, base, original_pool=7):
super(TestTimePoolHead, self).__init__()
self.base = base
self.original_pool = original_pool
base_fc = self.base.get_classifier()
if isinstance(base_fc, nn.Conv2d):
self.fc = base_fc
else:
self.fc = nn.Conv2d(
self.base.num_features, self.base.num_classes, kernel_size=1, bias=True)
self.fc.weight.data.copy_(base_fc.weight.data.view(self.fc.weight.size()))
self.fc.bias.data.copy_(base_fc.bias.data.view(self.fc.bias.size()))
self.base.reset_classifier(0) # delete original fc layer
def forward(self, x):
x = self.base.forward_features(x)
x = F.avg_pool2d(x, kernel_size=self.original_pool, stride=1)
x = self.fc(x)
x = adaptive_avgmax_pool2d(x, 1)
return x.view(x.size(0), -1)
def apply_test_time_pool(model, config, args):
test_time_pool = False
if not hasattr(model, 'default_cfg') or not model.default_cfg:
return model, False
if not args.no_test_pool and \
config['input_size'][-1] > model.default_cfg['input_size'][-1] and \
config['input_size'][-2] > model.default_cfg['input_size'][-2]:
_logger.info('Target input size %s > pretrained default %s, using test time pooling' %
(str(config['input_size'][-2:]), str(model.default_cfg['input_size'][-2:])))
model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size'])
test_time_pool = True
return model, test_time_pool

@ -0,0 +1,60 @@
import torch
import math
import warnings
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
# type: (Tensor, float, float, float, float) -> Tensor
r"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
return _no_grad_trunc_normal_(tensor, mean, std, a, b)

@ -0,0 +1,462 @@
""" MobileNet V3
A PyTorch impl of MobileNet-V3, compatible with TF weights from official impl.
Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244
Hacked together by / Copyright 2020 Ross Wightman
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
from timm.models.features import FeatureInfo, FeatureHooks
from timm.models.helpers import build_model_with_cfg
from .layers import SelectAdaptivePool2d, create_conv2d, HardSigmoid
from timm.models.registry import register_model
from .efficientnet_blocks import ConvBnAct, SqueezeExcite, DepthwiseSeparableConv, InvertedResidual
__all__ = ['MobileNetV3']
def _cfg(url='', **kwargs):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (1, 1),
'crop_pct': 0.875, 'interpolation': 'bilinear',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'conv_stem', 'classifier': 'classifier',
**kwargs
}
default_cfgs = {
'mobilenetv3_large_075': _cfg(url=''),
'mobilenetv3_large_100': _cfg(
interpolation='bicubic',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth'),
'mobilenetv3_small_075': _cfg(url=''),
'mobilenetv3_small_100': _cfg(url=''),
'mobilenetv3_rw': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth',
interpolation='bicubic'),
'tf_mobilenetv3_large_075': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
'tf_mobilenetv3_large_100': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
'tf_mobilenetv3_large_minimal_100': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
'tf_mobilenetv3_small_075': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
'tf_mobilenetv3_small_100': _cfg(
url= 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
'tf_mobilenetv3_small_minimal_100': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
}
_DEBUG = False
class MobileNetV3(nn.Module):
""" MobiletNet-V3
Based on my EfficientNet implementation and building blocks, this model utilizes the MobileNet-v3 specific
'efficient head', where global pooling is done before the head convolution without a final batch-norm
layer before the classifier.
Paper: https://arxiv.org/abs/1905.02244
"""
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True,
channel_multiplier=1.0, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0.,
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, global_pool='avg'):
super(MobileNetV3, self).__init__()
#temporary fix to support tf_models with torch quantization
pad_type=''
self.num_classes = num_classes
self.num_features = num_features
self.drop_rate = drop_rate
# Stem
stem_size = round_channels(stem_size, channel_multiplier)
self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
self.bn1 = norm_layer(stem_size, **norm_kwargs)
self.act1 = act_layer(inplace=True)
# Middle stages (IR/ER/DS Blocks)
builder = EfficientNetBuilder(
channel_multiplier, 8, None, 32, pad_type, act_layer, se_kwargs,
norm_layer, norm_kwargs, drop_path_rate, verbose=_DEBUG)
self.blocks = nn.Sequential(*builder(stem_size, block_args))
self.feature_info = builder.features
head_chs = builder.in_chs
# Head + Pooling
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
num_pooled_chs = head_chs * self.global_pool.feat_mult()
self.conv_head = create_conv2d(num_pooled_chs, self.num_features, 1, padding=pad_type, bias=head_bias)
self.act2 = act_layer(inplace=True)
self.classifier = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
# Quantization Stubs
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
efficientnet_init_weights(self)
def as_sequential(self):
layers = [self.conv_stem, self.bn1, self.act1]
layers.extend(self.blocks)
layers.extend([self.global_pool, self.conv_head, self.act2])
layers.extend([nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier])
return nn.Sequential(*layers)
def get_classifier(self):
return self.classifier
def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes
# cannot meaningfully change pooling of efficient head after creation
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.classifier = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.conv_stem(x)
x = self.bn1(x)
x = self.act1(x)
x = self.blocks(x)
x = self.global_pool(x)
x = self.conv_head(x)
x = self.act2(x)
return x
def forward(self, x):
x = self.quant(x)
x = self.forward_features(x)
if not self.global_pool.is_identity():
x = x.flatten(1)
if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training)
x = self.classifier(x)
x = self.dequant(x)
return x
def fuse_model(self):
modules_to_fuse = [['conv_stem','bn1']]
if type(self.act1) == nn.ReLU:
modules_to_fuse[0].append('act1')
if type(self.act2) == nn.ReLU:
modules_to_fuse.append(['conv_head','act2'])
torch.quantization.fuse_modules(self, modules_to_fuse, inplace=True)
for m in self.modules():
if type(m) in [ConvBnAct, SqueezeExcite, DepthwiseSeparableConv, InvertedResidual]:
m.fuse_module()
class MobileNetV3Features(nn.Module):
""" MobileNetV3 Feature Extractor
A work-in-progress feature extraction module for MobileNet-V3 to use as a backbone for segmentation
and object detection models.
"""
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck',
in_chans=3, stem_size=16, channel_multiplier=1.0, output_stride=32, pad_type='',
act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0., se_kwargs=None,
norm_layer=nn.BatchNorm2d, norm_kwargs=None):
super(MobileNetV3Features, self).__init__()
norm_kwargs = norm_kwargs or {}
self.drop_rate = drop_rate
# Stem
stem_size = round_channels(stem_size, channel_multiplier)
self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
self.bn1 = norm_layer(stem_size, **norm_kwargs)
self.act1 = act_layer(inplace=True)
# Middle stages (IR/ER/DS Blocks)
builder = EfficientNetBuilder(
channel_multiplier, 8, None, output_stride, pad_type, act_layer, se_kwargs,
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
self.blocks = nn.Sequential(*builder(stem_size, block_args))
self.feature_info = FeatureInfo(builder.features, out_indices)
self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices}
efficientnet_init_weights(self)
# Register feature extraction hooks with FeatureHooks helper
self.feature_hooks = None
if feature_location != 'bottleneck':
hooks = self.feature_info.get_dicts(keys=('module', 'hook_type'))
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
def forward(self, x) -> List[torch.Tensor]:
x = self.conv_stem(x)
x = self.bn1(x)
x = self.act1(x)
if self.feature_hooks is None:
features = []
if 0 in self._stage_out_idx:
features.append(x) # add stem out
for i, b in enumerate(self.blocks):
x = b(x)
if i + 1 in self._stage_out_idx:
features.append(x)
return features
else:
self.blocks(x)
out = self.feature_hooks.get_output(x.device)
return list(out.values())
def _create_mnv3(model_kwargs, variant, pretrained=False):
if model_kwargs.pop('features_only', False):
load_strict = False
model_kwargs.pop('num_classes', 0)
model_kwargs.pop('num_features', 0)
model_kwargs.pop('head_conv', None)
model_kwargs.pop('head_bias', None)
model_cls = MobileNetV3Features
else:
load_strict = True
model_cls = MobileNetV3
return build_model_with_cfg(
model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
pretrained_strict=load_strict, **model_kwargs)
def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
"""Creates a MobileNet-V3 model.
Ref impl: ?
Paper: https://arxiv.org/abs/1905.02244
Args:
channel_multiplier: multiplier to number of channels per layer.
"""
arch_def = [
# stage 0, 112x112 in
['ds_r1_k3_s1_e1_c16_nre_noskip'], # relu
# stage 1, 112x112 in
['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu
# stage 2, 56x56 in
['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu
# stage 3, 28x28 in
['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish
# stage 4, 14x14in
['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish
# stage 5, 14x14in
['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish
# stage 6, 7x7 in
['cn_r1_k1_s1_c960'], # hard-swish
]
model_kwargs = dict(
block_args=decode_arch_def(arch_def),
head_bias=False,
channel_multiplier=channel_multiplier,
norm_kwargs=resolve_bn_args(kwargs),
act_layer=resolve_act_layer(kwargs, 'hard_swish'),
se_kwargs=dict(gate_fn=tf_mobilenetv3_large_100, reduce_mid=True, divisor=1),
**kwargs,
)
model = _create_mnv3(model_kwargs, variant, pretrained)
return model
def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
"""Creates a MobileNet-V3 model.
Ref impl: ?
Paper: https://arxiv.org/abs/1905.02244
Args:
channel_multiplier: multiplier to number of channels per layer.
"""
if 'small' in variant:
num_features = 1024
if 'minimal' in variant:
act_layer = resolve_act_layer(kwargs, 'relu')
arch_def = [
# stage 0, 112x112 in
['ds_r1_k3_s2_e1_c16'],
# stage 1, 56x56 in
['ir_r1_k3_s2_e4.5_c24', 'ir_r1_k3_s1_e3.67_c24'],
# stage 2, 28x28 in
['ir_r1_k3_s2_e4_c40', 'ir_r2_k3_s1_e6_c40'],
# stage 3, 14x14 in
['ir_r2_k3_s1_e3_c48'],
# stage 4, 14x14in
['ir_r3_k3_s2_e6_c96'],
# stage 6, 7x7 in
['cn_r1_k1_s1_c576'],
]
else:
act_layer = resolve_act_layer(kwargs, 'hard_swish')
arch_def = [
# stage 0, 112x112 in
['ds_r1_k3_s2_e1_c16_se0.25_nre'], # relu
# stage 1, 56x56 in
['ir_r1_k3_s2_e4.5_c24_nre', 'ir_r1_k3_s1_e3.67_c24_nre'], # relu
# stage 2, 28x28 in
['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r2_k5_s1_e6_c40_se0.25'], # hard-swish
# stage 3, 14x14 in
['ir_r2_k5_s1_e3_c48_se0.25'], # hard-swish
# stage 4, 14x14in
['ir_r3_k5_s2_e6_c96_se0.25'], # hard-swish
# stage 6, 7x7 in
['cn_r1_k1_s1_c576'], # hard-swish
]
else:
num_features = 1280
if 'minimal' in variant:
act_layer = resolve_act_layer(kwargs, 'relu')
arch_def = [
# stage 0, 112x112 in
['ds_r1_k3_s1_e1_c16'],
# stage 1, 112x112 in
['ir_r1_k3_s2_e4_c24', 'ir_r1_k3_s1_e3_c24'],
# stage 2, 56x56 in
['ir_r3_k3_s2_e3_c40'],
# stage 3, 28x28 in
['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'],
# stage 4, 14x14in
['ir_r2_k3_s1_e6_c112'],
# stage 5, 14x14in
['ir_r3_k3_s2_e6_c160'],
# stage 6, 7x7 in
['cn_r1_k1_s1_c960'],
]
else:
act_layer = resolve_act_layer(kwargs, 'hard_swish')
arch_def = [
# stage 0, 112x112 in
['ds_r1_k3_s1_e1_c16_nre'], # relu
# stage 1, 112x112 in
['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu
# stage 2, 56x56 in
['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu
# stage 3, 28x28 in
['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish
# stage 4, 14x14in
['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish
# stage 5, 14x14in
['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish
# stage 6, 7x7 in
['cn_r1_k1_s1_c960'], # hard-swish
]
model_kwargs = dict(
block_args=decode_arch_def(arch_def),
num_features=num_features,
stem_size=16,
channel_multiplier=channel_multiplier,
norm_kwargs=resolve_bn_args(kwargs),
act_layer=act_layer,
se_kwargs=dict(act_layer=nn.ReLU, gate_fn=HardSigmoid, reduce_mid=True, divisor=8),
**kwargs,
)
model = _create_mnv3(model_kwargs, variant, pretrained)
return model
@register_model
def quant_mobilenetv3_large_075(pretrained=False, **kwargs):
""" MobileNet V3 """
model = _gen_mobilenet_v3('mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs)
return model
@register_model
def quant_mobilenetv3_large_100(pretrained=False, **kwargs):
""" MobileNet V3 """
model = _gen_mobilenet_v3('mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs)
return model
@register_model
def quant_mobilenetv3_small_075(pretrained=False, **kwargs):
""" MobileNet V3 """
model = _gen_mobilenet_v3('mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs)
return model
@register_model
def quant_mobilenetv3_small_100(pretrained=False, **kwargs):
""" MobileNet V3 """
model = _gen_mobilenet_v3('mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs)
return model
@register_model
def quant_mobilenetv3_rw(pretrained=False, **kwargs):
""" MobileNet V3 """
if pretrained:
# pretrained model trained with non-default BN epsilon
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
model = _gen_mobilenet_v3_rw('mobilenetv3_rw', 1.0, pretrained=pretrained, **kwargs)
return model
@register_model
def quant_tf_mobilenetv3_large_075(pretrained=False, **kwargs):
""" MobileNet V3 """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_mobilenet_v3('tf_mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs)
return model
@register_model
def quant_tf_mobilenetv3_large_100(pretrained=False, **kwargs):
""" MobileNet V3 """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_mobilenet_v3('tf_mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs)
return model
@register_model
def quant_tf_mobilenetv3_large_minimal_100(pretrained=False, **kwargs):
""" MobileNet V3 """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_mobilenet_v3('tf_mobilenetv3_large_minimal_100', 1.0, pretrained=pretrained, **kwargs)
return model
@register_model
def quant_tf_mobilenetv3_small_075(pretrained=False, **kwargs):
""" MobileNet V3 """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_mobilenet_v3('tf_mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs)
return model
@register_model
def quant_tf_mobilenetv3_small_100(pretrained=False, **kwargs):
""" MobileNet V3 """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_mobilenet_v3('tf_mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs)
return model
@register_model
def quant_tf_mobilenetv3_small_minimal_100(pretrained=False, **kwargs):
""" MobileNet V3 """
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
kwargs['pad_type'] = 'same'
model = _gen_mobilenet_v3('tf_mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs)
return model

@ -0,0 +1,323 @@
""" ReXNet
A PyTorch impl of `ReXNet: Diminishing Representational Bottleneck on Convolutional Neural Network` -
https://arxiv.org/abs/2007.00992
Adapted from original impl at https://github.com/clovaai/rexnet
Copyright (c) 2020-present NAVER Corp. MIT license
Changes for timm, feature extraction, and rounded channel variant hacked together by Ross Wightman
Copyright 2020 Ross Wightman
"""
import torch
import torch.nn as nn
from math import ceil
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.helpers import build_model_with_cfg
from .layers import ClassifierHead, create_act_layer, create_conv2d
from timm.models.registry import register_model
from .layers.activations import sigmoid, Swish, HardSwish, HardSigmoid
def _cfg(url=''):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.conv', 'classifier': 'head.fc',
}
default_cfgs = dict(
rexnet_100=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rexnet/rexnetv1_100-1b4dddf4.pth'),
rexnet_130=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rexnet/rexnetv1_130-590d768e.pth'),
rexnet_150=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rexnet/rexnetv1_150-bd1a6aa8.pth'),
rexnet_200=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rexnet/rexnetv1_200-8c0b7f2d.pth'),
rexnetr_100=_cfg(
url=''),
rexnetr_130=_cfg(
url=''),
rexnetr_150=_cfg(
url=''),
rexnetr_200=_cfg(
url=''),
)
def make_divisible(v, divisor=8, min_value=None):
min_value = min_value or divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
return new_v
class ConvBn(nn.Module):
def __init__(self, in_chs, out_chs, kernel_size,
stride=1, dilation=1, pad_type='',
norm_layer=nn.BatchNorm2d, groups = 1,norm_kwargs=None):
super(ConvBn, self).__init__()
norm_kwargs = norm_kwargs or {}
self.conv = create_conv2d(in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, groups=groups,padding=pad_type)
self.bn1 = norm_layer(out_chs, **norm_kwargs)
def feature_info(self, location):
if location == 'expansion': # output of conv after act, same as block coutput
info = dict(module='act1', hook_type='forward', num_chs=self.conv.out_channels)
else: # location == 'bottleneck', block output
info = dict(module='', hook_type='', num_chs=self.conv.out_channels)
return info
def forward(self, x):
x = self.conv(x)
x = self.bn1(x)
return x
def fuse_module(self):
modules_to_fuse = ['conv','bn1']
torch.quantization.fuse_modules(self, modules_to_fuse, inplace=True)
class ConvBnAct(nn.Module):
def __init__(self, in_chs, out_chs, kernel_size,
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d, norm_kwargs=None):
super(ConvBnAct, self).__init__()
norm_kwargs = norm_kwargs or {}
self.conv = create_conv2d(in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, padding=pad_type)
self.bn1 = norm_layer(out_chs, **norm_kwargs)
self.act1 = act_layer(inplace=True)
self.out_channels = out_chs
def feature_info(self, location):
if location == 'expansion': # output of conv after act, same as block coutput
info = dict(module='act1', hook_type='forward', num_chs=self.conv.out_channels)
else: # location == 'bottleneck', block output
info = dict(module='', hook_type='', num_chs=self.conv.out_channels)
return info
def forward(self, x):
x = self.conv(x)
x = self.bn1(x)
x = self.act1(x)
return x
def fuse_module(self):
modules_to_fuse = ['conv','bn1']
if type(self.act1) == nn.ReLU:
modules_to_fuse.append('act1')
torch.quantization.fuse_modules(self, modules_to_fuse, inplace=True)
class SEWithNorm(nn.Module):
def __init__(self, channels, reduction=16, act_layer=nn.ReLU, divisor=1, reduction_channels=None,
gate_layer='sigmoid'):
super(SEWithNorm, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
reduction_channels = reduction_channels or make_divisible(channels // reduction, divisor=divisor)
self.fc1 = nn.Conv2d(
channels, reduction_channels, kernel_size=1, padding=0, bias=True)
self.bn = nn.BatchNorm2d(reduction_channels)
self.act = act_layer(inplace=True)
self.fc2 = nn.Conv2d(
reduction_channels, channels, kernel_size=1, padding=0, bias=True)
self.gate = create_act_layer(gate_layer)
self.quant_mul = nn.quantized.FloatFunctional()
def forward(self, x):
x_se = self.avg_pool(x)
x_se = self.fc1(x_se)
x_se = self.bn(x_se)
x_se = self.act(x_se)
x_se = self.fc2(x_se)
return self.quant_mul.mul(x, self.gate(x_se))
def fuse_module(self):
modules_to_fuse = ['fc1','bn','act']
torch.quantization.fuse_modules(self, modules_to_fuse, inplace=True)
class LinearBottleneck(nn.Module):
def __init__(self, in_chs, out_chs, stride, exp_ratio=1.0, use_se=True, se_rd=12, ch_div=1):
super(LinearBottleneck, self).__init__()
self.use_shortcut = stride == 1 and in_chs <= out_chs
self.in_channels = in_chs
self.out_channels = out_chs
if exp_ratio != 1.:
dw_chs = make_divisible(round(in_chs * exp_ratio), divisor=ch_div)
self.conv_exp = ConvBnAct(in_chs, dw_chs,1, act_layer=Swish)
else:
dw_chs = in_chs
self.conv_exp = None
self.conv_dw = ConvBn(dw_chs, dw_chs, 3, stride=stride, groups=dw_chs)
self.se = SEWithNorm(dw_chs, reduction=se_rd, divisor=ch_div) if use_se else None
self.act_dw = nn.ReLU6()
self.conv_pwl = ConvBn(dw_chs, out_chs, 1)
if self.use_shortcut:
self.skip_add = nn.quantized.FloatFunctional()
def feat_channels(self, exp=False):
return self.conv_dw.out_channels if exp else self.out_channels
def forward(self, x):
shortcut = x
if self.conv_exp is not None:
x = self.conv_exp(x)
x = self.conv_dw(x)
if self.se is not None:
x = self.se(x)
x = self.act_dw(x)
x = self.conv_pwl(x)
if self.use_shortcut:
x[:, 0:self.in_channels]= self.skip_add.add(x[:, 0:self.in_channels], shortcut)
return x
def _block_cfg(width_mult=1.0, depth_mult=1.0, initial_chs=16, final_chs=180, use_se=True, ch_div=1):
layers = [1, 2, 2, 3, 3, 5]
strides = [1, 2, 2, 2, 1, 2]
layers = [ceil(element * depth_mult) for element in layers]
strides = sum([[element] + [1] * (layers[idx] - 1) for idx, element in enumerate(strides)], [])
exp_ratios = [1] * layers[0] + [6] * sum(layers[1:])
depth = sum(layers[:]) * 3
base_chs = initial_chs / width_mult if width_mult < 1.0 else initial_chs
# The following channel configuration is a simple instance to make each layer become an expand layer.
out_chs_list = []
for i in range(depth // 3):
out_chs_list.append(make_divisible(round(base_chs * width_mult), divisor=ch_div))
base_chs += final_chs / (depth // 3 * 1.0)
if use_se:
use_ses = [False] * (layers[0] + layers[1]) + [True] * sum(layers[2:])
else:
use_ses = [False] * sum(layers[:])
return zip(out_chs_list, exp_ratios, strides, use_ses)
def _build_blocks(block_cfg, prev_chs, width_mult, se_rd=12, ch_div=1, feature_location='bottleneck'):
feat_exp = feature_location == 'expansion'
feat_chs = [prev_chs]
feature_info = []
curr_stride = 2
features = []
for block_idx, (chs, exp_ratio, stride, se) in enumerate(block_cfg):
if stride > 1:
fname = 'stem' if block_idx == 0 else f'features.{block_idx - 1}'
if block_idx > 0 and feat_exp:
fname += '.act_dw'
feature_info += [dict(num_chs=feat_chs[-1], reduction=curr_stride, module=fname)]
curr_stride *= stride
features.append(LinearBottleneck(
in_chs=prev_chs, out_chs=chs, exp_ratio=exp_ratio, stride=stride, use_se=se, se_rd=se_rd, ch_div=ch_div))
prev_chs = chs
feat_chs += [features[-1].feat_channels(feat_exp)]
pen_chs = make_divisible(1280 * width_mult, divisor=ch_div)
feature_info += [dict(
num_chs=pen_chs if feat_exp else feat_chs[-1], reduction=curr_stride,
module=f'features.{len(features) - int(not feat_exp)}')]
features.append(ConvBnAct(prev_chs, pen_chs,1, act_layer=Swish))
return features, feature_info
class ReXNetV1(nn.Module):
def __init__(self, in_chans=3, num_classes=1000, global_pool='avg', output_stride=32,
initial_chs=16, final_chs=180, width_mult=1.0, depth_mult=1.0, use_se=True,
se_rd=12, ch_div=1, drop_rate=0.2, feature_location='bottleneck'):
super(ReXNetV1, self).__init__()
self.drop_rate = drop_rate
assert output_stride == 32 # FIXME support dilation
stem_base_chs = 32 / width_mult if width_mult < 1.0 else 32
stem_chs = make_divisible(round(stem_base_chs * width_mult), divisor=ch_div)
self.stem = ConvBnAct(in_chans, stem_chs, 3, stride=2, act_layer=Swish)
block_cfg = _block_cfg(width_mult, depth_mult, initial_chs, final_chs, use_se, ch_div)
features, self.feature_info = _build_blocks(
block_cfg, stem_chs, width_mult, se_rd, ch_div, feature_location)
self.num_features = features[-1].out_channels
self.features = nn.Sequential(*features)
self.head = ClassifierHead(self.num_features, num_classes, global_pool, drop_rate)
# Quantization Stubs
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
# FIXME weight init, the original appears to use PyTorch defaults
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool='avg'):
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
def forward_features(self, x):
x = self.stem(x)
x = self.features(x)
return x
def forward(self, x):
x = self.quant(x)
x = self.forward_features(x)
x = self.head(x)
x = self.dequant(x)
return x
def fuse_model(self):
for m in self.modules():
if type(m) in [ConvBnAct, ConvBn, SEWithNorm]:
m.fuse_module()
def _create_rexnet(variant, pretrained, **kwargs):
feature_cfg = dict(flatten_sequential=True)
if kwargs.get('feature_location', '') == 'expansion':
feature_cfg['feature_cls'] = 'hook'
return build_model_with_cfg(
ReXNetV1, variant, pretrained, default_cfg=default_cfgs[variant], feature_cfg=feature_cfg, **kwargs)
@register_model
def quant_rexnet_100(pretrained=False, **kwargs):
"""ReXNet V1 1.0x"""
return _create_rexnet('rexnet_100', pretrained, **kwargs)
@register_model
def quant_rexnet_130(pretrained=False, **kwargs):
"""ReXNet V1 1.3x"""
return _create_rexnet('rexnet_130', pretrained, width_mult=1.3, **kwargs)
@register_model
def quant_rexnet_150(pretrained=False, **kwargs):
"""ReXNet V1 1.5x"""
return _create_rexnet('rexnet_150', pretrained, width_mult=1.5, **kwargs)
@register_model
def quant_rexnet_200(pretrained=False, **kwargs):
"""ReXNet V1 2.0x"""
return _create_rexnet('rexnet_200', pretrained, width_mult=2.0, **kwargs)
@register_model
def quant_rexnetr_100(pretrained=False, **kwargs):
"""ReXNet V1 1.0x w/ rounded (mod 8) channels"""
return _create_rexnet('rexnetr_100', pretrained, ch_div=8, **kwargs)
@register_model
def quant_rexnetr_130(pretrained=False, **kwargs):
"""ReXNet V1 1.3x w/ rounded (mod 8) channels"""
return _create_rexnet('rexnetr_130', pretrained, width_mult=1.3, ch_div=8, **kwargs)
@register_model
def quant_rexnetr_150(pretrained=False, **kwargs):
"""ReXNet V1 1.5x w/ rounded (mod 8) channels"""
return _create_rexnet('rexnetr_150', pretrained, width_mult=1.5, ch_div=8, **kwargs)
@register_model
def quant_rexnetr_200(pretrained=False, **kwargs):
"""ReXNet V1 2.0x w/ rounded (mod 8) channels"""
return _create_rexnet('rexnetr_200', pretrained, width_mult=2.0, ch_div=8, **kwargs)
Loading…
Cancel
Save