Merge bc25e17bed
into a2727c1bf7
commit
00dfd2cb6d
@ -0,0 +1,456 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
""" ImageNet Validation Script
|
||||||
|
|
||||||
|
This is intended to be a lean and easily modifiable ImageNet validation script for evaluating pretrained
|
||||||
|
models or training checkpoints against ImageNet or similarly organized image datasets. It prioritizes
|
||||||
|
canonical PyTorch, standard Python style, and good performance. Repurpose as you see fit.
|
||||||
|
|
||||||
|
Hacked together by Ross Wightman (https://github.com/rwightman)
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import csv
|
||||||
|
import glob
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.parallel
|
||||||
|
from collections import OrderedDict
|
||||||
|
from contextlib import suppress
|
||||||
|
|
||||||
|
import torch.quantization
|
||||||
|
import torch.quantization.quantize_fx as quantize_fx
|
||||||
|
import copy
|
||||||
|
#currently, quantization only runs on CPUs
|
||||||
|
os.environ['CUDA_VISIBLE_DEVICES'] = ""
|
||||||
|
|
||||||
|
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
|
||||||
|
from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet
|
||||||
|
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_legacy
|
||||||
|
|
||||||
|
#has_apex = False
|
||||||
|
#try:
|
||||||
|
# from apex import amp
|
||||||
|
# has_apex = True
|
||||||
|
#except ImportError:
|
||||||
|
# pass
|
||||||
|
|
||||||
|
#as_native_amp = False
|
||||||
|
#try:
|
||||||
|
# if getattr(torch.cuda.amp, 'autocast') is not None:
|
||||||
|
# has_native_amp = True
|
||||||
|
#except AttributeError:
|
||||||
|
# pass
|
||||||
|
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
_logger = logging.getLogger('validate')
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')
|
||||||
|
parser.add_argument('data', metavar='DIR',
|
||||||
|
help='path to dataset')
|
||||||
|
parser.add_argument('--dataset', '-d', metavar='NAME', default='',
|
||||||
|
help='dataset type (default: ImageFolder/ImageTar if empty)')
|
||||||
|
#argument for calibration dataset
|
||||||
|
parser.add_argument('--calib-data', metavar='DIR',
|
||||||
|
help='path to calibration dataset')
|
||||||
|
# quantization option(weight only, dynamic, static)
|
||||||
|
parser.add_argument('--quant_option', metavar='NAME', default='static',
|
||||||
|
help='quantization option (weight_only, dynamic, static) (default: static)')
|
||||||
|
parser.add_argument('--split', metavar='NAME', default='validation',
|
||||||
|
help='dataset split (default: validation)')
|
||||||
|
parser.add_argument('--model', '-m', metavar='NAME', default='dpn92',
|
||||||
|
help='model architecture (default: dpn92)')
|
||||||
|
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
|
||||||
|
help='number of data loading workers (default: 2)')
|
||||||
|
parser.add_argument('-b', '--batch-size', default=256, type=int,
|
||||||
|
metavar='N', help='mini-batch size (default: 256)')
|
||||||
|
parser.add_argument('--img-size', default=None, type=int,
|
||||||
|
metavar='N', help='Input image dimension, uses model default if empty')
|
||||||
|
parser.add_argument('--input-size', default=None, nargs=3, type=int,
|
||||||
|
metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
|
||||||
|
parser.add_argument('--crop-pct', default=None, type=float,
|
||||||
|
metavar='N', help='Input image center crop pct')
|
||||||
|
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
|
||||||
|
help='Override mean pixel value of dataset')
|
||||||
|
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
|
||||||
|
help='Override std deviation of of dataset')
|
||||||
|
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
|
||||||
|
help='Image resize interpolation type (overrides model)')
|
||||||
|
parser.add_argument('--num-classes', type=int, default=None,
|
||||||
|
help='Number classes in dataset')
|
||||||
|
parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',
|
||||||
|
help='path to class to idx mapping file (default: "")')
|
||||||
|
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
|
||||||
|
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
|
||||||
|
parser.add_argument('--log-freq', default=10, type=int,
|
||||||
|
metavar='N', help='batch logging frequency (default: 10)')
|
||||||
|
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
|
||||||
|
help='path to latest checkpoint (default: none)')
|
||||||
|
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
|
||||||
|
help='use pre-trained model')
|
||||||
|
#parser.add_argument('--num-gpu', type=int, default=1,
|
||||||
|
# help='Number of GPUS to use')
|
||||||
|
#num-gpu is set to zero(no gpu usage)
|
||||||
|
parser.add_argument('--num-gpu', type=int, default=0,
|
||||||
|
help='Number of GPUS to use')
|
||||||
|
parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true',
|
||||||
|
help='disable test time pool')
|
||||||
|
parser.add_argument('--no-prefetcher', action='store_true', default=False,
|
||||||
|
help='disable fast prefetcher')
|
||||||
|
parser.add_argument('--pin-mem', action='store_true', default=False,
|
||||||
|
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
||||||
|
parser.add_argument('--channels-last', action='store_true', default=False,
|
||||||
|
help='Use channels_last memory layout')
|
||||||
|
#parser.add_argument('--amp', action='store_true', default=False,
|
||||||
|
# help='Use AMP mixed precision. Defaults to Apex, fallback to native Torch AMP.')
|
||||||
|
#parser.add_argument('--apex-amp', action='store_true', default=False,
|
||||||
|
# help='Use NVIDIA Apex AMP mixed precision')
|
||||||
|
#parser.add_argument('--native-amp', action='store_true', default=False,
|
||||||
|
# help='Use Native Torch AMP mixed precision')
|
||||||
|
parser.add_argument('--tf-preprocessing', action='store_true', default=False,
|
||||||
|
help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
|
||||||
|
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
|
||||||
|
help='use ema version of weights if present')
|
||||||
|
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
|
||||||
|
help='convert model torchscript for inference')
|
||||||
|
parser.add_argument('--legacy-jit', dest='legacy_jit', action='store_true',
|
||||||
|
help='use legacy jit mode for pytorch 1.5/1.5.1/1.6 to get back fusion performance')
|
||||||
|
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
|
||||||
|
help='Output csv file for validation results (summary)')
|
||||||
|
parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME',
|
||||||
|
help='Real labels JSON file for imagenet evaluation')
|
||||||
|
parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME',
|
||||||
|
help='Valid label indices txt file for validation of partial label space')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def validate(args):
|
||||||
|
# might as well try to validate something
|
||||||
|
args.pretrained = args.pretrained or not args.checkpoint
|
||||||
|
args.prefetcher = not args.no_prefetcher
|
||||||
|
# amp_autocast = suppress # do nothing
|
||||||
|
# if args.amp:
|
||||||
|
# if has_native_amp:
|
||||||
|
# args.native_amp = True
|
||||||
|
# elif has_apex:
|
||||||
|
# args.apex_amp = True
|
||||||
|
# else:
|
||||||
|
# _logger.warning("Neither APEX or Native Torch AMP is available.")
|
||||||
|
# assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set."
|
||||||
|
# if args.native_amp:
|
||||||
|
# amp_autocast = torch.cuda.amp.autocast
|
||||||
|
# _logger.info('Validating in mixed precision with native PyTorch AMP.')
|
||||||
|
# elif args.apex_amp:
|
||||||
|
# _logger.info('Validating in mixed precision with NVIDIA APEX AMP.')
|
||||||
|
# else:
|
||||||
|
# _logger.info('Validating in float32. AMP not enabled.')
|
||||||
|
|
||||||
|
if args.legacy_jit:
|
||||||
|
set_jit_legacy()
|
||||||
|
|
||||||
|
# create model
|
||||||
|
model = create_model(
|
||||||
|
args.model,
|
||||||
|
pretrained=args.pretrained,
|
||||||
|
num_classes=args.num_classes,
|
||||||
|
in_chans=3,
|
||||||
|
global_pool=args.gp,
|
||||||
|
scriptable=args.torchscript)
|
||||||
|
if args.num_classes is None:
|
||||||
|
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
|
||||||
|
args.num_classes = model.num_classes
|
||||||
|
|
||||||
|
if args.checkpoint:
|
||||||
|
load_checkpoint(model, args.checkpoint, args.use_ema)
|
||||||
|
|
||||||
|
param_count = sum([m.numel() for m in model.parameters()])
|
||||||
|
_logger.info('Model %s created, param count: %d' % (args.model, param_count))
|
||||||
|
|
||||||
|
data_config = resolve_data_config(vars(args), model=model, use_test_size=True)
|
||||||
|
test_time_pool = False
|
||||||
|
if not args.no_test_pool:
|
||||||
|
model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True)
|
||||||
|
|
||||||
|
if args.torchscript:
|
||||||
|
torch.jit.optimized_execution(True)
|
||||||
|
model = torch.jit.script(model)
|
||||||
|
|
||||||
|
# model = model.cuda()
|
||||||
|
# if args.apex_amp:
|
||||||
|
# model = amp.initialize(model, opt_level='O1')
|
||||||
|
|
||||||
|
if args.channels_last:
|
||||||
|
model = model.to(memory_format=torch.channels_last)
|
||||||
|
|
||||||
|
# if args.num_gpu > 1:
|
||||||
|
# model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))
|
||||||
|
|
||||||
|
# criterion = nn.CrossEntropyLoss().cuda()
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
dataset = create_dataset(
|
||||||
|
root=args.data, name=args.dataset, split=args.split,
|
||||||
|
load_bytes=args.tf_preprocessing, class_map=args.class_map)
|
||||||
|
|
||||||
|
# added for post quantization calibration
|
||||||
|
|
||||||
|
calib_dataset = create_dataset(
|
||||||
|
root=args.data, name=args.dataset, split=args.split,
|
||||||
|
load_bytes=args.tf_preprocessing, class_map=args.class_map)
|
||||||
|
|
||||||
|
|
||||||
|
if args.valid_labels:
|
||||||
|
with open(args.valid_labels, 'r') as f:
|
||||||
|
valid_labels = {int(line.rstrip()) for line in f}
|
||||||
|
valid_labels = [i in valid_labels for i in range(args.num_classes)]
|
||||||
|
else:
|
||||||
|
valid_labels = None
|
||||||
|
|
||||||
|
if args.real_labels:
|
||||||
|
real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels)
|
||||||
|
else:
|
||||||
|
real_labels = None
|
||||||
|
|
||||||
|
crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
|
||||||
|
loader = create_loader(
|
||||||
|
dataset,
|
||||||
|
input_size=data_config['input_size'],
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
use_prefetcher=args.prefetcher,
|
||||||
|
interpolation=data_config['interpolation'],
|
||||||
|
mean=data_config['mean'],
|
||||||
|
std=data_config['std'],
|
||||||
|
num_workers=args.workers,
|
||||||
|
crop_pct=crop_pct,
|
||||||
|
pin_memory=args.pin_mem,
|
||||||
|
tf_preprocessing=args.tf_preprocessing)
|
||||||
|
|
||||||
|
#Also create loader for calibration dataset
|
||||||
|
calib_loader = create_loader(
|
||||||
|
calib_dataset,
|
||||||
|
input_size=data_config['input_size'],
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
use_prefetcher=args.prefetcher,
|
||||||
|
interpolation=data_config['interpolation'],
|
||||||
|
mean=data_config['mean'],
|
||||||
|
std=data_config['std'],
|
||||||
|
num_workers=args.workers,
|
||||||
|
crop_pct=crop_pct,
|
||||||
|
pin_memory=args.pin_mem,
|
||||||
|
tf_preprocessing=args.tf_preprocessing)
|
||||||
|
|
||||||
|
batch_time = AverageMeter()
|
||||||
|
losses = AverageMeter()
|
||||||
|
top1 = AverageMeter()
|
||||||
|
top5 = AverageMeter()
|
||||||
|
|
||||||
|
print('Start calibration of quantization observers before post-quantization')
|
||||||
|
model_to_quantize = copy.deepcopy(model)
|
||||||
|
model_to_quantize.eval()
|
||||||
|
|
||||||
|
#post training static quantization
|
||||||
|
if args.quant_option == 'static':
|
||||||
|
qconfig_dict = {"": torch.quantization.default_static_qconfig}
|
||||||
|
model_to_quantize = copy.deepcopy(model_fp)
|
||||||
|
qconfig_dict = {"": torch.quantization.get_default_qconfig('qnnpack')}
|
||||||
|
model_to_quantize.eval()
|
||||||
|
# prepare
|
||||||
|
model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_dict)
|
||||||
|
# calibrate
|
||||||
|
with torch.no_grad():
|
||||||
|
# warmup, reduce variability of first batch time, especially for comparing torchscript vs non
|
||||||
|
input = torch.randn((args.batch_size,) + tuple(data_config['input_size']))
|
||||||
|
if args.channels_last:
|
||||||
|
input = input.contiguous(memory_format=torch.channels_last)
|
||||||
|
model(input)
|
||||||
|
end = time.time()
|
||||||
|
for batch_idx, (input, target) in enumerate(loader):
|
||||||
|
|
||||||
|
if args.channels_last:
|
||||||
|
input = input.contiguous(memory_format=torch.channels_last)
|
||||||
|
|
||||||
|
if valid_labels is not None:
|
||||||
|
output = output[:, valid_labels]
|
||||||
|
loss = criterion(output, target)
|
||||||
|
|
||||||
|
if real_labels is not None:
|
||||||
|
real_labels.add_result(output)
|
||||||
|
|
||||||
|
# measure accuracy and record loss
|
||||||
|
acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5))
|
||||||
|
losses.update(loss.item(), input.size(0))
|
||||||
|
top1.update(acc1.item(), input.size(0))
|
||||||
|
top5.update(acc5.item(), input.size(0))
|
||||||
|
|
||||||
|
# measure elapsed time
|
||||||
|
batch_time.update(time.time() - end)
|
||||||
|
end = time.time()
|
||||||
|
|
||||||
|
if batch_idx % args.log_freq == 0:
|
||||||
|
_logger.info(
|
||||||
|
'Test: [{0:>4d}/{1}] '
|
||||||
|
'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
|
||||||
|
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
|
||||||
|
'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) '
|
||||||
|
'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
|
||||||
|
batch_idx, len(loader), batch_time=batch_time,
|
||||||
|
rate_avg=input.size(0) / batch_time.avg,
|
||||||
|
loss=losses, top1=top1, top5=top5))
|
||||||
|
# quantize
|
||||||
|
model_quantized = quantize_fx.convert_fx(model_prepared)
|
||||||
|
#post training dynamic/weight only quantization
|
||||||
|
elif args.quant_option == 'dynamic':
|
||||||
|
qconfig_dict = {"": torch.quantization.default_dynamic_qconfig}
|
||||||
|
# prepare
|
||||||
|
model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_dict)
|
||||||
|
# no calibration needed when we only have dynamici/weight_only quantization
|
||||||
|
# quantize
|
||||||
|
model_quantized = quantize_fx.convert_fx(model_prepared)
|
||||||
|
else:
|
||||||
|
_logger.warning("Invalid quantization option. Set option to default(static)")
|
||||||
|
#
|
||||||
|
# fusion
|
||||||
|
#
|
||||||
|
model_to_quantize = copy.deepcopy(model_fp)
|
||||||
|
model_fused = quantize_fx.fuse_fx(model_to_quantize)
|
||||||
|
|
||||||
|
model = model_fused
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# warmup, reduce variability of first batch time, especially for comparing torchscript vs non
|
||||||
|
# input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).cuda()
|
||||||
|
input = torch.randn((args.batch_size,) + tuple(data_config['input_size']))
|
||||||
|
if args.channels_last:
|
||||||
|
input = input.contiguous(memory_format=torch.channels_last)
|
||||||
|
model(input)
|
||||||
|
end = time.time()
|
||||||
|
for batch_idx, (input, target) in enumerate(loader):
|
||||||
|
# if args.no_prefetcher:
|
||||||
|
# target = target.cuda()
|
||||||
|
# input = input.cuda()
|
||||||
|
if args.channels_last:
|
||||||
|
input = input.contiguous(memory_format=torch.channels_last)
|
||||||
|
|
||||||
|
# compute output
|
||||||
|
# with amp_autocast():
|
||||||
|
# output = model(input)
|
||||||
|
|
||||||
|
if valid_labels is not None:
|
||||||
|
output = output[:, valid_labels]
|
||||||
|
loss = criterion(output, target)
|
||||||
|
|
||||||
|
if real_labels is not None:
|
||||||
|
real_labels.add_result(output)
|
||||||
|
|
||||||
|
# measure accuracy and record loss
|
||||||
|
acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5))
|
||||||
|
losses.update(loss.item(), input.size(0))
|
||||||
|
top1.update(acc1.item(), input.size(0))
|
||||||
|
top5.update(acc5.item(), input.size(0))
|
||||||
|
|
||||||
|
# measure elapsed time
|
||||||
|
batch_time.update(time.time() - end)
|
||||||
|
end = time.time()
|
||||||
|
|
||||||
|
if batch_idx % args.log_freq == 0:
|
||||||
|
_logger.info(
|
||||||
|
'Test: [{0:>4d}/{1}] '
|
||||||
|
'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
|
||||||
|
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
|
||||||
|
'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) '
|
||||||
|
'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
|
||||||
|
batch_idx, len(loader), batch_time=batch_time,
|
||||||
|
rate_avg=input.size(0) / batch_time.avg,
|
||||||
|
loss=losses, top1=top1, top5=top5))
|
||||||
|
|
||||||
|
if real_labels is not None:
|
||||||
|
# real labels mode replaces topk values at the end
|
||||||
|
top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5)
|
||||||
|
else:
|
||||||
|
top1a, top5a = top1.avg, top5.avg
|
||||||
|
results = OrderedDict(
|
||||||
|
top1=round(top1a, 4), top1_err=round(100 - top1a, 4),
|
||||||
|
top5=round(top5a, 4), top5_err=round(100 - top5a, 4),
|
||||||
|
param_count=round(param_count / 1e6, 2),
|
||||||
|
img_size=data_config['input_size'][-1],
|
||||||
|
cropt_pct=crop_pct,
|
||||||
|
interpolation=data_config['interpolation'])
|
||||||
|
|
||||||
|
_logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
|
||||||
|
results['top1'], results['top1_err'], results['top5'], results['top5_err']))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
setup_default_logging()
|
||||||
|
args = parser.parse_args()
|
||||||
|
model_cfgs = []
|
||||||
|
model_names = []
|
||||||
|
if os.path.isdir(args.checkpoint):
|
||||||
|
# validate all checkpoints in a path with same model
|
||||||
|
checkpoints = glob.glob(args.checkpoint + '/*.pth.tar')
|
||||||
|
checkpoints += glob.glob(args.checkpoint + '/*.pth')
|
||||||
|
model_names = list_models(args.model)
|
||||||
|
model_cfgs = [(args.model, c) for c in sorted(checkpoints, key=natural_key)]
|
||||||
|
else:
|
||||||
|
if args.model == 'all':
|
||||||
|
# validate all models in a list of names with pretrained checkpoints
|
||||||
|
args.pretrained = True
|
||||||
|
model_names = list_models(pretrained=True, exclude_filters=['*in21k'])
|
||||||
|
model_cfgs = [(n, '') for n in model_names]
|
||||||
|
elif not is_model(args.model):
|
||||||
|
# model name doesn't exist, try as wildcard filter
|
||||||
|
model_names = list_models(args.model)
|
||||||
|
model_cfgs = [(n, '') for n in model_names]
|
||||||
|
|
||||||
|
if len(model_cfgs):
|
||||||
|
results_file = args.results_file or './results-all.csv'
|
||||||
|
_logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names)))
|
||||||
|
results = []
|
||||||
|
try:
|
||||||
|
start_batch_size = args.batch_size
|
||||||
|
for m, c in model_cfgs:
|
||||||
|
batch_size = start_batch_size
|
||||||
|
args.model = m
|
||||||
|
args.checkpoint = c
|
||||||
|
result = OrderedDict(model=args.model)
|
||||||
|
r = {}
|
||||||
|
while not r and batch_size >= args.num_gpu:
|
||||||
|
# torch.cuda.empty_cache()
|
||||||
|
torch.empty_cache()
|
||||||
|
try:
|
||||||
|
args.batch_size = batch_size
|
||||||
|
print('Validating with batch size: %d' % args.batch_size)
|
||||||
|
r = validate(args)
|
||||||
|
except RuntimeError as e:
|
||||||
|
if batch_size <= args.num_gpu:
|
||||||
|
print("Validation failed with no ability to reduce batch size. Exiting.")
|
||||||
|
raise e
|
||||||
|
batch_size = max(batch_size // 2, args.num_gpu)
|
||||||
|
print("Validation failed, reducing batch size by 50%")
|
||||||
|
result.update(r)
|
||||||
|
if args.checkpoint:
|
||||||
|
result['checkpoint'] = args.checkpoint
|
||||||
|
results.append(result)
|
||||||
|
except KeyboardInterrupt as e:
|
||||||
|
pass
|
||||||
|
results = sorted(results, key=lambda x: x['top1'], reverse=True)
|
||||||
|
if len(results):
|
||||||
|
write_results(results_file, results)
|
||||||
|
else:
|
||||||
|
validate(args)
|
||||||
|
|
||||||
|
|
||||||
|
def write_results(results_file, results):
|
||||||
|
with open(results_file, mode='w') as cf:
|
||||||
|
dw = csv.DictWriter(cf, fieldnames=results[0].keys())
|
||||||
|
dw.writeheader()
|
||||||
|
for r in results:
|
||||||
|
dw.writerow(r)
|
||||||
|
cf.flush()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
@ -0,0 +1,252 @@
|
|||||||
|
""" Loader Factory, Fast Collate, CUDA Prefetcher
|
||||||
|
|
||||||
|
Prefetcher and Fast Collate inspired by NVIDIA APEX example at
|
||||||
|
https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#diff-cf86c282ff7fba81fad27a559379d5bf
|
||||||
|
|
||||||
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch.utils.data
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .transforms_factory import create_transform
|
||||||
|
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
|
from .distributed_sampler import OrderedDistributedSampler
|
||||||
|
from .random_erasing import RandomErasing
|
||||||
|
from .mixup import FastCollateMixup
|
||||||
|
|
||||||
|
|
||||||
|
def fast_collate(batch):
|
||||||
|
""" A fast collation function optimized for uint8 images (np array or torch) and int64 targets (labels)"""
|
||||||
|
assert isinstance(batch[0], tuple)
|
||||||
|
batch_size = len(batch)
|
||||||
|
if isinstance(batch[0][0], tuple):
|
||||||
|
# This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position
|
||||||
|
# such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position
|
||||||
|
inner_tuple_size = len(batch[0][0])
|
||||||
|
flattened_batch_size = batch_size * inner_tuple_size
|
||||||
|
targets = torch.zeros(flattened_batch_size, dtype=torch.int64)
|
||||||
|
tensor = torch.zeros((flattened_batch_size, *batch[0][0][0].shape), dtype=torch.uint8)
|
||||||
|
for i in range(batch_size):
|
||||||
|
assert len(batch[i][0]) == inner_tuple_size # all input tensor tuples must be same length
|
||||||
|
for j in range(inner_tuple_size):
|
||||||
|
targets[i + j * batch_size] = batch[i][1]
|
||||||
|
tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j])
|
||||||
|
return tensor, targets
|
||||||
|
elif isinstance(batch[0][0], np.ndarray):
|
||||||
|
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
|
||||||
|
assert len(targets) == batch_size
|
||||||
|
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
|
||||||
|
for i in range(batch_size):
|
||||||
|
tensor[i] += torch.from_numpy(batch[i][0])
|
||||||
|
return tensor, targets
|
||||||
|
elif isinstance(batch[0][0], torch.Tensor):
|
||||||
|
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
|
||||||
|
assert len(targets) == batch_size
|
||||||
|
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
|
||||||
|
for i in range(batch_size):
|
||||||
|
tensor[i].copy_(batch[i][0])
|
||||||
|
return tensor, targets
|
||||||
|
else:
|
||||||
|
assert False
|
||||||
|
|
||||||
|
|
||||||
|
class PrefetchLoader:
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
loader,
|
||||||
|
mean=IMAGENET_DEFAULT_MEAN,
|
||||||
|
std=IMAGENET_DEFAULT_STD,
|
||||||
|
fp16=False,
|
||||||
|
re_prob=0.,
|
||||||
|
re_mode='const',
|
||||||
|
re_count=1,
|
||||||
|
re_num_splits=0):
|
||||||
|
self.loader = loader
|
||||||
|
self.mean = torch.tensor([x * 255 for x in mean]).view(1, 3, 1, 1)
|
||||||
|
self.std = torch.tensor([x * 255 for x in std]).view(1, 3, 1, 1)
|
||||||
|
self.fp16 = fp16
|
||||||
|
if fp16:
|
||||||
|
self.mean = self.mean.half()
|
||||||
|
self.std = self.std.half()
|
||||||
|
if re_prob > 0.:
|
||||||
|
self.random_erasing = RandomErasing(
|
||||||
|
probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits)
|
||||||
|
else:
|
||||||
|
self.random_erasing = None
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
first = True
|
||||||
|
|
||||||
|
for next_input, next_target in self.loader:
|
||||||
|
if self.fp16:
|
||||||
|
next_input = next_input.half().sub_(self.mean).div_(self.std)
|
||||||
|
else:
|
||||||
|
next_input = next_input.float().sub_(self.mean).div_(self.std)
|
||||||
|
if self.random_erasing is not None:
|
||||||
|
next_input = self.random_erasing(next_input)
|
||||||
|
|
||||||
|
if not first:
|
||||||
|
yield input, target
|
||||||
|
else:
|
||||||
|
first = False
|
||||||
|
|
||||||
|
input = next_input
|
||||||
|
target = next_target
|
||||||
|
|
||||||
|
yield input, target
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.loader)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sampler(self):
|
||||||
|
return self.loader.sampler
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dataset(self):
|
||||||
|
return self.loader.dataset
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mixup_enabled(self):
|
||||||
|
if isinstance(self.loader.collate_fn, FastCollateMixup):
|
||||||
|
return self.loader.collate_fn.mixup_enabled
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@mixup_enabled.setter
|
||||||
|
def mixup_enabled(self, x):
|
||||||
|
if isinstance(self.loader.collate_fn, FastCollateMixup):
|
||||||
|
self.loader.collate_fn.mixup_enabled = x
|
||||||
|
|
||||||
|
|
||||||
|
def create_loader(
|
||||||
|
dataset,
|
||||||
|
input_size,
|
||||||
|
batch_size,
|
||||||
|
is_training=False,
|
||||||
|
use_prefetcher=True,
|
||||||
|
no_aug=False,
|
||||||
|
re_prob=0.,
|
||||||
|
re_mode='const',
|
||||||
|
re_count=1,
|
||||||
|
re_split=False,
|
||||||
|
scale=None,
|
||||||
|
ratio=None,
|
||||||
|
hflip=0.5,
|
||||||
|
vflip=0.,
|
||||||
|
color_jitter=0.4,
|
||||||
|
auto_augment=None,
|
||||||
|
num_aug_splits=0,
|
||||||
|
interpolation='bilinear',
|
||||||
|
mean=IMAGENET_DEFAULT_MEAN,
|
||||||
|
std=IMAGENET_DEFAULT_STD,
|
||||||
|
num_workers=1,
|
||||||
|
distributed=False,
|
||||||
|
crop_pct=None,
|
||||||
|
collate_fn=None,
|
||||||
|
pin_memory=False,
|
||||||
|
fp16=False,
|
||||||
|
tf_preprocessing=False,
|
||||||
|
use_multi_epochs_loader=False
|
||||||
|
):
|
||||||
|
re_num_splits = 0
|
||||||
|
if re_split:
|
||||||
|
# apply RE to second half of batch if no aug split otherwise line up with aug split
|
||||||
|
re_num_splits = num_aug_splits or 2
|
||||||
|
dataset.transform = create_transform(
|
||||||
|
input_size,
|
||||||
|
is_training=is_training,
|
||||||
|
use_prefetcher=use_prefetcher,
|
||||||
|
no_aug=no_aug,
|
||||||
|
scale=scale,
|
||||||
|
ratio=ratio,
|
||||||
|
hflip=hflip,
|
||||||
|
vflip=vflip,
|
||||||
|
color_jitter=color_jitter,
|
||||||
|
auto_augment=auto_augment,
|
||||||
|
interpolation=interpolation,
|
||||||
|
mean=mean,
|
||||||
|
std=std,
|
||||||
|
crop_pct=crop_pct,
|
||||||
|
tf_preprocessing=tf_preprocessing,
|
||||||
|
re_prob=re_prob,
|
||||||
|
re_mode=re_mode,
|
||||||
|
re_count=re_count,
|
||||||
|
re_num_splits=re_num_splits,
|
||||||
|
separate=num_aug_splits > 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
sampler = None
|
||||||
|
if distributed:
|
||||||
|
if is_training:
|
||||||
|
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
||||||
|
else:
|
||||||
|
# This will add extra duplicate entries to result in equal num
|
||||||
|
# of samples per-process, will slightly alter validation results
|
||||||
|
sampler = OrderedDistributedSampler(dataset)
|
||||||
|
|
||||||
|
if collate_fn is None:
|
||||||
|
collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate
|
||||||
|
|
||||||
|
loader_class = torch.utils.data.DataLoader
|
||||||
|
|
||||||
|
if use_multi_epochs_loader:
|
||||||
|
loader_class = MultiEpochsDataLoader
|
||||||
|
|
||||||
|
loader = loader_class(
|
||||||
|
dataset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=sampler is None and is_training,
|
||||||
|
num_workers=num_workers,
|
||||||
|
sampler=sampler,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
drop_last=is_training,
|
||||||
|
)
|
||||||
|
if use_prefetcher:
|
||||||
|
prefetch_re_prob = re_prob if is_training and not no_aug else 0.
|
||||||
|
loader = PrefetchLoader(
|
||||||
|
loader,
|
||||||
|
mean=mean,
|
||||||
|
std=std,
|
||||||
|
fp16=fp16,
|
||||||
|
re_prob=prefetch_re_prob,
|
||||||
|
re_mode=re_mode,
|
||||||
|
re_count=re_count,
|
||||||
|
re_num_splits=re_num_splits
|
||||||
|
)
|
||||||
|
|
||||||
|
return loader
|
||||||
|
|
||||||
|
|
||||||
|
class MultiEpochsDataLoader(torch.utils.data.DataLoader):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._DataLoader__initialized = False
|
||||||
|
self.batch_sampler = _RepeatSampler(self.batch_sampler)
|
||||||
|
self._DataLoader__initialized = True
|
||||||
|
self.iterator = super().__iter__()
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.batch_sampler.sampler)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
for i in range(len(self)):
|
||||||
|
yield next(self.iterator)
|
||||||
|
|
||||||
|
|
||||||
|
class _RepeatSampler(object):
|
||||||
|
""" Sampler that repeats forever.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sampler (Sampler)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, sampler):
|
||||||
|
self.sampler = sampler
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
while True:
|
||||||
|
yield from iter(self.sampler)
|
@ -0,0 +1,10 @@
|
|||||||
|
from .efficientnet import *
|
||||||
|
from .mobilenetv3 import *
|
||||||
|
from .rexnet import *
|
||||||
|
|
||||||
|
from timm.models.factory import create_model
|
||||||
|
from timm.models.helpers import load_checkpoint, resume_checkpoint
|
||||||
|
from .layers import TestTimePoolHead, apply_test_time_pool
|
||||||
|
from .layers import convert_splitbn_model
|
||||||
|
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit
|
||||||
|
from timm.models.registry import *
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,423 @@
|
|||||||
|
""" EfficientNet, MobileNetV3, etc Blocks
|
||||||
|
|
||||||
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from .layers import create_conv2d, drop_path, get_act_layer
|
||||||
|
from .layers.activations import sigmoid, HardSigmoid
|
||||||
|
|
||||||
|
# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per
|
||||||
|
# papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay)
|
||||||
|
# NOTE: momentum varies btw .99 and .9997 depending on source
|
||||||
|
# .99 in official TF TPU impl
|
||||||
|
# .9997 (/w .999 in search space) for paper
|
||||||
|
BN_MOMENTUM_TF_DEFAULT = 1 - 0.99
|
||||||
|
BN_EPS_TF_DEFAULT = 1e-3
|
||||||
|
_BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT)
|
||||||
|
|
||||||
|
|
||||||
|
def get_bn_args_tf():
|
||||||
|
return _BN_ARGS_TF.copy()
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_bn_args(kwargs):
|
||||||
|
bn_args = get_bn_args_tf() if kwargs.pop('bn_tf', False) else {}
|
||||||
|
bn_momentum = kwargs.pop('bn_momentum', None)
|
||||||
|
if bn_momentum is not None:
|
||||||
|
bn_args['momentum'] = bn_momentum
|
||||||
|
bn_eps = kwargs.pop('bn_eps', None)
|
||||||
|
if bn_eps is not None:
|
||||||
|
bn_args['eps'] = bn_eps
|
||||||
|
return bn_args
|
||||||
|
|
||||||
|
|
||||||
|
_SE_ARGS_DEFAULT = dict(
|
||||||
|
gate_fn=sigmoid,
|
||||||
|
act_layer=None,
|
||||||
|
reduce_mid=False,
|
||||||
|
divisor=1)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_se_args(kwargs, in_chs, act_layer=None):
|
||||||
|
se_kwargs = kwargs.copy() if kwargs is not None else {}
|
||||||
|
# fill in args that aren't specified with the defaults
|
||||||
|
for k, v in _SE_ARGS_DEFAULT.items():
|
||||||
|
se_kwargs.setdefault(k, v)
|
||||||
|
# some models, like MobilNetV3, calculate SE reduction chs from the containing block's mid_ch instead of in_ch
|
||||||
|
if not se_kwargs.pop('reduce_mid'):
|
||||||
|
se_kwargs['reduced_base_chs'] = in_chs
|
||||||
|
# act_layer override, if it remains None, the containing block's act_layer will be used
|
||||||
|
if se_kwargs['act_layer'] is None:
|
||||||
|
assert act_layer is not None
|
||||||
|
se_kwargs['act_layer'] = act_layer
|
||||||
|
return se_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_act_layer(kwargs, default='relu'):
|
||||||
|
act_layer = kwargs.pop('act_layer', default)
|
||||||
|
if isinstance(act_layer, str):
|
||||||
|
act_layer = get_act_layer(act_layer)
|
||||||
|
return act_layer
|
||||||
|
|
||||||
|
|
||||||
|
def make_divisible(v, divisor=8, min_value=None):
|
||||||
|
min_value = min_value or divisor
|
||||||
|
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||||
|
# Make sure that round down does not go down by more than 10%.
|
||||||
|
if new_v < 0.9 * v:
|
||||||
|
new_v += divisor
|
||||||
|
return new_v
|
||||||
|
|
||||||
|
|
||||||
|
def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None):
|
||||||
|
"""Round number of filters based on depth multiplier."""
|
||||||
|
if not multiplier:
|
||||||
|
return channels
|
||||||
|
channels *= multiplier
|
||||||
|
return make_divisible(channels, divisor, channel_min)
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelShuffle(nn.Module):
|
||||||
|
# FIXME haven't used yet
|
||||||
|
def __init__(self, groups):
|
||||||
|
super(ChannelShuffle, self).__init__()
|
||||||
|
self.groups = groups
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]"""
|
||||||
|
N, C, H, W = x.size()
|
||||||
|
g = self.groups
|
||||||
|
assert C % g == 0, "Incompatible group size {} for input channel {}".format(
|
||||||
|
g, C
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
x.view(N, g, int(C / g), H, W)
|
||||||
|
.permute(0, 2, 1, 3, 4)
|
||||||
|
.contiguous()
|
||||||
|
.view(N, C, H, W)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SqueezeExcite(nn.Module):
|
||||||
|
def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None,
|
||||||
|
act_layer=nn.ReLU, gate_fn=sigmoid, divisor=1, **_):
|
||||||
|
super(SqueezeExcite, self).__init__()
|
||||||
|
self.gate_fn = gate_fn
|
||||||
|
reduced_chs = make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor)
|
||||||
|
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||||
|
self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
|
||||||
|
self.act1 = act_layer(inplace=True)
|
||||||
|
self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)
|
||||||
|
self.quant_mul = nn.quantized.FloatFunctional()
|
||||||
|
if gate_fn == HardSigmoid:
|
||||||
|
self.gate_fn = HardSigmoid()
|
||||||
|
def forward(self, x):
|
||||||
|
x_se = self.avg_pool(x)
|
||||||
|
x_se = self.conv_reduce(x_se)
|
||||||
|
x_se = self.act1(x_se)
|
||||||
|
x_se = self.conv_expand(x_se)
|
||||||
|
x = self.quant_mul.mul(x, self.gate_fn(x_se))
|
||||||
|
return x
|
||||||
|
def fuse_module(self):
|
||||||
|
if type(self.act1) == nn.ReLU:
|
||||||
|
modules_to_fuse = ['conv_reduce','act1']
|
||||||
|
torch.quantization.fuse_modules(self, modules_to_fuse, inplace=True)
|
||||||
|
|
||||||
|
class ConvBnAct(nn.Module):
|
||||||
|
def __init__(self, in_chs, out_chs, kernel_size,
|
||||||
|
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU,
|
||||||
|
norm_layer=nn.BatchNorm2d, norm_kwargs=None):
|
||||||
|
super(ConvBnAct, self).__init__()
|
||||||
|
norm_kwargs = norm_kwargs or {}
|
||||||
|
self.conv = create_conv2d(in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, padding=pad_type)
|
||||||
|
self.bn1 = norm_layer(out_chs, **norm_kwargs)
|
||||||
|
self.act1 = act_layer(inplace=True)
|
||||||
|
|
||||||
|
def feature_info(self, location):
|
||||||
|
if location == 'expansion': # output of conv after act, same as block coutput
|
||||||
|
info = dict(module='act1', hook_type='forward', num_chs=self.conv.out_channels)
|
||||||
|
else: # location == 'bottleneck', block output
|
||||||
|
info = dict(module='', hook_type='', num_chs=self.conv.out_channels)
|
||||||
|
return info
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
x = self.bn1(x)
|
||||||
|
x = self.act1(x)
|
||||||
|
return x
|
||||||
|
def fuse_module(self):
|
||||||
|
modules_to_fuse = ['conv','bn1']
|
||||||
|
if type(self.act1) == nn.ReLU:
|
||||||
|
modules_to_fuse.append('act1')
|
||||||
|
torch.quantization.fuse_modules(self, modules_to_fuse, inplace=True)
|
||||||
|
|
||||||
|
class DepthwiseSeparableConv(nn.Module):
|
||||||
|
""" DepthwiseSeparable block
|
||||||
|
Used for DS convs in MobileNet-V1 and in the place of IR blocks that have no expansion
|
||||||
|
(factor of 1.0). This is an alternative to having a IR with an optional first pw conv.
|
||||||
|
"""
|
||||||
|
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
|
||||||
|
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False,
|
||||||
|
pw_kernel_size=1, pw_act=False, se_ratio=0., se_kwargs=None,
|
||||||
|
norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_path_rate=0.):
|
||||||
|
super(DepthwiseSeparableConv, self).__init__()
|
||||||
|
norm_kwargs = norm_kwargs or {}
|
||||||
|
has_se = se_ratio is not None and se_ratio > 0.
|
||||||
|
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
|
||||||
|
self.has_pw_act = pw_act # activation after point-wise conv
|
||||||
|
self.drop_path_rate = drop_path_rate
|
||||||
|
|
||||||
|
self.conv_dw = create_conv2d(
|
||||||
|
in_chs, in_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=pad_type, depthwise=True)
|
||||||
|
self.bn1 = norm_layer(in_chs, **norm_kwargs)
|
||||||
|
self.act1 = act_layer(inplace=True)
|
||||||
|
|
||||||
|
# Squeeze-and-excitation
|
||||||
|
if has_se:
|
||||||
|
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
|
||||||
|
self.se = SqueezeExcite(in_chs, se_ratio=se_ratio, **se_kwargs)
|
||||||
|
else:
|
||||||
|
self.se = None
|
||||||
|
|
||||||
|
self.conv_pw = create_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type)
|
||||||
|
self.bn2 = norm_layer(out_chs, **norm_kwargs)
|
||||||
|
self.act2 = act_layer(inplace=True) if self.has_pw_act else nn.Identity()
|
||||||
|
self.skip_add = nn.quantized.FloatFunctional()
|
||||||
|
|
||||||
|
def feature_info(self, location):
|
||||||
|
if location == 'expansion': # after SE, input to PW
|
||||||
|
info = dict(module='conv_pw', hook_type='forward_pre', num_chs=self.conv_pw.in_channels)
|
||||||
|
else: # location == 'bottleneck', block output
|
||||||
|
info = dict(module='', hook_type='', num_chs=self.conv_pw.out_channels)
|
||||||
|
return info
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
residual = x
|
||||||
|
|
||||||
|
x = self.conv_dw(x)
|
||||||
|
x = self.bn1(x)
|
||||||
|
x = self.act1(x)
|
||||||
|
|
||||||
|
if self.se is not None:
|
||||||
|
x = self.se(x)
|
||||||
|
|
||||||
|
x = self.conv_pw(x)
|
||||||
|
x = self.bn2(x)
|
||||||
|
x = self.act2(x)
|
||||||
|
|
||||||
|
if self.has_residual:
|
||||||
|
if self.drop_path_rate > 0.:
|
||||||
|
x = drop_path(x, self.drop_path_rate, self.training)
|
||||||
|
x = self.skip_add.add(x, residual)
|
||||||
|
return x
|
||||||
|
def fuse_module(self):
|
||||||
|
modules_to_fuse = [['conv_dw','bn1'],['conv_pw','bn2']]
|
||||||
|
if type(self.act1) == nn.ReLU:
|
||||||
|
modules_to_fuse[0].append('act1')
|
||||||
|
if type(self.act2) == nn.ReLU:
|
||||||
|
modules_to_fuse[1].append('act2')
|
||||||
|
torch.quantization.fuse_modules(self, modules_to_fuse, inplace=True)
|
||||||
|
|
||||||
|
class InvertedResidual(nn.Module):
|
||||||
|
""" Inverted residual block w/ optional SE and CondConv routing"""
|
||||||
|
|
||||||
|
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
|
||||||
|
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False,
|
||||||
|
exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1,
|
||||||
|
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
|
||||||
|
conv_kwargs=None, drop_path_rate=0.):
|
||||||
|
super(InvertedResidual, self).__init__()
|
||||||
|
norm_kwargs = norm_kwargs or {}
|
||||||
|
conv_kwargs = conv_kwargs or {}
|
||||||
|
mid_chs = make_divisible(in_chs * exp_ratio)
|
||||||
|
has_se = se_ratio is not None and se_ratio > 0.
|
||||||
|
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
|
||||||
|
self.drop_path_rate = drop_path_rate
|
||||||
|
|
||||||
|
# Point-wise expansion
|
||||||
|
self.conv_pw = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs)
|
||||||
|
self.bn1 = norm_layer(mid_chs, **norm_kwargs)
|
||||||
|
self.act1 = act_layer(inplace=True)
|
||||||
|
|
||||||
|
# Depth-wise convolution
|
||||||
|
self.conv_dw = create_conv2d(
|
||||||
|
mid_chs, mid_chs, dw_kernel_size, stride=stride, dilation=dilation,
|
||||||
|
padding=pad_type, depthwise=True, **conv_kwargs)
|
||||||
|
self.bn2 = norm_layer(mid_chs, **norm_kwargs)
|
||||||
|
self.act2 = act_layer(inplace=True)
|
||||||
|
|
||||||
|
# Squeeze-and-excitation
|
||||||
|
if has_se:
|
||||||
|
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
|
||||||
|
self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs)
|
||||||
|
else:
|
||||||
|
self.se = None
|
||||||
|
|
||||||
|
# Point-wise linear projection
|
||||||
|
self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs)
|
||||||
|
self.bn3 = norm_layer(out_chs, **norm_kwargs)
|
||||||
|
|
||||||
|
self.skip_add = nn.quantized.FloatFunctional()
|
||||||
|
def feature_info(self, location):
|
||||||
|
if location == 'expansion': # after SE, input to PWL
|
||||||
|
info = dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels)
|
||||||
|
else: # location == 'bottleneck', block output
|
||||||
|
info = dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels)
|
||||||
|
return info
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
residual = x
|
||||||
|
|
||||||
|
# Point-wise expansion
|
||||||
|
x = self.conv_pw(x)
|
||||||
|
x = self.bn1(x)
|
||||||
|
x = self.act1(x)
|
||||||
|
|
||||||
|
# Depth-wise convolution
|
||||||
|
x = self.conv_dw(x)
|
||||||
|
x = self.bn2(x)
|
||||||
|
x = self.act2(x)
|
||||||
|
|
||||||
|
# Squeeze-and-excitation
|
||||||
|
if self.se is not None:
|
||||||
|
x = self.se(x)
|
||||||
|
|
||||||
|
# Point-wise linear projection
|
||||||
|
x = self.conv_pwl(x)
|
||||||
|
x = self.bn3(x)
|
||||||
|
|
||||||
|
if self.has_residual:
|
||||||
|
if self.drop_path_rate > 0.:
|
||||||
|
x = drop_path(x, self.drop_path_rate, self.training)
|
||||||
|
x = self.skip_add.add(x, residual)
|
||||||
|
|
||||||
|
return x
|
||||||
|
def fuse_module(self):
|
||||||
|
modules_to_fuse = [['conv_pw','bn1'],['conv_dw','bn2'],['conv_pwl','bn3']]
|
||||||
|
if type(self.act1) == nn.ReLU:
|
||||||
|
modules_to_fuse[0].append('act1')
|
||||||
|
if type(self.act2) == nn.ReLU:
|
||||||
|
modules_to_fuse[1].append('act2')
|
||||||
|
|
||||||
|
torch.quantization.fuse_modules(self, modules_to_fuse, inplace=True)
|
||||||
|
|
||||||
|
class CondConvResidual(InvertedResidual):
|
||||||
|
""" Inverted residual block w/ CondConv routing"""
|
||||||
|
|
||||||
|
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
|
||||||
|
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False,
|
||||||
|
exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1,
|
||||||
|
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
|
||||||
|
num_experts=0, drop_path_rate=0.):
|
||||||
|
|
||||||
|
self.num_experts = num_experts
|
||||||
|
conv_kwargs = dict(num_experts=self.num_experts)
|
||||||
|
|
||||||
|
super(CondConvResidual, self).__init__(
|
||||||
|
in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, dilation=dilation, pad_type=pad_type,
|
||||||
|
act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size,
|
||||||
|
pw_kernel_size=pw_kernel_size, se_ratio=se_ratio, se_kwargs=se_kwargs,
|
||||||
|
norm_layer=norm_layer, norm_kwargs=norm_kwargs, conv_kwargs=conv_kwargs,
|
||||||
|
drop_path_rate=drop_path_rate)
|
||||||
|
|
||||||
|
self.routing_fn = nn.Linear(in_chs, self.num_experts)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
residual = x
|
||||||
|
|
||||||
|
# CondConv routing
|
||||||
|
pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1)
|
||||||
|
routing_weights = torch.sigmoid(self.routing_fn(pooled_inputs))
|
||||||
|
|
||||||
|
# Point-wise expansion
|
||||||
|
x = self.conv_pw(x, routing_weights)
|
||||||
|
x = self.bn1(x)
|
||||||
|
x = self.act1(x)
|
||||||
|
|
||||||
|
# Depth-wise convolution
|
||||||
|
x = self.conv_dw(x, routing_weights)
|
||||||
|
x = self.bn2(x)
|
||||||
|
x = self.act2(x)
|
||||||
|
|
||||||
|
# Squeeze-and-excitation
|
||||||
|
if self.se is not None:
|
||||||
|
x = self.se(x)
|
||||||
|
|
||||||
|
# Point-wise linear projection
|
||||||
|
x = self.conv_pwl(x, routing_weights)
|
||||||
|
x = self.bn3(x)
|
||||||
|
|
||||||
|
if self.has_residual:
|
||||||
|
if self.drop_path_rate > 0.:
|
||||||
|
x = drop_path(x, self.drop_path_rate, self.training)
|
||||||
|
x += residual
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class EdgeResidual(nn.Module):
|
||||||
|
""" Residual block with expansion convolution followed by pointwise-linear w/ stride"""
|
||||||
|
|
||||||
|
def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_chs=0,
|
||||||
|
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU, noskip=False, pw_kernel_size=1,
|
||||||
|
se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
|
||||||
|
drop_path_rate=0.):
|
||||||
|
super(EdgeResidual, self).__init__()
|
||||||
|
norm_kwargs = norm_kwargs or {}
|
||||||
|
if fake_in_chs > 0:
|
||||||
|
mid_chs = make_divisible(fake_in_chs * exp_ratio)
|
||||||
|
else:
|
||||||
|
mid_chs = make_divisible(in_chs * exp_ratio)
|
||||||
|
has_se = se_ratio is not None and se_ratio > 0.
|
||||||
|
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
|
||||||
|
self.drop_path_rate = drop_path_rate
|
||||||
|
|
||||||
|
# Expansion convolution
|
||||||
|
self.conv_exp = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type)
|
||||||
|
self.bn1 = norm_layer(mid_chs, **norm_kwargs)
|
||||||
|
self.act1 = act_layer(inplace=True)
|
||||||
|
|
||||||
|
# Squeeze-and-excitation
|
||||||
|
if has_se:
|
||||||
|
se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer)
|
||||||
|
self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs)
|
||||||
|
else:
|
||||||
|
self.se = None
|
||||||
|
|
||||||
|
# Point-wise linear projection
|
||||||
|
self.conv_pwl = create_conv2d(
|
||||||
|
mid_chs, out_chs, pw_kernel_size, stride=stride, dilation=dilation, padding=pad_type)
|
||||||
|
self.bn2 = norm_layer(out_chs, **norm_kwargs)
|
||||||
|
|
||||||
|
def feature_info(self, location):
|
||||||
|
if location == 'expansion': # after SE, before PWL
|
||||||
|
info = dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels)
|
||||||
|
else: # location == 'bottleneck', block output
|
||||||
|
info = dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels)
|
||||||
|
return info
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
residual = x
|
||||||
|
|
||||||
|
# Expansion convolution
|
||||||
|
x = self.conv_exp(x)
|
||||||
|
x = self.bn1(x)
|
||||||
|
x = self.act1(x)
|
||||||
|
|
||||||
|
# Squeeze-and-excitation
|
||||||
|
if self.se is not None:
|
||||||
|
x = self.se(x)
|
||||||
|
|
||||||
|
# Point-wise linear projection
|
||||||
|
x = self.conv_pwl(x)
|
||||||
|
x = self.bn2(x)
|
||||||
|
|
||||||
|
if self.has_residual:
|
||||||
|
if self.drop_path_rate > 0.:
|
||||||
|
x = drop_path(x, self.drop_path_rate, self.training)
|
||||||
|
x += residual
|
||||||
|
|
||||||
|
return x
|
@ -0,0 +1,414 @@
|
|||||||
|
""" EfficientNet, MobileNetV3, etc Builder
|
||||||
|
|
||||||
|
Assembles EfficieNet and related network feature blocks from string definitions.
|
||||||
|
Handles stride, dilation calculations, and selects feature extraction points.
|
||||||
|
|
||||||
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import re
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .efficientnet_blocks import *
|
||||||
|
from .layers import CondConv2d, get_condconv_initializer
|
||||||
|
|
||||||
|
__all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights"]
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _log_info_if(msg, condition):
|
||||||
|
if condition:
|
||||||
|
_logger.info(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_ksize(ss):
|
||||||
|
if ss.isdigit():
|
||||||
|
return int(ss)
|
||||||
|
else:
|
||||||
|
return [int(k) for k in ss.split('.')]
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_block_str(block_str):
|
||||||
|
""" Decode block definition string
|
||||||
|
|
||||||
|
Gets a list of block arg (dicts) through a string notation of arguments.
|
||||||
|
E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip
|
||||||
|
|
||||||
|
All args can exist in any order with the exception of the leading string which
|
||||||
|
is assumed to indicate the block type.
|
||||||
|
|
||||||
|
leading string - block type (
|
||||||
|
ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct)
|
||||||
|
r - number of repeat blocks,
|
||||||
|
k - kernel size,
|
||||||
|
s - strides (1-9),
|
||||||
|
e - expansion ratio,
|
||||||
|
c - output channels,
|
||||||
|
se - squeeze/excitation ratio
|
||||||
|
n - activation fn ('re', 'r6', 'hs', or 'sw')
|
||||||
|
Args:
|
||||||
|
block_str: a string representation of block arguments.
|
||||||
|
Returns:
|
||||||
|
A list of block args (dicts)
|
||||||
|
Raises:
|
||||||
|
ValueError: if the string def not properly specified (TODO)
|
||||||
|
"""
|
||||||
|
assert isinstance(block_str, str)
|
||||||
|
ops = block_str.split('_')
|
||||||
|
block_type = ops[0] # take the block type off the front
|
||||||
|
ops = ops[1:]
|
||||||
|
options = {}
|
||||||
|
noskip = False
|
||||||
|
for op in ops:
|
||||||
|
# string options being checked on individual basis, combine if they grow
|
||||||
|
if op == 'noskip':
|
||||||
|
noskip = True
|
||||||
|
elif op.startswith('n'):
|
||||||
|
# activation fn
|
||||||
|
key = op[0]
|
||||||
|
v = op[1:]
|
||||||
|
if v == 're':
|
||||||
|
value = get_act_layer('relu')
|
||||||
|
elif v == 'r6':
|
||||||
|
value = get_act_layer('relu6')
|
||||||
|
elif v == 'hs':
|
||||||
|
value = get_act_layer('hard_swish')
|
||||||
|
elif v == 'sw':
|
||||||
|
value = get_act_layer('swish')
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
options[key] = value
|
||||||
|
else:
|
||||||
|
# all numeric options
|
||||||
|
splits = re.split(r'(\d.*)', op)
|
||||||
|
if len(splits) >= 2:
|
||||||
|
key, value = splits[:2]
|
||||||
|
options[key] = value
|
||||||
|
|
||||||
|
# if act_layer is None, the model default (passed to model init) will be used
|
||||||
|
act_layer = options['n'] if 'n' in options else None
|
||||||
|
exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1
|
||||||
|
pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1
|
||||||
|
fake_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def
|
||||||
|
|
||||||
|
num_repeat = int(options['r'])
|
||||||
|
# each type of block has different valid arguments, fill accordingly
|
||||||
|
if block_type == 'ir':
|
||||||
|
block_args = dict(
|
||||||
|
block_type=block_type,
|
||||||
|
dw_kernel_size=_parse_ksize(options['k']),
|
||||||
|
exp_kernel_size=exp_kernel_size,
|
||||||
|
pw_kernel_size=pw_kernel_size,
|
||||||
|
out_chs=int(options['c']),
|
||||||
|
exp_ratio=float(options['e']),
|
||||||
|
se_ratio=float(options['se']) if 'se' in options else None,
|
||||||
|
stride=int(options['s']),
|
||||||
|
act_layer=act_layer,
|
||||||
|
noskip=noskip,
|
||||||
|
)
|
||||||
|
if 'cc' in options:
|
||||||
|
block_args['num_experts'] = int(options['cc'])
|
||||||
|
elif block_type == 'ds' or block_type == 'dsa':
|
||||||
|
block_args = dict(
|
||||||
|
block_type=block_type,
|
||||||
|
dw_kernel_size=_parse_ksize(options['k']),
|
||||||
|
pw_kernel_size=pw_kernel_size,
|
||||||
|
out_chs=int(options['c']),
|
||||||
|
se_ratio=float(options['se']) if 'se' in options else None,
|
||||||
|
stride=int(options['s']),
|
||||||
|
act_layer=act_layer,
|
||||||
|
pw_act=block_type == 'dsa',
|
||||||
|
noskip=block_type == 'dsa' or noskip,
|
||||||
|
)
|
||||||
|
elif block_type == 'er':
|
||||||
|
block_args = dict(
|
||||||
|
block_type=block_type,
|
||||||
|
exp_kernel_size=_parse_ksize(options['k']),
|
||||||
|
pw_kernel_size=pw_kernel_size,
|
||||||
|
out_chs=int(options['c']),
|
||||||
|
exp_ratio=float(options['e']),
|
||||||
|
fake_in_chs=fake_in_chs,
|
||||||
|
se_ratio=float(options['se']) if 'se' in options else None,
|
||||||
|
stride=int(options['s']),
|
||||||
|
act_layer=act_layer,
|
||||||
|
noskip=noskip,
|
||||||
|
)
|
||||||
|
elif block_type == 'cn':
|
||||||
|
block_args = dict(
|
||||||
|
block_type=block_type,
|
||||||
|
kernel_size=int(options['k']),
|
||||||
|
out_chs=int(options['c']),
|
||||||
|
stride=int(options['s']),
|
||||||
|
act_layer=act_layer,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert False, 'Unknown block type (%s)' % block_type
|
||||||
|
|
||||||
|
return block_args, num_repeat
|
||||||
|
|
||||||
|
|
||||||
|
def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'):
|
||||||
|
""" Per-stage depth scaling
|
||||||
|
Scales the block repeats in each stage. This depth scaling impl maintains
|
||||||
|
compatibility with the EfficientNet scaling method, while allowing sensible
|
||||||
|
scaling for other models that may have multiple block arg definitions in each stage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# We scale the total repeat count for each stage, there may be multiple
|
||||||
|
# block arg defs per stage so we need to sum.
|
||||||
|
num_repeat = sum(repeats)
|
||||||
|
if depth_trunc == 'round':
|
||||||
|
# Truncating to int by rounding allows stages with few repeats to remain
|
||||||
|
# proportionally smaller for longer. This is a good choice when stage definitions
|
||||||
|
# include single repeat stages that we'd prefer to keep that way as long as possible
|
||||||
|
num_repeat_scaled = max(1, round(num_repeat * depth_multiplier))
|
||||||
|
else:
|
||||||
|
# The default for EfficientNet truncates repeats to int via 'ceil'.
|
||||||
|
# Any multiplier > 1.0 will result in an increased depth for every stage.
|
||||||
|
num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier))
|
||||||
|
|
||||||
|
# Proportionally distribute repeat count scaling to each block definition in the stage.
|
||||||
|
# Allocation is done in reverse as it results in the first block being less likely to be scaled.
|
||||||
|
# The first block makes less sense to repeat in most of the arch definitions.
|
||||||
|
repeats_scaled = []
|
||||||
|
for r in repeats[::-1]:
|
||||||
|
rs = max(1, round((r / num_repeat * num_repeat_scaled)))
|
||||||
|
repeats_scaled.append(rs)
|
||||||
|
num_repeat -= r
|
||||||
|
num_repeat_scaled -= rs
|
||||||
|
repeats_scaled = repeats_scaled[::-1]
|
||||||
|
|
||||||
|
# Apply the calculated scaling to each block arg in the stage
|
||||||
|
sa_scaled = []
|
||||||
|
for ba, rep in zip(stack_args, repeats_scaled):
|
||||||
|
sa_scaled.extend([deepcopy(ba) for _ in range(rep)])
|
||||||
|
return sa_scaled
|
||||||
|
|
||||||
|
|
||||||
|
def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_multiplier=1, fix_first_last=False):
|
||||||
|
arch_args = []
|
||||||
|
for stack_idx, block_strings in enumerate(arch_def):
|
||||||
|
assert isinstance(block_strings, list)
|
||||||
|
stack_args = []
|
||||||
|
repeats = []
|
||||||
|
for block_str in block_strings:
|
||||||
|
assert isinstance(block_str, str)
|
||||||
|
ba, rep = _decode_block_str(block_str)
|
||||||
|
if ba.get('num_experts', 0) > 0 and experts_multiplier > 1:
|
||||||
|
ba['num_experts'] *= experts_multiplier
|
||||||
|
stack_args.append(ba)
|
||||||
|
repeats.append(rep)
|
||||||
|
if fix_first_last and (stack_idx == 0 or stack_idx == len(arch_def) - 1):
|
||||||
|
arch_args.append(_scale_stage_depth(stack_args, repeats, 1.0, depth_trunc))
|
||||||
|
else:
|
||||||
|
arch_args.append(_scale_stage_depth(stack_args, repeats, depth_multiplier, depth_trunc))
|
||||||
|
return arch_args
|
||||||
|
|
||||||
|
|
||||||
|
class EfficientNetBuilder:
|
||||||
|
""" Build Trunk Blocks
|
||||||
|
|
||||||
|
This ended up being somewhat of a cross between
|
||||||
|
https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py
|
||||||
|
and
|
||||||
|
https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
|
||||||
|
output_stride=32, pad_type='', act_layer=None, se_kwargs=None,
|
||||||
|
norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_path_rate=0., feature_location='',
|
||||||
|
verbose=False):
|
||||||
|
self.channel_multiplier = channel_multiplier
|
||||||
|
self.channel_divisor = channel_divisor
|
||||||
|
self.channel_min = channel_min
|
||||||
|
self.output_stride = output_stride
|
||||||
|
self.pad_type = pad_type
|
||||||
|
self.act_layer = act_layer
|
||||||
|
self.se_kwargs = se_kwargs
|
||||||
|
self.norm_layer = norm_layer
|
||||||
|
self.norm_kwargs = norm_kwargs
|
||||||
|
self.drop_path_rate = drop_path_rate
|
||||||
|
if feature_location == 'depthwise':
|
||||||
|
# old 'depthwise' mode renamed 'expansion' to match TF impl, old expansion mode didn't make sense
|
||||||
|
_logger.warning("feature_location=='depthwise' is deprecated, using 'expansion'")
|
||||||
|
feature_location = 'expansion'
|
||||||
|
self.feature_location = feature_location
|
||||||
|
assert feature_location in ('bottleneck', 'expansion', '')
|
||||||
|
self.verbose = verbose
|
||||||
|
|
||||||
|
# state updated during build, consumed by model
|
||||||
|
self.in_chs = None
|
||||||
|
self.features = []
|
||||||
|
|
||||||
|
def _round_channels(self, chs):
|
||||||
|
return round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min)
|
||||||
|
|
||||||
|
def _make_block(self, ba, block_idx, block_count):
|
||||||
|
drop_path_rate = self.drop_path_rate * block_idx / block_count
|
||||||
|
bt = ba.pop('block_type')
|
||||||
|
ba['in_chs'] = self.in_chs
|
||||||
|
ba['out_chs'] = self._round_channels(ba['out_chs'])
|
||||||
|
if 'fake_in_chs' in ba and ba['fake_in_chs']:
|
||||||
|
# FIXME this is a hack to work around mismatch in origin impl input filters
|
||||||
|
ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs'])
|
||||||
|
ba['norm_layer'] = self.norm_layer
|
||||||
|
ba['norm_kwargs'] = self.norm_kwargs
|
||||||
|
ba['pad_type'] = self.pad_type
|
||||||
|
# block act fn overrides the model default
|
||||||
|
ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer
|
||||||
|
assert ba['act_layer'] is not None
|
||||||
|
if bt == 'ir':
|
||||||
|
ba['drop_path_rate'] = drop_path_rate
|
||||||
|
ba['se_kwargs'] = self.se_kwargs
|
||||||
|
_log_info_if(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
|
||||||
|
if ba.get('num_experts', 0) > 0:
|
||||||
|
block = CondConvResidual(**ba)
|
||||||
|
else:
|
||||||
|
block = InvertedResidual(**ba)
|
||||||
|
elif bt == 'ds' or bt == 'dsa':
|
||||||
|
ba['drop_path_rate'] = drop_path_rate
|
||||||
|
ba['se_kwargs'] = self.se_kwargs
|
||||||
|
_log_info_if(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
|
||||||
|
block = DepthwiseSeparableConv(**ba)
|
||||||
|
elif bt == 'er':
|
||||||
|
ba['drop_path_rate'] = drop_path_rate
|
||||||
|
ba['se_kwargs'] = self.se_kwargs
|
||||||
|
_log_info_if(' EdgeResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
|
||||||
|
block = EdgeResidual(**ba)
|
||||||
|
elif bt == 'cn':
|
||||||
|
_log_info_if(' ConvBnAct {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
|
||||||
|
block = ConvBnAct(**ba)
|
||||||
|
else:
|
||||||
|
assert False, 'Uknkown block type (%s) while building model.' % bt
|
||||||
|
self.in_chs = ba['out_chs'] # update in_chs for arg of next block
|
||||||
|
|
||||||
|
return block
|
||||||
|
|
||||||
|
def __call__(self, in_chs, model_block_args):
|
||||||
|
""" Build the blocks
|
||||||
|
Args:
|
||||||
|
in_chs: Number of input-channels passed to first block
|
||||||
|
model_block_args: A list of lists, outer list defines stages, inner
|
||||||
|
list contains strings defining block configuration(s)
|
||||||
|
Return:
|
||||||
|
List of block stacks (each stack wrapped in nn.Sequential)
|
||||||
|
"""
|
||||||
|
_log_info_if('Building model trunk with %d stages...' % len(model_block_args), self.verbose)
|
||||||
|
self.in_chs = in_chs
|
||||||
|
total_block_count = sum([len(x) for x in model_block_args])
|
||||||
|
total_block_idx = 0
|
||||||
|
current_stride = 2
|
||||||
|
current_dilation = 1
|
||||||
|
stages = []
|
||||||
|
if model_block_args[0][0]['stride'] > 1:
|
||||||
|
# if the first block starts with a stride, we need to extract first level feat from stem
|
||||||
|
feature_info = dict(
|
||||||
|
module='act1', num_chs=in_chs, stage=0, reduction=current_stride,
|
||||||
|
hook_type='forward' if self.feature_location != 'bottleneck' else '')
|
||||||
|
self.features.append(feature_info)
|
||||||
|
|
||||||
|
# outer list of block_args defines the stacks
|
||||||
|
for stack_idx, stack_args in enumerate(model_block_args):
|
||||||
|
last_stack = stack_idx + 1 == len(model_block_args)
|
||||||
|
_log_info_if('Stack: {}'.format(stack_idx), self.verbose)
|
||||||
|
assert isinstance(stack_args, list)
|
||||||
|
|
||||||
|
blocks = []
|
||||||
|
# each stack (stage of blocks) contains a list of block arguments
|
||||||
|
for block_idx, block_args in enumerate(stack_args):
|
||||||
|
last_block = block_idx + 1 == len(stack_args)
|
||||||
|
_log_info_if(' Block: {}'.format(block_idx), self.verbose)
|
||||||
|
|
||||||
|
assert block_args['stride'] in (1, 2)
|
||||||
|
if block_idx >= 1: # only the first block in any stack can have a stride > 1
|
||||||
|
block_args['stride'] = 1
|
||||||
|
|
||||||
|
extract_features = False
|
||||||
|
if last_block:
|
||||||
|
next_stack_idx = stack_idx + 1
|
||||||
|
extract_features = next_stack_idx >= len(model_block_args) or \
|
||||||
|
model_block_args[next_stack_idx][0]['stride'] > 1
|
||||||
|
|
||||||
|
next_dilation = current_dilation
|
||||||
|
if block_args['stride'] > 1:
|
||||||
|
next_output_stride = current_stride * block_args['stride']
|
||||||
|
if next_output_stride > self.output_stride:
|
||||||
|
next_dilation = current_dilation * block_args['stride']
|
||||||
|
block_args['stride'] = 1
|
||||||
|
_log_info_if(' Converting stride to dilation to maintain output_stride=={}'.format(
|
||||||
|
self.output_stride), self.verbose)
|
||||||
|
else:
|
||||||
|
current_stride = next_output_stride
|
||||||
|
block_args['dilation'] = current_dilation
|
||||||
|
if next_dilation != current_dilation:
|
||||||
|
current_dilation = next_dilation
|
||||||
|
|
||||||
|
# create the block
|
||||||
|
block = self._make_block(block_args, total_block_idx, total_block_count)
|
||||||
|
blocks.append(block)
|
||||||
|
|
||||||
|
# stash feature module name and channel info for model feature extraction
|
||||||
|
if extract_features:
|
||||||
|
feature_info = dict(
|
||||||
|
stage=stack_idx + 1, reduction=current_stride, **block.feature_info(self.feature_location))
|
||||||
|
module_name = f'blocks.{stack_idx}.{block_idx}'
|
||||||
|
leaf_name = feature_info.get('module', '')
|
||||||
|
feature_info['module'] = '.'.join([module_name, leaf_name]) if leaf_name else module_name
|
||||||
|
self.features.append(feature_info)
|
||||||
|
|
||||||
|
total_block_idx += 1 # incr global block idx (across all stacks)
|
||||||
|
stages.append(nn.Sequential(*blocks))
|
||||||
|
return stages
|
||||||
|
|
||||||
|
|
||||||
|
def _init_weight_goog(m, n='', fix_group_fanout=True):
|
||||||
|
""" Weight initialization as per Tensorflow official implementations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
m (nn.Module): module to init
|
||||||
|
n (str): module name
|
||||||
|
fix_group_fanout (bool): enable correct (matching Tensorflow TPU impl) fanout calculation w/ group convs
|
||||||
|
|
||||||
|
Handles layers in EfficientNet, EfficientNet-CondConv, MixNet, MnasNet, MobileNetV3, etc:
|
||||||
|
* https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py
|
||||||
|
* https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
|
||||||
|
"""
|
||||||
|
if isinstance(m, CondConv2d):
|
||||||
|
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||||
|
if fix_group_fanout:
|
||||||
|
fan_out //= m.groups
|
||||||
|
init_weight_fn = get_condconv_initializer(
|
||||||
|
lambda w: w.data.normal_(0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape)
|
||||||
|
init_weight_fn(m.weight)
|
||||||
|
if m.bias is not None:
|
||||||
|
m.bias.data.zero_()
|
||||||
|
elif isinstance(m, nn.Conv2d):
|
||||||
|
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||||
|
if fix_group_fanout:
|
||||||
|
fan_out //= m.groups
|
||||||
|
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||||
|
if m.bias is not None:
|
||||||
|
m.bias.data.zero_()
|
||||||
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
|
m.weight.data.fill_(1.0)
|
||||||
|
m.bias.data.zero_()
|
||||||
|
elif isinstance(m, nn.Linear):
|
||||||
|
fan_out = m.weight.size(0) # fan-out
|
||||||
|
fan_in = 0
|
||||||
|
if 'routing_fn' in n:
|
||||||
|
fan_in = m.weight.size(1)
|
||||||
|
init_range = 1.0 / math.sqrt(fan_in + fan_out)
|
||||||
|
m.weight.data.uniform_(-init_range, init_range)
|
||||||
|
m.bias.data.zero_()
|
||||||
|
|
||||||
|
|
||||||
|
def efficientnet_init_weights(model: nn.Module, init_fn=None):
|
||||||
|
init_fn = init_fn or _init_weight_goog
|
||||||
|
for n, m in model.named_modules():
|
||||||
|
init_fn(m, n)
|
||||||
|
|
@ -0,0 +1,31 @@
|
|||||||
|
from .activations import *
|
||||||
|
from .adaptive_avgmax_pool import \
|
||||||
|
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
|
||||||
|
from .anti_aliasing import AntiAliasDownsampleLayer
|
||||||
|
from .blur_pool import BlurPool2d
|
||||||
|
from .classifier import ClassifierHead, create_classifier
|
||||||
|
from .cond_conv2d import CondConv2d, get_condconv_initializer
|
||||||
|
from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\
|
||||||
|
set_layer_config
|
||||||
|
from .conv2d_same import Conv2dSame
|
||||||
|
from .conv_bn_act import ConvBnAct
|
||||||
|
from .create_act import create_act_layer, get_act_layer
|
||||||
|
|
||||||
|
from .create_conv2d import create_conv2d
|
||||||
|
from .create_norm_act import create_norm_act, get_norm_act_layer
|
||||||
|
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
|
||||||
|
|
||||||
|
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
|
||||||
|
from .inplace_abn import InplaceAbn
|
||||||
|
from .mixed_conv2d import MixedConv2d
|
||||||
|
from .norm_act import BatchNormAct2d
|
||||||
|
from .padding import get_padding
|
||||||
|
from .pool2d_same import AvgPool2dSame, create_pool2d
|
||||||
|
|
||||||
|
from .selective_kernel import SelectiveKernelConv
|
||||||
|
from .separable_conv import SeparableConv2d, SeparableConvBnAct
|
||||||
|
from .space_to_depth import SpaceToDepthModule
|
||||||
|
from .split_attn import SplitAttnConv2d
|
||||||
|
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
|
||||||
|
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
|
||||||
|
from .weight_init import trunc_normal_
|
@ -0,0 +1,66 @@
|
|||||||
|
""" Activations
|
||||||
|
|
||||||
|
A collection of activations fn and modules with a common interface so that they can
|
||||||
|
easily be swapped. All have an `inplace` arg even if not used.
|
||||||
|
|
||||||
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn as nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class Swish(nn.Module):
|
||||||
|
def __init__(self, inplace: bool = False):
|
||||||
|
super(Swish, self).__init__()
|
||||||
|
nn.sig = nn.Sigmoid()
|
||||||
|
self.skip = nn.quantized.FloatFunctional()
|
||||||
|
def forward(self, x):
|
||||||
|
out = nn.sig(x)
|
||||||
|
|
||||||
|
return self.skip.mul(x,out)
|
||||||
|
|
||||||
|
def sigmoid(x, inplace: bool = False):
|
||||||
|
return x.sigmoid_() if inplace else x.sigmoid()
|
||||||
|
|
||||||
|
|
||||||
|
# PyTorch has this, but not with a consistent inplace argmument interface
|
||||||
|
class Sigmoid(nn.Module):
|
||||||
|
def __init__(self, inplace: bool = False):
|
||||||
|
super(Sigmoid, self).__init__()
|
||||||
|
nn.sig = nn.Sigmoid()
|
||||||
|
def forward(self, x):
|
||||||
|
out = nn.sig(x)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class HardSwish(nn.Module):
|
||||||
|
def __init__(self, inplace: bool = False):
|
||||||
|
super(HardSwish, self).__init__()
|
||||||
|
self.relu6 = nn.ReLU6(inplace)
|
||||||
|
self.quant_mul1 = nn.quantized.FloatFunctional()
|
||||||
|
self.quant_mul2 = nn.quantized.FloatFunctional()
|
||||||
|
self.quant_add = nn.quantized.FloatFunctional()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.quant_add.add_scalar(x, 3.0)
|
||||||
|
out = self.relu6(out)
|
||||||
|
out = self.quant_mul1.mul(x,out)
|
||||||
|
out = self.quant_mul2.mul_scalar(out, 1/6)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class HardSigmoid(nn.Module):
|
||||||
|
def __init__(self, inplace: bool = False):
|
||||||
|
super(HardSigmoid, self).__init__()
|
||||||
|
self.relu6 = nn.ReLU6(inplace)
|
||||||
|
self.quant_add = nn.quantized.FloatFunctional()
|
||||||
|
self.quant_mul = nn.quantized.FloatFunctional()
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.quant_add.add_scalar(x, 3.0)
|
||||||
|
out = self.relu6(out)
|
||||||
|
out = self.quant_mul.mul_scalar(out,1/6)
|
||||||
|
return out
|
||||||
|
|
@ -0,0 +1,90 @@
|
|||||||
|
""" Activations
|
||||||
|
|
||||||
|
A collection of jit-scripted activations fn and modules with a common interface so that they can
|
||||||
|
easily be swapped. All have an `inplace` arg even if not used.
|
||||||
|
|
||||||
|
All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not
|
||||||
|
currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted
|
||||||
|
versions if they contain in-place ops.
|
||||||
|
|
||||||
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn as nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def swish_jit(x, inplace: bool = False):
|
||||||
|
"""Swish - Described in: https://arxiv.org/abs/1710.05941
|
||||||
|
"""
|
||||||
|
return x.mul(x.sigmoid())
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def mish_jit(x, _inplace: bool = False):
|
||||||
|
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
|
||||||
|
"""
|
||||||
|
return x.mul(F.softplus(x).tanh())
|
||||||
|
|
||||||
|
|
||||||
|
class SwishJit(nn.Module):
|
||||||
|
def __init__(self, inplace: bool = False):
|
||||||
|
super(SwishJit, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return swish_jit(x)
|
||||||
|
|
||||||
|
|
||||||
|
class MishJit(nn.Module):
|
||||||
|
def __init__(self, inplace: bool = False):
|
||||||
|
super(MishJit, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return mish_jit(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def hard_sigmoid_jit(x, inplace: bool = False):
|
||||||
|
# return F.relu6(x + 3.) / 6.
|
||||||
|
return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster?
|
||||||
|
|
||||||
|
|
||||||
|
class HardSigmoidJit(nn.Module):
|
||||||
|
def __init__(self, inplace: bool = False):
|
||||||
|
super(HardSigmoidJit, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return hard_sigmoid_jit(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def hard_swish_jit(x, inplace: bool = False):
|
||||||
|
# return x * (F.relu6(x + 3.) / 6)
|
||||||
|
return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster?
|
||||||
|
|
||||||
|
|
||||||
|
class HardSwishJit(nn.Module):
|
||||||
|
def __init__(self, inplace: bool = False):
|
||||||
|
super(HardSwishJit, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return hard_swish_jit(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def hard_mish_jit(x, inplace: bool = False):
|
||||||
|
""" Hard Mish
|
||||||
|
Experimental, based on notes by Mish author Diganta Misra at
|
||||||
|
https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
|
||||||
|
"""
|
||||||
|
return 0.5 * x * (x + 2).clamp(min=0, max=2)
|
||||||
|
|
||||||
|
|
||||||
|
class HardMishJit(nn.Module):
|
||||||
|
def __init__(self, inplace: bool = False):
|
||||||
|
super(HardMishJit, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return hard_mish_jit(x)
|
@ -0,0 +1,208 @@
|
|||||||
|
""" Activations (memory-efficient w/ custom autograd)
|
||||||
|
|
||||||
|
A collection of activations fn and modules with a common interface so that they can
|
||||||
|
easily be swapped. All have an `inplace` arg even if not used.
|
||||||
|
|
||||||
|
These activations are not compatible with jit scripting or ONNX export of the model, please use either
|
||||||
|
the JIT or basic versions of the activations.
|
||||||
|
|
||||||
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn as nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def swish_jit_fwd(x):
|
||||||
|
return x.mul(torch.sigmoid(x))
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def swish_jit_bwd(x, grad_output):
|
||||||
|
x_sigmoid = torch.sigmoid(x)
|
||||||
|
return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))
|
||||||
|
|
||||||
|
|
||||||
|
class SwishJitAutoFn(torch.autograd.Function):
|
||||||
|
""" torch.jit.script optimised Swish w/ memory-efficient checkpoint
|
||||||
|
Inspired by conversation btw Jeremy Howard & Adam Pazske
|
||||||
|
https://twitter.com/jeremyphoward/status/1188251041835315200
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x):
|
||||||
|
ctx.save_for_backward(x)
|
||||||
|
return swish_jit_fwd(x)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
x = ctx.saved_tensors[0]
|
||||||
|
return swish_jit_bwd(x, grad_output)
|
||||||
|
|
||||||
|
|
||||||
|
def swish_me(x, inplace=False):
|
||||||
|
return SwishJitAutoFn.apply(x)
|
||||||
|
|
||||||
|
|
||||||
|
class SwishMe(nn.Module):
|
||||||
|
def __init__(self, inplace: bool = False):
|
||||||
|
super(SwishMe, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return SwishJitAutoFn.apply(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def mish_jit_fwd(x):
|
||||||
|
return x.mul(torch.tanh(F.softplus(x)))
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def mish_jit_bwd(x, grad_output):
|
||||||
|
x_sigmoid = torch.sigmoid(x)
|
||||||
|
x_tanh_sp = F.softplus(x).tanh()
|
||||||
|
return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))
|
||||||
|
|
||||||
|
|
||||||
|
class MishJitAutoFn(torch.autograd.Function):
|
||||||
|
""" Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
|
||||||
|
A memory efficient, jit scripted variant of Mish
|
||||||
|
"""
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x):
|
||||||
|
ctx.save_for_backward(x)
|
||||||
|
return mish_jit_fwd(x)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
x = ctx.saved_tensors[0]
|
||||||
|
return mish_jit_bwd(x, grad_output)
|
||||||
|
|
||||||
|
|
||||||
|
def mish_me(x, inplace=False):
|
||||||
|
return MishJitAutoFn.apply(x)
|
||||||
|
|
||||||
|
|
||||||
|
class MishMe(nn.Module):
|
||||||
|
def __init__(self, inplace: bool = False):
|
||||||
|
super(MishMe, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return MishJitAutoFn.apply(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def hard_sigmoid_jit_fwd(x, inplace: bool = False):
|
||||||
|
return (x + 3).clamp(min=0, max=6).div(6.)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def hard_sigmoid_jit_bwd(x, grad_output):
|
||||||
|
m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6.
|
||||||
|
return grad_output * m
|
||||||
|
|
||||||
|
|
||||||
|
class HardSigmoidJitAutoFn(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x):
|
||||||
|
ctx.save_for_backward(x)
|
||||||
|
return hard_sigmoid_jit_fwd(x)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
x = ctx.saved_tensors[0]
|
||||||
|
return hard_sigmoid_jit_bwd(x, grad_output)
|
||||||
|
|
||||||
|
|
||||||
|
def hard_sigmoid_me(x, inplace: bool = False):
|
||||||
|
return HardSigmoidJitAutoFn.apply(x)
|
||||||
|
|
||||||
|
|
||||||
|
class HardSigmoidMe(nn.Module):
|
||||||
|
def __init__(self, inplace: bool = False):
|
||||||
|
super(HardSigmoidMe, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return HardSigmoidJitAutoFn.apply(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def hard_swish_jit_fwd(x):
|
||||||
|
return x * (x + 3).clamp(min=0, max=6).div(6.)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def hard_swish_jit_bwd(x, grad_output):
|
||||||
|
m = torch.ones_like(x) * (x >= 3.)
|
||||||
|
m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m)
|
||||||
|
return grad_output * m
|
||||||
|
|
||||||
|
|
||||||
|
class HardSwishJitAutoFn(torch.autograd.Function):
|
||||||
|
"""A memory efficient, jit-scripted HardSwish activation"""
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x):
|
||||||
|
ctx.save_for_backward(x)
|
||||||
|
return hard_swish_jit_fwd(x)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
x = ctx.saved_tensors[0]
|
||||||
|
return hard_swish_jit_bwd(x, grad_output)
|
||||||
|
|
||||||
|
|
||||||
|
def hard_swish_me(x, inplace=False):
|
||||||
|
return HardSwishJitAutoFn.apply(x)
|
||||||
|
|
||||||
|
|
||||||
|
class HardSwishMe(nn.Module):
|
||||||
|
def __init__(self, inplace: bool = False):
|
||||||
|
super(HardSwishMe, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return HardSwishJitAutoFn.apply(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def hard_mish_jit_fwd(x):
|
||||||
|
return 0.5 * x * (x + 2).clamp(min=0, max=2)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def hard_mish_jit_bwd(x, grad_output):
|
||||||
|
m = torch.ones_like(x) * (x >= -2.)
|
||||||
|
m = torch.where((x >= -2.) & (x <= 0.), x + 1., m)
|
||||||
|
return grad_output * m
|
||||||
|
|
||||||
|
|
||||||
|
class HardMishJitAutoFn(torch.autograd.Function):
|
||||||
|
""" A memory efficient, jit scripted variant of Hard Mish
|
||||||
|
Experimental, based on notes by Mish author Diganta Misra at
|
||||||
|
https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
|
||||||
|
"""
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x):
|
||||||
|
ctx.save_for_backward(x)
|
||||||
|
return hard_mish_jit_fwd(x)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
x = ctx.saved_tensors[0]
|
||||||
|
return hard_mish_jit_bwd(x, grad_output)
|
||||||
|
|
||||||
|
|
||||||
|
def hard_mish_me(x, inplace: bool = False):
|
||||||
|
return HardMishJitAutoFn.apply(x)
|
||||||
|
|
||||||
|
|
||||||
|
class HardMishMe(nn.Module):
|
||||||
|
def __init__(self, inplace: bool = False):
|
||||||
|
super(HardMishMe, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return HardMishJitAutoFn.apply(x)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -0,0 +1,106 @@
|
|||||||
|
""" PyTorch selectable adaptive pooling
|
||||||
|
Adaptive pooling with the ability to select the type of pooling from:
|
||||||
|
* 'avg' - Average pooling
|
||||||
|
* 'max' - Max pooling
|
||||||
|
* 'avgmax' - Sum of average and max pooling re-scaled by 0.5
|
||||||
|
* 'avgmaxc' - Concatenation of average and max pooling along feature dim, doubles feature dim
|
||||||
|
|
||||||
|
Both a functional and a nn.Module version of the pooling is provided.
|
||||||
|
|
||||||
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def adaptive_pool_feat_mult(pool_type='avg'):
|
||||||
|
if pool_type == 'catavgmax':
|
||||||
|
return 2
|
||||||
|
else:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
def adaptive_avgmax_pool2d(x, output_size=1):
|
||||||
|
x_avg = F.adaptive_avg_pool2d(x, output_size)
|
||||||
|
x_max = F.adaptive_max_pool2d(x, output_size)
|
||||||
|
return 0.5 * (x_avg + x_max)
|
||||||
|
|
||||||
|
|
||||||
|
def adaptive_catavgmax_pool2d(x, output_size=1):
|
||||||
|
x_avg = F.adaptive_avg_pool2d(x, output_size)
|
||||||
|
x_max = F.adaptive_max_pool2d(x, output_size)
|
||||||
|
return torch.cat((x_avg, x_max), 1)
|
||||||
|
|
||||||
|
|
||||||
|
def select_adaptive_pool2d(x, pool_type='avg', output_size=1):
|
||||||
|
"""Selectable global pooling function with dynamic input kernel size
|
||||||
|
"""
|
||||||
|
if pool_type == 'avg':
|
||||||
|
x = F.adaptive_avg_pool2d(x, output_size)
|
||||||
|
elif pool_type == 'avgmax':
|
||||||
|
x = adaptive_avgmax_pool2d(x, output_size)
|
||||||
|
elif pool_type == 'catavgmax':
|
||||||
|
x = adaptive_catavgmax_pool2d(x, output_size)
|
||||||
|
elif pool_type == 'max':
|
||||||
|
x = F.adaptive_max_pool2d(x, output_size)
|
||||||
|
else:
|
||||||
|
assert False, 'Invalid pool type: %s' % pool_type
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AdaptiveAvgMaxPool2d(nn.Module):
|
||||||
|
def __init__(self, output_size=1):
|
||||||
|
super(AdaptiveAvgMaxPool2d, self).__init__()
|
||||||
|
self.output_size = output_size
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return adaptive_avgmax_pool2d(x, self.output_size)
|
||||||
|
|
||||||
|
|
||||||
|
class AdaptiveCatAvgMaxPool2d(nn.Module):
|
||||||
|
def __init__(self, output_size=1):
|
||||||
|
super(AdaptiveCatAvgMaxPool2d, self).__init__()
|
||||||
|
self.output_size = output_size
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return adaptive_catavgmax_pool2d(x, self.output_size)
|
||||||
|
|
||||||
|
|
||||||
|
class SelectAdaptivePool2d(nn.Module):
|
||||||
|
"""Selectable global pooling layer with dynamic input kernel size
|
||||||
|
"""
|
||||||
|
def __init__(self, output_size=1, pool_type='avg', flatten=False):
|
||||||
|
super(SelectAdaptivePool2d, self).__init__()
|
||||||
|
self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing
|
||||||
|
self.flatten = flatten
|
||||||
|
if pool_type == '':
|
||||||
|
self.pool = nn.Identity() # pass through
|
||||||
|
elif pool_type == 'avg':
|
||||||
|
self.pool = nn.AdaptiveAvgPool2d(output_size)
|
||||||
|
elif pool_type == 'avgmax':
|
||||||
|
self.pool = AdaptiveAvgMaxPool2d(output_size)
|
||||||
|
elif pool_type == 'catavgmax':
|
||||||
|
self.pool = AdaptiveCatAvgMaxPool2d(output_size)
|
||||||
|
elif pool_type == 'max':
|
||||||
|
self.pool = nn.AdaptiveMaxPool2d(output_size)
|
||||||
|
else:
|
||||||
|
assert False, 'Invalid pool type: %s' % pool_type
|
||||||
|
|
||||||
|
def is_identity(self):
|
||||||
|
return self.pool_type == ''
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.pool(x)
|
||||||
|
if self.flatten:
|
||||||
|
x = x.flatten(1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def feat_mult(self):
|
||||||
|
return adaptive_pool_feat_mult(self.pool_type)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return self.__class__.__name__ + ' (' \
|
||||||
|
+ 'pool_type=' + self.pool_type \
|
||||||
|
+ ', flatten=' + str(self.flatten) + ')'
|
||||||
|
|
@ -0,0 +1,60 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.parallel
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class AntiAliasDownsampleLayer(nn.Module):
|
||||||
|
def __init__(self, channels: int = 0, filt_size: int = 3, stride: int = 2, no_jit: bool = False):
|
||||||
|
super(AntiAliasDownsampleLayer, self).__init__()
|
||||||
|
if no_jit:
|
||||||
|
self.op = Downsample(channels, filt_size, stride)
|
||||||
|
else:
|
||||||
|
self.op = DownsampleJIT(channels, filt_size, stride)
|
||||||
|
|
||||||
|
# FIXME I should probably override _apply and clear DownsampleJIT filter cache for .cuda(), .half(), etc calls
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.op(x)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
class DownsampleJIT(object):
|
||||||
|
def __init__(self, channels: int = 0, filt_size: int = 3, stride: int = 2):
|
||||||
|
self.channels = channels
|
||||||
|
self.stride = stride
|
||||||
|
self.filt_size = filt_size
|
||||||
|
assert self.filt_size == 3
|
||||||
|
assert stride == 2
|
||||||
|
self.filt = {} # lazy init by device for DataParallel compat
|
||||||
|
|
||||||
|
def _create_filter(self, like: torch.Tensor):
|
||||||
|
filt = torch.tensor([1., 2., 1.], dtype=like.dtype, device=like.device)
|
||||||
|
filt = filt[:, None] * filt[None, :]
|
||||||
|
filt = filt / torch.sum(filt)
|
||||||
|
return filt[None, None, :, :].repeat((self.channels, 1, 1, 1))
|
||||||
|
|
||||||
|
def __call__(self, input: torch.Tensor):
|
||||||
|
input_pad = F.pad(input, (1, 1, 1, 1), 'reflect')
|
||||||
|
filt = self.filt.get(str(input.device), self._create_filter(input))
|
||||||
|
return F.conv2d(input_pad, filt, stride=2, padding=0, groups=input.shape[1])
|
||||||
|
|
||||||
|
|
||||||
|
class Downsample(nn.Module):
|
||||||
|
def __init__(self, channels=None, filt_size=3, stride=2):
|
||||||
|
super(Downsample, self).__init__()
|
||||||
|
self.channels = channels
|
||||||
|
self.filt_size = filt_size
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
assert self.filt_size == 3
|
||||||
|
filt = torch.tensor([1., 2., 1.])
|
||||||
|
filt = filt[:, None] * filt[None, :]
|
||||||
|
filt = filt / torch.sum(filt)
|
||||||
|
|
||||||
|
# self.filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1))
|
||||||
|
self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1)))
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
input_pad = F.pad(input, (1, 1, 1, 1), 'reflect')
|
||||||
|
return F.conv2d(input_pad, self.filt, stride=self.stride, padding=0, groups=input.shape[1])
|
@ -0,0 +1,58 @@
|
|||||||
|
"""
|
||||||
|
BlurPool layer inspired by
|
||||||
|
- Kornia's Max_BlurPool2d
|
||||||
|
- Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar`
|
||||||
|
|
||||||
|
FIXME merge this impl with those in `anti_aliasing.py`
|
||||||
|
|
||||||
|
Hacked together by Chris Ha and Ross Wightman
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import numpy as np
|
||||||
|
from typing import Dict
|
||||||
|
from .padding import get_padding
|
||||||
|
|
||||||
|
|
||||||
|
class BlurPool2d(nn.Module):
|
||||||
|
r"""Creates a module that computes blurs and downsample a given feature map.
|
||||||
|
See :cite:`zhang2019shiftinvar` for more details.
|
||||||
|
Corresponds to the Downsample class, which does blurring and subsampling
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channels = Number of input channels
|
||||||
|
filt_size (int): binomial filter size for blurring. currently supports 3 (default) and 5.
|
||||||
|
stride (int): downsampling filter stride
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: the transformed tensor.
|
||||||
|
"""
|
||||||
|
filt: Dict[str, torch.Tensor]
|
||||||
|
|
||||||
|
def __init__(self, channels, filt_size=3, stride=2) -> None:
|
||||||
|
super(BlurPool2d, self).__init__()
|
||||||
|
assert filt_size > 1
|
||||||
|
self.channels = channels
|
||||||
|
self.filt_size = filt_size
|
||||||
|
self.stride = stride
|
||||||
|
pad_size = [get_padding(filt_size, stride, dilation=1)] * 4
|
||||||
|
self.padding = nn.ReflectionPad2d(pad_size)
|
||||||
|
self._coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs) # for torchscript compat
|
||||||
|
self.filt = {} # lazy init by device for DataParallel compat
|
||||||
|
|
||||||
|
def _create_filter(self, like: torch.Tensor):
|
||||||
|
blur_filter = (self._coeffs[:, None] * self._coeffs[None, :]).to(dtype=like.dtype, device=like.device)
|
||||||
|
return blur_filter[None, None, :, :].repeat(self.channels, 1, 1, 1)
|
||||||
|
|
||||||
|
def _apply(self, fn):
|
||||||
|
# override nn.Module _apply, reset filter cache if used
|
||||||
|
self.filt = {}
|
||||||
|
super(BlurPool2d, self)._apply(fn)
|
||||||
|
|
||||||
|
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
C = input_tensor.shape[1]
|
||||||
|
blur_filt = self.filt.get(str(input_tensor.device), self._create_filter(input_tensor))
|
||||||
|
return F.conv2d(
|
||||||
|
self.padding(input_tensor), blur_filt, stride=self.stride, groups=C)
|
@ -0,0 +1,41 @@
|
|||||||
|
""" Classifier head and layer factory
|
||||||
|
|
||||||
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
|
"""
|
||||||
|
from torch import nn as nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from .adaptive_avgmax_pool import SelectAdaptivePool2d
|
||||||
|
|
||||||
|
|
||||||
|
def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False):
|
||||||
|
flatten = not use_conv # flatten when we use a Linear layer after pooling
|
||||||
|
if not pool_type:
|
||||||
|
assert num_classes == 0 or use_conv,\
|
||||||
|
'Pooling can only be disabled if classifier is also removed or conv classifier is used'
|
||||||
|
flatten = False # disable flattening if pooling is pass-through (no pooling)
|
||||||
|
global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten)
|
||||||
|
num_pooled_features = num_features * global_pool.feat_mult()
|
||||||
|
if num_classes <= 0:
|
||||||
|
fc = nn.Identity() # pass-through (no classifier)
|
||||||
|
elif use_conv:
|
||||||
|
fc = nn.Conv2d(num_pooled_features, num_classes, 1, bias=True)
|
||||||
|
else:
|
||||||
|
fc = nn.Linear(num_pooled_features, num_classes, bias=True)
|
||||||
|
return global_pool, fc
|
||||||
|
|
||||||
|
|
||||||
|
class ClassifierHead(nn.Module):
|
||||||
|
"""Classifier head w/ configurable global pooling and dropout."""
|
||||||
|
|
||||||
|
def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0.):
|
||||||
|
super(ClassifierHead, self).__init__()
|
||||||
|
self.drop_rate = drop_rate
|
||||||
|
self.global_pool, self.fc = create_classifier(in_chs, num_classes, pool_type=pool_type)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.global_pool(x)
|
||||||
|
if self.drop_rate:
|
||||||
|
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
|
||||||
|
x = self.fc(x)
|
||||||
|
return x
|
@ -0,0 +1,122 @@
|
|||||||
|
""" PyTorch Conditionally Parameterized Convolution (CondConv)
|
||||||
|
|
||||||
|
Paper: CondConv: Conditionally Parameterized Convolutions for Efficient Inference
|
||||||
|
(https://arxiv.org/abs/1904.04971)
|
||||||
|
|
||||||
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from functools import partial
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import nn as nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from .helpers import tup_pair
|
||||||
|
from .conv2d_same import conv2d_same
|
||||||
|
from .padding import get_padding_value
|
||||||
|
|
||||||
|
|
||||||
|
def get_condconv_initializer(initializer, num_experts, expert_shape):
|
||||||
|
def condconv_initializer(weight):
|
||||||
|
"""CondConv initializer function."""
|
||||||
|
num_params = np.prod(expert_shape)
|
||||||
|
if (len(weight.shape) != 2 or weight.shape[0] != num_experts or
|
||||||
|
weight.shape[1] != num_params):
|
||||||
|
raise (ValueError(
|
||||||
|
'CondConv variables must have shape [num_experts, num_params]'))
|
||||||
|
for i in range(num_experts):
|
||||||
|
initializer(weight[i].view(expert_shape))
|
||||||
|
return condconv_initializer
|
||||||
|
|
||||||
|
|
||||||
|
class CondConv2d(nn.Module):
|
||||||
|
""" Conditionally Parameterized Convolution
|
||||||
|
Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py
|
||||||
|
|
||||||
|
Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion:
|
||||||
|
https://github.com/pytorch/pytorch/issues/17983
|
||||||
|
"""
|
||||||
|
__constants__ = ['in_channels', 'out_channels', 'dynamic_padding']
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size=3,
|
||||||
|
stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4):
|
||||||
|
super(CondConv2d, self).__init__()
|
||||||
|
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.kernel_size = tup_pair(kernel_size)
|
||||||
|
self.stride = tup_pair(stride)
|
||||||
|
padding_val, is_padding_dynamic = get_padding_value(
|
||||||
|
padding, kernel_size, stride=stride, dilation=dilation)
|
||||||
|
self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript
|
||||||
|
self.padding = tup_pair(padding_val)
|
||||||
|
self.dilation = tup_pair(dilation)
|
||||||
|
self.groups = groups
|
||||||
|
self.num_experts = num_experts
|
||||||
|
|
||||||
|
self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size
|
||||||
|
weight_num_param = 1
|
||||||
|
for wd in self.weight_shape:
|
||||||
|
weight_num_param *= wd
|
||||||
|
self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param))
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
self.bias_shape = (self.out_channels,)
|
||||||
|
self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels))
|
||||||
|
else:
|
||||||
|
self.register_parameter('bias', None)
|
||||||
|
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
init_weight = get_condconv_initializer(
|
||||||
|
partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape)
|
||||||
|
init_weight(self.weight)
|
||||||
|
if self.bias is not None:
|
||||||
|
fan_in = np.prod(self.weight_shape[1:])
|
||||||
|
bound = 1 / math.sqrt(fan_in)
|
||||||
|
init_bias = get_condconv_initializer(
|
||||||
|
partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape)
|
||||||
|
init_bias(self.bias)
|
||||||
|
|
||||||
|
def forward(self, x, routing_weights):
|
||||||
|
B, C, H, W = x.shape
|
||||||
|
weight = torch.matmul(routing_weights, self.weight)
|
||||||
|
new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size
|
||||||
|
weight = weight.view(new_weight_shape)
|
||||||
|
bias = None
|
||||||
|
if self.bias is not None:
|
||||||
|
bias = torch.matmul(routing_weights, self.bias)
|
||||||
|
bias = bias.view(B * self.out_channels)
|
||||||
|
# move batch elements with channels so each batch element can be efficiently convolved with separate kernel
|
||||||
|
x = x.view(1, B * C, H, W)
|
||||||
|
if self.dynamic_padding:
|
||||||
|
out = conv2d_same(
|
||||||
|
x, weight, bias, stride=self.stride, padding=self.padding,
|
||||||
|
dilation=self.dilation, groups=self.groups * B)
|
||||||
|
else:
|
||||||
|
out = F.conv2d(
|
||||||
|
x, weight, bias, stride=self.stride, padding=self.padding,
|
||||||
|
dilation=self.dilation, groups=self.groups * B)
|
||||||
|
out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1])
|
||||||
|
|
||||||
|
# Literal port (from TF definition)
|
||||||
|
# x = torch.split(x, 1, 0)
|
||||||
|
# weight = torch.split(weight, 1, 0)
|
||||||
|
# if self.bias is not None:
|
||||||
|
# bias = torch.matmul(routing_weights, self.bias)
|
||||||
|
# bias = torch.split(bias, 1, 0)
|
||||||
|
# else:
|
||||||
|
# bias = [None] * B
|
||||||
|
# out = []
|
||||||
|
# for xi, wi, bi in zip(x, weight, bias):
|
||||||
|
# wi = wi.view(*self.weight_shape)
|
||||||
|
# if bi is not None:
|
||||||
|
# bi = bi.view(*self.bias_shape)
|
||||||
|
# out.append(self.conv_fn(
|
||||||
|
# xi, wi, bi, stride=self.stride, padding=self.padding,
|
||||||
|
# dilation=self.dilation, groups=self.groups))
|
||||||
|
# out = torch.cat(out, 0)
|
||||||
|
return out
|
@ -0,0 +1,115 @@
|
|||||||
|
""" Model / Layer Config singleton state
|
||||||
|
"""
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'is_exportable', 'is_scriptable', 'is_no_jit',
|
||||||
|
'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config'
|
||||||
|
]
|
||||||
|
|
||||||
|
# Set to True if prefer to have layers with no jit optimization (includes activations)
|
||||||
|
_NO_JIT = False
|
||||||
|
|
||||||
|
# Set to True if prefer to have activation layers with no jit optimization
|
||||||
|
# NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying
|
||||||
|
# the jit flags so far are activations. This will change as more layers are updated and/or added.
|
||||||
|
_NO_ACTIVATION_JIT = False
|
||||||
|
|
||||||
|
# Set to True if exporting a model with Same padding via ONNX
|
||||||
|
_EXPORTABLE = False
|
||||||
|
|
||||||
|
# Set to True if wanting to use torch.jit.script on a model
|
||||||
|
_SCRIPTABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
def is_no_jit():
|
||||||
|
return _NO_JIT
|
||||||
|
|
||||||
|
|
||||||
|
class set_no_jit:
|
||||||
|
def __init__(self, mode: bool) -> None:
|
||||||
|
global _NO_JIT
|
||||||
|
self.prev = _NO_JIT
|
||||||
|
_NO_JIT = mode
|
||||||
|
|
||||||
|
def __enter__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __exit__(self, *args: Any) -> bool:
|
||||||
|
global _NO_JIT
|
||||||
|
_NO_JIT = self.prev
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_exportable():
|
||||||
|
return _EXPORTABLE
|
||||||
|
|
||||||
|
|
||||||
|
class set_exportable:
|
||||||
|
def __init__(self, mode: bool) -> None:
|
||||||
|
global _EXPORTABLE
|
||||||
|
self.prev = _EXPORTABLE
|
||||||
|
_EXPORTABLE = mode
|
||||||
|
|
||||||
|
def __enter__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __exit__(self, *args: Any) -> bool:
|
||||||
|
global _EXPORTABLE
|
||||||
|
_EXPORTABLE = self.prev
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_scriptable():
|
||||||
|
return _SCRIPTABLE
|
||||||
|
|
||||||
|
|
||||||
|
class set_scriptable:
|
||||||
|
def __init__(self, mode: bool) -> None:
|
||||||
|
global _SCRIPTABLE
|
||||||
|
self.prev = _SCRIPTABLE
|
||||||
|
_SCRIPTABLE = mode
|
||||||
|
|
||||||
|
def __enter__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __exit__(self, *args: Any) -> bool:
|
||||||
|
global _SCRIPTABLE
|
||||||
|
_SCRIPTABLE = self.prev
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class set_layer_config:
|
||||||
|
""" Layer config context manager that allows setting all layer config flags at once.
|
||||||
|
If a flag arg is None, it will not change the current value.
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
scriptable: Optional[bool] = None,
|
||||||
|
exportable: Optional[bool] = None,
|
||||||
|
no_jit: Optional[bool] = None,
|
||||||
|
no_activation_jit: Optional[bool] = None):
|
||||||
|
global _SCRIPTABLE
|
||||||
|
global _EXPORTABLE
|
||||||
|
global _NO_JIT
|
||||||
|
global _NO_ACTIVATION_JIT
|
||||||
|
self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT
|
||||||
|
if scriptable is not None:
|
||||||
|
_SCRIPTABLE = scriptable
|
||||||
|
if exportable is not None:
|
||||||
|
_EXPORTABLE = exportable
|
||||||
|
if no_jit is not None:
|
||||||
|
_NO_JIT = no_jit
|
||||||
|
if no_activation_jit is not None:
|
||||||
|
_NO_ACTIVATION_JIT = no_activation_jit
|
||||||
|
|
||||||
|
def __enter__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __exit__(self, *args: Any) -> bool:
|
||||||
|
global _SCRIPTABLE
|
||||||
|
global _EXPORTABLE
|
||||||
|
global _NO_JIT
|
||||||
|
global _NO_ACTIVATION_JIT
|
||||||
|
_SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev
|
||||||
|
return False
|
@ -0,0 +1,42 @@
|
|||||||
|
""" Conv2d w/ Same Padding
|
||||||
|
|
||||||
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from typing import Tuple, Optional
|
||||||
|
|
||||||
|
from .padding import pad_same, get_padding_value
|
||||||
|
|
||||||
|
|
||||||
|
def conv2d_same(
|
||||||
|
x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1),
|
||||||
|
padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1):
|
||||||
|
x = pad_same(x, weight.shape[-2:], stride, dilation)
|
||||||
|
return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)
|
||||||
|
|
||||||
|
|
||||||
|
class Conv2dSame(nn.Conv2d):
|
||||||
|
""" Tensorflow like 'SAME' convolution wrapper for 2D convolutions
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
||||||
|
padding=0, dilation=1, groups=1, bias=True):
|
||||||
|
super(Conv2dSame, self).__init__(
|
||||||
|
in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||||
|
|
||||||
|
|
||||||
|
def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
|
||||||
|
padding = kwargs.pop('padding', '')
|
||||||
|
kwargs.setdefault('bias', False)
|
||||||
|
padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs)
|
||||||
|
if is_dynamic:
|
||||||
|
return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)
|
||||||
|
else:
|
||||||
|
return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
|
||||||
|
|
||||||
|
|
@ -0,0 +1,40 @@
|
|||||||
|
""" Conv2d + BN + Act
|
||||||
|
|
||||||
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
|
"""
|
||||||
|
from torch import nn as nn
|
||||||
|
|
||||||
|
from .create_conv2d import create_conv2d
|
||||||
|
from .create_norm_act import convert_norm_act_type
|
||||||
|
|
||||||
|
|
||||||
|
class ConvBnAct(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1,
|
||||||
|
norm_layer=nn.BatchNorm2d, norm_kwargs=None, act_layer=nn.ReLU, apply_act=True,
|
||||||
|
drop_block=None, aa_layer=None):
|
||||||
|
super(ConvBnAct, self).__init__()
|
||||||
|
use_aa = aa_layer is not None
|
||||||
|
|
||||||
|
self.conv = create_conv2d(
|
||||||
|
in_channels, out_channels, kernel_size, stride=1 if use_aa else stride,
|
||||||
|
padding=padding, dilation=dilation, groups=groups, bias=False)
|
||||||
|
|
||||||
|
# NOTE for backwards compatibility with models that use separate norm and act layer definitions
|
||||||
|
norm_act_layer, norm_act_args = convert_norm_act_type(norm_layer, act_layer, norm_kwargs)
|
||||||
|
self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block, **norm_act_args)
|
||||||
|
self.aa = aa_layer(channels=out_channels) if stride == 2 and use_aa else None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def in_channels(self):
|
||||||
|
return self.conv.in_channels
|
||||||
|
|
||||||
|
@property
|
||||||
|
def out_channels(self):
|
||||||
|
return self.conv.out_channels
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
x = self.bn(x)
|
||||||
|
if self.aa is not None:
|
||||||
|
x = self.aa(x)
|
||||||
|
return x
|
@ -0,0 +1,36 @@
|
|||||||
|
""" Activation Factory
|
||||||
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
|
"""
|
||||||
|
from .activations import *
|
||||||
|
from .activations_jit import *
|
||||||
|
from .activations_me import *
|
||||||
|
from .config import is_exportable, is_scriptable, is_no_jit
|
||||||
|
|
||||||
|
_ACT_LAYER_DEFAULT = dict(
|
||||||
|
swish=Swish,
|
||||||
|
relu=nn.ReLU,
|
||||||
|
relu6=nn.ReLU6,
|
||||||
|
sigmoid=Sigmoid,
|
||||||
|
hard_sigmoid=HardSigmoid,
|
||||||
|
hard_swish=HardSwish,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def get_act_layer(name='relu'):
|
||||||
|
""" Activation Layer Factory
|
||||||
|
Fetching activation layers by name with this function allows export or torch script friendly
|
||||||
|
functions to be returned dynamically based on current config.
|
||||||
|
"""
|
||||||
|
if not name:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return _ACT_LAYER_DEFAULT[name]
|
||||||
|
|
||||||
|
|
||||||
|
def create_act_layer(name, inplace=False, **kwargs):
|
||||||
|
act_layer = get_act_layer(name)
|
||||||
|
if act_layer is not None:
|
||||||
|
return act_layer(inplace=inplace, **kwargs)
|
||||||
|
else:
|
||||||
|
return None
|
@ -0,0 +1,30 @@
|
|||||||
|
""" Create Conv2d Factory Method
|
||||||
|
|
||||||
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .mixed_conv2d import MixedConv2d
|
||||||
|
from .cond_conv2d import CondConv2d
|
||||||
|
from .conv2d_same import create_conv2d_pad
|
||||||
|
|
||||||
|
|
||||||
|
def create_conv2d(in_channels, out_channels, kernel_size, **kwargs):
|
||||||
|
""" Select a 2d convolution implementation based on arguments
|
||||||
|
Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d.
|
||||||
|
|
||||||
|
Used extensively by EfficientNet, MobileNetv3 and related networks.
|
||||||
|
"""
|
||||||
|
if isinstance(kernel_size, list):
|
||||||
|
assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently
|
||||||
|
assert 'groups' not in kwargs # MixedConv groups are defined by kernel list
|
||||||
|
# We're going to use only lists for defining the MixedConv2d kernel groups,
|
||||||
|
# ints, tuples, other iterables will continue to pass to normal conv and specify h, w.
|
||||||
|
m = MixedConv2d(in_channels, out_channels, kernel_size, **kwargs)
|
||||||
|
else:
|
||||||
|
depthwise = kwargs.pop('depthwise', False)
|
||||||
|
groups = out_channels if depthwise else kwargs.pop('groups', 1)
|
||||||
|
if 'num_experts' in kwargs and kwargs['num_experts'] > 0:
|
||||||
|
m = CondConv2d(in_channels, out_channels, kernel_size, groups=groups, **kwargs)
|
||||||
|
else:
|
||||||
|
m = create_conv2d_pad(in_channels, out_channels, kernel_size, groups=groups, **kwargs)
|
||||||
|
return m
|
@ -0,0 +1,74 @@
|
|||||||
|
""" NormAct (Normalizaiton + Activation Layer) Factory
|
||||||
|
|
||||||
|
Create norm + act combo modules that attempt to be backwards compatible with separate norm + act
|
||||||
|
isntances in models. Where these are used it will be possible to swap separate BN + act layers with
|
||||||
|
combined modules like IABN or EvoNorms.
|
||||||
|
|
||||||
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
|
"""
|
||||||
|
import types
|
||||||
|
import functools
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
|
||||||
|
from .norm_act import BatchNormAct2d, GroupNormAct
|
||||||
|
from .inplace_abn import InplaceAbn
|
||||||
|
|
||||||
|
_NORM_ACT_TYPES = {BatchNormAct2d, GroupNormAct, EvoNormBatch2d, EvoNormSample2d, InplaceAbn}
|
||||||
|
_NORM_ACT_REQUIRES_ARG = {BatchNormAct2d, GroupNormAct, InplaceAbn} # requires act_layer arg to define act type
|
||||||
|
|
||||||
|
def get_norm_act_layer(layer_class):
|
||||||
|
layer_class = layer_class.replace('_', '').lower()
|
||||||
|
if layer_class.startswith("batchnorm"):
|
||||||
|
layer = BatchNormAct2d
|
||||||
|
elif layer_class.startswith("groupnorm"):
|
||||||
|
layer = GroupNormAct
|
||||||
|
elif layer_class == "evonormbatch":
|
||||||
|
layer = EvoNormBatch2d
|
||||||
|
elif layer_class == "evonormsample":
|
||||||
|
layer = EvoNormSample2d
|
||||||
|
elif layer_class == "iabn" or layer_class == "inplaceabn":
|
||||||
|
layer = InplaceAbn
|
||||||
|
else:
|
||||||
|
assert False, "Invalid norm_act layer (%s)" % layer_class
|
||||||
|
return layer
|
||||||
|
|
||||||
|
|
||||||
|
def create_norm_act(layer_type, num_features, apply_act=True, jit=False, **kwargs):
|
||||||
|
layer_parts = layer_type.split('-') # e.g. batchnorm-leaky_relu
|
||||||
|
assert len(layer_parts) in (1, 2)
|
||||||
|
layer = get_norm_act_layer(layer_parts[0])
|
||||||
|
#activation_class = layer_parts[1].lower() if len(layer_parts) > 1 else '' # FIXME support string act selection?
|
||||||
|
layer_instance = layer(num_features, apply_act=apply_act, **kwargs)
|
||||||
|
if jit:
|
||||||
|
layer_instance = torch.jit.script(layer_instance)
|
||||||
|
return layer_instance
|
||||||
|
|
||||||
|
|
||||||
|
def convert_norm_act_type(norm_layer, act_layer, norm_kwargs=None):
|
||||||
|
assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial))
|
||||||
|
assert act_layer is None or isinstance(act_layer, (type, str, types.FunctionType, functools.partial))
|
||||||
|
norm_act_args = norm_kwargs.copy() if norm_kwargs else {}
|
||||||
|
if isinstance(norm_layer, str):
|
||||||
|
norm_act_layer = get_norm_act_layer(norm_layer)
|
||||||
|
elif norm_layer in _NORM_ACT_TYPES:
|
||||||
|
norm_act_layer = norm_layer
|
||||||
|
elif isinstance(norm_layer, (types.FunctionType, functools.partial)):
|
||||||
|
# assuming this is a lambda/fn/bound partial that creates norm_act layer
|
||||||
|
norm_act_layer = norm_layer
|
||||||
|
else:
|
||||||
|
type_name = norm_layer.__name__.lower()
|
||||||
|
if type_name.startswith('batchnorm'):
|
||||||
|
norm_act_layer = BatchNormAct2d
|
||||||
|
elif type_name.startswith('groupnorm'):
|
||||||
|
norm_act_layer = GroupNormAct
|
||||||
|
else:
|
||||||
|
assert False, f"No equivalent norm_act layer for {type_name}"
|
||||||
|
if norm_act_layer in _NORM_ACT_REQUIRES_ARG:
|
||||||
|
# Must pass `act_layer` through for backwards compat where `act_layer=None` implies no activation.
|
||||||
|
# In the future, may force use of `apply_act` with `act_layer` arg bound to relevant NormAct types
|
||||||
|
# It is intended that functions/partial does not trigger this, they should define act.
|
||||||
|
norm_act_args.update(dict(act_layer=act_layer))
|
||||||
|
return norm_act_layer, norm_act_args
|
@ -0,0 +1,167 @@
|
|||||||
|
""" DropBlock, DropPath
|
||||||
|
|
||||||
|
PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers.
|
||||||
|
|
||||||
|
Papers:
|
||||||
|
DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890)
|
||||||
|
|
||||||
|
Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382)
|
||||||
|
|
||||||
|
Code:
|
||||||
|
DropBlock impl inspired by two Tensorflow impl that I liked:
|
||||||
|
- https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74
|
||||||
|
- https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py
|
||||||
|
|
||||||
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def drop_block_2d(
|
||||||
|
x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0,
|
||||||
|
with_noise: bool = False, inplace: bool = False, batchwise: bool = False):
|
||||||
|
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
|
||||||
|
|
||||||
|
DropBlock with an experimental gaussian noise option. This layer has been tested on a few training
|
||||||
|
runs with success, but needs further validation and possibly optimization for lower runtime impact.
|
||||||
|
"""
|
||||||
|
B, C, H, W = x.shape
|
||||||
|
total_size = W * H
|
||||||
|
clipped_block_size = min(block_size, min(W, H))
|
||||||
|
# seed_drop_rate, the gamma parameter
|
||||||
|
gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
|
||||||
|
(W - block_size + 1) * (H - block_size + 1))
|
||||||
|
|
||||||
|
# Forces the block to be inside the feature map.
|
||||||
|
w_i, h_i = torch.meshgrid(torch.arange(W).to(x.device), torch.arange(H).to(x.device))
|
||||||
|
valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \
|
||||||
|
((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2))
|
||||||
|
valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype)
|
||||||
|
|
||||||
|
if batchwise:
|
||||||
|
# one mask for whole batch, quite a bit faster
|
||||||
|
uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device)
|
||||||
|
else:
|
||||||
|
uniform_noise = torch.rand_like(x)
|
||||||
|
block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype)
|
||||||
|
block_mask = -F.max_pool2d(
|
||||||
|
-block_mask,
|
||||||
|
kernel_size=clipped_block_size, # block_size,
|
||||||
|
stride=1,
|
||||||
|
padding=clipped_block_size // 2)
|
||||||
|
|
||||||
|
if with_noise:
|
||||||
|
normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x)
|
||||||
|
if inplace:
|
||||||
|
x.mul_(block_mask).add_(normal_noise * (1 - block_mask))
|
||||||
|
else:
|
||||||
|
x = x * block_mask + normal_noise * (1 - block_mask)
|
||||||
|
else:
|
||||||
|
normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype)
|
||||||
|
if inplace:
|
||||||
|
x.mul_(block_mask * normalize_scale)
|
||||||
|
else:
|
||||||
|
x = x * block_mask * normalize_scale
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def drop_block_fast_2d(
|
||||||
|
x: torch.Tensor, drop_prob: float = 0.1, block_size: int = 7,
|
||||||
|
gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False, batchwise: bool = False):
|
||||||
|
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
|
||||||
|
|
||||||
|
DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid
|
||||||
|
block mask at edges.
|
||||||
|
"""
|
||||||
|
B, C, H, W = x.shape
|
||||||
|
total_size = W * H
|
||||||
|
clipped_block_size = min(block_size, min(W, H))
|
||||||
|
gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
|
||||||
|
(W - block_size + 1) * (H - block_size + 1))
|
||||||
|
|
||||||
|
if batchwise:
|
||||||
|
# one mask for whole batch, quite a bit faster
|
||||||
|
block_mask = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) < gamma
|
||||||
|
else:
|
||||||
|
# mask per batch element
|
||||||
|
block_mask = torch.rand_like(x) < gamma
|
||||||
|
block_mask = F.max_pool2d(
|
||||||
|
block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2)
|
||||||
|
|
||||||
|
if with_noise:
|
||||||
|
normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x)
|
||||||
|
if inplace:
|
||||||
|
x.mul_(1. - block_mask).add_(normal_noise * block_mask)
|
||||||
|
else:
|
||||||
|
x = x * (1. - block_mask) + normal_noise * block_mask
|
||||||
|
else:
|
||||||
|
block_mask = 1 - block_mask
|
||||||
|
normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(dtype=x.dtype)
|
||||||
|
if inplace:
|
||||||
|
x.mul_(block_mask * normalize_scale)
|
||||||
|
else:
|
||||||
|
x = x * block_mask * normalize_scale
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DropBlock2d(nn.Module):
|
||||||
|
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
drop_prob=0.1,
|
||||||
|
block_size=7,
|
||||||
|
gamma_scale=1.0,
|
||||||
|
with_noise=False,
|
||||||
|
inplace=False,
|
||||||
|
batchwise=False,
|
||||||
|
fast=True):
|
||||||
|
super(DropBlock2d, self).__init__()
|
||||||
|
self.drop_prob = drop_prob
|
||||||
|
self.gamma_scale = gamma_scale
|
||||||
|
self.block_size = block_size
|
||||||
|
self.with_noise = with_noise
|
||||||
|
self.inplace = inplace
|
||||||
|
self.batchwise = batchwise
|
||||||
|
self.fast = fast # FIXME finish comparisons of fast vs not
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if not self.training or not self.drop_prob:
|
||||||
|
return x
|
||||||
|
if self.fast:
|
||||||
|
return drop_block_fast_2d(
|
||||||
|
x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise)
|
||||||
|
else:
|
||||||
|
return drop_block_2d(
|
||||||
|
x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise)
|
||||||
|
|
||||||
|
|
||||||
|
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
||||||
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||||
|
|
||||||
|
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
||||||
|
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||||
|
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
||||||
|
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
||||||
|
'survival rate' as the argument.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if drop_prob == 0. or not training:
|
||||||
|
return x
|
||||||
|
keep_prob = 1 - drop_prob
|
||||||
|
random_tensor = keep_prob + torch.rand((x.size()[0], 1, 1, 1), dtype=x.dtype, device=x.device)
|
||||||
|
random_tensor.floor_() # binarize
|
||||||
|
output = x.div(keep_prob) * random_tensor
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class DropPath(nn.Module):
|
||||||
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||||
|
"""
|
||||||
|
def __init__(self, drop_prob=None):
|
||||||
|
super(DropPath, self).__init__()
|
||||||
|
self.drop_prob = drop_prob
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return drop_path(x, self.drop_prob, self.training)
|
@ -0,0 +1,83 @@
|
|||||||
|
"""EvoNormB0 (Batched) and EvoNormS0 (Sample) in PyTorch
|
||||||
|
|
||||||
|
An attempt at getting decent performing EvoNorms running in PyTorch.
|
||||||
|
While currently faster than other impl, still quite a ways off the built-in BN
|
||||||
|
in terms of memory usage and throughput (roughly 5x mem, 1/2 - 1/3x speed).
|
||||||
|
|
||||||
|
Still very much a WIP, fiddling with buffer usage, in-place/jit optimizations, and layouts.
|
||||||
|
|
||||||
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class EvoNormBatch2d(nn.Module):
|
||||||
|
def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, drop_block=None):
|
||||||
|
super(EvoNormBatch2d, self).__init__()
|
||||||
|
self.apply_act = apply_act # apply activation (non-linearity)
|
||||||
|
self.momentum = momentum
|
||||||
|
self.eps = eps
|
||||||
|
param_shape = (1, num_features, 1, 1)
|
||||||
|
self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True)
|
||||||
|
self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True)
|
||||||
|
if apply_act:
|
||||||
|
self.v = nn.Parameter(torch.ones(param_shape), requires_grad=True)
|
||||||
|
self.register_buffer('running_var', torch.ones(1, num_features, 1, 1))
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
nn.init.ones_(self.weight)
|
||||||
|
nn.init.zeros_(self.bias)
|
||||||
|
if self.apply_act:
|
||||||
|
nn.init.ones_(self.v)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
assert x.dim() == 4, 'expected 4D input'
|
||||||
|
x_type = x.dtype
|
||||||
|
if self.training:
|
||||||
|
var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True)
|
||||||
|
n = x.numel() / x.shape[1]
|
||||||
|
self.running_var.copy_(
|
||||||
|
var.detach() * self.momentum * (n / (n - 1)) + self.running_var * (1 - self.momentum))
|
||||||
|
else:
|
||||||
|
var = self.running_var
|
||||||
|
|
||||||
|
if self.apply_act:
|
||||||
|
v = self.v.to(dtype=x_type)
|
||||||
|
d = x * v + (x.var(dim=(2, 3), unbiased=False, keepdim=True) + self.eps).sqrt().to(dtype=x_type)
|
||||||
|
d = d.max((var + self.eps).sqrt().to(dtype=x_type))
|
||||||
|
x = x / d
|
||||||
|
return x * self.weight + self.bias
|
||||||
|
|
||||||
|
|
||||||
|
class EvoNormSample2d(nn.Module):
|
||||||
|
def __init__(self, num_features, apply_act=True, groups=8, eps=1e-5, drop_block=None):
|
||||||
|
super(EvoNormSample2d, self).__init__()
|
||||||
|
self.apply_act = apply_act # apply activation (non-linearity)
|
||||||
|
self.groups = groups
|
||||||
|
self.eps = eps
|
||||||
|
param_shape = (1, num_features, 1, 1)
|
||||||
|
self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True)
|
||||||
|
self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True)
|
||||||
|
if apply_act:
|
||||||
|
self.v = nn.Parameter(torch.ones(param_shape), requires_grad=True)
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
nn.init.ones_(self.weight)
|
||||||
|
nn.init.zeros_(self.bias)
|
||||||
|
if self.apply_act:
|
||||||
|
nn.init.ones_(self.v)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
assert x.dim() == 4, 'expected 4D input'
|
||||||
|
B, C, H, W = x.shape
|
||||||
|
assert C % self.groups == 0
|
||||||
|
if self.apply_act:
|
||||||
|
n = x * (x * self.v).sigmoid()
|
||||||
|
x = x.reshape(B, self.groups, -1)
|
||||||
|
x = n.reshape(B, self.groups, -1) / (x.var(dim=-1, unbiased=False, keepdim=True) + self.eps).sqrt()
|
||||||
|
x = x.reshape(B, C, H, W)
|
||||||
|
return x * self.weight + self.bias
|
@ -0,0 +1,27 @@
|
|||||||
|
""" Layer/Module Helpers
|
||||||
|
|
||||||
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
|
"""
|
||||||
|
from itertools import repeat
|
||||||
|
from torch._six import container_abcs
|
||||||
|
|
||||||
|
|
||||||
|
# From PyTorch internals
|
||||||
|
def _ntuple(n):
|
||||||
|
def parse(x):
|
||||||
|
if isinstance(x, container_abcs.Iterable):
|
||||||
|
return x
|
||||||
|
return tuple(repeat(x, n))
|
||||||
|
return parse
|
||||||
|
|
||||||
|
|
||||||
|
tup_single = _ntuple(1)
|
||||||
|
tup_pair = _ntuple(2)
|
||||||
|
tup_triple = _ntuple(3)
|
||||||
|
tup_quadruple = _ntuple(4)
|
||||||
|
ntup = _ntuple
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -0,0 +1,87 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn as nn
|
||||||
|
|
||||||
|
try:
|
||||||
|
from inplace_abn.functions import inplace_abn, inplace_abn_sync
|
||||||
|
has_iabn = True
|
||||||
|
except ImportError:
|
||||||
|
has_iabn = False
|
||||||
|
|
||||||
|
def inplace_abn(x, weight, bias, running_mean, running_var,
|
||||||
|
training=True, momentum=0.1, eps=1e-05, activation="leaky_relu", activation_param=0.01):
|
||||||
|
raise ImportError(
|
||||||
|
"Please install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.11'")
|
||||||
|
|
||||||
|
def inplace_abn_sync(**kwargs):
|
||||||
|
inplace_abn(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class InplaceAbn(nn.Module):
|
||||||
|
"""Activated Batch Normalization
|
||||||
|
|
||||||
|
This gathers a BatchNorm and an activation function in a single module
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
num_features : int
|
||||||
|
Number of feature channels in the input and output.
|
||||||
|
eps : float
|
||||||
|
Small constant to prevent numerical issues.
|
||||||
|
momentum : float
|
||||||
|
Momentum factor applied to compute running statistics.
|
||||||
|
affine : bool
|
||||||
|
If `True` apply learned scale and shift transformation after normalization.
|
||||||
|
act_layer : str or nn.Module type
|
||||||
|
Name or type of the activation functions, one of: `leaky_relu`, `elu`
|
||||||
|
act_param : float
|
||||||
|
Negative slope for the `leaky_relu` activation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, apply_act=True,
|
||||||
|
act_layer="leaky_relu", act_param=0.01, drop_block=None):
|
||||||
|
super(InplaceAbn, self).__init__()
|
||||||
|
self.num_features = num_features
|
||||||
|
self.affine = affine
|
||||||
|
self.eps = eps
|
||||||
|
self.momentum = momentum
|
||||||
|
if apply_act:
|
||||||
|
if isinstance(act_layer, str):
|
||||||
|
assert act_layer in ('leaky_relu', 'elu', 'identity', '')
|
||||||
|
self.act_name = act_layer if act_layer else 'identity'
|
||||||
|
else:
|
||||||
|
# convert act layer passed as type to string
|
||||||
|
if act_layer == nn.ELU:
|
||||||
|
self.act_name = 'elu'
|
||||||
|
elif act_layer == nn.LeakyReLU:
|
||||||
|
self.act_name = 'leaky_relu'
|
||||||
|
elif act_layer == nn.Identity:
|
||||||
|
self.act_name = 'identity'
|
||||||
|
else:
|
||||||
|
assert False, f'Invalid act layer {act_layer.__name__} for IABN'
|
||||||
|
else:
|
||||||
|
self.act_name = 'identity'
|
||||||
|
self.act_param = act_param
|
||||||
|
if self.affine:
|
||||||
|
self.weight = nn.Parameter(torch.ones(num_features))
|
||||||
|
self.bias = nn.Parameter(torch.zeros(num_features))
|
||||||
|
else:
|
||||||
|
self.register_parameter('weight', None)
|
||||||
|
self.register_parameter('bias', None)
|
||||||
|
self.register_buffer('running_mean', torch.zeros(num_features))
|
||||||
|
self.register_buffer('running_var', torch.ones(num_features))
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
nn.init.constant_(self.running_mean, 0)
|
||||||
|
nn.init.constant_(self.running_var, 1)
|
||||||
|
if self.affine:
|
||||||
|
nn.init.constant_(self.weight, 1)
|
||||||
|
nn.init.constant_(self.bias, 0)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
output = inplace_abn(
|
||||||
|
x, self.weight, self.bias, self.running_mean, self.running_var,
|
||||||
|
self.training, self.momentum, self.eps, self.act_name, self.act_param)
|
||||||
|
if isinstance(output, tuple):
|
||||||
|
output = output[0]
|
||||||
|
return output
|
@ -0,0 +1,49 @@
|
|||||||
|
""" Median Pool
|
||||||
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
|
"""
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from .helpers import tup_pair, tup_quadruple
|
||||||
|
|
||||||
|
|
||||||
|
class MedianPool2d(nn.Module):
|
||||||
|
""" Median pool (usable as median filter when stride=1) module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
kernel_size: size of pooling kernel, int or 2-tuple
|
||||||
|
stride: pool stride, int or 2-tuple
|
||||||
|
padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad
|
||||||
|
same: override padding and enforce same padding, boolean
|
||||||
|
"""
|
||||||
|
def __init__(self, kernel_size=3, stride=1, padding=0, same=False):
|
||||||
|
super(MedianPool2d, self).__init__()
|
||||||
|
self.k = tup_pair(kernel_size)
|
||||||
|
self.stride = tup_pair(stride)
|
||||||
|
self.padding = tup_quadruple(padding) # convert to l, r, t, b
|
||||||
|
self.same = same
|
||||||
|
|
||||||
|
def _padding(self, x):
|
||||||
|
if self.same:
|
||||||
|
ih, iw = x.size()[2:]
|
||||||
|
if ih % self.stride[0] == 0:
|
||||||
|
ph = max(self.k[0] - self.stride[0], 0)
|
||||||
|
else:
|
||||||
|
ph = max(self.k[0] - (ih % self.stride[0]), 0)
|
||||||
|
if iw % self.stride[1] == 0:
|
||||||
|
pw = max(self.k[1] - self.stride[1], 0)
|
||||||
|
else:
|
||||||
|
pw = max(self.k[1] - (iw % self.stride[1]), 0)
|
||||||
|
pl = pw // 2
|
||||||
|
pr = pw - pl
|
||||||
|
pt = ph // 2
|
||||||
|
pb = ph - pt
|
||||||
|
padding = (pl, pr, pt, pb)
|
||||||
|
else:
|
||||||
|
padding = self.padding
|
||||||
|
return padding
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = F.pad(x, self._padding(x), mode='reflect')
|
||||||
|
x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1])
|
||||||
|
x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0]
|
||||||
|
return x
|
@ -0,0 +1,51 @@
|
|||||||
|
""" PyTorch Mixed Convolution
|
||||||
|
|
||||||
|
Paper: MixConv: Mixed Depthwise Convolutional Kernels (https://arxiv.org/abs/1907.09595)
|
||||||
|
|
||||||
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn as nn
|
||||||
|
|
||||||
|
from .conv2d_same import create_conv2d_pad
|
||||||
|
|
||||||
|
|
||||||
|
def _split_channels(num_chan, num_groups):
|
||||||
|
split = [num_chan // num_groups for _ in range(num_groups)]
|
||||||
|
split[0] += num_chan - sum(split)
|
||||||
|
return split
|
||||||
|
|
||||||
|
|
||||||
|
class MixedConv2d(nn.ModuleDict):
|
||||||
|
""" Mixed Grouped Convolution
|
||||||
|
|
||||||
|
Based on MDConv and GroupedConv in MixNet impl:
|
||||||
|
https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py
|
||||||
|
"""
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size=3,
|
||||||
|
stride=1, padding='', dilation=1, depthwise=False, **kwargs):
|
||||||
|
super(MixedConv2d, self).__init__()
|
||||||
|
|
||||||
|
kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size]
|
||||||
|
num_groups = len(kernel_size)
|
||||||
|
in_splits = _split_channels(in_channels, num_groups)
|
||||||
|
out_splits = _split_channels(out_channels, num_groups)
|
||||||
|
self.in_channels = sum(in_splits)
|
||||||
|
self.out_channels = sum(out_splits)
|
||||||
|
for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)):
|
||||||
|
conv_groups = out_ch if depthwise else 1
|
||||||
|
# use add_module to keep key space clean
|
||||||
|
self.add_module(
|
||||||
|
str(idx),
|
||||||
|
create_conv2d_pad(
|
||||||
|
in_ch, out_ch, k, stride=stride,
|
||||||
|
padding=padding, dilation=dilation, groups=conv_groups, **kwargs)
|
||||||
|
)
|
||||||
|
self.splits = in_splits
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x_split = torch.split(x, self.splits, 1)
|
||||||
|
x_out = [c(x_split[i]) for i, c in enumerate(self.values())]
|
||||||
|
x = torch.cat(x_out, 1)
|
||||||
|
return x
|
@ -0,0 +1,86 @@
|
|||||||
|
""" Normalization + Activation Layers
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
from torch import nn as nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from .create_act import get_act_layer
|
||||||
|
|
||||||
|
|
||||||
|
class BatchNormAct2d(nn.BatchNorm2d):
|
||||||
|
"""BatchNorm + Activation
|
||||||
|
|
||||||
|
This module performs BatchNorm + Activation in a manner that will remain backwards
|
||||||
|
compatible with weights trained with separate bn, act. This is why we inherit from BN
|
||||||
|
instead of composing it as a .bn member.
|
||||||
|
"""
|
||||||
|
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True,
|
||||||
|
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None):
|
||||||
|
super(BatchNormAct2d, self).__init__(
|
||||||
|
num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
|
||||||
|
if isinstance(act_layer, str):
|
||||||
|
act_layer = get_act_layer(act_layer)
|
||||||
|
if act_layer is not None and apply_act:
|
||||||
|
act_args = dict(inplace=True) if inplace else {}
|
||||||
|
self.act = act_layer(**act_args)
|
||||||
|
else:
|
||||||
|
self.act = None
|
||||||
|
|
||||||
|
def _forward_jit(self, x):
|
||||||
|
""" A cut & paste of the contents of the PyTorch BatchNorm2d forward function
|
||||||
|
"""
|
||||||
|
# exponential_average_factor is self.momentum set to
|
||||||
|
# (when it is available) only so that if gets updated
|
||||||
|
# in ONNX graph when this node is exported to ONNX.
|
||||||
|
if self.momentum is None:
|
||||||
|
exponential_average_factor = 0.0
|
||||||
|
else:
|
||||||
|
exponential_average_factor = self.momentum
|
||||||
|
|
||||||
|
if self.training and self.track_running_stats:
|
||||||
|
# TODO: if statement only here to tell the jit to skip emitting this when it is None
|
||||||
|
if self.num_batches_tracked is not None:
|
||||||
|
self.num_batches_tracked += 1
|
||||||
|
if self.momentum is None: # use cumulative moving average
|
||||||
|
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
|
||||||
|
else: # use exponential moving average
|
||||||
|
exponential_average_factor = self.momentum
|
||||||
|
|
||||||
|
x = F.batch_norm(
|
||||||
|
x, self.running_mean, self.running_var, self.weight, self.bias,
|
||||||
|
self.training or not self.track_running_stats,
|
||||||
|
exponential_average_factor, self.eps)
|
||||||
|
return x
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def _forward_python(self, x):
|
||||||
|
return super(BatchNormAct2d, self).forward(x)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# FIXME cannot call parent forward() and maintain jit.script compatibility?
|
||||||
|
if torch.jit.is_scripting():
|
||||||
|
x = self._forward_jit(x)
|
||||||
|
else:
|
||||||
|
x = self._forward_python(x)
|
||||||
|
if self.act is not None:
|
||||||
|
x = self.act(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class GroupNormAct(nn.GroupNorm):
|
||||||
|
|
||||||
|
def __init__(self, num_groups, num_channels, eps=1e-5, affine=True,
|
||||||
|
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None):
|
||||||
|
super(GroupNormAct, self).__init__(num_groups, num_channels, eps=eps, affine=affine)
|
||||||
|
if isinstance(act_layer, str):
|
||||||
|
act_layer = get_act_layer(act_layer)
|
||||||
|
if act_layer is not None and apply_act:
|
||||||
|
self.act = act_layer(inplace=inplace)
|
||||||
|
else:
|
||||||
|
self.act = None
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
|
||||||
|
if self.act is not None:
|
||||||
|
x = self.act(x)
|
||||||
|
return x
|
@ -0,0 +1,56 @@
|
|||||||
|
""" Padding Helpers
|
||||||
|
|
||||||
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
|
"""
|
||||||
|
import math
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
# Calculate symmetric padding for a convolution
|
||||||
|
def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int:
|
||||||
|
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
|
||||||
|
return padding
|
||||||
|
|
||||||
|
|
||||||
|
# Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution
|
||||||
|
def get_same_padding(x: int, k: int, s: int, d: int):
|
||||||
|
return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0)
|
||||||
|
|
||||||
|
|
||||||
|
# Can SAME padding for given args be done statically?
|
||||||
|
def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_):
|
||||||
|
return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
|
||||||
|
|
||||||
|
|
||||||
|
# Dynamically pad input x with 'SAME' padding for conv with specified args
|
||||||
|
def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0):
|
||||||
|
ih, iw = x.size()[-2:]
|
||||||
|
pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1])
|
||||||
|
if pad_h > 0 or pad_w > 0:
|
||||||
|
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:
|
||||||
|
dynamic = False
|
||||||
|
if isinstance(padding, str):
|
||||||
|
# for any string padding, the padding will be calculated for you, one of three ways
|
||||||
|
padding = padding.lower()
|
||||||
|
if padding == 'same':
|
||||||
|
# TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
|
||||||
|
if is_static_pad(kernel_size, **kwargs):
|
||||||
|
# static case, no extra overhead
|
||||||
|
padding = get_padding(kernel_size, **kwargs)
|
||||||
|
else:
|
||||||
|
# dynamic 'SAME' padding, has runtime/GPU memory overhead
|
||||||
|
padding = 0
|
||||||
|
dynamic = True
|
||||||
|
elif padding == 'valid':
|
||||||
|
# 'VALID' padding, same as padding=0
|
||||||
|
padding = 0
|
||||||
|
else:
|
||||||
|
# Default to PyTorch style 'same'-ish symmetric padding
|
||||||
|
padding = get_padding(kernel_size, **kwargs)
|
||||||
|
return padding, dynamic
|
@ -0,0 +1,71 @@
|
|||||||
|
""" AvgPool2d w/ Same Padding
|
||||||
|
|
||||||
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from typing import List, Tuple, Optional
|
||||||
|
|
||||||
|
from .helpers import tup_pair
|
||||||
|
from .padding import pad_same, get_padding_value
|
||||||
|
|
||||||
|
|
||||||
|
def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0),
|
||||||
|
ceil_mode: bool = False, count_include_pad: bool = True):
|
||||||
|
# FIXME how to deal with count_include_pad vs not for external padding?
|
||||||
|
x = pad_same(x, kernel_size, stride)
|
||||||
|
return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
|
||||||
|
|
||||||
|
|
||||||
|
class AvgPool2dSame(nn.AvgPool2d):
|
||||||
|
""" Tensorflow like 'SAME' wrapper for 2D average pooling
|
||||||
|
"""
|
||||||
|
def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True):
|
||||||
|
kernel_size = tup_pair(kernel_size)
|
||||||
|
stride = tup_pair(stride)
|
||||||
|
super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return avg_pool2d_same(
|
||||||
|
x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad)
|
||||||
|
|
||||||
|
|
||||||
|
def max_pool2d_same(
|
||||||
|
x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0),
|
||||||
|
dilation: List[int] = (1, 1), ceil_mode: bool = False):
|
||||||
|
x = pad_same(x, kernel_size, stride, value=-float('inf'))
|
||||||
|
return F.max_pool2d(x, kernel_size, stride, (0, 0), dilation, ceil_mode)
|
||||||
|
|
||||||
|
|
||||||
|
class MaxPool2dSame(nn.MaxPool2d):
|
||||||
|
""" Tensorflow like 'SAME' wrapper for 2D max pooling
|
||||||
|
"""
|
||||||
|
def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False, count_include_pad=True):
|
||||||
|
kernel_size = tup_pair(kernel_size)
|
||||||
|
stride = tup_pair(stride)
|
||||||
|
dilation = tup_pair(dilation)
|
||||||
|
super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode, count_include_pad)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return max_pool2d_same(x, self.kernel_size, self.stride, self.padding, self.dilation, self.ceil_mode)
|
||||||
|
|
||||||
|
|
||||||
|
def create_pool2d(pool_type, kernel_size, stride=None, **kwargs):
|
||||||
|
stride = stride or kernel_size
|
||||||
|
padding = kwargs.pop('padding', '')
|
||||||
|
padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, **kwargs)
|
||||||
|
if is_dynamic:
|
||||||
|
if pool_type == 'avg':
|
||||||
|
return AvgPool2dSame(kernel_size, stride=stride, **kwargs)
|
||||||
|
elif pool_type == 'max':
|
||||||
|
return MaxPool2dSame(kernel_size, stride=stride, **kwargs)
|
||||||
|
else:
|
||||||
|
assert False, f'Unsupported pool type {pool_type}'
|
||||||
|
else:
|
||||||
|
if pool_type == 'avg':
|
||||||
|
return nn.AvgPool2d(kernel_size, stride=stride, padding=padding, **kwargs)
|
||||||
|
elif pool_type == 'max':
|
||||||
|
return nn.MaxPool2d(kernel_size, stride=stride, padding=padding, **kwargs)
|
||||||
|
else:
|
||||||
|
assert False, f'Unsupported pool type {pool_type}'
|
@ -0,0 +1,120 @@
|
|||||||
|
""" Selective Kernel Convolution/Attention
|
||||||
|
|
||||||
|
Paper: Selective Kernel Networks (https://arxiv.org/abs/1903.06586)
|
||||||
|
|
||||||
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
from torch import nn as nn
|
||||||
|
|
||||||
|
from .conv_bn_act import ConvBnAct
|
||||||
|
|
||||||
|
|
||||||
|
def _kernel_valid(k):
|
||||||
|
if isinstance(k, (list, tuple)):
|
||||||
|
for ki in k:
|
||||||
|
return _kernel_valid(ki)
|
||||||
|
assert k >= 3 and k % 2
|
||||||
|
|
||||||
|
|
||||||
|
class SelectiveKernelAttn(nn.Module):
|
||||||
|
def __init__(self, channels, num_paths=2, attn_channels=32,
|
||||||
|
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
||||||
|
""" Selective Kernel Attention Module
|
||||||
|
|
||||||
|
Selective Kernel attention mechanism factored out into its own module.
|
||||||
|
|
||||||
|
"""
|
||||||
|
super(SelectiveKernelAttn, self).__init__()
|
||||||
|
self.num_paths = num_paths
|
||||||
|
self.pool = nn.AdaptiveAvgPool2d(1)
|
||||||
|
self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False)
|
||||||
|
self.bn = norm_layer(attn_channels)
|
||||||
|
self.act = act_layer(inplace=True)
|
||||||
|
self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
assert x.shape[1] == self.num_paths
|
||||||
|
x = torch.sum(x, dim=1)
|
||||||
|
x = self.pool(x)
|
||||||
|
x = self.fc_reduce(x)
|
||||||
|
x = self.bn(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.fc_select(x)
|
||||||
|
B, C, H, W = x.shape
|
||||||
|
x = x.view(B, self.num_paths, C // self.num_paths, H, W)
|
||||||
|
x = torch.softmax(x, dim=1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SelectiveKernelConv(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size=None, stride=1, dilation=1, groups=1,
|
||||||
|
attn_reduction=16, min_attn_channels=32, keep_3x3=True, split_input=False,
|
||||||
|
drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None):
|
||||||
|
""" Selective Kernel Convolution Module
|
||||||
|
|
||||||
|
As described in Selective Kernel Networks (https://arxiv.org/abs/1903.06586) with some modifications.
|
||||||
|
|
||||||
|
Largest change is the input split, which divides the input channels across each convolution path, this can
|
||||||
|
be viewed as a grouping of sorts, but the output channel counts expand to the module level value. This keeps
|
||||||
|
the parameter count from ballooning when the convolutions themselves don't have groups, but still provides
|
||||||
|
a noteworthy increase in performance over similar param count models without this attention layer. -Ross W
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): module input (feature) channel count
|
||||||
|
out_channels (int): module output (feature) channel count
|
||||||
|
kernel_size (int, list): kernel size for each convolution branch
|
||||||
|
stride (int): stride for convolutions
|
||||||
|
dilation (int): dilation for module as a whole, impacts dilation of each branch
|
||||||
|
groups (int): number of groups for each branch
|
||||||
|
attn_reduction (int, float): reduction factor for attention features
|
||||||
|
min_attn_channels (int): minimum attention feature channels
|
||||||
|
keep_3x3 (bool): keep all branch convolution kernels as 3x3, changing larger kernels for dilations
|
||||||
|
split_input (bool): split input channels evenly across each convolution branch, keeps param count lower,
|
||||||
|
can be viewed as grouping by path, output expands to module out_channels count
|
||||||
|
drop_block (nn.Module): drop block module
|
||||||
|
act_layer (nn.Module): activation layer to use
|
||||||
|
norm_layer (nn.Module): batchnorm/norm layer to use
|
||||||
|
"""
|
||||||
|
super(SelectiveKernelConv, self).__init__()
|
||||||
|
kernel_size = kernel_size or [3, 5] # default to one 3x3 and one 5x5 branch. 5x5 -> 3x3 + dilation
|
||||||
|
_kernel_valid(kernel_size)
|
||||||
|
if not isinstance(kernel_size, list):
|
||||||
|
kernel_size = [kernel_size] * 2
|
||||||
|
if keep_3x3:
|
||||||
|
dilation = [dilation * (k - 1) // 2 for k in kernel_size]
|
||||||
|
kernel_size = [3] * len(kernel_size)
|
||||||
|
else:
|
||||||
|
dilation = [dilation] * len(kernel_size)
|
||||||
|
self.num_paths = len(kernel_size)
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.split_input = split_input
|
||||||
|
if self.split_input:
|
||||||
|
assert in_channels % self.num_paths == 0
|
||||||
|
in_channels = in_channels // self.num_paths
|
||||||
|
groups = min(out_channels, groups)
|
||||||
|
|
||||||
|
conv_kwargs = dict(
|
||||||
|
stride=stride, groups=groups, drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer,
|
||||||
|
aa_layer=aa_layer)
|
||||||
|
self.paths = nn.ModuleList([
|
||||||
|
ConvBnAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs)
|
||||||
|
for k, d in zip(kernel_size, dilation)])
|
||||||
|
|
||||||
|
attn_channels = max(int(out_channels / attn_reduction), min_attn_channels)
|
||||||
|
self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels)
|
||||||
|
self.drop_block = drop_block
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.split_input:
|
||||||
|
x_split = torch.split(x, self.in_channels // self.num_paths, 1)
|
||||||
|
x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)]
|
||||||
|
else:
|
||||||
|
x_paths = [op(x) for op in self.paths]
|
||||||
|
x = torch.stack(x_paths, dim=1)
|
||||||
|
x_attn = self.attn(x)
|
||||||
|
x = x * x_attn
|
||||||
|
x = torch.sum(x, dim=1)
|
||||||
|
return x
|
@ -0,0 +1,74 @@
|
|||||||
|
""" Depthwise Separable Conv Modules
|
||||||
|
|
||||||
|
Basic DWS convs. Other variations of DWS exist with batch norm or activations between the
|
||||||
|
DW and PW convs such as the Depthwise modules in MobileNetV2 / EfficientNet and Xception.
|
||||||
|
|
||||||
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
|
"""
|
||||||
|
from torch import nn as nn
|
||||||
|
|
||||||
|
from .create_conv2d import create_conv2d
|
||||||
|
from .create_norm_act import convert_norm_act_type
|
||||||
|
|
||||||
|
|
||||||
|
class SeparableConvBnAct(nn.Module):
|
||||||
|
""" Separable Conv w/ trailing Norm and Activation
|
||||||
|
"""
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False,
|
||||||
|
channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
|
||||||
|
act_layer=nn.ReLU, apply_act=True, drop_block=None):
|
||||||
|
super(SeparableConvBnAct, self).__init__()
|
||||||
|
norm_kwargs = norm_kwargs or {}
|
||||||
|
|
||||||
|
self.conv_dw = create_conv2d(
|
||||||
|
in_channels, int(in_channels * channel_multiplier), kernel_size,
|
||||||
|
stride=stride, dilation=dilation, padding=padding, depthwise=True)
|
||||||
|
|
||||||
|
self.conv_pw = create_conv2d(
|
||||||
|
int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias)
|
||||||
|
|
||||||
|
norm_act_layer, norm_act_args = convert_norm_act_type(norm_layer, act_layer, norm_kwargs)
|
||||||
|
self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block, **norm_act_args)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def in_channels(self):
|
||||||
|
return self.conv_dw.in_channels
|
||||||
|
|
||||||
|
@property
|
||||||
|
def out_channels(self):
|
||||||
|
return self.conv_pw.out_channels
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv_dw(x)
|
||||||
|
x = self.conv_pw(x)
|
||||||
|
if self.bn is not None:
|
||||||
|
x = self.bn(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SeparableConv2d(nn.Module):
|
||||||
|
""" Separable Conv
|
||||||
|
"""
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False,
|
||||||
|
channel_multiplier=1.0, pw_kernel_size=1):
|
||||||
|
super(SeparableConv2d, self).__init__()
|
||||||
|
|
||||||
|
self.conv_dw = create_conv2d(
|
||||||
|
in_channels, int(in_channels * channel_multiplier), kernel_size,
|
||||||
|
stride=stride, dilation=dilation, padding=padding, depthwise=True)
|
||||||
|
|
||||||
|
self.conv_pw = create_conv2d(
|
||||||
|
int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def in_channels(self):
|
||||||
|
return self.conv_dw.in_channels
|
||||||
|
|
||||||
|
@property
|
||||||
|
def out_channels(self):
|
||||||
|
return self.conv_pw.out_channels
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv_dw(x)
|
||||||
|
x = self.conv_pw(x)
|
||||||
|
return x
|
@ -0,0 +1,53 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class SpaceToDepth(nn.Module):
|
||||||
|
def __init__(self, block_size=4):
|
||||||
|
super().__init__()
|
||||||
|
assert block_size == 4
|
||||||
|
self.bs = block_size
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
N, C, H, W = x.size()
|
||||||
|
x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs)
|
||||||
|
x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs)
|
||||||
|
x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
class SpaceToDepthJit(object):
|
||||||
|
def __call__(self, x: torch.Tensor):
|
||||||
|
# assuming hard-coded that block_size==4 for acceleration
|
||||||
|
N, C, H, W = x.size()
|
||||||
|
x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs)
|
||||||
|
x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs)
|
||||||
|
x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SpaceToDepthModule(nn.Module):
|
||||||
|
def __init__(self, no_jit=False):
|
||||||
|
super().__init__()
|
||||||
|
if not no_jit:
|
||||||
|
self.op = SpaceToDepthJit()
|
||||||
|
else:
|
||||||
|
self.op = SpaceToDepth()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.op(x)
|
||||||
|
|
||||||
|
|
||||||
|
class DepthToSpace(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, block_size):
|
||||||
|
super().__init__()
|
||||||
|
self.bs = block_size
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
N, C, H, W = x.size()
|
||||||
|
x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W)
|
||||||
|
x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs)
|
||||||
|
x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs)
|
||||||
|
return x
|
@ -0,0 +1,88 @@
|
|||||||
|
""" Split Attention Conv2d (for ResNeSt Models)
|
||||||
|
|
||||||
|
Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955
|
||||||
|
|
||||||
|
Adapted from original PyTorch impl at https://github.com/zhanghang1989/ResNeSt
|
||||||
|
|
||||||
|
Modified for torchscript compat, performance, and consistency with timm by Ross Wightman
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class RadixSoftmax(nn.Module):
|
||||||
|
def __init__(self, radix, cardinality):
|
||||||
|
super(RadixSoftmax, self).__init__()
|
||||||
|
self.radix = radix
|
||||||
|
self.cardinality = cardinality
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
batch = x.size(0)
|
||||||
|
if self.radix > 1:
|
||||||
|
x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)
|
||||||
|
x = F.softmax(x, dim=1)
|
||||||
|
x = x.reshape(batch, -1)
|
||||||
|
else:
|
||||||
|
x = torch.sigmoid(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SplitAttnConv2d(nn.Module):
|
||||||
|
"""Split-Attention Conv2d
|
||||||
|
"""
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
|
||||||
|
dilation=1, groups=1, bias=False, radix=2, reduction_factor=4,
|
||||||
|
act_layer=nn.ReLU, norm_layer=None, drop_block=None, **kwargs):
|
||||||
|
super(SplitAttnConv2d, self).__init__()
|
||||||
|
self.radix = radix
|
||||||
|
self.drop_block = drop_block
|
||||||
|
mid_chs = out_channels * radix
|
||||||
|
attn_chs = max(in_channels * radix // reduction_factor, 32)
|
||||||
|
|
||||||
|
self.conv = nn.Conv2d(
|
||||||
|
in_channels, mid_chs, kernel_size, stride, padding, dilation,
|
||||||
|
groups=groups * radix, bias=bias, **kwargs)
|
||||||
|
self.bn0 = norm_layer(mid_chs) if norm_layer is not None else None
|
||||||
|
self.act0 = act_layer(inplace=True)
|
||||||
|
self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups)
|
||||||
|
self.bn1 = norm_layer(attn_chs) if norm_layer is not None else None
|
||||||
|
self.act1 = act_layer(inplace=True)
|
||||||
|
self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups)
|
||||||
|
self.rsoftmax = RadixSoftmax(radix, groups)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def in_channels(self):
|
||||||
|
return self.conv.in_channels
|
||||||
|
|
||||||
|
@property
|
||||||
|
def out_channels(self):
|
||||||
|
return self.fc1.out_channels
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
if self.bn0 is not None:
|
||||||
|
x = self.bn0(x)
|
||||||
|
if self.drop_block is not None:
|
||||||
|
x = self.drop_block(x)
|
||||||
|
x = self.act0(x)
|
||||||
|
|
||||||
|
B, RC, H, W = x.shape
|
||||||
|
if self.radix > 1:
|
||||||
|
x = x.reshape((B, self.radix, RC // self.radix, H, W))
|
||||||
|
x_gap = x.sum(dim=1)
|
||||||
|
else:
|
||||||
|
x_gap = x
|
||||||
|
x_gap = F.adaptive_avg_pool2d(x_gap, 1)
|
||||||
|
x_gap = self.fc1(x_gap)
|
||||||
|
if self.bn1 is not None:
|
||||||
|
x_gap = self.bn1(x_gap)
|
||||||
|
x_gap = self.act1(x_gap)
|
||||||
|
x_attn = self.fc2(x_gap)
|
||||||
|
|
||||||
|
x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1)
|
||||||
|
if self.radix > 1:
|
||||||
|
out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(dim=1)
|
||||||
|
else:
|
||||||
|
out = x * x_attn
|
||||||
|
return out.contiguous()
|
@ -0,0 +1,75 @@
|
|||||||
|
""" Split BatchNorm
|
||||||
|
|
||||||
|
A PyTorch BatchNorm layer that splits input batch into N equal parts and passes each through
|
||||||
|
a separate BN layer. The first split is passed through the parent BN layers with weight/bias
|
||||||
|
keys the same as the original BN. All other splits pass through BN sub-layers under the '.aux_bn'
|
||||||
|
namespace.
|
||||||
|
|
||||||
|
This allows easily removing the auxiliary BN layers after training to efficiently
|
||||||
|
achieve the 'Auxiliary BatchNorm' as described in the AdvProp Paper, section 4.2,
|
||||||
|
'Disentangled Learning via An Auxiliary BN'
|
||||||
|
|
||||||
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class SplitBatchNorm2d(torch.nn.BatchNorm2d):
|
||||||
|
|
||||||
|
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
|
||||||
|
track_running_stats=True, num_splits=2):
|
||||||
|
super().__init__(num_features, eps, momentum, affine, track_running_stats)
|
||||||
|
assert num_splits > 1, 'Should have at least one aux BN layer (num_splits at least 2)'
|
||||||
|
self.num_splits = num_splits
|
||||||
|
self.aux_bn = nn.ModuleList([
|
||||||
|
nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_splits - 1)])
|
||||||
|
|
||||||
|
def forward(self, input: torch.Tensor):
|
||||||
|
if self.training: # aux BN only relevant while training
|
||||||
|
split_size = input.shape[0] // self.num_splits
|
||||||
|
assert input.shape[0] == split_size * self.num_splits, "batch size must be evenly divisible by num_splits"
|
||||||
|
split_input = input.split(split_size)
|
||||||
|
x = [super().forward(split_input[0])]
|
||||||
|
for i, a in enumerate(self.aux_bn):
|
||||||
|
x.append(a(split_input[i + 1]))
|
||||||
|
return torch.cat(x, dim=0)
|
||||||
|
else:
|
||||||
|
return super().forward(input)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_splitbn_model(module, num_splits=2):
|
||||||
|
"""
|
||||||
|
Recursively traverse module and its children to replace all instances of
|
||||||
|
``torch.nn.modules.batchnorm._BatchNorm`` with `SplitBatchnorm2d`.
|
||||||
|
Args:
|
||||||
|
module (torch.nn.Module): input module
|
||||||
|
num_splits: number of separate batchnorm layers to split input across
|
||||||
|
Example::
|
||||||
|
>>> # model is an instance of torch.nn.Module
|
||||||
|
>>> model = timm.models.convert_splitbn_model(model, num_splits=2)
|
||||||
|
"""
|
||||||
|
mod = module
|
||||||
|
if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm):
|
||||||
|
return module
|
||||||
|
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
|
||||||
|
mod = SplitBatchNorm2d(
|
||||||
|
module.num_features, module.eps, module.momentum, module.affine,
|
||||||
|
module.track_running_stats, num_splits=num_splits)
|
||||||
|
mod.running_mean = module.running_mean
|
||||||
|
mod.running_var = module.running_var
|
||||||
|
mod.num_batches_tracked = module.num_batches_tracked
|
||||||
|
if module.affine:
|
||||||
|
mod.weight.data = module.weight.data.clone().detach()
|
||||||
|
mod.bias.data = module.bias.data.clone().detach()
|
||||||
|
for aux in mod.aux_bn:
|
||||||
|
aux.running_mean = module.running_mean.clone()
|
||||||
|
aux.running_var = module.running_var.clone()
|
||||||
|
aux.num_batches_tracked = module.num_batches_tracked.clone()
|
||||||
|
if module.affine:
|
||||||
|
aux.weight.data = module.weight.data.clone().detach()
|
||||||
|
aux.bias.data = module.bias.data.clone().detach()
|
||||||
|
for name, child in module.named_children():
|
||||||
|
mod.add_module(name, convert_splitbn_model(child, num_splits=num_splits))
|
||||||
|
del module
|
||||||
|
return mod
|
@ -0,0 +1,50 @@
|
|||||||
|
""" Test Time Pooling (Average-Max Pool)
|
||||||
|
|
||||||
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from .adaptive_avgmax_pool import adaptive_avgmax_pool2d
|
||||||
|
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TestTimePoolHead(nn.Module):
|
||||||
|
def __init__(self, base, original_pool=7):
|
||||||
|
super(TestTimePoolHead, self).__init__()
|
||||||
|
self.base = base
|
||||||
|
self.original_pool = original_pool
|
||||||
|
base_fc = self.base.get_classifier()
|
||||||
|
if isinstance(base_fc, nn.Conv2d):
|
||||||
|
self.fc = base_fc
|
||||||
|
else:
|
||||||
|
self.fc = nn.Conv2d(
|
||||||
|
self.base.num_features, self.base.num_classes, kernel_size=1, bias=True)
|
||||||
|
self.fc.weight.data.copy_(base_fc.weight.data.view(self.fc.weight.size()))
|
||||||
|
self.fc.bias.data.copy_(base_fc.bias.data.view(self.fc.bias.size()))
|
||||||
|
self.base.reset_classifier(0) # delete original fc layer
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.base.forward_features(x)
|
||||||
|
x = F.avg_pool2d(x, kernel_size=self.original_pool, stride=1)
|
||||||
|
x = self.fc(x)
|
||||||
|
x = adaptive_avgmax_pool2d(x, 1)
|
||||||
|
return x.view(x.size(0), -1)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_test_time_pool(model, config, args):
|
||||||
|
test_time_pool = False
|
||||||
|
if not hasattr(model, 'default_cfg') or not model.default_cfg:
|
||||||
|
return model, False
|
||||||
|
if not args.no_test_pool and \
|
||||||
|
config['input_size'][-1] > model.default_cfg['input_size'][-1] and \
|
||||||
|
config['input_size'][-2] > model.default_cfg['input_size'][-2]:
|
||||||
|
_logger.info('Target input size %s > pretrained default %s, using test time pooling' %
|
||||||
|
(str(config['input_size'][-2:]), str(model.default_cfg['input_size'][-2:])))
|
||||||
|
model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size'])
|
||||||
|
test_time_pool = True
|
||||||
|
return model, test_time_pool
|
@ -0,0 +1,60 @@
|
|||||||
|
import torch
|
||||||
|
import math
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
|
||||||
|
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
||||||
|
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
||||||
|
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
||||||
|
def norm_cdf(x):
|
||||||
|
# Computes standard normal cumulative distribution function
|
||||||
|
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
||||||
|
|
||||||
|
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
||||||
|
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
||||||
|
"The distribution of values may be incorrect.",
|
||||||
|
stacklevel=2)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# Values are generated by using a truncated uniform distribution and
|
||||||
|
# then using the inverse CDF for the normal distribution.
|
||||||
|
# Get upper and lower cdf values
|
||||||
|
l = norm_cdf((a - mean) / std)
|
||||||
|
u = norm_cdf((b - mean) / std)
|
||||||
|
|
||||||
|
# Uniformly fill tensor with values from [l, u], then translate to
|
||||||
|
# [2l-1, 2u-1].
|
||||||
|
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
||||||
|
|
||||||
|
# Use inverse cdf transform for normal distribution to get truncated
|
||||||
|
# standard normal
|
||||||
|
tensor.erfinv_()
|
||||||
|
|
||||||
|
# Transform to proper mean, std
|
||||||
|
tensor.mul_(std * math.sqrt(2.))
|
||||||
|
tensor.add_(mean)
|
||||||
|
|
||||||
|
# Clamp to ensure it's in the proper range
|
||||||
|
tensor.clamp_(min=a, max=b)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
||||||
|
# type: (Tensor, float, float, float, float) -> Tensor
|
||||||
|
r"""Fills the input Tensor with values drawn from a truncated
|
||||||
|
normal distribution. The values are effectively drawn from the
|
||||||
|
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
||||||
|
with values outside :math:`[a, b]` redrawn until they are within
|
||||||
|
the bounds. The method used for generating the random values works
|
||||||
|
best when :math:`a \leq \text{mean} \leq b`.
|
||||||
|
Args:
|
||||||
|
tensor: an n-dimensional `torch.Tensor`
|
||||||
|
mean: the mean of the normal distribution
|
||||||
|
std: the standard deviation of the normal distribution
|
||||||
|
a: the minimum cutoff value
|
||||||
|
b: the maximum cutoff value
|
||||||
|
Examples:
|
||||||
|
>>> w = torch.empty(3, 5)
|
||||||
|
>>> nn.init.trunc_normal_(w)
|
||||||
|
"""
|
||||||
|
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
@ -0,0 +1,462 @@
|
|||||||
|
|
||||||
|
""" MobileNet V3
|
||||||
|
|
||||||
|
A PyTorch impl of MobileNet-V3, compatible with TF weights from official impl.
|
||||||
|
|
||||||
|
Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244
|
||||||
|
|
||||||
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||||
|
from .efficientnet_blocks import round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
|
||||||
|
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights
|
||||||
|
from timm.models.features import FeatureInfo, FeatureHooks
|
||||||
|
from timm.models.helpers import build_model_with_cfg
|
||||||
|
from .layers import SelectAdaptivePool2d, create_conv2d, HardSigmoid
|
||||||
|
from timm.models.registry import register_model
|
||||||
|
from .efficientnet_blocks import ConvBnAct, SqueezeExcite, DepthwiseSeparableConv, InvertedResidual
|
||||||
|
__all__ = ['MobileNetV3']
|
||||||
|
|
||||||
|
|
||||||
|
def _cfg(url='', **kwargs):
|
||||||
|
return {
|
||||||
|
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (1, 1),
|
||||||
|
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||||
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||||
|
'first_conv': 'conv_stem', 'classifier': 'classifier',
|
||||||
|
**kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
default_cfgs = {
|
||||||
|
'mobilenetv3_large_075': _cfg(url=''),
|
||||||
|
'mobilenetv3_large_100': _cfg(
|
||||||
|
interpolation='bicubic',
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth'),
|
||||||
|
'mobilenetv3_small_075': _cfg(url=''),
|
||||||
|
'mobilenetv3_small_100': _cfg(url=''),
|
||||||
|
'mobilenetv3_rw': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth',
|
||||||
|
interpolation='bicubic'),
|
||||||
|
'tf_mobilenetv3_large_075': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth',
|
||||||
|
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
|
||||||
|
'tf_mobilenetv3_large_100': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth',
|
||||||
|
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
|
||||||
|
'tf_mobilenetv3_large_minimal_100': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth',
|
||||||
|
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
|
||||||
|
'tf_mobilenetv3_small_075': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth',
|
||||||
|
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
|
||||||
|
'tf_mobilenetv3_small_100': _cfg(
|
||||||
|
url= 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth',
|
||||||
|
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
|
||||||
|
'tf_mobilenetv3_small_minimal_100': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth',
|
||||||
|
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
|
||||||
|
}
|
||||||
|
|
||||||
|
_DEBUG = False
|
||||||
|
|
||||||
|
|
||||||
|
class MobileNetV3(nn.Module):
|
||||||
|
""" MobiletNet-V3
|
||||||
|
|
||||||
|
Based on my EfficientNet implementation and building blocks, this model utilizes the MobileNet-v3 specific
|
||||||
|
'efficient head', where global pooling is done before the head convolution without a final batch-norm
|
||||||
|
layer before the classifier.
|
||||||
|
|
||||||
|
Paper: https://arxiv.org/abs/1905.02244
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True,
|
||||||
|
channel_multiplier=1.0, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0.,
|
||||||
|
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, global_pool='avg'):
|
||||||
|
super(MobileNetV3, self).__init__()
|
||||||
|
|
||||||
|
#temporary fix to support tf_models with torch quantization
|
||||||
|
pad_type=''
|
||||||
|
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.num_features = num_features
|
||||||
|
self.drop_rate = drop_rate
|
||||||
|
|
||||||
|
# Stem
|
||||||
|
stem_size = round_channels(stem_size, channel_multiplier)
|
||||||
|
self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
|
||||||
|
self.bn1 = norm_layer(stem_size, **norm_kwargs)
|
||||||
|
self.act1 = act_layer(inplace=True)
|
||||||
|
|
||||||
|
# Middle stages (IR/ER/DS Blocks)
|
||||||
|
builder = EfficientNetBuilder(
|
||||||
|
channel_multiplier, 8, None, 32, pad_type, act_layer, se_kwargs,
|
||||||
|
norm_layer, norm_kwargs, drop_path_rate, verbose=_DEBUG)
|
||||||
|
self.blocks = nn.Sequential(*builder(stem_size, block_args))
|
||||||
|
self.feature_info = builder.features
|
||||||
|
head_chs = builder.in_chs
|
||||||
|
|
||||||
|
# Head + Pooling
|
||||||
|
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||||
|
num_pooled_chs = head_chs * self.global_pool.feat_mult()
|
||||||
|
self.conv_head = create_conv2d(num_pooled_chs, self.num_features, 1, padding=pad_type, bias=head_bias)
|
||||||
|
self.act2 = act_layer(inplace=True)
|
||||||
|
self.classifier = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
|
||||||
|
# Quantization Stubs
|
||||||
|
self.quant = torch.quantization.QuantStub()
|
||||||
|
self.dequant = torch.quantization.DeQuantStub()
|
||||||
|
|
||||||
|
efficientnet_init_weights(self)
|
||||||
|
|
||||||
|
def as_sequential(self):
|
||||||
|
layers = [self.conv_stem, self.bn1, self.act1]
|
||||||
|
layers.extend(self.blocks)
|
||||||
|
layers.extend([self.global_pool, self.conv_head, self.act2])
|
||||||
|
layers.extend([nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier])
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def get_classifier(self):
|
||||||
|
return self.classifier
|
||||||
|
|
||||||
|
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||||
|
self.num_classes = num_classes
|
||||||
|
# cannot meaningfully change pooling of efficient head after creation
|
||||||
|
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||||
|
self.classifier = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
|
||||||
|
def forward_features(self, x):
|
||||||
|
x = self.conv_stem(x)
|
||||||
|
x = self.bn1(x)
|
||||||
|
x = self.act1(x)
|
||||||
|
x = self.blocks(x)
|
||||||
|
x = self.global_pool(x)
|
||||||
|
x = self.conv_head(x)
|
||||||
|
x = self.act2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.quant(x)
|
||||||
|
x = self.forward_features(x)
|
||||||
|
if not self.global_pool.is_identity():
|
||||||
|
x = x.flatten(1)
|
||||||
|
if self.drop_rate > 0.:
|
||||||
|
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
||||||
|
x = self.classifier(x)
|
||||||
|
x = self.dequant(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def fuse_model(self):
|
||||||
|
modules_to_fuse = [['conv_stem','bn1']]
|
||||||
|
if type(self.act1) == nn.ReLU:
|
||||||
|
modules_to_fuse[0].append('act1')
|
||||||
|
if type(self.act2) == nn.ReLU:
|
||||||
|
modules_to_fuse.append(['conv_head','act2'])
|
||||||
|
torch.quantization.fuse_modules(self, modules_to_fuse, inplace=True)
|
||||||
|
for m in self.modules():
|
||||||
|
if type(m) in [ConvBnAct, SqueezeExcite, DepthwiseSeparableConv, InvertedResidual]:
|
||||||
|
m.fuse_module()
|
||||||
|
|
||||||
|
class MobileNetV3Features(nn.Module):
|
||||||
|
""" MobileNetV3 Feature Extractor
|
||||||
|
|
||||||
|
A work-in-progress feature extraction module for MobileNet-V3 to use as a backbone for segmentation
|
||||||
|
and object detection models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck',
|
||||||
|
in_chans=3, stem_size=16, channel_multiplier=1.0, output_stride=32, pad_type='',
|
||||||
|
act_layer=nn.ReLU, drop_rate=0., drop_path_rate=0., se_kwargs=None,
|
||||||
|
norm_layer=nn.BatchNorm2d, norm_kwargs=None):
|
||||||
|
super(MobileNetV3Features, self).__init__()
|
||||||
|
norm_kwargs = norm_kwargs or {}
|
||||||
|
self.drop_rate = drop_rate
|
||||||
|
|
||||||
|
# Stem
|
||||||
|
stem_size = round_channels(stem_size, channel_multiplier)
|
||||||
|
self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
|
||||||
|
self.bn1 = norm_layer(stem_size, **norm_kwargs)
|
||||||
|
self.act1 = act_layer(inplace=True)
|
||||||
|
|
||||||
|
# Middle stages (IR/ER/DS Blocks)
|
||||||
|
builder = EfficientNetBuilder(
|
||||||
|
channel_multiplier, 8, None, output_stride, pad_type, act_layer, se_kwargs,
|
||||||
|
norm_layer, norm_kwargs, drop_path_rate, feature_location=feature_location, verbose=_DEBUG)
|
||||||
|
self.blocks = nn.Sequential(*builder(stem_size, block_args))
|
||||||
|
self.feature_info = FeatureInfo(builder.features, out_indices)
|
||||||
|
self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices}
|
||||||
|
|
||||||
|
efficientnet_init_weights(self)
|
||||||
|
|
||||||
|
# Register feature extraction hooks with FeatureHooks helper
|
||||||
|
self.feature_hooks = None
|
||||||
|
if feature_location != 'bottleneck':
|
||||||
|
hooks = self.feature_info.get_dicts(keys=('module', 'hook_type'))
|
||||||
|
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
|
||||||
|
|
||||||
|
def forward(self, x) -> List[torch.Tensor]:
|
||||||
|
x = self.conv_stem(x)
|
||||||
|
x = self.bn1(x)
|
||||||
|
x = self.act1(x)
|
||||||
|
if self.feature_hooks is None:
|
||||||
|
features = []
|
||||||
|
if 0 in self._stage_out_idx:
|
||||||
|
features.append(x) # add stem out
|
||||||
|
for i, b in enumerate(self.blocks):
|
||||||
|
x = b(x)
|
||||||
|
if i + 1 in self._stage_out_idx:
|
||||||
|
features.append(x)
|
||||||
|
return features
|
||||||
|
else:
|
||||||
|
self.blocks(x)
|
||||||
|
out = self.feature_hooks.get_output(x.device)
|
||||||
|
return list(out.values())
|
||||||
|
|
||||||
|
|
||||||
|
def _create_mnv3(model_kwargs, variant, pretrained=False):
|
||||||
|
if model_kwargs.pop('features_only', False):
|
||||||
|
load_strict = False
|
||||||
|
model_kwargs.pop('num_classes', 0)
|
||||||
|
model_kwargs.pop('num_features', 0)
|
||||||
|
model_kwargs.pop('head_conv', None)
|
||||||
|
model_kwargs.pop('head_bias', None)
|
||||||
|
model_cls = MobileNetV3Features
|
||||||
|
else:
|
||||||
|
load_strict = True
|
||||||
|
model_cls = MobileNetV3
|
||||||
|
return build_model_with_cfg(
|
||||||
|
model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
|
||||||
|
pretrained_strict=load_strict, **model_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
|
||||||
|
"""Creates a MobileNet-V3 model.
|
||||||
|
|
||||||
|
Ref impl: ?
|
||||||
|
Paper: https://arxiv.org/abs/1905.02244
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channel_multiplier: multiplier to number of channels per layer.
|
||||||
|
"""
|
||||||
|
arch_def = [
|
||||||
|
# stage 0, 112x112 in
|
||||||
|
['ds_r1_k3_s1_e1_c16_nre_noskip'], # relu
|
||||||
|
# stage 1, 112x112 in
|
||||||
|
['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu
|
||||||
|
# stage 2, 56x56 in
|
||||||
|
['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu
|
||||||
|
# stage 3, 28x28 in
|
||||||
|
['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish
|
||||||
|
# stage 4, 14x14in
|
||||||
|
['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish
|
||||||
|
# stage 5, 14x14in
|
||||||
|
['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish
|
||||||
|
# stage 6, 7x7 in
|
||||||
|
['cn_r1_k1_s1_c960'], # hard-swish
|
||||||
|
]
|
||||||
|
model_kwargs = dict(
|
||||||
|
block_args=decode_arch_def(arch_def),
|
||||||
|
head_bias=False,
|
||||||
|
channel_multiplier=channel_multiplier,
|
||||||
|
norm_kwargs=resolve_bn_args(kwargs),
|
||||||
|
act_layer=resolve_act_layer(kwargs, 'hard_swish'),
|
||||||
|
se_kwargs=dict(gate_fn=tf_mobilenetv3_large_100, reduce_mid=True, divisor=1),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
model = _create_mnv3(model_kwargs, variant, pretrained)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
|
||||||
|
"""Creates a MobileNet-V3 model.
|
||||||
|
|
||||||
|
Ref impl: ?
|
||||||
|
Paper: https://arxiv.org/abs/1905.02244
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channel_multiplier: multiplier to number of channels per layer.
|
||||||
|
"""
|
||||||
|
if 'small' in variant:
|
||||||
|
num_features = 1024
|
||||||
|
if 'minimal' in variant:
|
||||||
|
act_layer = resolve_act_layer(kwargs, 'relu')
|
||||||
|
arch_def = [
|
||||||
|
# stage 0, 112x112 in
|
||||||
|
['ds_r1_k3_s2_e1_c16'],
|
||||||
|
# stage 1, 56x56 in
|
||||||
|
['ir_r1_k3_s2_e4.5_c24', 'ir_r1_k3_s1_e3.67_c24'],
|
||||||
|
# stage 2, 28x28 in
|
||||||
|
['ir_r1_k3_s2_e4_c40', 'ir_r2_k3_s1_e6_c40'],
|
||||||
|
# stage 3, 14x14 in
|
||||||
|
['ir_r2_k3_s1_e3_c48'],
|
||||||
|
# stage 4, 14x14in
|
||||||
|
['ir_r3_k3_s2_e6_c96'],
|
||||||
|
# stage 6, 7x7 in
|
||||||
|
['cn_r1_k1_s1_c576'],
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
act_layer = resolve_act_layer(kwargs, 'hard_swish')
|
||||||
|
arch_def = [
|
||||||
|
# stage 0, 112x112 in
|
||||||
|
['ds_r1_k3_s2_e1_c16_se0.25_nre'], # relu
|
||||||
|
# stage 1, 56x56 in
|
||||||
|
['ir_r1_k3_s2_e4.5_c24_nre', 'ir_r1_k3_s1_e3.67_c24_nre'], # relu
|
||||||
|
# stage 2, 28x28 in
|
||||||
|
['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r2_k5_s1_e6_c40_se0.25'], # hard-swish
|
||||||
|
# stage 3, 14x14 in
|
||||||
|
['ir_r2_k5_s1_e3_c48_se0.25'], # hard-swish
|
||||||
|
# stage 4, 14x14in
|
||||||
|
['ir_r3_k5_s2_e6_c96_se0.25'], # hard-swish
|
||||||
|
# stage 6, 7x7 in
|
||||||
|
['cn_r1_k1_s1_c576'], # hard-swish
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
num_features = 1280
|
||||||
|
if 'minimal' in variant:
|
||||||
|
act_layer = resolve_act_layer(kwargs, 'relu')
|
||||||
|
arch_def = [
|
||||||
|
# stage 0, 112x112 in
|
||||||
|
['ds_r1_k3_s1_e1_c16'],
|
||||||
|
# stage 1, 112x112 in
|
||||||
|
['ir_r1_k3_s2_e4_c24', 'ir_r1_k3_s1_e3_c24'],
|
||||||
|
# stage 2, 56x56 in
|
||||||
|
['ir_r3_k3_s2_e3_c40'],
|
||||||
|
# stage 3, 28x28 in
|
||||||
|
['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'],
|
||||||
|
# stage 4, 14x14in
|
||||||
|
['ir_r2_k3_s1_e6_c112'],
|
||||||
|
# stage 5, 14x14in
|
||||||
|
['ir_r3_k3_s2_e6_c160'],
|
||||||
|
# stage 6, 7x7 in
|
||||||
|
['cn_r1_k1_s1_c960'],
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
act_layer = resolve_act_layer(kwargs, 'hard_swish')
|
||||||
|
arch_def = [
|
||||||
|
# stage 0, 112x112 in
|
||||||
|
['ds_r1_k3_s1_e1_c16_nre'], # relu
|
||||||
|
# stage 1, 112x112 in
|
||||||
|
['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu
|
||||||
|
# stage 2, 56x56 in
|
||||||
|
['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu
|
||||||
|
# stage 3, 28x28 in
|
||||||
|
['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish
|
||||||
|
# stage 4, 14x14in
|
||||||
|
['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish
|
||||||
|
# stage 5, 14x14in
|
||||||
|
['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish
|
||||||
|
# stage 6, 7x7 in
|
||||||
|
['cn_r1_k1_s1_c960'], # hard-swish
|
||||||
|
]
|
||||||
|
|
||||||
|
model_kwargs = dict(
|
||||||
|
block_args=decode_arch_def(arch_def),
|
||||||
|
num_features=num_features,
|
||||||
|
stem_size=16,
|
||||||
|
channel_multiplier=channel_multiplier,
|
||||||
|
norm_kwargs=resolve_bn_args(kwargs),
|
||||||
|
act_layer=act_layer,
|
||||||
|
se_kwargs=dict(act_layer=nn.ReLU, gate_fn=HardSigmoid, reduce_mid=True, divisor=8),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
model = _create_mnv3(model_kwargs, variant, pretrained)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def quant_mobilenetv3_large_075(pretrained=False, **kwargs):
|
||||||
|
""" MobileNet V3 """
|
||||||
|
model = _gen_mobilenet_v3('mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def quant_mobilenetv3_large_100(pretrained=False, **kwargs):
|
||||||
|
""" MobileNet V3 """
|
||||||
|
model = _gen_mobilenet_v3('mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def quant_mobilenetv3_small_075(pretrained=False, **kwargs):
|
||||||
|
""" MobileNet V3 """
|
||||||
|
model = _gen_mobilenet_v3('mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def quant_mobilenetv3_small_100(pretrained=False, **kwargs):
|
||||||
|
""" MobileNet V3 """
|
||||||
|
model = _gen_mobilenet_v3('mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def quant_mobilenetv3_rw(pretrained=False, **kwargs):
|
||||||
|
""" MobileNet V3 """
|
||||||
|
if pretrained:
|
||||||
|
# pretrained model trained with non-default BN epsilon
|
||||||
|
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||||
|
model = _gen_mobilenet_v3_rw('mobilenetv3_rw', 1.0, pretrained=pretrained, **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def quant_tf_mobilenetv3_large_075(pretrained=False, **kwargs):
|
||||||
|
""" MobileNet V3 """
|
||||||
|
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||||
|
kwargs['pad_type'] = 'same'
|
||||||
|
model = _gen_mobilenet_v3('tf_mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def quant_tf_mobilenetv3_large_100(pretrained=False, **kwargs):
|
||||||
|
""" MobileNet V3 """
|
||||||
|
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||||
|
kwargs['pad_type'] = 'same'
|
||||||
|
model = _gen_mobilenet_v3('tf_mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def quant_tf_mobilenetv3_large_minimal_100(pretrained=False, **kwargs):
|
||||||
|
""" MobileNet V3 """
|
||||||
|
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||||
|
kwargs['pad_type'] = 'same'
|
||||||
|
model = _gen_mobilenet_v3('tf_mobilenetv3_large_minimal_100', 1.0, pretrained=pretrained, **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def quant_tf_mobilenetv3_small_075(pretrained=False, **kwargs):
|
||||||
|
""" MobileNet V3 """
|
||||||
|
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||||
|
kwargs['pad_type'] = 'same'
|
||||||
|
model = _gen_mobilenet_v3('tf_mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def quant_tf_mobilenetv3_small_100(pretrained=False, **kwargs):
|
||||||
|
""" MobileNet V3 """
|
||||||
|
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||||
|
kwargs['pad_type'] = 'same'
|
||||||
|
model = _gen_mobilenet_v3('tf_mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def quant_tf_mobilenetv3_small_minimal_100(pretrained=False, **kwargs):
|
||||||
|
""" MobileNet V3 """
|
||||||
|
kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
|
||||||
|
kwargs['pad_type'] = 'same'
|
||||||
|
model = _gen_mobilenet_v3('tf_mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs)
|
||||||
|
return model
|
@ -0,0 +1,323 @@
|
|||||||
|
""" ReXNet
|
||||||
|
|
||||||
|
A PyTorch impl of `ReXNet: Diminishing Representational Bottleneck on Convolutional Neural Network` -
|
||||||
|
https://arxiv.org/abs/2007.00992
|
||||||
|
|
||||||
|
Adapted from original impl at https://github.com/clovaai/rexnet
|
||||||
|
Copyright (c) 2020-present NAVER Corp. MIT license
|
||||||
|
|
||||||
|
Changes for timm, feature extraction, and rounded channel variant hacked together by Ross Wightman
|
||||||
|
Copyright 2020 Ross Wightman
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from math import ceil
|
||||||
|
|
||||||
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
|
from timm.models.helpers import build_model_with_cfg
|
||||||
|
from .layers import ClassifierHead, create_act_layer, create_conv2d
|
||||||
|
from timm.models.registry import register_model
|
||||||
|
from .layers.activations import sigmoid, Swish, HardSwish, HardSigmoid
|
||||||
|
|
||||||
|
def _cfg(url=''):
|
||||||
|
return {
|
||||||
|
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||||
|
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
||||||
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||||
|
'first_conv': 'stem.conv', 'classifier': 'head.fc',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
default_cfgs = dict(
|
||||||
|
rexnet_100=_cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rexnet/rexnetv1_100-1b4dddf4.pth'),
|
||||||
|
rexnet_130=_cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rexnet/rexnetv1_130-590d768e.pth'),
|
||||||
|
rexnet_150=_cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rexnet/rexnetv1_150-bd1a6aa8.pth'),
|
||||||
|
rexnet_200=_cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rexnet/rexnetv1_200-8c0b7f2d.pth'),
|
||||||
|
rexnetr_100=_cfg(
|
||||||
|
url=''),
|
||||||
|
rexnetr_130=_cfg(
|
||||||
|
url=''),
|
||||||
|
rexnetr_150=_cfg(
|
||||||
|
url=''),
|
||||||
|
rexnetr_200=_cfg(
|
||||||
|
url=''),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_divisible(v, divisor=8, min_value=None):
|
||||||
|
min_value = min_value or divisor
|
||||||
|
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||||
|
return new_v
|
||||||
|
|
||||||
|
class ConvBn(nn.Module):
|
||||||
|
def __init__(self, in_chs, out_chs, kernel_size,
|
||||||
|
stride=1, dilation=1, pad_type='',
|
||||||
|
norm_layer=nn.BatchNorm2d, groups = 1,norm_kwargs=None):
|
||||||
|
super(ConvBn, self).__init__()
|
||||||
|
norm_kwargs = norm_kwargs or {}
|
||||||
|
self.conv = create_conv2d(in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, groups=groups,padding=pad_type)
|
||||||
|
self.bn1 = norm_layer(out_chs, **norm_kwargs)
|
||||||
|
|
||||||
|
def feature_info(self, location):
|
||||||
|
if location == 'expansion': # output of conv after act, same as block coutput
|
||||||
|
info = dict(module='act1', hook_type='forward', num_chs=self.conv.out_channels)
|
||||||
|
else: # location == 'bottleneck', block output
|
||||||
|
info = dict(module='', hook_type='', num_chs=self.conv.out_channels)
|
||||||
|
return info
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
x = self.bn1(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def fuse_module(self):
|
||||||
|
modules_to_fuse = ['conv','bn1']
|
||||||
|
torch.quantization.fuse_modules(self, modules_to_fuse, inplace=True)
|
||||||
|
|
||||||
|
class ConvBnAct(nn.Module):
|
||||||
|
def __init__(self, in_chs, out_chs, kernel_size,
|
||||||
|
stride=1, dilation=1, pad_type='', act_layer=nn.ReLU,
|
||||||
|
norm_layer=nn.BatchNorm2d, norm_kwargs=None):
|
||||||
|
super(ConvBnAct, self).__init__()
|
||||||
|
norm_kwargs = norm_kwargs or {}
|
||||||
|
self.conv = create_conv2d(in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, padding=pad_type)
|
||||||
|
self.bn1 = norm_layer(out_chs, **norm_kwargs)
|
||||||
|
self.act1 = act_layer(inplace=True)
|
||||||
|
self.out_channels = out_chs
|
||||||
|
def feature_info(self, location):
|
||||||
|
if location == 'expansion': # output of conv after act, same as block coutput
|
||||||
|
info = dict(module='act1', hook_type='forward', num_chs=self.conv.out_channels)
|
||||||
|
else: # location == 'bottleneck', block output
|
||||||
|
info = dict(module='', hook_type='', num_chs=self.conv.out_channels)
|
||||||
|
return info
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
x = self.bn1(x)
|
||||||
|
x = self.act1(x)
|
||||||
|
return x
|
||||||
|
def fuse_module(self):
|
||||||
|
modules_to_fuse = ['conv','bn1']
|
||||||
|
if type(self.act1) == nn.ReLU:
|
||||||
|
modules_to_fuse.append('act1')
|
||||||
|
torch.quantization.fuse_modules(self, modules_to_fuse, inplace=True)
|
||||||
|
|
||||||
|
class SEWithNorm(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, channels, reduction=16, act_layer=nn.ReLU, divisor=1, reduction_channels=None,
|
||||||
|
gate_layer='sigmoid'):
|
||||||
|
super(SEWithNorm, self).__init__()
|
||||||
|
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||||
|
reduction_channels = reduction_channels or make_divisible(channels // reduction, divisor=divisor)
|
||||||
|
self.fc1 = nn.Conv2d(
|
||||||
|
channels, reduction_channels, kernel_size=1, padding=0, bias=True)
|
||||||
|
self.bn = nn.BatchNorm2d(reduction_channels)
|
||||||
|
self.act = act_layer(inplace=True)
|
||||||
|
self.fc2 = nn.Conv2d(
|
||||||
|
reduction_channels, channels, kernel_size=1, padding=0, bias=True)
|
||||||
|
self.gate = create_act_layer(gate_layer)
|
||||||
|
self.quant_mul = nn.quantized.FloatFunctional()
|
||||||
|
def forward(self, x):
|
||||||
|
x_se = self.avg_pool(x)
|
||||||
|
x_se = self.fc1(x_se)
|
||||||
|
x_se = self.bn(x_se)
|
||||||
|
x_se = self.act(x_se)
|
||||||
|
x_se = self.fc2(x_se)
|
||||||
|
return self.quant_mul.mul(x, self.gate(x_se))
|
||||||
|
|
||||||
|
def fuse_module(self):
|
||||||
|
modules_to_fuse = ['fc1','bn','act']
|
||||||
|
torch.quantization.fuse_modules(self, modules_to_fuse, inplace=True)
|
||||||
|
|
||||||
|
|
||||||
|
class LinearBottleneck(nn.Module):
|
||||||
|
def __init__(self, in_chs, out_chs, stride, exp_ratio=1.0, use_se=True, se_rd=12, ch_div=1):
|
||||||
|
super(LinearBottleneck, self).__init__()
|
||||||
|
self.use_shortcut = stride == 1 and in_chs <= out_chs
|
||||||
|
self.in_channels = in_chs
|
||||||
|
self.out_channels = out_chs
|
||||||
|
|
||||||
|
if exp_ratio != 1.:
|
||||||
|
dw_chs = make_divisible(round(in_chs * exp_ratio), divisor=ch_div)
|
||||||
|
self.conv_exp = ConvBnAct(in_chs, dw_chs,1, act_layer=Swish)
|
||||||
|
else:
|
||||||
|
dw_chs = in_chs
|
||||||
|
self.conv_exp = None
|
||||||
|
|
||||||
|
self.conv_dw = ConvBn(dw_chs, dw_chs, 3, stride=stride, groups=dw_chs)
|
||||||
|
self.se = SEWithNorm(dw_chs, reduction=se_rd, divisor=ch_div) if use_se else None
|
||||||
|
self.act_dw = nn.ReLU6()
|
||||||
|
|
||||||
|
self.conv_pwl = ConvBn(dw_chs, out_chs, 1)
|
||||||
|
if self.use_shortcut:
|
||||||
|
self.skip_add = nn.quantized.FloatFunctional()
|
||||||
|
|
||||||
|
def feat_channels(self, exp=False):
|
||||||
|
return self.conv_dw.out_channels if exp else self.out_channels
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
shortcut = x
|
||||||
|
if self.conv_exp is not None:
|
||||||
|
x = self.conv_exp(x)
|
||||||
|
x = self.conv_dw(x)
|
||||||
|
if self.se is not None:
|
||||||
|
x = self.se(x)
|
||||||
|
x = self.act_dw(x)
|
||||||
|
x = self.conv_pwl(x)
|
||||||
|
if self.use_shortcut:
|
||||||
|
x[:, 0:self.in_channels]= self.skip_add.add(x[:, 0:self.in_channels], shortcut)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _block_cfg(width_mult=1.0, depth_mult=1.0, initial_chs=16, final_chs=180, use_se=True, ch_div=1):
|
||||||
|
layers = [1, 2, 2, 3, 3, 5]
|
||||||
|
strides = [1, 2, 2, 2, 1, 2]
|
||||||
|
layers = [ceil(element * depth_mult) for element in layers]
|
||||||
|
strides = sum([[element] + [1] * (layers[idx] - 1) for idx, element in enumerate(strides)], [])
|
||||||
|
exp_ratios = [1] * layers[0] + [6] * sum(layers[1:])
|
||||||
|
depth = sum(layers[:]) * 3
|
||||||
|
base_chs = initial_chs / width_mult if width_mult < 1.0 else initial_chs
|
||||||
|
|
||||||
|
# The following channel configuration is a simple instance to make each layer become an expand layer.
|
||||||
|
out_chs_list = []
|
||||||
|
for i in range(depth // 3):
|
||||||
|
out_chs_list.append(make_divisible(round(base_chs * width_mult), divisor=ch_div))
|
||||||
|
base_chs += final_chs / (depth // 3 * 1.0)
|
||||||
|
|
||||||
|
if use_se:
|
||||||
|
use_ses = [False] * (layers[0] + layers[1]) + [True] * sum(layers[2:])
|
||||||
|
else:
|
||||||
|
use_ses = [False] * sum(layers[:])
|
||||||
|
|
||||||
|
return zip(out_chs_list, exp_ratios, strides, use_ses)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_blocks(block_cfg, prev_chs, width_mult, se_rd=12, ch_div=1, feature_location='bottleneck'):
|
||||||
|
feat_exp = feature_location == 'expansion'
|
||||||
|
feat_chs = [prev_chs]
|
||||||
|
feature_info = []
|
||||||
|
curr_stride = 2
|
||||||
|
features = []
|
||||||
|
for block_idx, (chs, exp_ratio, stride, se) in enumerate(block_cfg):
|
||||||
|
if stride > 1:
|
||||||
|
fname = 'stem' if block_idx == 0 else f'features.{block_idx - 1}'
|
||||||
|
if block_idx > 0 and feat_exp:
|
||||||
|
fname += '.act_dw'
|
||||||
|
feature_info += [dict(num_chs=feat_chs[-1], reduction=curr_stride, module=fname)]
|
||||||
|
curr_stride *= stride
|
||||||
|
features.append(LinearBottleneck(
|
||||||
|
in_chs=prev_chs, out_chs=chs, exp_ratio=exp_ratio, stride=stride, use_se=se, se_rd=se_rd, ch_div=ch_div))
|
||||||
|
prev_chs = chs
|
||||||
|
feat_chs += [features[-1].feat_channels(feat_exp)]
|
||||||
|
pen_chs = make_divisible(1280 * width_mult, divisor=ch_div)
|
||||||
|
feature_info += [dict(
|
||||||
|
num_chs=pen_chs if feat_exp else feat_chs[-1], reduction=curr_stride,
|
||||||
|
module=f'features.{len(features) - int(not feat_exp)}')]
|
||||||
|
features.append(ConvBnAct(prev_chs, pen_chs,1, act_layer=Swish))
|
||||||
|
return features, feature_info
|
||||||
|
|
||||||
|
|
||||||
|
class ReXNetV1(nn.Module):
|
||||||
|
def __init__(self, in_chans=3, num_classes=1000, global_pool='avg', output_stride=32,
|
||||||
|
initial_chs=16, final_chs=180, width_mult=1.0, depth_mult=1.0, use_se=True,
|
||||||
|
se_rd=12, ch_div=1, drop_rate=0.2, feature_location='bottleneck'):
|
||||||
|
super(ReXNetV1, self).__init__()
|
||||||
|
self.drop_rate = drop_rate
|
||||||
|
|
||||||
|
assert output_stride == 32 # FIXME support dilation
|
||||||
|
stem_base_chs = 32 / width_mult if width_mult < 1.0 else 32
|
||||||
|
stem_chs = make_divisible(round(stem_base_chs * width_mult), divisor=ch_div)
|
||||||
|
self.stem = ConvBnAct(in_chans, stem_chs, 3, stride=2, act_layer=Swish)
|
||||||
|
|
||||||
|
block_cfg = _block_cfg(width_mult, depth_mult, initial_chs, final_chs, use_se, ch_div)
|
||||||
|
features, self.feature_info = _build_blocks(
|
||||||
|
block_cfg, stem_chs, width_mult, se_rd, ch_div, feature_location)
|
||||||
|
self.num_features = features[-1].out_channels
|
||||||
|
self.features = nn.Sequential(*features)
|
||||||
|
|
||||||
|
self.head = ClassifierHead(self.num_features, num_classes, global_pool, drop_rate)
|
||||||
|
|
||||||
|
# Quantization Stubs
|
||||||
|
self.quant = torch.quantization.QuantStub()
|
||||||
|
self.dequant = torch.quantization.DeQuantStub()
|
||||||
|
# FIXME weight init, the original appears to use PyTorch defaults
|
||||||
|
|
||||||
|
def get_classifier(self):
|
||||||
|
return self.head.fc
|
||||||
|
|
||||||
|
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||||
|
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
||||||
|
|
||||||
|
def forward_features(self, x):
|
||||||
|
x = self.stem(x)
|
||||||
|
x = self.features(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.quant(x)
|
||||||
|
x = self.forward_features(x)
|
||||||
|
x = self.head(x)
|
||||||
|
x = self.dequant(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def fuse_model(self):
|
||||||
|
for m in self.modules():
|
||||||
|
if type(m) in [ConvBnAct, ConvBn, SEWithNorm]:
|
||||||
|
m.fuse_module()
|
||||||
|
def _create_rexnet(variant, pretrained, **kwargs):
|
||||||
|
feature_cfg = dict(flatten_sequential=True)
|
||||||
|
if kwargs.get('feature_location', '') == 'expansion':
|
||||||
|
feature_cfg['feature_cls'] = 'hook'
|
||||||
|
return build_model_with_cfg(
|
||||||
|
ReXNetV1, variant, pretrained, default_cfg=default_cfgs[variant], feature_cfg=feature_cfg, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def quant_rexnet_100(pretrained=False, **kwargs):
|
||||||
|
"""ReXNet V1 1.0x"""
|
||||||
|
return _create_rexnet('rexnet_100', pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def quant_rexnet_130(pretrained=False, **kwargs):
|
||||||
|
"""ReXNet V1 1.3x"""
|
||||||
|
return _create_rexnet('rexnet_130', pretrained, width_mult=1.3, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def quant_rexnet_150(pretrained=False, **kwargs):
|
||||||
|
"""ReXNet V1 1.5x"""
|
||||||
|
return _create_rexnet('rexnet_150', pretrained, width_mult=1.5, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def quant_rexnet_200(pretrained=False, **kwargs):
|
||||||
|
"""ReXNet V1 2.0x"""
|
||||||
|
return _create_rexnet('rexnet_200', pretrained, width_mult=2.0, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def quant_rexnetr_100(pretrained=False, **kwargs):
|
||||||
|
"""ReXNet V1 1.0x w/ rounded (mod 8) channels"""
|
||||||
|
return _create_rexnet('rexnetr_100', pretrained, ch_div=8, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def quant_rexnetr_130(pretrained=False, **kwargs):
|
||||||
|
"""ReXNet V1 1.3x w/ rounded (mod 8) channels"""
|
||||||
|
return _create_rexnet('rexnetr_130', pretrained, width_mult=1.3, ch_div=8, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def quant_rexnetr_150(pretrained=False, **kwargs):
|
||||||
|
"""ReXNet V1 1.5x w/ rounded (mod 8) channels"""
|
||||||
|
return _create_rexnet('rexnetr_150', pretrained, width_mult=1.5, ch_div=8, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def quant_rexnetr_200(pretrained=False, **kwargs):
|
||||||
|
"""ReXNet V1 2.0x w/ rounded (mod 8) channels"""
|
||||||
|
return _create_rexnet('rexnetr_200', pretrained, width_mult=2.0, ch_div=8, **kwargs)
|
Loading…
Reference in new issue