From 5855b07ae00bf78325ea79a80bb5c7755e3f36e2 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 1 Feb 2019 21:48:56 -0800 Subject: [PATCH] Initial commit, puting some ol pieces together --- dataset.py | 91 +++ inference.py | 138 ++++ models/__init__.py | 1 + models/adaptive_avgmax_pool.py | 61 ++ models/dpn.py | 298 ++++++++ models/fbresnet200.py | 1254 ++++++++++++++++++++++++++++++++ models/inception_resnet_v2.py | 325 +++++++++ models/inception_v4.py | 294 ++++++++ models/median_pool.py | 48 ++ models/model_factory.py | 194 +++++ models/my_densenet.py | 184 +++++ models/my_resnet.py | 247 +++++++ models/pnasnet.py | 401 ++++++++++ models/senet.py | 517 +++++++++++++ models/wrn50_2.py | 393 ++++++++++ models/xception.py | 237 ++++++ optim/nadam.py | 85 +++ train.py | 407 +++++++++++ utils.py | 139 ++++ validate.py | 174 +++++ 20 files changed, 5488 insertions(+) create mode 100644 dataset.py create mode 100644 inference.py create mode 100644 models/__init__.py create mode 100644 models/adaptive_avgmax_pool.py create mode 100644 models/dpn.py create mode 100644 models/fbresnet200.py create mode 100644 models/inception_resnet_v2.py create mode 100644 models/inception_v4.py create mode 100644 models/median_pool.py create mode 100644 models/model_factory.py create mode 100644 models/my_densenet.py create mode 100644 models/my_resnet.py create mode 100644 models/pnasnet.py create mode 100644 models/senet.py create mode 100644 models/wrn50_2.py create mode 100644 models/xception.py create mode 100644 optim/nadam.py create mode 100644 train.py create mode 100644 utils.py create mode 100644 validate.py diff --git a/dataset.py b/dataset.py new file mode 100644 index 00000000..7191bb26 --- /dev/null +++ b/dataset.py @@ -0,0 +1,91 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import torch.utils.data as data + +import os +import re +import torch +from PIL import Image + +IMG_EXTENSIONS = ['.png', '.jpg', '.jpeg'] + + +def natural_key(string_): + """See http://www.codinghorror.com/blog/archives/001018.html""" + return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] + + +def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True): + if class_to_idx is None: + class_to_idx = dict() + build_class_idx = True + else: + build_class_idx = False + labels = [] + filenames = [] + for root, subdirs, files in os.walk(folder, topdown=False): + rel_path = os.path.relpath(root, folder) if (root != folder) else '' + label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_') + if build_class_idx and not subdirs: + class_to_idx[label] = None + for f in files: + base, ext = os.path.splitext(f) + if ext.lower() in types: + filenames.append(os.path.join(root, f)) + labels.append(label) + if build_class_idx: + classes = sorted(class_to_idx.keys(), key=natural_key) + for idx, c in enumerate(classes): + class_to_idx[c] = idx + images_and_targets = zip(filenames, [class_to_idx[l] for l in labels]) + if sort: + images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0])) + if build_class_idx: + return images_and_targets, classes, class_to_idx + else: + return images_and_targets + + +class Dataset(data.Dataset): + + def __init__( + self, + root, + transform=None): + + imgs, _, _ = find_images_and_targets(root) + if len(imgs) == 0: + raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" + "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) + self.root = root + self.imgs = imgs + self.transform = transform + + def __getitem__(self, index): + path, target = self.imgs[index] + img = Image.open(path).convert('RGB') + if self.transform is not None: + img = self.transform(img) + if target is None: + target = torch.zeros(1).long() + return img, target + + def __len__(self): + return len(self.imgs) + + def set_transform(self, transform): + self.transform = transform + + def filenames(self, indices=[], basename=False): + if indices: + if basename: + return [os.path.basename(self.imgs[i][0]) for i in indices] + else: + return [self.imgs[i][0] for i in indices] + else: + if basename: + return [os.path.basename(x[0]) for x in self.imgs] + else: + return [x[0] for x in self.imgs] diff --git a/inference.py b/inference.py new file mode 100644 index 00000000..4ca5919d --- /dev/null +++ b/inference.py @@ -0,0 +1,138 @@ +"""Sample PyTorch Inference script +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import time +import argparse +import numpy as np +import torch +import torch.autograd as autograd +import torch.utils.data as data + +import model_factory +from dataset import Dataset + + +parser = argparse.ArgumentParser(description='PyTorch ImageNet Inference') +parser.add_argument('data', metavar='DIR', + help='path to dataset') +parser.add_argument('--output_dir', metavar='DIR', default='./', + help='path to output files') +parser.add_argument('--model', '-m', metavar='MODEL', default='dpn92', + help='model architecture (default: dpn92)') +parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', + help='number of data loading workers (default: 2)') +parser.add_argument('-b', '--batch-size', default=256, type=int, + metavar='N', help='mini-batch size (default: 256)') +parser.add_argument('--img-size', default=224, type=int, + metavar='N', help='Input image dimension') +parser.add_argument('--print-freq', '-p', default=10, type=int, + metavar='N', help='print frequency (default: 10)') +parser.add_argument('--restore-checkpoint', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') +parser.add_argument('--pretrained', dest='pretrained', action='store_true', + help='use pre-trained model') +parser.add_argument('--multi-gpu', dest='multi_gpu', action='store_true', + help='use multiple-gpus') +parser.add_argument('--no-test-pool', dest='test_time_pool', action='store_false', + help='use pre-trained model') + + +def main(): + args = parser.parse_args() + + # create model + num_classes = 1000 + model = model_factory.create_model( + args.model, + num_classes=num_classes, + pretrained=args.pretrained, + test_time_pool=args.test_time_pool) + + # resume from a checkpoint + if args.restore_checkpoint and os.path.isfile(args.restore_checkpoint): + print("=> loading checkpoint '{}'".format(args.restore_checkpoint)) + checkpoint = torch.load(args.restore_checkpoint) + if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: + model.load_state_dict(checkpoint['state_dict']) + else: + model.load_state_dict(checkpoint) + print("=> loaded checkpoint '{}'".format(args.restore_checkpoint)) + elif not args.pretrained: + print("=> no checkpoint found at '{}'".format(args.restore_checkpoint)) + exit(1) + + if args.multi_gpu: + model = torch.nn.DataParallel(model).cuda() + else: + model = model.cuda() + + transforms = model_factory.get_transforms_eval( + args.model, + args.img_size) + + dataset = Dataset( + args.data, + transforms) + + loader = data.DataLoader( + dataset, + batch_size=args.batch_size, shuffle=False, + num_workers=args.workers, pin_memory=True) + + model.eval() + + batch_time = AverageMeter() + end = time.time() + top5_ids = [] + with torch.no_grad(): + for batch_idx, (input, _) in enumerate(loader): + input = input.cuda() + labels = model(input) + top5 = labels.topk(5)[1] + top5_ids.append(top5.cpu().numpy()) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if batch_idx % args.print_freq == 0: + print('Predict: [{0}/{1}]\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format( + batch_idx, len(loader), batch_time=batch_time)) + + top5_ids = np.concatenate(top5_ids, axis=0).squeeze() + + with open(os.path.join(args.output_dir, './top5_ids.csv'), 'w') as out_file: + filenames = dataset.filenames() + for filename, label in zip(filenames, top5_ids): + filename = os.path.basename(filename) + out_file.write('{0},{1},{2},{3},{4},{5}\n'.format( + filename, label[0], label[1], label[2], label[3], label[4])) + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +if __name__ == '__main__': + main() diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 00000000..53cbf96a --- /dev/null +++ b/models/__init__.py @@ -0,0 +1 @@ +from .model_factory import create_model, get_transforms_eval, get_transforms_train diff --git a/models/adaptive_avgmax_pool.py b/models/adaptive_avgmax_pool.py new file mode 100644 index 00000000..611b05ac --- /dev/null +++ b/models/adaptive_avgmax_pool.py @@ -0,0 +1,61 @@ +""" PyTorch selectable adaptive pooling +Adaptive pooling with the ability to select the type of pooling from: + * 'avg' - Average pooling + * 'max' - Max pooling + * 'avgmax' - Sum of average and max pooling re-scaled by 0.5 + * 'avgmaxc' - Concatenation of average and max pooling along feature dim, doubles feature dim + +Both a functional and a nn.Module version of the pooling is provided. + +Author: Ross Wightman (rwightman) +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def adaptive_avgmax_pool2d(x, pool_type='avg', padding=0, count_include_pad=False): + """Selectable global pooling function with dynamic input kernel size + """ + if pool_type == 'avgmax': + x_avg = F.avg_pool2d( + x, kernel_size=(x.size(2), x.size(3)), padding=padding, count_include_pad=count_include_pad) + x_max = F.max_pool2d(x, kernel_size=(x.size(2), x.size(3)), padding=padding) + x = 0.5 * (x_avg + x_max) + elif pool_type == 'max': + x = F.max_pool2d(x, kernel_size=(x.size(2), x.size(3)), padding=padding) + else: + if pool_type != 'avg': + print('Invalid pool type %s specified. Defaulting to average pooling.' % pool_type) + x = F.avg_pool2d( + x, kernel_size=(x.size(2), x.size(3)), padding=padding, count_include_pad=count_include_pad) + return x + + +class AdaptiveAvgMaxPool2d(torch.nn.Module): + """Selectable global pooling layer with dynamic input kernel size + """ + def __init__(self, output_size=1, pool_type='avg'): + super(AdaptiveAvgMaxPool2d, self).__init__() + self.output_size = output_size + self.pool_type = pool_type + if pool_type == 'avgmax': + self.pool = nn.ModuleList([nn.AdaptiveAvgPool2d(output_size), nn.AdaptiveMaxPool2d(output_size)]) + elif pool_type == 'max': + self.pool = nn.AdaptiveMaxPool2d(output_size) + else: + if pool_type != 'avg': + print('Invalid pool type %s specified. Defaulting to average pooling.' % pool_type) + self.pool = nn.AdaptiveAvgPool2d(output_size) + + def forward(self, x): + if self.pool_type == 'avgmax': + x = 0.5 * torch.sum(torch.stack([p(x) for p in self.pool]), 0).squeeze(dim=0) + else: + x = self.pool(x) + return x + + def __repr__(self): + return self.__class__.__name__ + ' (' \ + + 'output_size=' + str(self.output_size) \ + + ', pool_type=' + self.pool_type + ')' diff --git a/models/dpn.py b/models/dpn.py new file mode 100644 index 00000000..57e48d3b --- /dev/null +++ b/models/dpn.py @@ -0,0 +1,298 @@ +""" PyTorch implementation of DualPathNetworks +Based on original MXNet implementation https://github.com/cypw/DPNs with +many ideas from another PyTorch implementation https://github.com/oyam/pytorch-DPNs. + +This implementation is compatible with the pretrained weights +from cypw's MXNet implementation. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.model_zoo as model_zoo +from collections import OrderedDict + +from .adaptive_avgmax_pool import adaptive_avgmax_pool2d + +__all__ = ['DPN', 'dpn68', 'dpn92', 'dpn98', 'dpn131', 'dpn107'] + + +# If anyone able to provide direct link hosting, more than happy to fill these out.. -rwightman +model_urls = { + 'dpn68': '', + 'dpn68b_extra': 'dpn68_extra-87733ef7.pth', + 'dpn92': '', + 'dpn92_extra': '', + 'dpn98': '', + 'dpn131': 'dpn131-89380fa2.pth', + 'dpn107_extra': 'dpn107_extra-fc014e8ec.pth' +} + + +def dpn68(num_classes=1000, pretrained=False, test_time_pool=7): + model = DPN( + small=True, num_init_features=10, k_r=128, groups=32, + k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64), + num_classes=num_classes, test_time_pool=test_time_pool) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['dpn68'])) + return model + + +def dpn68b(num_classes=1000, pretrained=False, test_time_pool=7): + model = DPN( + small=True, num_init_features=10, k_r=128, groups=32, + b=True, k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64), + num_classes=num_classes, test_time_pool=test_time_pool) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['dpn68b_extra'])) + return model + + +def dpn92(num_classes=1000, pretrained=False, test_time_pool=7, extra=True): + model = DPN( + num_init_features=64, k_r=96, groups=32, + k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128), + num_classes=num_classes, test_time_pool=test_time_pool) + if pretrained: + if extra: + model.load_state_dict(model_zoo.load_url(model_urls['dpn92_extra'])) + else: + model.load_state_dict(model_zoo.load_url(model_urls['dpn92'])) + return model + + +def dpn98(num_classes=1000, pretrained=False, test_time_pool=7): + model = DPN( + num_init_features=96, k_r=160, groups=40, + k_sec=(3, 6, 20, 3), inc_sec=(16, 32, 32, 128), + num_classes=num_classes, test_time_pool=test_time_pool) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['dpn98'])) + return model + + +def dpn131(num_classes=1000, pretrained=False, test_time_pool=7): + model = DPN( + num_init_features=128, k_r=160, groups=40, + k_sec=(4, 8, 28, 3), inc_sec=(16, 32, 32, 128), + num_classes=num_classes, test_time_pool=test_time_pool) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['dpn131'])) + return model + + +def dpn107(num_classes=1000, pretrained=False, test_time_pool=7): + model = DPN( + num_init_features=128, k_r=200, groups=50, + k_sec=(4, 8, 20, 3), inc_sec=(20, 64, 64, 128), + num_classes=num_classes, test_time_pool=test_time_pool) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['dpn107_extra'])) + return model + + +class CatBnAct(nn.Module): + def __init__(self, in_chs, activation_fn=nn.ReLU(inplace=True)): + super(CatBnAct, self).__init__() + self.bn = nn.BatchNorm2d(in_chs, eps=0.001) + self.act = activation_fn + + def forward(self, x): + x = torch.cat(x, dim=1) if isinstance(x, tuple) else x + return self.act(self.bn(x)) + + +class BnActConv2d(nn.Module): + def __init__(self, in_chs, out_chs, kernel_size, stride, + padding=0, groups=1, activation_fn=nn.ReLU(inplace=True)): + super(BnActConv2d, self).__init__() + self.bn = nn.BatchNorm2d(in_chs, eps=0.001) + self.act = activation_fn + self.conv = nn.Conv2d(in_chs, out_chs, kernel_size, stride, padding, groups=groups, bias=False) + + def forward(self, x): + return self.conv(self.act(self.bn(x))) + + +class InputBlock(nn.Module): + def __init__(self, num_init_features, kernel_size=7, + padding=3, activation_fn=nn.ReLU(inplace=True)): + super(InputBlock, self).__init__() + self.conv = nn.Conv2d( + 3, num_init_features, kernel_size=kernel_size, stride=2, padding=padding, bias=False) + self.bn = nn.BatchNorm2d(num_init_features, eps=0.001) + self.act = activation_fn + self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.act(x) + x = self.pool(x) + return x + + +class DualPathBlock(nn.Module): + def __init__( + self, in_chs, num_1x1_a, num_3x3_b, num_1x1_c, inc, groups, block_type='normal', b=False): + super(DualPathBlock, self).__init__() + self.num_1x1_c = num_1x1_c + self.inc = inc + self.b = b + if block_type is 'proj': + self.key_stride = 1 + self.has_proj = True + elif block_type is 'down': + self.key_stride = 2 + self.has_proj = True + else: + assert block_type is 'normal' + self.key_stride = 1 + self.has_proj = False + + if self.has_proj: + # Using different member names here to allow easier parameter key matching for conversion + if self.key_stride == 2: + self.c1x1_w_s2 = BnActConv2d( + in_chs=in_chs, out_chs=num_1x1_c + 2 * inc, kernel_size=1, stride=2) + else: + self.c1x1_w_s1 = BnActConv2d( + in_chs=in_chs, out_chs=num_1x1_c + 2 * inc, kernel_size=1, stride=1) + self.c1x1_a = BnActConv2d(in_chs=in_chs, out_chs=num_1x1_a, kernel_size=1, stride=1) + self.c3x3_b = BnActConv2d( + in_chs=num_1x1_a, out_chs=num_3x3_b, kernel_size=3, + stride=self.key_stride, padding=1, groups=groups) + if b: + self.c1x1_c = CatBnAct(in_chs=num_3x3_b) + self.c1x1_c1 = nn.Conv2d(num_3x3_b, num_1x1_c, kernel_size=1, bias=False) + self.c1x1_c2 = nn.Conv2d(num_3x3_b, inc, kernel_size=1, bias=False) + else: + self.c1x1_c = BnActConv2d(in_chs=num_3x3_b, out_chs=num_1x1_c + inc, kernel_size=1, stride=1) + + def forward(self, x): + x_in = torch.cat(x, dim=1) if isinstance(x, tuple) else x + if self.has_proj: + if self.key_stride == 2: + x_s = self.c1x1_w_s2(x_in) + else: + x_s = self.c1x1_w_s1(x_in) + x_s1 = x_s[:, :self.num_1x1_c, :, :] + x_s2 = x_s[:, self.num_1x1_c:, :, :] + else: + x_s1 = x[0] + x_s2 = x[1] + x_in = self.c1x1_a(x_in) + x_in = self.c3x3_b(x_in) + if self.b: + x_in = self.c1x1_c(x_in) + out1 = self.c1x1_c1(x_in) + out2 = self.c1x1_c2(x_in) + else: + x_in = self.c1x1_c(x_in) + out1 = x_in[:, :self.num_1x1_c, :, :] + out2 = x_in[:, self.num_1x1_c:, :, :] + resid = x_s1 + out1 + dense = torch.cat([x_s2, out2], dim=1) + return resid, dense + + +class DPN(nn.Module): + def __init__(self, small=False, num_init_features=64, k_r=96, groups=32, + b=False, k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128), + num_classes=1000, test_time_pool=0, fc_act=nn.ELU(inplace=True)): + super(DPN, self).__init__() + self.num_classes = num_classes + self.test_time_pool = test_time_pool + self.b = b + bw_factor = 1 if small else 4 + + blocks = OrderedDict() + + # conv1 + if small: + blocks['conv1_1'] = InputBlock(num_init_features, kernel_size=3, padding=1) + else: + blocks['conv1_1'] = InputBlock(num_init_features, kernel_size=7, padding=3) + + # conv2 + bw = 64 * bw_factor + inc = inc_sec[0] + r = (k_r * bw) // (64 * bw_factor) + blocks['conv2_1'] = DualPathBlock(num_init_features, r, r, bw, inc, groups, 'proj', b) + in_chs = bw + 3 * inc + for i in range(2, k_sec[0] + 1): + blocks['conv2_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b) + in_chs += inc + + # conv3 + bw = 128 * bw_factor + inc = inc_sec[1] + r = (k_r * bw) // (64 * bw_factor) + blocks['conv3_1'] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'down', b) + in_chs = bw + 3 * inc + for i in range(2, k_sec[1] + 1): + blocks['conv3_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b) + in_chs += inc + + # conv4 + bw = 256 * bw_factor + inc = inc_sec[2] + r = (k_r * bw) // (64 * bw_factor) + blocks['conv4_1'] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'down', b) + in_chs = bw + 3 * inc + for i in range(2, k_sec[2] + 1): + blocks['conv4_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b) + in_chs += inc + + # conv5 + bw = 512 * bw_factor + inc = inc_sec[3] + r = (k_r * bw) // (64 * bw_factor) + blocks['conv5_1'] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'down', b) + in_chs = bw + 3 * inc + for i in range(2, k_sec[3] + 1): + blocks['conv5_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b) + in_chs += inc + blocks['conv5_bn_ac'] = CatBnAct(in_chs, activation_fn=fc_act) + self.num_features = in_chs + self.features = nn.Sequential(blocks) + + # Using 1x1 conv for the FC layer to allow the extra pooling scheme + self.classifier = nn.Conv2d(in_chs, num_classes, kernel_size=1, bias=True) + + def get_classifier(self): + return self.classifier + + def reset_classifier(self, num_classes): + self.num_classes = num_classes + del self.classifier + if num_classes: + self.classifier = nn.Conv2d(self.num_features, num_classes, kernel_size=1, bias=True) + else: + self.classifier = None + + def forward_features(self, x, pool=True): + x = self.features(x) + if pool: + x = adaptive_avgmax_pool2d(x, pool_type='avg') + x = x.view(x.size(0), -1) + return x + + def forward(self, x): + x = self.features(x) + if not self.training and self.test_time_pool: + x = F.avg_pool2d(x, kernel_size=self.test_time_pool, stride=1) + out = self.classifier(x) + # The extra test time pool should be pooling an img_size//32 - 6 size patch + out = adaptive_avgmax_pool2d(out, pool_type='avgmax') + else: + x = adaptive_avgmax_pool2d(x, pool_type='avg') + out = self.classifier(x) + return out.view(out.size(0), -1) + + diff --git a/models/fbresnet200.py b/models/fbresnet200.py new file mode 100644 index 00000000..00da02b9 --- /dev/null +++ b/models/fbresnet200.py @@ -0,0 +1,1254 @@ +"""Facebook ResNet-200 Torch Model +Model with weights ported from https://github.com/facebook/fb.resnet.torch (BSD-3-Clause) +using https://github.com/clcarwin/convert_torch_to_pytorch (MIT) +""" +import torch +import torch.nn as nn +import torch.nn.init as init +import torch.nn.functional as F +import torch.utils.model_zoo as model_zoo +from torch.autograd import Variable +from functools import reduce +from collections import OrderedDict +from .adaptive_avgmax_pool import * + +model_urls = { + 'fbresnet200': 'https://www.dropbox.com/s/tchq8fbdd4wabjx/fbresnet_200-37304a01b.pth?dl=1', +} + +class LambdaBase(nn.Sequential): + def __init__(self, fn, *args): + super(LambdaBase, self).__init__(*args) + self.lambda_func = fn + + def forward_prepare(self, input): + output = [] + for module in self._modules.values(): + output.append(module(input)) + return output if output else input + + +class Lambda(LambdaBase): + def forward(self, input): + return self.lambda_func(self.forward_prepare(input)) + + +class LambdaMap(LambdaBase): + def forward(self, input): + return list(map(self.lambda_func, self.forward_prepare(input))) + + +class LambdaReduce(LambdaBase): + def forward(self, input): + return reduce(self.lambda_func, self.forward_prepare(input)) + + +def fbresnet200_features(activation_fn=nn.ReLU()): + return nn.Sequential( # Sequential, + nn.Conv2d(3, 64, (7, 7), (2, 2), (3, 3)), + nn.BatchNorm2d(64), + activation_fn, + nn.MaxPool2d((3, 3), (2, 2), (1, 1)), + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(64), + activation_fn, + nn.Conv2d(64, 64, (1, 1)), + nn.BatchNorm2d(64), + activation_fn, + nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(64), + activation_fn, + nn.Conv2d(64, 256, (1, 1)), + ), + nn.Sequential( # Sequential, + nn.Conv2d(64, 256, (1, 1)), + nn.BatchNorm2d(256), + ), + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 64, (1, 1)), + nn.BatchNorm2d(64), + activation_fn, + nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(64), + activation_fn, + nn.Conv2d(64, 256, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 64, (1, 1)), + nn.BatchNorm2d(64), + activation_fn, + nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(64), + activation_fn, + nn.Conv2d(64, 256, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + ), + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 128, (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 128, (3, 3), (2, 2), (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 512, (1, 1)), + ), + nn.Sequential( # Sequential, + nn.Conv2d(256, 512, (1, 1), (2, 2)), + nn.BatchNorm2d(512), + ), + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(512), + activation_fn, + nn.Conv2d(512, 128, (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 512, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(512), + activation_fn, + nn.Conv2d(512, 128, (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 512, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(512), + activation_fn, + nn.Conv2d(512, 128, (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 512, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(512), + activation_fn, + nn.Conv2d(512, 128, (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 512, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(512), + activation_fn, + nn.Conv2d(512, 128, (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 512, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(512), + activation_fn, + nn.Conv2d(512, 128, (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 512, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(512), + activation_fn, + nn.Conv2d(512, 128, (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 512, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(512), + activation_fn, + nn.Conv2d(512, 128, (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 512, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(512), + activation_fn, + nn.Conv2d(512, 128, (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 512, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(512), + activation_fn, + nn.Conv2d(512, 128, (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 512, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(512), + activation_fn, + nn.Conv2d(512, 128, (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 512, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(512), + activation_fn, + nn.Conv2d(512, 128, (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 512, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(512), + activation_fn, + nn.Conv2d(512, 128, (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 512, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(512), + activation_fn, + nn.Conv2d(512, 128, (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 512, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(512), + activation_fn, + nn.Conv2d(512, 128, (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 512, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(512), + activation_fn, + nn.Conv2d(512, 128, (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 512, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(512), + activation_fn, + nn.Conv2d(512, 128, (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 512, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(512), + activation_fn, + nn.Conv2d(512, 128, (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 512, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(512), + activation_fn, + nn.Conv2d(512, 128, (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 512, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(512), + activation_fn, + nn.Conv2d(512, 128, (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 512, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(512), + activation_fn, + nn.Conv2d(512, 128, (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 512, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(512), + activation_fn, + nn.Conv2d(512, 128, (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 512, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(512), + activation_fn, + nn.Conv2d(512, 128, (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 512, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + ), + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(512), + activation_fn, + nn.Conv2d(512, 256, (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 256, (3, 3), (2, 2), (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 1024, (1, 1)), + ), + nn.Sequential( # Sequential, + nn.Conv2d(512, 1024, (1, 1), (2, 2)), + nn.BatchNorm2d(1024), + ), + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(1024), + activation_fn, + nn.Conv2d(1024, 256, (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 1024, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(1024), + activation_fn, + nn.Conv2d(1024, 256, (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 1024, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(1024), + activation_fn, + nn.Conv2d(1024, 256, (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 1024, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(1024), + activation_fn, + nn.Conv2d(1024, 256, (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 1024, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(1024), + activation_fn, + nn.Conv2d(1024, 256, (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 1024, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(1024), + activation_fn, + nn.Conv2d(1024, 256, (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 1024, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(1024), + activation_fn, + nn.Conv2d(1024, 256, (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 1024, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(1024), + activation_fn, + nn.Conv2d(1024, 256, (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 1024, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(1024), + activation_fn, + nn.Conv2d(1024, 256, (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 1024, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(1024), + activation_fn, + nn.Conv2d(1024, 256, (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 1024, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(1024), + activation_fn, + nn.Conv2d(1024, 256, (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 1024, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(1024), + activation_fn, + nn.Conv2d(1024, 256, (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 1024, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(1024), + activation_fn, + nn.Conv2d(1024, 256, (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 1024, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(1024), + activation_fn, + nn.Conv2d(1024, 256, (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 1024, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(1024), + activation_fn, + nn.Conv2d(1024, 256, (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 1024, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(1024), + activation_fn, + nn.Conv2d(1024, 256, (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 1024, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(1024), + activation_fn, + nn.Conv2d(1024, 256, (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 1024, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(1024), + activation_fn, + nn.Conv2d(1024, 256, (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 1024, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(1024), + activation_fn, + nn.Conv2d(1024, 256, (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 1024, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(1024), + activation_fn, + nn.Conv2d(1024, 256, (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 1024, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(1024), + activation_fn, + nn.Conv2d(1024, 256, (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 1024, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(1024), + activation_fn, + nn.Conv2d(1024, 256, (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 1024, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(1024), + activation_fn, + nn.Conv2d(1024, 256, (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 1024, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(1024), + activation_fn, + nn.Conv2d(1024, 256, (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 1024, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(1024), + activation_fn, + nn.Conv2d(1024, 256, (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 1024, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(1024), + activation_fn, + nn.Conv2d(1024, 256, (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 1024, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(1024), + activation_fn, + nn.Conv2d(1024, 256, (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 1024, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(1024), + activation_fn, + nn.Conv2d(1024, 256, (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 1024, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(1024), + activation_fn, + nn.Conv2d(1024, 256, (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 1024, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(1024), + activation_fn, + nn.Conv2d(1024, 256, (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 1024, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(1024), + activation_fn, + nn.Conv2d(1024, 256, (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 1024, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(1024), + activation_fn, + nn.Conv2d(1024, 256, (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 1024, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(1024), + activation_fn, + nn.Conv2d(1024, 256, (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 1024, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(1024), + activation_fn, + nn.Conv2d(1024, 256, (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 1024, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(1024), + activation_fn, + nn.Conv2d(1024, 256, (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 1024, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + ), + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(1024), + activation_fn, + nn.Conv2d(1024, 512, (1, 1)), + nn.BatchNorm2d(512), + activation_fn, + nn.Conv2d(512, 512, (3, 3), (2, 2), (1, 1)), + nn.BatchNorm2d(512), + activation_fn, + nn.Conv2d(512, 2048, (1, 1)), + ), + nn.Sequential( # Sequential, + nn.Conv2d(1024, 2048, (1, 1), (2, 2)), + nn.BatchNorm2d(2048), + ), + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(2048), + activation_fn, + nn.Conv2d(2048, 512, (1, 1)), + nn.BatchNorm2d(512), + activation_fn, + nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(512), + activation_fn, + nn.Conv2d(512, 2048, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.BatchNorm2d(2048), + activation_fn, + nn.Conv2d(2048, 512, (1, 1)), + nn.BatchNorm2d(512), + activation_fn, + nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1)), + nn.BatchNorm2d(512), + activation_fn, + nn.Conv2d(512, 2048, (1, 1)), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + ), + ), + Lambda(lambda x: x), # Copy, + nn.BatchNorm2d(2048), + activation_fn, + ) + + +class ResNet200(nn.Module): + + def __init__(self, num_classes=1000, activation_fn=nn.ReLU(), drop_rate=0., global_pool='avg'): + super(ResNet200, self).__init__() + self.drop_rate = drop_rate + self.global_pool = global_pool + self.num_classes = num_classes + self.num_features = 2048 + self.features = fbresnet200_features(activation_fn=activation_fn) + self.fc = nn.Linear(2048, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal(m.weight) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def get_classifier(self): + return self.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.global_pool = global_pool + self.num_classes = num_classes + del self.fc + self.fc = nn.Linear(2048, num_classes) + + def forward_features(self, x, pool=True): + x = self.features(x) + if pool: + x = adaptive_avgmax_pool2d(x, self.global_pool) + x = x.view(x.size(0), -1) + return x + + def forward(self, x): + x = self.forward_features(x) + if self.drop_rate > 0: + x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.fc(x) + return x + + +def fbresnet200(pretrained=False, num_classes=1000, **kwargs): + model = ResNet200(num_classes=num_classes, **kwargs) + if pretrained: + # Remap pretrained weights to match our class module with features + fc + pretrained_weights = model_zoo.load_url(model_urls['fbresnet200']) + feature_keys = filter(lambda k: '13.1.' not in k, pretrained_weights.keys()) + remapped_weights = OrderedDict() + for k in feature_keys: + remapped_weights['features.' + k] = pretrained_weights[k] + remapped_weights['fc.weight'] = pretrained_weights['13.1.weight'] + remapped_weights['fc.bias'] = pretrained_weights['13.1.bias'] + model.load_state_dict(remapped_weights) + return model diff --git a/models/inception_resnet_v2.py b/models/inception_resnet_v2.py new file mode 100644 index 00000000..d364a3b6 --- /dev/null +++ b/models/inception_resnet_v2.py @@ -0,0 +1,325 @@ +""" Pytorch Inception-Resnet-V2 implementation +Sourced from https://github.com/Cadene/tensorflow-model-zoo.torch (MIT License) which is +based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License) +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.model_zoo as model_zoo +import numpy as np +from .adaptive_avgmax_pool import * + +model_urls = { + 'imagenet': 'http://webia.lip6.fr/~cadene/Downloads/inceptionresnetv2-d579a627.pth' +} + + +class BasicConv2d(nn.Module): + def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0): + super(BasicConv2d, self).__init__() + self.conv = nn.Conv2d( + in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) + self.bn = nn.BatchNorm2d(out_planes, eps=.001) + self.relu = nn.ReLU(inplace=False) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class Mixed_5b(nn.Module): + def __init__(self): + super(Mixed_5b, self).__init__() + + self.branch0 = BasicConv2d(192, 96, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(192, 48, kernel_size=1, stride=1), + BasicConv2d(48, 64, kernel_size=5, stride=1, padding=2) + ) + + self.branch2 = nn.Sequential( + BasicConv2d(192, 64, kernel_size=1, stride=1), + BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1), + BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1) + ) + + self.branch3 = nn.Sequential( + nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), + BasicConv2d(192, 64, kernel_size=1, stride=1) + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class Block35(nn.Module): + def __init__(self, scale=1.0): + super(Block35, self).__init__() + + self.scale = scale + + self.branch0 = BasicConv2d(320, 32, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(320, 32, kernel_size=1, stride=1), + BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1) + ) + + self.branch2 = nn.Sequential( + BasicConv2d(320, 32, kernel_size=1, stride=1), + BasicConv2d(32, 48, kernel_size=3, stride=1, padding=1), + BasicConv2d(48, 64, kernel_size=3, stride=1, padding=1) + ) + + self.conv2d = nn.Conv2d(128, 320, kernel_size=1, stride=1) + self.relu = nn.ReLU(inplace=False) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + out = torch.cat((x0, x1, x2), 1) + out = self.conv2d(out) + out = out * self.scale + x + out = self.relu(out) + return out + + +class Mixed_6a(nn.Module): + def __init__(self): + super(Mixed_6a, self).__init__() + + self.branch0 = BasicConv2d(320, 384, kernel_size=3, stride=2) + + self.branch1 = nn.Sequential( + BasicConv2d(320, 256, kernel_size=1, stride=1), + BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1), + BasicConv2d(256, 384, kernel_size=3, stride=2) + ) + + self.branch2 = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + out = torch.cat((x0, x1, x2), 1) + return out + + +class Block17(nn.Module): + def __init__(self, scale=1.0): + super(Block17, self).__init__() + + self.scale = scale + + self.branch0 = BasicConv2d(1088, 192, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(1088, 128, kernel_size=1, stride=1), + BasicConv2d(128, 160, kernel_size=(1, 7), stride=1, padding=(0, 3)), + BasicConv2d(160, 192, kernel_size=(7, 1), stride=1, padding=(3, 0)) + ) + + self.conv2d = nn.Conv2d(384, 1088, kernel_size=1, stride=1) + self.relu = nn.ReLU(inplace=False) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + out = torch.cat((x0, x1), 1) + out = self.conv2d(out) + out = out * self.scale + x + out = self.relu(out) + return out + + +class Mixed_7a(nn.Module): + def __init__(self): + super(Mixed_7a, self).__init__() + + self.branch0 = nn.Sequential( + BasicConv2d(1088, 256, kernel_size=1, stride=1), + BasicConv2d(256, 384, kernel_size=3, stride=2) + ) + + self.branch1 = nn.Sequential( + BasicConv2d(1088, 256, kernel_size=1, stride=1), + BasicConv2d(256, 288, kernel_size=3, stride=2) + ) + + self.branch2 = nn.Sequential( + BasicConv2d(1088, 256, kernel_size=1, stride=1), + BasicConv2d(256, 288, kernel_size=3, stride=1, padding=1), + BasicConv2d(288, 320, kernel_size=3, stride=2) + ) + + self.branch3 = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class Block8(nn.Module): + def __init__(self, scale=1.0, noReLU=False): + super(Block8, self).__init__() + + self.scale = scale + self.noReLU = noReLU + + self.branch0 = BasicConv2d(2080, 192, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(2080, 192, kernel_size=1, stride=1), + BasicConv2d(192, 224, kernel_size=(1, 3), stride=1, padding=(0, 1)), + BasicConv2d(224, 256, kernel_size=(3, 1), stride=1, padding=(1, 0)) + ) + + self.conv2d = nn.Conv2d(448, 2080, kernel_size=1, stride=1) + if not self.noReLU: + self.relu = nn.ReLU(inplace=False) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + out = torch.cat((x0, x1), 1) + out = self.conv2d(out) + out = out * self.scale + x + if not self.noReLU: + out = self.relu(out) + return out + + +class InceptionResnetV2(nn.Module): + def __init__(self, num_classes=1001, drop_rate=0., global_pool='avg'): + super(InceptionResnetV2, self).__init__() + self.drop_rate = drop_rate + self.global_pool = global_pool + self.num_classes = num_classes + self.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2) + self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1) + self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1) + self.maxpool_3a = nn.MaxPool2d(3, stride=2) + self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1) + self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1) + self.maxpool_5a = nn.MaxPool2d(3, stride=2) + self.mixed_5b = Mixed_5b() + self.repeat = nn.Sequential( + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17), + Block35(scale=0.17) + ) + self.mixed_6a = Mixed_6a() + self.repeat_1 = nn.Sequential( + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10), + Block17(scale=0.10) + ) + self.mixed_7a = Mixed_7a() + self.repeat_2 = nn.Sequential( + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20), + Block8(scale=0.20) + ) + self.block8 = Block8(noReLU=True) + self.conv2d_7b = BasicConv2d(2080, 1536, kernel_size=1, stride=1) + self.num_features = 1536 + self.classif = nn.Linear(1536, num_classes) + + def get_classifier(self): + return self.classif + + def reset_classifier(self, num_classes, global_pool='avg'): + self.global_pool = global_pool + self.num_classes = num_classes + del self.classif + if num_classes: + self.classif = torch.nn.Linear(1536, num_classes) + else: + self.classif = None + + def forward_features(self, x, pool=True): + x = self.conv2d_1a(x) + x = self.conv2d_2a(x) + x = self.conv2d_2b(x) + x = self.maxpool_3a(x) + x = self.conv2d_3b(x) + x = self.conv2d_4a(x) + x = self.maxpool_5a(x) + x = self.mixed_5b(x) + x = self.repeat(x) + x = self.mixed_6a(x) + x = self.repeat_1(x) + x = self.mixed_7a(x) + x = self.repeat_2(x) + x = self.block8(x) + x = self.conv2d_7b(x) + if pool: + x = adaptive_avgmax_pool2d(x, self.global_pool, count_include_pad=False) + x = x.view(x.size(0), -1) + return x + + def forward(self, x): + x = self.forward_features(x, pool=True) + if self.drop_rate > 0: + x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.classif(x) + return x + + +def inception_resnet_v2(pretrained=False, num_classes=1001, **kwargs): + r"""InceptionResnetV2 model architecture from the + `"InceptionV4, Inception-ResNet..." `_ paper. + + Args: + pretrained ('string'): If True, returns a model pre-trained on ImageNet + """ + model = InceptionResnetV2(num_classes=num_classes, **kwargs) + if pretrained: + print('Loading pretrained from %s' % model_urls['imagenet']) + model.load_state_dict(model_zoo.load_url(model_urls['imagenet'])) + return model + diff --git a/models/inception_v4.py b/models/inception_v4.py new file mode 100644 index 00000000..bcb84661 --- /dev/null +++ b/models/inception_v4.py @@ -0,0 +1,294 @@ +""" Pytorch Inception-V4 implementation +Sourced from https://github.com/Cadene/tensorflow-model-zoo.torch (MIT License) which is +based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License) +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.model_zoo as model_zoo +from .adaptive_avgmax_pool import * + +model_urls = { + 'imagenet': 'http://webia.lip6.fr/~cadene/Downloads/inceptionv4-97ef9c30.pth' +} + + +class BasicConv2d(nn.Module): + def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0): + super(BasicConv2d, self).__init__() + self.conv = nn.Conv2d( + in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) + self.bn = nn.BatchNorm2d(out_planes, eps=0.001) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + x = self.relu(x) + return x + + +class Mixed_3a(nn.Module): + def __init__(self): + super(Mixed_3a, self).__init__() + self.maxpool = nn.MaxPool2d(3, stride=2) + self.conv = BasicConv2d(64, 96, kernel_size=3, stride=2) + + def forward(self, x): + x0 = self.maxpool(x) + x1 = self.conv(x) + out = torch.cat((x0, x1), 1) + return out + + +class Mixed_4a(nn.Module): + def __init__(self): + super(Mixed_4a, self).__init__() + + self.branch0 = nn.Sequential( + BasicConv2d(160, 64, kernel_size=1, stride=1), + BasicConv2d(64, 96, kernel_size=3, stride=1) + ) + + self.branch1 = nn.Sequential( + BasicConv2d(160, 64, kernel_size=1, stride=1), + BasicConv2d(64, 64, kernel_size=(1, 7), stride=1, padding=(0, 3)), + BasicConv2d(64, 64, kernel_size=(7, 1), stride=1, padding=(3, 0)), + BasicConv2d(64, 96, kernel_size=(3, 3), stride=1) + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + out = torch.cat((x0, x1), 1) + return out + + +class Mixed_5a(nn.Module): + def __init__(self): + super(Mixed_5a, self).__init__() + self.conv = BasicConv2d(192, 192, kernel_size=3, stride=2) + self.maxpool = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x0 = self.conv(x) + x1 = self.maxpool(x) + out = torch.cat((x0, x1), 1) + return out + + +class Inception_A(nn.Module): + def __init__(self): + super(Inception_A, self).__init__() + self.branch0 = BasicConv2d(384, 96, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(384, 64, kernel_size=1, stride=1), + BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1) + ) + + self.branch2 = nn.Sequential( + BasicConv2d(384, 64, kernel_size=1, stride=1), + BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1), + BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1) + ) + + self.branch3 = nn.Sequential( + nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), + BasicConv2d(384, 96, kernel_size=1, stride=1) + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class Reduction_A(nn.Module): + def __init__(self): + super(Reduction_A, self).__init__() + self.branch0 = BasicConv2d(384, 384, kernel_size=3, stride=2) + + self.branch1 = nn.Sequential( + BasicConv2d(384, 192, kernel_size=1, stride=1), + BasicConv2d(192, 224, kernel_size=3, stride=1, padding=1), + BasicConv2d(224, 256, kernel_size=3, stride=2) + ) + + self.branch2 = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + out = torch.cat((x0, x1, x2), 1) + return out + + +class Inception_B(nn.Module): + def __init__(self): + super(Inception_B, self).__init__() + self.branch0 = BasicConv2d(1024, 384, kernel_size=1, stride=1) + + self.branch1 = nn.Sequential( + BasicConv2d(1024, 192, kernel_size=1, stride=1), + BasicConv2d(192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)), + BasicConv2d(224, 256, kernel_size=(7, 1), stride=1, padding=(3, 0)) + ) + + self.branch2 = nn.Sequential( + BasicConv2d(1024, 192, kernel_size=1, stride=1), + BasicConv2d(192, 192, kernel_size=(7, 1), stride=1, padding=(3, 0)), + BasicConv2d(192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)), + BasicConv2d(224, 224, kernel_size=(7, 1), stride=1, padding=(3, 0)), + BasicConv2d(224, 256, kernel_size=(1, 7), stride=1, padding=(0, 3)) + ) + + self.branch3 = nn.Sequential( + nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), + BasicConv2d(1024, 128, kernel_size=1, stride=1) + ) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + x3 = self.branch3(x) + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class Reduction_B(nn.Module): + def __init__(self): + super(Reduction_B, self).__init__() + + self.branch0 = nn.Sequential( + BasicConv2d(1024, 192, kernel_size=1, stride=1), + BasicConv2d(192, 192, kernel_size=3, stride=2) + ) + + self.branch1 = nn.Sequential( + BasicConv2d(1024, 256, kernel_size=1, stride=1), + BasicConv2d(256, 256, kernel_size=(1, 7), stride=1, padding=(0, 3)), + BasicConv2d(256, 320, kernel_size=(7, 1), stride=1, padding=(3, 0)), + BasicConv2d(320, 320, kernel_size=3, stride=2) + ) + + self.branch2 = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x0 = self.branch0(x) + x1 = self.branch1(x) + x2 = self.branch2(x) + out = torch.cat((x0, x1, x2), 1) + return out + + +class Inception_C(nn.Module): + def __init__(self): + super(Inception_C, self).__init__() + + self.branch0 = BasicConv2d(1536, 256, kernel_size=1, stride=1) + + self.branch1_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1) + self.branch1_1a = BasicConv2d(384, 256, kernel_size=(1, 3), stride=1, padding=(0, 1)) + self.branch1_1b = BasicConv2d(384, 256, kernel_size=(3, 1), stride=1, padding=(1, 0)) + + self.branch2_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1) + self.branch2_1 = BasicConv2d(384, 448, kernel_size=(3, 1), stride=1, padding=(1, 0)) + self.branch2_2 = BasicConv2d(448, 512, kernel_size=(1, 3), stride=1, padding=(0, 1)) + self.branch2_3a = BasicConv2d(512, 256, kernel_size=(1, 3), stride=1, padding=(0, 1)) + self.branch2_3b = BasicConv2d(512, 256, kernel_size=(3, 1), stride=1, padding=(1, 0)) + + self.branch3 = nn.Sequential( + nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), + BasicConv2d(1536, 256, kernel_size=1, stride=1) + ) + + def forward(self, x): + x0 = self.branch0(x) + + x1_0 = self.branch1_0(x) + x1_1a = self.branch1_1a(x1_0) + x1_1b = self.branch1_1b(x1_0) + x1 = torch.cat((x1_1a, x1_1b), 1) + + x2_0 = self.branch2_0(x) + x2_1 = self.branch2_1(x2_0) + x2_2 = self.branch2_2(x2_1) + x2_3a = self.branch2_3a(x2_2) + x2_3b = self.branch2_3b(x2_2) + x2 = torch.cat((x2_3a, x2_3b), 1) + + x3 = self.branch3(x) + + out = torch.cat((x0, x1, x2, x3), 1) + return out + + +class InceptionV4(nn.Module): + def __init__(self, num_classes=1001, drop_rate=0., global_pool='avg'): + super(InceptionV4, self).__init__() + self.drop_rate = drop_rate + self.global_pool = global_pool + self.num_classes = num_classes + self.features = nn.Sequential( + BasicConv2d(3, 32, kernel_size=3, stride=2), + BasicConv2d(32, 32, kernel_size=3, stride=1), + BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1), + Mixed_3a(), + Mixed_4a(), + Mixed_5a(), + Inception_A(), + Inception_A(), + Inception_A(), + Inception_A(), + Reduction_A(), # Mixed_6a + Inception_B(), + Inception_B(), + Inception_B(), + Inception_B(), + Inception_B(), + Inception_B(), + Inception_B(), + Reduction_B(), # Mixed_7a + Inception_C(), + Inception_C(), + Inception_C(), + ) + self.classif = nn.Linear(1536, num_classes) + + def get_classifier(self): + return self.classif + + def reset_classifier(self, num_classes, global_pool='avg'): + self.global_pool = global_pool + self.num_classes = num_classes + self.classif = nn.Linear(1536, num_classes) + + def forward_features(self, x, pool=True): + x = self.features(x) + if pool: + x = adaptive_avgmax_pool2d(x, self.global_pool, count_include_pad=False) + x = x.view(x.size(0), -1) + return x + + def forward(self, x): + x = self.forward_features(x) + if self.drop_rate > 0: + x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.classif(x) + return x + + +def inception_v4(pretrained=False, num_classes=1001, **kwargs): + model = InceptionV4(num_classes=num_classes, **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['imagenet'])) + return model + + + diff --git a/models/median_pool.py b/models/median_pool.py new file mode 100644 index 00000000..a902fa27 --- /dev/null +++ b/models/median_pool.py @@ -0,0 +1,48 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.modules.utils import _pair, _quadruple + + +class MedianPool2d(nn.Module): + """ Median pool (usable as median filter when stride=1) module. + + Args: + kernel_size: size of pooling kernel, int or 2-tuple + stride: pool stride, int or 2-tuple + padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad + same: override padding and enforce same padding, boolean + """ + def __init__(self, kernel_size=3, stride=1, padding=0, same=False): + super(MedianPool2d, self).__init__() + self.k = _pair(kernel_size) + self.stride = _pair(stride) + self.padding = _quadruple(padding) # convert to l, r, t, b + self.same = same + + def _padding(self, x): + if self.same: + ih, iw = x.size()[2:] + if ih % self.stride[0] == 0: + ph = max(self.k[0] - self.stride[0], 0) + else: + ph = max(self.k[0] - (ih % self.stride[0]), 0) + if iw % self.stride[1] == 0: + pw = max(self.k[1] - self.stride[1], 0) + else: + pw = max(self.k[1] - (iw % self.stride[1]), 0) + pl = pw // 2 + pr = pw - pl + pt = ph // 2 + pb = ph - pt + padding = (pl, pr, pt, pb) + else: + padding = self.padding + return padding + + def forward(self, x): + x = F.pad(x, self._padding(x), mode='reflect') + x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1]) + x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0] + return x diff --git a/models/model_factory.py b/models/model_factory.py new file mode 100644 index 00000000..47a5217d --- /dev/null +++ b/models/model_factory.py @@ -0,0 +1,194 @@ +import torch +from torchvision import transforms +from PIL import Image +import math +import os + +from .inception_v4 import inception_v4 +from .inception_resnet_v2 import inception_resnet_v2 +from .wrn50_2 import wrn50_2 +from .my_densenet import densenet161, densenet121, densenet169, densenet201 +from .my_resnet import resnet18, resnet34, resnet50, resnet101, resnet152 +from .fbresnet200 import fbresnet200 +from .dpn import dpn68, dpn68b, dpn92, dpn98, dpn131, dpn107 +from .senet import se_resnet18, se_resnet34, se_resnet50, se_resnet101, se_resnet152,\ + se_resnext50_32x4d, se_resnext101_32x4d + + +model_config_dict = { + 'resnet18': { + 'model_name': 'resnet18', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'}, + 'resnet34': { + 'model_name': 'resnet34', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'}, + 'resnet50': { + 'model_name': 'resnet50', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'}, + 'resnet101': { + 'model_name': 'resnet101', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'}, + 'resnet152': { + 'model_name': 'resnet152', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'}, + 'densenet121': { + 'model_name': 'densenet121', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'}, + 'densenet169': { + 'model_name': 'densenet169', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'}, + 'densenet201': { + 'model_name': 'densenet201', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'}, + 'densenet161': { + 'model_name': 'densenet161', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'}, + 'dpn107': { + 'model_name': 'dpn107', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'dpn'}, + 'dpn92_extra': { + 'model_name': 'dpn92', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'dpn'}, + 'dpn92': { + 'model_name': 'dpn92', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'dpn'}, + 'dpn68': { + 'model_name': 'dpn68', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'dpn'}, + 'dpn68b': { + 'model_name': 'dpn68b', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'dpn'}, + 'dpn68b_extra': { + 'model_name': 'dpn68b', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'dpn'}, + 'inception_resnet_v2': { + 'model_name': 'inception_resnet_v2', 'num_classes': 1001, 'input_size': 299, 'normalizer': 'le'}, +} + + +def create_model( + model_name='resnet50', + pretrained=None, + num_classes=1000, + checkpoint_path='', + **kwargs): + + test_time_pool = kwargs.pop('test_time_pool') if 'test_time_pool' in kwargs else 0 + + if model_name == 'dpn68': + model = dpn68( + num_classes=num_classes, pretrained=pretrained, test_time_pool=test_time_pool) + elif model_name == 'dpn68b': + model = dpn68b( + num_classes=num_classes, pretrained=pretrained, test_time_pool=test_time_pool) + elif model_name == 'dpn92': + model = dpn92( + num_classes=num_classes, pretrained=pretrained, test_time_pool=test_time_pool) + elif model_name == 'dpn98': + model = dpn98( + num_classes=num_classes, pretrained=pretrained, test_time_pool=test_time_pool) + elif model_name == 'dpn131': + model = dpn131( + num_classes=num_classes, pretrained=pretrained, test_time_pool=test_time_pool) + elif model_name == 'dpn107': + model = dpn107( + num_classes=num_classes, pretrained=pretrained, test_time_pool=test_time_pool) + elif model_name == 'resnet18': + model = resnet18(num_classes=num_classes, pretrained=pretrained, **kwargs) + elif model_name == 'resnet34': + model = resnet34(num_classes=num_classes, pretrained=pretrained, **kwargs) + elif model_name == 'resnet50': + model = resnet50(num_classes=num_classes, pretrained=pretrained, **kwargs) + elif model_name == 'resnet101': + model = resnet101(num_classes=num_classes, pretrained=pretrained, **kwargs) + elif model_name == 'resnet152': + model = resnet152(num_classes=num_classes, pretrained=pretrained, **kwargs) + elif model_name == 'densenet121': + model = densenet121(num_classes=num_classes, pretrained=pretrained, **kwargs) + elif model_name == 'densenet161': + model = densenet161(num_classes=num_classes, pretrained=pretrained, **kwargs) + elif model_name == 'densenet169': + model = densenet169(num_classes=num_classes, pretrained=pretrained, **kwargs) + elif model_name == 'densenet201': + model = densenet201(num_classes=num_classes, pretrained=pretrained, **kwargs) + elif model_name == 'inception_resnet_v2': + model = inception_resnet_v2(num_classes=num_classes, pretrained=pretrained, **kwargs) + elif model_name == 'inception_v4': + model = inception_v4(num_classes=num_classes, pretrained=pretrained, **kwargs) + elif model_name == 'wrn50': + model = wrn50_2(num_classes=num_classes, pretrained=pretrained, **kwargs) + elif model_name == 'fbresnet200': + model = fbresnet200(num_classes=num_classes, pretrained=pretrained, **kwargs) + elif model_name == 'seresnet18': + model = se_resnet18(num_classes=num_classes, pretrained=pretrained) + elif model_name == 'seresnet34': + model = se_resnet34(num_classes=num_classes, pretrained=pretrained) + else: + assert False and "Invalid model" + + if checkpoint_path and not pretrained: + print(checkpoint_path) + load_checkpoint(model, checkpoint_path) + + return model + + +def load_checkpoint(model, checkpoint_path): + if checkpoint_path is not None and os.path.isfile(checkpoint_path): + print('Loading checkpoint', checkpoint_path) + checkpoint = torch.load(checkpoint_path) + if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: + model.load_state_dict(checkpoint['state_dict']) + else: + model.load_state_dict(checkpoint) + else: + print("Error: No checkpoint found at %s." % checkpoint_path) + + +class LeNormalize(object): + """Normalize to -1..1 in Google Inception style + """ + def __call__(self, tensor): + for t in tensor: + t.sub_(0.5).mul_(2.0) + return tensor + + +DEFAULT_CROP_PCT = 0.875 + + +def get_transforms_train(model_name, img_size=224): + if 'dpn' in model_name: + normalize = transforms.Normalize( + mean=[124 / 255, 117 / 255, 104 / 255], + std=[1 / (.0167 * 255)] * 3) + elif 'inception' in model_name: + normalize = LeNormalize() + else: + normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + + return transforms.Compose([ + transforms.RandomResizedCrop(img_size, scale=(0.3, 1.0)), + transforms.RandomHorizontalFlip(), + transforms.ColorJitter(0.3, 0.3, 0.3), + transforms.ToTensor(), + normalize]) + + +def get_transforms_eval(model_name, img_size=224, crop_pct=None): + crop_pct = crop_pct or DEFAULT_CROP_PCT + if 'dpn' in model_name: + if crop_pct is None: + # Use default 87.5% crop for model's native img_size + # but use 100% crop for larger than native as it + # improves test time results across all models. + if img_size == 224: + scale_size = int(math.floor(img_size / DEFAULT_CROP_PCT)) + else: + scale_size = img_size + else: + scale_size = int(math.floor(img_size / crop_pct)) + normalize = transforms.Normalize( + mean=[124 / 255, 117 / 255, 104 / 255], + std=[1 / (.0167 * 255)] * 3) + elif 'inception' in model_name: + scale_size = int(math.floor(img_size / crop_pct)) + normalize = LeNormalize() + else: + scale_size = int(math.floor(img_size / crop_pct)) + normalize = transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + + return transforms.Compose([ + transforms.Resize(scale_size, Image.BICUBIC), + transforms.CenterCrop(img_size), + transforms.ToTensor(), + normalize]) diff --git a/models/my_densenet.py b/models/my_densenet.py new file mode 100644 index 00000000..1d29f574 --- /dev/null +++ b/models/my_densenet.py @@ -0,0 +1,184 @@ +"""Pytorch Densenet implementation tweaks +This file is a copy of https://github.com/pytorch/vision 'densenet.py' (BSD-3-Clause) with +fixed kwargs passthrough and addition of dynamic global avg/max pool. +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.model_zoo as model_zoo +from collections import OrderedDict +from .adaptive_avgmax_pool import * + +__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] + + +model_urls = { + 'densenet121': 'https://download.pytorch.org/models/densenet121-241335ed.pth', + 'densenet169': 'https://download.pytorch.org/models/densenet169-6f0f7f60.pth', + 'densenet201': 'https://download.pytorch.org/models/densenet201-4c113574.pth', + 'densenet161': 'https://download.pytorch.org/models/densenet161-17b70270.pth', +} + + +def densenet121(pretrained=False, **kwargs): + r"""Densenet-121 model from + `"Densely Connected Convolutional Networks" ` + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['densenet121'])) + return model + + +def densenet169(pretrained=False, **kwargs): + r"""Densenet-169 model from + `"Densely Connected Convolutional Networks" ` + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['densenet169'])) + return model + + +def densenet201(pretrained=False, **kwargs): + r"""Densenet-201 model from + `"Densely Connected Convolutional Networks" ` + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['densenet201'])) + return model + + +def densenet161(pretrained=False, **kwargs): + r"""Densenet-201 model from + `"Densely Connected Convolutional Networks" ` + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + print(kwargs) + model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['densenet161'])) + return model + + +class _DenseLayer(nn.Sequential): + def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): + super(_DenseLayer, self).__init__() + self.add_module('norm.1', nn.BatchNorm2d(num_input_features)), + self.add_module('relu.1', nn.ReLU(inplace=True)), + self.add_module('conv.1', nn.Conv2d(num_input_features, bn_size * + growth_rate, kernel_size=1, stride=1, bias=False)), + self.add_module('norm.2', nn.BatchNorm2d(bn_size * growth_rate)), + self.add_module('relu.2', nn.ReLU(inplace=True)), + self.add_module('conv.2', nn.Conv2d(bn_size * growth_rate, growth_rate, + kernel_size=3, stride=1, padding=1, bias=False)), + self.drop_rate = drop_rate + + def forward(self, x): + new_features = super(_DenseLayer, self).forward(x) + if self.drop_rate > 0: + new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) + return torch.cat([x, new_features], 1) + + +class _DenseBlock(nn.Sequential): + def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate): + super(_DenseBlock, self).__init__() + for i in range(num_layers): + layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate) + self.add_module('denselayer%d' % (i + 1), layer) + + +class _Transition(nn.Sequential): + def __init__(self, num_input_features, num_output_features): + super(_Transition, self).__init__() + self.add_module('norm', nn.BatchNorm2d(num_input_features)) + self.add_module('relu', nn.ReLU(inplace=True)) + self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, + kernel_size=1, stride=1, bias=False)) + self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) + + +class DenseNet(nn.Module): + r"""Densenet-BC model class, based on + `"Densely Connected Convolutional Networks" ` + + Args: + growth_rate (int) - how many filters to add each layer (`k` in paper) + block_config (list of 4 ints) - how many layers in each pooling block + num_init_features (int) - the number of filters to learn in the first convolution layer + bn_size (int) - multiplicative factor for number of bottle neck layers + (i.e. bn_size * k features in the bottleneck layer) + drop_rate (float) - dropout rate after each dense layer + num_classes (int) - number of classification classes + """ + def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), + num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000, global_pool='avg'): + self.global_pool = global_pool + self.num_classes = num_classes + super(DenseNet, self).__init__() + + # First convolution + self.features = nn.Sequential(OrderedDict([ + ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), + ('norm0', nn.BatchNorm2d(num_init_features)), + ('relu0', nn.ReLU(inplace=True)), + ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), + ])) + + # Each denseblock + num_features = num_init_features + for i, num_layers in enumerate(block_config): + block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, + bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate) + self.features.add_module('denseblock%d' % (i + 1), block) + num_features = num_features + num_layers * growth_rate + if i != len(block_config) - 1: + trans = _Transition( + num_input_features=num_features, num_output_features=num_features // 2) + self.features.add_module('transition%d' % (i + 1), trans) + num_features = num_features // 2 + + # Final batch norm + self.features.add_module('norm5', nn.BatchNorm2d(num_features)) + + # Linear layer + self.classifier = torch.nn.Linear(num_features, num_classes) + + self.num_features = num_features + + def get_classifier(self): + return self.classifier + + def reset_classifier(self, num_classes, global_pool='avg'): + self.global_pool = global_pool + self.num_classes = num_classes + del self.classifier + if num_classes: + self.classifier = torch.nn.Linear(self.num_features, num_classes) + else: + self.classifier = None + + def forward_features(self, x, pool=True): + features = self.features(x) + out = F.relu(features, inplace=True) + if pool: + out = adaptive_avgmax_pool2d(out, self.global_pool) + out = x.view(out.size(0), -1) + return out + + def forward(self, x): + return self.classifier(self.forward_features(x, pool=True)) + diff --git a/models/my_resnet.py b/models/my_resnet.py new file mode 100644 index 00000000..743a5c80 --- /dev/null +++ b/models/my_resnet.py @@ -0,0 +1,247 @@ +"""Pytorch ResNet implementation tweaks +This file is a copy of https://github.com/pytorch/vision 'resnet.py' (BSD-3-Clause) with +additional dropout and dynamic global avg/max pool. +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +import torch.utils.model_zoo as model_zoo +from .adaptive_avgmax_pool import AdaptiveAvgMaxPool2d + +__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'] + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d( + in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, drop_rate=0.0): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + self.drop_rate = drop_rate + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + if self.drop_rate > 0.: + out = F.dropout(out, p=self.drop_rate, training=self.training) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, drop_rate=0.0): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + self.drop_rate = drop_rate + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + if self.drop_rate > 0.: + out = F.dropout(out, p=self.drop_rate, training=self.training) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000, + drop_rate=0.0, block_drop_rate=0.0, + global_pool='avg'): + self.num_classes = num_classes + self.inplanes = 64 + self.drop_rate = drop_rate + self.expansion = block.expansion + super(ResNet, self).__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0], drop_rate=block_drop_rate) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, drop_rate=block_drop_rate) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, drop_rate=block_drop_rate) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, drop_rate=block_drop_rate) + self.global_pool = AdaptiveAvgMaxPool2d(pool_type=global_pool) + self.num_features = 512 * block.expansion + self.fc = nn.Linear(self.num_features, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1, drop_rate=0.): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [block(self.inplanes, planes, stride, downsample, drop_rate)] + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def get_classifier(self): + return self.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.global_pool = AdaptiveAvgMaxPool2d(pool_type=global_pool) + self.num_classes = num_classes + del self.fc + if num_classes: + self.fc = nn.Linear(512 * self.expansion, num_classes) + else: + self.fc = None + + def forward_features(self, x, pool=True): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + if pool: + x = self.global_pool(x) + x = x.view(x.size(0), -1) + return x + + def forward(self, x): + x = self.forward_features(x) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.fc(x) + return x + + +def resnet18(pretrained=False, **kwargs): + """Constructs a ResNet-18 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) + return model + + +def resnet34(pretrained=False, **kwargs): + """Constructs a ResNet-34 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) + return model + + +def resnet50(pretrained=False, **kwargs): + """Constructs a ResNet-50 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) + return model + + +def resnet101(pretrained=False, **kwargs): + """Constructs a ResNet-101 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) + return model + + +def resnet152(pretrained=False, **kwargs): + """Constructs a ResNet-152 model. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) + return model \ No newline at end of file diff --git a/models/pnasnet.py b/models/pnasnet.py new file mode 100644 index 00000000..c169c695 --- /dev/null +++ b/models/pnasnet.py @@ -0,0 +1,401 @@ +from __future__ import print_function, division, absolute_import +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo + + +pretrained_settings = { + 'pnasnet5large': { + 'imagenet': { + 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/pnasnet5large-bf079911.pth', + 'input_space': 'RGB', + 'input_size': [3, 331, 331], + 'input_range': [0, 1], + 'mean': [0.5, 0.5, 0.5], + 'std': [0.5, 0.5, 0.5], + 'num_classes': 1000 + }, + 'imagenet+background': { + 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/pnasnet5large-bf079911.pth', + 'input_space': 'RGB', + 'input_size': [3, 331, 331], + 'input_range': [0, 1], + 'mean': [0.5, 0.5, 0.5], + 'std': [0.5, 0.5, 0.5], + 'num_classes': 1001 + } + } +} + + +class MaxPool(nn.Module): + + def __init__(self, kernel_size, stride=1, padding=1, zero_pad=False): + super(MaxPool, self).__init__() + self.zero_pad = nn.ZeroPad2d((1, 0, 1, 0)) if zero_pad else None + self.pool = nn.MaxPool2d(kernel_size, stride=stride, padding=padding) + + def forward(self, x): + if self.zero_pad: + x = self.zero_pad(x) + x = self.pool(x) + if self.zero_pad: + x = x[:, :, 1:, 1:] + return x + + +class SeparableConv2d(nn.Module): + + def __init__(self, in_channels, out_channels, dw_kernel_size, dw_stride, + dw_padding): + super(SeparableConv2d, self).__init__() + self.depthwise_conv2d = nn.Conv2d(in_channels, in_channels, + kernel_size=dw_kernel_size, + stride=dw_stride, padding=dw_padding, + groups=in_channels, bias=False) + self.pointwise_conv2d = nn.Conv2d(in_channels, out_channels, + kernel_size=1, bias=False) + + def forward(self, x): + x = self.depthwise_conv2d(x) + x = self.pointwise_conv2d(x) + return x + + +class BranchSeparables(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + stem_cell=False, zero_pad=False): + super(BranchSeparables, self).__init__() + padding = kernel_size // 2 + middle_channels = out_channels if stem_cell else in_channels + self.zero_pad = nn.ZeroPad2d((1, 0, 1, 0)) if zero_pad else None + self.relu_1 = nn.ReLU() + self.separable_1 = SeparableConv2d(in_channels, middle_channels, + kernel_size, dw_stride=stride, + dw_padding=padding) + self.bn_sep_1 = nn.BatchNorm2d(middle_channels, eps=0.001) + self.relu_2 = nn.ReLU() + self.separable_2 = SeparableConv2d(middle_channels, out_channels, + kernel_size, dw_stride=1, + dw_padding=padding) + self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001) + + def forward(self, x): + x = self.relu_1(x) + if self.zero_pad: + x = self.zero_pad(x) + x = self.separable_1(x) + if self.zero_pad: + x = x[:, :, 1:, 1:].contiguous() + x = self.bn_sep_1(x) + x = self.relu_2(x) + x = self.separable_2(x) + x = self.bn_sep_2(x) + return x + + +class ReluConvBn(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, stride=1): + super(ReluConvBn, self).__init__() + self.relu = nn.ReLU() + self.conv = nn.Conv2d(in_channels, out_channels, + kernel_size=kernel_size, stride=stride, + bias=False) + self.bn = nn.BatchNorm2d(out_channels, eps=0.001) + + def forward(self, x): + x = self.relu(x) + x = self.conv(x) + x = self.bn(x) + return x + + +class FactorizedReduction(nn.Module): + + def __init__(self, in_channels, out_channels): + super(FactorizedReduction, self).__init__() + self.relu = nn.ReLU() + self.path_1 = nn.Sequential(OrderedDict([ + ('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)), + ('conv', nn.Conv2d(in_channels, out_channels // 2, + kernel_size=1, bias=False)), + ])) + self.path_2 = nn.Sequential(OrderedDict([ + ('pad', nn.ZeroPad2d((0, 1, 0, 1))), + ('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)), + ('conv', nn.Conv2d(in_channels, out_channels // 2, + kernel_size=1, bias=False)), + ])) + self.final_path_bn = nn.BatchNorm2d(out_channels, eps=0.001) + + def forward(self, x): + x = self.relu(x) + + x_path1 = self.path_1(x) + + x_path2 = self.path_2.pad(x) + x_path2 = x_path2[:, :, 1:, 1:] + x_path2 = self.path_2.avgpool(x_path2) + x_path2 = self.path_2.conv(x_path2) + + out = self.final_path_bn(torch.cat([x_path1, x_path2], 1)) + return out + + +class CellBase(nn.Module): + + def cell_forward(self, x_left, x_right): + x_comb_iter_0_left = self.comb_iter_0_left(x_left) + x_comb_iter_0_right = self.comb_iter_0_right(x_left) + x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right + + x_comb_iter_1_left = self.comb_iter_1_left(x_right) + x_comb_iter_1_right = self.comb_iter_1_right(x_right) + x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right + + x_comb_iter_2_left = self.comb_iter_2_left(x_right) + x_comb_iter_2_right = self.comb_iter_2_right(x_right) + x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right + + x_comb_iter_3_left = self.comb_iter_3_left(x_comb_iter_2) + x_comb_iter_3_right = self.comb_iter_3_right(x_right) + x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right + + x_comb_iter_4_left = self.comb_iter_4_left(x_left) + if self.comb_iter_4_right: + x_comb_iter_4_right = self.comb_iter_4_right(x_right) + else: + x_comb_iter_4_right = x_right + x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right + + x_out = torch.cat( + [x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, + x_comb_iter_4], 1) + return x_out + + +class CellStem0(CellBase): + + def __init__(self, in_channels_left, out_channels_left, in_channels_right, + out_channels_right): + super(CellStem0, self).__init__() + self.conv_1x1 = ReluConvBn(in_channels_right, out_channels_right, + kernel_size=1) + self.comb_iter_0_left = BranchSeparables(in_channels_left, + out_channels_left, + kernel_size=5, stride=2, + stem_cell=True) + self.comb_iter_0_right = nn.Sequential(OrderedDict([ + ('max_pool', MaxPool(3, stride=2)), + ('conv', nn.Conv2d(in_channels_left, out_channels_left, + kernel_size=1, bias=False)), + ('bn', nn.BatchNorm2d(out_channels_left, eps=0.001)), + ])) + self.comb_iter_1_left = BranchSeparables(out_channels_right, + out_channels_right, + kernel_size=7, stride=2) + self.comb_iter_1_right = MaxPool(3, stride=2) + self.comb_iter_2_left = BranchSeparables(out_channels_right, + out_channels_right, + kernel_size=5, stride=2) + self.comb_iter_2_right = BranchSeparables(out_channels_right, + out_channels_right, + kernel_size=3, stride=2) + self.comb_iter_3_left = BranchSeparables(out_channels_right, + out_channels_right, + kernel_size=3) + self.comb_iter_3_right = MaxPool(3, stride=2) + self.comb_iter_4_left = BranchSeparables(in_channels_right, + out_channels_right, + kernel_size=3, stride=2, + stem_cell=True) + self.comb_iter_4_right = ReluConvBn(out_channels_right, + out_channels_right, + kernel_size=1, stride=2) + + def forward(self, x_left): + x_right = self.conv_1x1(x_left) + x_out = self.cell_forward(x_left, x_right) + return x_out + + +class Cell(CellBase): + + def __init__(self, in_channels_left, out_channels_left, in_channels_right, + out_channels_right, is_reduction=False, zero_pad=False, + match_prev_layer_dimensions=False): + super(Cell, self).__init__() + + # If `is_reduction` is set to `True` stride 2 is used for + # convolutional and pooling layers to reduce the spatial size of + # the output of a cell approximately by a factor of 2. + stride = 2 if is_reduction else 1 + + # If `match_prev_layer_dimensions` is set to `True` + # `FactorizedReduction` is used to reduce the spatial size + # of the left input of a cell approximately by a factor of 2. + self.match_prev_layer_dimensions = match_prev_layer_dimensions + if match_prev_layer_dimensions: + self.conv_prev_1x1 = FactorizedReduction(in_channels_left, + out_channels_left) + else: + self.conv_prev_1x1 = ReluConvBn(in_channels_left, + out_channels_left, kernel_size=1) + + self.conv_1x1 = ReluConvBn(in_channels_right, out_channels_right, + kernel_size=1) + self.comb_iter_0_left = BranchSeparables(out_channels_left, + out_channels_left, + kernel_size=5, stride=stride, + zero_pad=zero_pad) + self.comb_iter_0_right = MaxPool(3, stride=stride, zero_pad=zero_pad) + self.comb_iter_1_left = BranchSeparables(out_channels_right, + out_channels_right, + kernel_size=7, stride=stride, + zero_pad=zero_pad) + self.comb_iter_1_right = MaxPool(3, stride=stride, zero_pad=zero_pad) + self.comb_iter_2_left = BranchSeparables(out_channels_right, + out_channels_right, + kernel_size=5, stride=stride, + zero_pad=zero_pad) + self.comb_iter_2_right = BranchSeparables(out_channels_right, + out_channels_right, + kernel_size=3, stride=stride, + zero_pad=zero_pad) + self.comb_iter_3_left = BranchSeparables(out_channels_right, + out_channels_right, + kernel_size=3) + self.comb_iter_3_right = MaxPool(3, stride=stride, zero_pad=zero_pad) + self.comb_iter_4_left = BranchSeparables(out_channels_left, + out_channels_left, + kernel_size=3, stride=stride, + zero_pad=zero_pad) + if is_reduction: + self.comb_iter_4_right = ReluConvBn(out_channels_right, + out_channels_right, + kernel_size=1, stride=stride) + else: + self.comb_iter_4_right = None + + def forward(self, x_left, x_right): + x_left = self.conv_prev_1x1(x_left) + x_right = self.conv_1x1(x_right) + x_out = self.cell_forward(x_left, x_right) + return x_out + + +class PNASNet5Large(nn.Module): + def __init__(self, num_classes=1001): + super(PNASNet5Large, self).__init__() + self.num_classes = num_classes + self.conv_0 = nn.Sequential(OrderedDict([ + ('conv', nn.Conv2d(3, 96, kernel_size=3, stride=2, bias=False)), + ('bn', nn.BatchNorm2d(96, eps=0.001)) + ])) + self.cell_stem_0 = CellStem0(in_channels_left=96, out_channels_left=54, + in_channels_right=96, + out_channels_right=54) + self.cell_stem_1 = Cell(in_channels_left=96, out_channels_left=108, + in_channels_right=270, out_channels_right=108, + match_prev_layer_dimensions=True, + is_reduction=True) + self.cell_0 = Cell(in_channels_left=270, out_channels_left=216, + in_channels_right=540, out_channels_right=216, + match_prev_layer_dimensions=True) + self.cell_1 = Cell(in_channels_left=540, out_channels_left=216, + in_channels_right=1080, out_channels_right=216) + self.cell_2 = Cell(in_channels_left=1080, out_channels_left=216, + in_channels_right=1080, out_channels_right=216) + self.cell_3 = Cell(in_channels_left=1080, out_channels_left=216, + in_channels_right=1080, out_channels_right=216) + self.cell_4 = Cell(in_channels_left=1080, out_channels_left=432, + in_channels_right=1080, out_channels_right=432, + is_reduction=True, zero_pad=True) + self.cell_5 = Cell(in_channels_left=1080, out_channels_left=432, + in_channels_right=2160, out_channels_right=432, + match_prev_layer_dimensions=True) + self.cell_6 = Cell(in_channels_left=2160, out_channels_left=432, + in_channels_right=2160, out_channels_right=432) + self.cell_7 = Cell(in_channels_left=2160, out_channels_left=432, + in_channels_right=2160, out_channels_right=432) + self.cell_8 = Cell(in_channels_left=2160, out_channels_left=864, + in_channels_right=2160, out_channels_right=864, + is_reduction=True) + self.cell_9 = Cell(in_channels_left=2160, out_channels_left=864, + in_channels_right=4320, out_channels_right=864, + match_prev_layer_dimensions=True) + self.cell_10 = Cell(in_channels_left=4320, out_channels_left=864, + in_channels_right=4320, out_channels_right=864) + self.cell_11 = Cell(in_channels_left=4320, out_channels_left=864, + in_channels_right=4320, out_channels_right=864) + self.relu = nn.ReLU() + self.avg_pool = nn.AvgPool2d(11, stride=1, padding=0) + self.dropout = nn.Dropout(0.5) + self.last_linear = nn.Linear(4320, num_classes) + + def features(self, x): + x_conv_0 = self.conv_0(x) + x_stem_0 = self.cell_stem_0(x_conv_0) + x_stem_1 = self.cell_stem_1(x_conv_0, x_stem_0) + x_cell_0 = self.cell_0(x_stem_0, x_stem_1) + x_cell_1 = self.cell_1(x_stem_1, x_cell_0) + x_cell_2 = self.cell_2(x_cell_0, x_cell_1) + x_cell_3 = self.cell_3(x_cell_1, x_cell_2) + x_cell_4 = self.cell_4(x_cell_2, x_cell_3) + x_cell_5 = self.cell_5(x_cell_3, x_cell_4) + x_cell_6 = self.cell_6(x_cell_4, x_cell_5) + x_cell_7 = self.cell_7(x_cell_5, x_cell_6) + x_cell_8 = self.cell_8(x_cell_6, x_cell_7) + x_cell_9 = self.cell_9(x_cell_7, x_cell_8) + x_cell_10 = self.cell_10(x_cell_8, x_cell_9) + x_cell_11 = self.cell_11(x_cell_9, x_cell_10) + return x_cell_11 + + def logits(self, features): + x = self.relu(features) + x = self.avg_pool(x) + x = x.view(x.size(0), -1) + x = self.dropout(x) + x = self.last_linear(x) + return x + + def forward(self, input): + x = self.features(input) + x = self.logits(x) + return x + + +def pnasnet5large(num_classes=1001, pretrained='imagenet'): + r"""PNASNet-5 model architecture from the + `"Progressive Neural Architecture Search" + `_ paper. + """ + if pretrained: + settings = pretrained_settings['pnasnet5large'][pretrained] + assert num_classes == settings[ + 'num_classes'], 'num_classes should be {}, but is {}'.format( + settings['num_classes'], num_classes) + + # both 'imagenet'&'imagenet+background' are loaded from same parameters + model = PNASNet5Large(num_classes=1001) + model.load_state_dict(model_zoo.load_url(settings['url'])) + + if pretrained == 'imagenet': + new_last_linear = nn.Linear(model.last_linear.in_features, 1000) + new_last_linear.weight.data = model.last_linear.weight.data[1:] + new_last_linear.bias.data = model.last_linear.bias.data[1:] + model.last_linear = new_last_linear + + model.input_space = settings['input_space'] + model.input_size = settings['input_size'] + model.input_range = settings['input_range'] + + model.mean = settings['mean'] + model.std = settings['std'] + else: + model = PNASNet5Large(num_classes=num_classes) + return model diff --git a/models/senet.py b/models/senet.py new file mode 100644 index 00000000..1d8d9056 --- /dev/null +++ b/models/senet.py @@ -0,0 +1,517 @@ +""" +ResNet code gently borrowed from +https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py +""" +from __future__ import print_function, division, absolute_import +from collections import OrderedDict +import math + +import torch.nn as nn +from torch.utils import model_zoo + +__all__ = ['SENet', 'senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152', + 'se_resnext50_32x4d', 'se_resnext101_32x4d'] + +pretrained_config = { + 'senet154': { + 'imagenet': { + 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth', + 'input_space': 'RGB', + 'input_size': [3, 224, 224], + 'input_range': [0, 1], + 'mean': [0.485, 0.456, 0.406], + 'std': [0.229, 0.224, 0.225], + 'num_classes': 1000 + } + }, + 'se_resnet18': { + 'imagenet': { + 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth', + 'input_space': 'RGB', + 'input_size': [3, 224, 224], + 'input_range': [0, 1], + 'mean': [0.485, 0.456, 0.406], + 'std': [0.229, 0.224, 0.225], + 'num_classes': 1000 + } + }, + 'se_resnet34': { + 'imagenet': { + 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth', + 'input_space': 'RGB', + 'input_size': [3, 224, 224], + 'input_range': [0, 1], + 'mean': [0.485, 0.456, 0.406], + 'std': [0.229, 0.224, 0.225], + 'num_classes': 1000 + } + }, + 'se_resnet50': { + 'imagenet': { + 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth', + 'input_space': 'RGB', + 'input_size': [3, 224, 224], + 'input_range': [0, 1], + 'mean': [0.485, 0.456, 0.406], + 'std': [0.229, 0.224, 0.225], + 'num_classes': 1000 + } + }, + 'se_resnet101': { + 'imagenet': { + 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth', + 'input_space': 'RGB', + 'input_size': [3, 224, 224], + 'input_range': [0, 1], + 'mean': [0.485, 0.456, 0.406], + 'std': [0.229, 0.224, 0.225], + 'num_classes': 1000 + } + }, + 'se_resnet152': { + 'imagenet': { + 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth', + 'input_space': 'RGB', + 'input_size': [3, 224, 224], + 'input_range': [0, 1], + 'mean': [0.485, 0.456, 0.406], + 'std': [0.229, 0.224, 0.225], + 'num_classes': 1000 + } + }, + 'se_resnext50_32x4d': { + 'imagenet': { + 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth', + 'input_space': 'RGB', + 'input_size': [3, 224, 224], + 'input_range': [0, 1], + 'mean': [0.485, 0.456, 0.406], + 'std': [0.229, 0.224, 0.225], + 'num_classes': 1000 + } + }, + 'se_resnext101_32x4d': { + 'imagenet': { + 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth', + 'input_space': 'RGB', + 'input_size': [3, 224, 224], + 'input_range': [0, 1], + 'mean': [0.485, 0.456, 0.406], + 'std': [0.229, 0.224, 0.225], + 'num_classes': 1000 + } + }, +} + + +class SEModule(nn.Module): + + def __init__(self, channels, reduction): + super(SEModule, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc1 = nn.Conv2d( + channels, channels // reduction, kernel_size=1, padding=0) + self.relu = nn.ReLU(inplace=True) + self.fc2 = nn.Conv2d( + channels // reduction, channels, kernel_size=1, padding=0) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + module_input = x + x = self.avg_pool(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.sigmoid(x) + return module_input * x + + +class Bottleneck(nn.Module): + """ + Base class for bottlenecks that implements `forward()` method. + """ + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out = self.se_module(out) + residual + out = self.relu(out) + + return out + + +class SEBottleneck(Bottleneck): + """ + Bottleneck for SENet154. + """ + expansion = 4 + + def __init__(self, inplanes, planes, groups, reduction, stride=1, + downsample=None): + super(SEBottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes * 2) + self.conv2 = nn.Conv2d( + planes * 2, planes * 4, kernel_size=3, stride=stride, + padding=1, groups=groups, bias=False) + self.bn2 = nn.BatchNorm2d(planes * 4) + self.conv3 = nn.Conv2d( + planes * 4, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.se_module = SEModule(planes * 4, reduction=reduction) + self.downsample = downsample + self.stride = stride + + +class SEResNetBottleneck(Bottleneck): + """ + ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe + implementation and uses `stride=stride` in `conv1` and not in `conv2` + (the latter is used in the torchvision implementation of ResNet). + """ + expansion = 4 + + def __init__(self, inplanes, planes, groups, reduction, stride=1, + downsample=None): + super(SEResNetBottleneck, self).__init__() + self.conv1 = nn.Conv2d( + inplanes, planes, kernel_size=1, bias=False, stride=stride) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d( + planes, planes, kernel_size=3, padding=1, groups=groups, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.se_module = SEModule(planes * 4, reduction=reduction) + self.downsample = downsample + self.stride = stride + + +class SEResNeXtBottleneck(Bottleneck): + """ + ResNeXt bottleneck type C with a Squeeze-and-Excitation module. + """ + expansion = 4 + + def __init__(self, inplanes, planes, groups, reduction, stride=1, + downsample=None, base_width=4): + super(SEResNeXtBottleneck, self).__init__() + width = math.floor(planes * (base_width / 64)) * groups + self.conv1 = nn.Conv2d( + inplanes, width, kernel_size=1, bias=False, stride=1) + self.bn1 = nn.BatchNorm2d(width) + self.conv2 = nn.Conv2d( + width, width, kernel_size=3, stride=stride, padding=1, groups=groups, bias=False) + self.bn2 = nn.BatchNorm2d(width) + self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.se_module = SEModule(planes * 4, reduction=reduction) + self.downsample = downsample + self.stride = stride + + +class SEResNetBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, groups, reduction, stride=1, downsample=None): + super(SEResNetBlock, self).__init__() + self.conv1 = nn.Conv2d( + inplanes, planes, kernel_size=3, padding=1, stride=stride, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d( + planes, planes, kernel_size=3, padding=1, groups=groups, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.se_module = SEModule(planes, reduction=reduction) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out = self.se_module(out) + residual + out = self.relu(out) + + return out + + +class SENet(nn.Module): + + def __init__(self, block, layers, groups, reduction, dropout_p=0.2, + inch=3, inplanes=128, input_3x3=True, downsample_kernel_size=3, + downsample_padding=1, num_classes=1000): + """ + Parameters + ---------- + block (nn.Module): Bottleneck class. + - For SENet154: SEBottleneck + - For SE-ResNet models: SEResNetBottleneck + - For SE-ResNeXt models: SEResNeXtBottleneck + layers (list of ints): Number of residual blocks for 4 layers of the + network (layer1...layer4). + groups (int): Number of groups for the 3x3 convolution in each + bottleneck block. + - For SENet154: 64 + - For SE-ResNet models: 1 + - For SE-ResNeXt models: 32 + reduction (int): Reduction ratio for Squeeze-and-Excitation modules. + - For all models: 16 + dropout_p (float or None): Drop probability for the Dropout layer. + If `None` the Dropout layer is not used. + - For SENet154: 0.2 + - For SE-ResNet models: None + - For SE-ResNeXt models: None + inplanes (int): Number of input channels for layer1. + - For SENet154: 128 + - For SE-ResNet models: 64 + - For SE-ResNeXt models: 64 + input_3x3 (bool): If `True`, use three 3x3 convolutions instead of + a single 7x7 convolution in layer0. + - For SENet154: True + - For SE-ResNet models: False + - For SE-ResNeXt models: False + downsample_kernel_size (int): Kernel size for downsampling convolutions + in layer2, layer3 and layer4. + - For SENet154: 3 + - For SE-ResNet models: 1 + - For SE-ResNeXt models: 1 + downsample_padding (int): Padding for downsampling convolutions in + layer2, layer3 and layer4. + - For SENet154: 1 + - For SE-ResNet models: 0 + - For SE-ResNeXt models: 0 + num_classes (int): Number of outputs in `last_linear` layer. + - For all models: 1000 + """ + super(SENet, self).__init__() + self.inplanes = inplanes + if input_3x3: + layer0_modules = [ + ('conv1', nn.Conv2d(inch, 64, 3, stride=2, padding=1, bias=False)), + ('bn1', nn.BatchNorm2d(64)), + ('relu1', nn.ReLU(inplace=True)), + ('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False)), + ('bn2', nn.BatchNorm2d(64)), + ('relu2', nn.ReLU(inplace=True)), + ('conv3', nn.Conv2d(64, inplanes, 3, stride=1, padding=1, bias=False)), + ('bn3', nn.BatchNorm2d(inplanes)), + ('relu3', nn.ReLU(inplace=True)), + ] + else: + layer0_modules = [ + ('conv1', nn.Conv2d( + inch, inplanes, kernel_size=7, stride=2, padding=3, bias=False)), + ('bn1', nn.BatchNorm2d(inplanes)), + ('relu1', nn.ReLU(inplace=True)), + ] + # To preserve compatibility with Caffe weights `ceil_mode=True` + # is used instead of `padding=1`. + layer0_modules.append(('pool', nn.MaxPool2d(3, stride=2, ceil_mode=True))) + self.layer0 = nn.Sequential(OrderedDict(layer0_modules)) + self.layer1 = self._make_layer( + block, + planes=64, + blocks=layers[0], + groups=groups, + reduction=reduction, + downsample_kernel_size=1, + downsample_padding=0 + ) + self.layer2 = self._make_layer( + block, + planes=128, + blocks=layers[1], + stride=2, + groups=groups, + reduction=reduction, + downsample_kernel_size=downsample_kernel_size, + downsample_padding=downsample_padding + ) + self.layer3 = self._make_layer( + block, + planes=256, + blocks=layers[2], + stride=2, + groups=groups, + reduction=reduction, + downsample_kernel_size=downsample_kernel_size, + downsample_padding=downsample_padding + ) + self.layer4 = self._make_layer( + block, + planes=512, + blocks=layers[3], + stride=2, + groups=groups, + reduction=reduction, + downsample_kernel_size=downsample_kernel_size, + downsample_padding=downsample_padding + ) + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None + self.last_linear = nn.Linear(512 * block.expansion, num_classes) + + def _make_layer(self, block, planes, blocks, groups, reduction, stride=1, + downsample_kernel_size=1, downsample_padding=0): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=downsample_kernel_size, stride=stride, + padding=downsample_padding, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [block( + self.inplanes, planes, groups, reduction, stride, downsample)] + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, groups, reduction)) + + return nn.Sequential(*layers) + + def forward_features(self, x): + x = self.layer0(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + return x + + def logits(self, x): + x = self.avg_pool(x) + if self.dropout is not None: + x = self.dropout(x) + x = x.view(x.size(0), -1) + x = self.last_linear(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.logits(x) + return x + + +def initialize_pretrained_model(model, num_classes, config): + assert num_classes == config['num_classes'], \ + 'num_classes should be {}, but is {}'.format( + config['num_classes'], num_classes) + model.load_state_dict(model_zoo.load_url(config['url'])) + model.input_space = config['input_space'] + model.input_size = config['input_size'] + model.input_range = config['input_range'] + model.mean = config['mean'] + model.std = config['std'] + + +def senet154(num_classes=1000, pretrained='imagenet'): + model = SENet(SEBottleneck, [3, 8, 36, 3], groups=64, reduction=16, + dropout_p=0.2, num_classes=num_classes) + if pretrained: + config = pretrained_config['senet154'][pretrained] + initialize_pretrained_model(model, num_classes, config) + return model + + +def se_resnet18(num_classes=1000, pretrained='imagenet'): + model = SENet(SEResNetBottleneck, [2, 2, 2, 2], groups=1, reduction=16, + dropout_p=None, inplanes=64, input_3x3=False, + downsample_kernel_size=1, downsample_padding=0, + num_classes=num_classes) + if pretrained: + config = pretrained_config['se_resnet18'][pretrained] + initialize_pretrained_model(model, num_classes, config) + return model + + +def se_resnet34(num_classes=1000, pretrained='imagenet'): + model = SENet(SEResNetBlock, [3, 4, 6, 3], groups=1, reduction=16, + dropout_p=None, inplanes=64, input_3x3=False, + downsample_kernel_size=1, downsample_padding=0, + num_classes=num_classes) + if pretrained: + config = pretrained_config['se_resnet34'][pretrained] + initialize_pretrained_model(model, num_classes, config) + return model + + +def se_resnet50(num_classes=1000, pretrained='imagenet'): + model = SENet(SEResNetBottleneck, [3, 4, 6, 3], groups=1, reduction=16, + dropout_p=None, inplanes=64, input_3x3=False, + downsample_kernel_size=1, downsample_padding=0, + num_classes=num_classes) + if pretrained: + config = pretrained_config['se_resnet50'][pretrained] + initialize_pretrained_model(model, num_classes, config) + return model + + +def se_resnet101(num_classes=1000, pretrained='imagenet'): + model = SENet(SEResNetBottleneck, [3, 4, 23, 3], groups=1, reduction=16, + dropout_p=None, inplanes=64, input_3x3=False, + downsample_kernel_size=1, downsample_padding=0, + num_classes=num_classes) + if pretrained: + config = pretrained_config['se_resnet101'][pretrained] + initialize_pretrained_model(model, num_classes, config) + return model + + +def se_resnet152(num_classes=1000, pretrained='imagenet'): + model = SENet(SEResNetBottleneck, [3, 8, 36, 3], groups=1, reduction=16, + dropout_p=None, inplanes=64, input_3x3=False, + downsample_kernel_size=1, downsample_padding=0, + num_classes=num_classes) + if pretrained: + config = pretrained_config['se_resnet152'][pretrained] + initialize_pretrained_model(model, num_classes, config) + return model + + +def se_resnext50_32x4d(num_classes=1000, pretrained='imagenet'): + model = SENet(SEResNeXtBottleneck, [3, 4, 6, 3], groups=32, reduction=16, + dropout_p=None, inplanes=64, input_3x3=False, + downsample_kernel_size=1, downsample_padding=0, + num_classes=num_classes) + if pretrained: + config = pretrained_config['se_resnext50_32x4d'][pretrained] + initialize_pretrained_model(model, num_classes, config) + return model + + +def se_resnext101_32x4d(num_classes=1000, pretrained='imagenet'): + model = SENet(SEResNeXtBottleneck, [3, 4, 23, 3], groups=32, reduction=16, + dropout_p=None, inplanes=64, input_3x3=False, + downsample_kernel_size=1, downsample_padding=0, + num_classes=num_classes) + if pretrained: + config = pretrained_config['se_resnext101_32x4d'][pretrained] + initialize_pretrained_model(model, num_classes, config) + return model diff --git a/models/wrn50_2.py b/models/wrn50_2.py new file mode 100644 index 00000000..63274fd2 --- /dev/null +++ b/models/wrn50_2.py @@ -0,0 +1,393 @@ +""" Pytorch Wide-Resnet-50-2 +Sourced by running https://github.com/clcarwin/convert_torch_to_pytorch (MIT) on +https://github.com/szagoruyko/wide-residual-networks/blob/master/pretrained/README.md +License of above is, as of yet, unclear. +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.model_zoo as model_zoo +from functools import reduce +from collections import OrderedDict +from .adaptive_avgmax_pool import * + +model_urls = { + 'wrn50_2': 'https://www.dropbox.com/s/fe7rj3okz9rctn0/wrn50_2-d98ded61.pth?dl=1', +} + + +class LambdaBase(nn.Sequential): + def __init__(self, fn, *args): + super(LambdaBase, self).__init__(*args) + self.lambda_func = fn + + def forward_prepare(self, input): + output = [] + for module in self._modules.values(): + output.append(module(input)) + return output if output else input + + +class Lambda(LambdaBase): + def forward(self, input): + return self.lambda_func(self.forward_prepare(input)) + + +class LambdaMap(LambdaBase): + def forward(self, input): + return list(map(self.lambda_func, self.forward_prepare(input))) + + +class LambdaReduce(LambdaBase): + def forward(self, input): + return reduce(self.lambda_func, self.forward_prepare(input)) + + +def wrn_50_2_features(activation_fn=nn.ReLU()): + features = nn.Sequential( # Sequential, + nn.Conv2d(3, 64, (7, 7), (2, 2), (3, 3), 1, 1, bias=False), + nn.BatchNorm2d(64), + activation_fn, + nn.MaxPool2d((3, 3), (2, 2), (1, 1)), + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Conv2d(64, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 1, bias=False), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), + nn.BatchNorm2d(256), + ), + nn.Sequential( # Sequential, + nn.Conv2d(64, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), + nn.BatchNorm2d(256), + ), + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + activation_fn, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Conv2d(256, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 1, bias=False), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), + nn.BatchNorm2d(256), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + activation_fn, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Conv2d(256, 128, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 1, bias=False), + nn.BatchNorm2d(128), + activation_fn, + nn.Conv2d(128, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), + nn.BatchNorm2d(256), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + activation_fn, + ), + ), + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Conv2d(256, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 256, (3, 3), (2, 2), (1, 1), 1, 1, bias=False), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), + nn.BatchNorm2d(512), + ), + nn.Sequential( # Sequential, + nn.Conv2d(256, 512, (1, 1), (2, 2), (0, 0), 1, 1, bias=False), + nn.BatchNorm2d(512), + ), + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + activation_fn, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 1, bias=False), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), + nn.BatchNorm2d(512), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + activation_fn, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 1, bias=False), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), + nn.BatchNorm2d(512), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + activation_fn, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Conv2d(512, 256, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 1, bias=False), + nn.BatchNorm2d(256), + activation_fn, + nn.Conv2d(256, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), + nn.BatchNorm2d(512), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + activation_fn, + ), + ), + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Conv2d(512, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), + nn.BatchNorm2d(512), + activation_fn, + nn.Conv2d(512, 512, (3, 3), (2, 2), (1, 1), 1, 1, bias=False), + nn.BatchNorm2d(512), + activation_fn, + nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), + nn.BatchNorm2d(1024), + ), + nn.Sequential( # Sequential, + nn.Conv2d(512, 1024, (1, 1), (2, 2), (0, 0), 1, 1, bias=False), + nn.BatchNorm2d(1024), + ), + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + activation_fn, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), + nn.BatchNorm2d(512), + activation_fn, + nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 1, bias=False), + nn.BatchNorm2d(512), + activation_fn, + nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + activation_fn, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), + nn.BatchNorm2d(512), + activation_fn, + nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 1, bias=False), + nn.BatchNorm2d(512), + activation_fn, + nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + activation_fn, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), + nn.BatchNorm2d(512), + activation_fn, + nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 1, bias=False), + nn.BatchNorm2d(512), + activation_fn, + nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + activation_fn, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), + nn.BatchNorm2d(512), + activation_fn, + nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 1, bias=False), + nn.BatchNorm2d(512), + activation_fn, + nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + activation_fn, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Conv2d(1024, 512, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), + nn.BatchNorm2d(512), + activation_fn, + nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 1, bias=False), + nn.BatchNorm2d(512), + activation_fn, + nn.Conv2d(512, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), + nn.BatchNorm2d(1024), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + activation_fn, + ), + ), + nn.Sequential( # Sequential, + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Conv2d(1024, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), + nn.BatchNorm2d(1024), + activation_fn, + nn.Conv2d(1024, 1024, (3, 3), (2, 2), (1, 1), 1, 1, bias=False), + nn.BatchNorm2d(1024), + activation_fn, + nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), + nn.BatchNorm2d(2048), + ), + nn.Sequential( # Sequential, + nn.Conv2d(1024, 2048, (1, 1), (2, 2), (0, 0), 1, 1, bias=False), + nn.BatchNorm2d(2048), + ), + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + activation_fn, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Conv2d(2048, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), + nn.BatchNorm2d(1024), + activation_fn, + nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 1, bias=False), + nn.BatchNorm2d(1024), + activation_fn, + nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), + nn.BatchNorm2d(2048), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + activation_fn, + ), + nn.Sequential( # Sequential, + LambdaMap(lambda x: x, # ConcatTable, + nn.Sequential( # Sequential, + nn.Conv2d(2048, 1024, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), + nn.BatchNorm2d(1024), + activation_fn, + nn.Conv2d(1024, 1024, (3, 3), (1, 1), (1, 1), 1, 1, bias=False), + nn.BatchNorm2d(1024), + activation_fn, + nn.Conv2d(1024, 2048, (1, 1), (1, 1), (0, 0), 1, 1, bias=False), + nn.BatchNorm2d(2048), + ), + Lambda(lambda x: x), # Identity, + ), + LambdaReduce(lambda x, y: x + y), # CAddTable, + activation_fn, + ), + ), + ) + return features + + +class Wrn50_2(nn.Module): + def __init__(self, num_classes=1000, activation_fn=nn.ReLU(), drop_rate=0., global_pool='avg'): + super(Wrn50_2, self).__init__() + self.drop_rate = drop_rate + self.num_classes = num_classes + self.num_features = 2048 + self.global_pool = global_pool + self.features = wrn_50_2_features(activation_fn=activation_fn) + self.fc = nn.Linear(2048, num_classes) + + def get_classifier(self): + return self.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool = global_pool + self.fc = nn.Linear(2048, num_classes) + + def forward_features(self, x, pool=True): + x = self.features(x) + if pool: + x = adaptive_avgmax_pool2d(x, self.global_pool) + x = x.view(x.size(0), -1) + return x + + def forward(self, x): + x = self.forward_features(x, pool=True) + if self.drop_rate > 0: + x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.fc(x) + return x + + +def wrn50_2(pretrained=False, num_classes=1000, **kwargs): + model = Wrn50_2(num_classes=num_classes, **kwargs) + if pretrained: + # Remap pretrained weights to match our class module with features + fc + pretrained_weights = model_zoo.load_url(model_urls['wrn50_2']) + feature_keys = filter(lambda k: '10.1.' not in k, pretrained_weights.keys()) + remapped_weights = OrderedDict() + for k in feature_keys: + remapped_weights['features.' + k] = pretrained_weights[k] + remapped_weights['fc.weight'] = pretrained_weights['10.1.weight'] + remapped_weights['fc.bias'] = pretrained_weights['10.1.bias'] + model.load_state_dict(remapped_weights) + return model \ No newline at end of file diff --git a/models/xception.py b/models/xception.py new file mode 100644 index 00000000..8aca27d8 --- /dev/null +++ b/models/xception.py @@ -0,0 +1,237 @@ +""" +Ported to pytorch thanks to [tstandley](https://github.com/tstandley/Xception-PyTorch) + +@author: tstandley +Adapted by cadene + +Creates an Xception Model as defined in: + +Francois Chollet +Xception: Deep Learning with Depthwise Separable Convolutions +https://arxiv.org/pdf/1610.02357.pdf + +This weights ported from the Keras implementation. Achieves the following performance on the validation set: + +Loss:0.9173 Prec@1:78.892 Prec@5:94.292 + +REMEMBER to set your image size to 3x299x299 for both test and validation + +normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], + std=[0.5, 0.5, 0.5]) + +The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 +""" +from __future__ import print_function, division, absolute_import +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.model_zoo as model_zoo +from torch.nn import init + +__all__ = ['xception'] + +pretrained_config = { + 'xception': { + 'imagenet': { + 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/xception-43020ad28.pth', + 'input_space': 'RGB', + 'input_size': [3, 299, 299], + 'input_range': [0, 1], + 'mean': [0.5, 0.5, 0.5], + 'std': [0.5, 0.5, 0.5], + 'num_classes': 1000, + 'scale': 0.8975 + # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 + } + } +} + + +class SeparableConv2d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False): + super(SeparableConv2d, self).__init__() + + self.conv1 = nn.Conv2d( + in_channels, in_channels, kernel_size, stride, padding, dilation, groups=in_channels, bias=bias) + self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias) + + def forward(self, x): + x = self.conv1(x) + x = self.pointwise(x) + return x + + +class Block(nn.Module): + def __init__(self, in_filters, out_filters, reps, strides=1, start_with_relu=True, grow_first=True): + super(Block, self).__init__() + + if out_filters != in_filters or strides != 1: + self.skip = nn.Conv2d(in_filters, out_filters, 1, stride=strides, bias=False) + self.skipbn = nn.BatchNorm2d(out_filters) + else: + self.skip = None + + self.relu = nn.ReLU(inplace=True) + rep = [] + + filters = in_filters + if grow_first: + rep.append(self.relu) + rep.append(SeparableConv2d(in_filters, out_filters, 3, stride=1, padding=1, bias=False)) + rep.append(nn.BatchNorm2d(out_filters)) + filters = out_filters + + for i in range(reps - 1): + rep.append(self.relu) + rep.append(SeparableConv2d(filters, filters, 3, stride=1, padding=1, bias=False)) + rep.append(nn.BatchNorm2d(filters)) + + if not grow_first: + rep.append(self.relu) + rep.append(SeparableConv2d(in_filters, out_filters, 3, stride=1, padding=1, bias=False)) + rep.append(nn.BatchNorm2d(out_filters)) + + if not start_with_relu: + rep = rep[1:] + else: + rep[0] = nn.ReLU(inplace=False) + + if strides != 1: + rep.append(nn.MaxPool2d(3, strides, 1)) + self.rep = nn.Sequential(*rep) + + def forward(self, inp): + x = self.rep(inp) + + if self.skip is not None: + skip = self.skip(inp) + skip = self.skipbn(skip) + else: + skip = inp + + x += skip + return x + + +class Xception(nn.Module): + """ + Xception optimized for the ImageNet dataset, as specified in + https://arxiv.org/pdf/1610.02357.pdf + """ + + def __init__(self, num_classes=1000): + """ Constructor + Args: + num_classes: number of classes + """ + super(Xception, self).__init__() + self.num_classes = num_classes + + self.conv1 = nn.Conv2d(3, 32, 3, 2, 0, bias=False) + self.bn1 = nn.BatchNorm2d(32) + self.relu = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(32, 64, 3, bias=False) + self.bn2 = nn.BatchNorm2d(64) + # do relu here + + self.block1 = Block(64, 128, 2, 2, start_with_relu=False, grow_first=True) + self.block2 = Block(128, 256, 2, 2, start_with_relu=True, grow_first=True) + self.block3 = Block(256, 728, 2, 2, start_with_relu=True, grow_first=True) + + self.block4 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) + self.block5 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) + self.block6 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) + self.block7 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) + + self.block8 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) + self.block9 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) + self.block10 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) + self.block11 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) + + self.block12 = Block(728, 1024, 2, 2, start_with_relu=True, grow_first=False) + + self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1) + self.bn3 = nn.BatchNorm2d(1536) + + # do relu here + self.conv4 = SeparableConv2d(1536, 2048, 3, 1, 1) + self.bn4 = nn.BatchNorm2d(2048) + + self.fc = nn.Linear(2048, num_classes) + + # #------- init weights -------- + # for m in self.modules(): + # if isinstance(m, nn.Conv2d): + # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + # m.weight.data.normal_(0, math.sqrt(2. / n)) + # elif isinstance(m, nn.BatchNorm2d): + # m.weight.data.fill_(1) + # m.bias.data.zero_() + # #----------------------------- + + def forward_features(self, input): + x = self.conv1(input) + x = self.bn1(x) + x = self.relu(x) + + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + + x = self.block1(x) + x = self.block2(x) + x = self.block3(x) + x = self.block4(x) + x = self.block5(x) + x = self.block6(x) + x = self.block7(x) + x = self.block8(x) + x = self.block9(x) + x = self.block10(x) + x = self.block11(x) + x = self.block12(x) + + x = self.conv3(x) + x = self.bn3(x) + x = self.relu(x) + + x = self.conv4(x) + x = self.bn4(x) + return x + + def logits(self, features): + x = self.relu(features) + + x = F.adaptive_avg_pool2d(x, (1, 1)) + x = x.view(x.size(0), -1) + x = self.last_linear(x) + return x + + def forward(self, input): + x = self.forward_features(input) + x = self.logits(x) + return x + + +def xception(num_classes=1000, pretrained='imagenet'): + model = Xception(num_classes=num_classes) + if pretrained: + config = pretrained_config['xception'][pretrained] + assert num_classes == config['num_classes'], \ + "num_classes should be {}, but is {}".format(config['num_classes'], num_classes) + + model = Xception(num_classes=num_classes) + model.load_state_dict(model_zoo.load_url(config['url'])) + + model.input_space = config['input_space'] + model.input_size = config['input_size'] + model.input_range = config['input_range'] + model.mean = config['mean'] + model.std = config['std'] + + # TODO: ugly + model.last_linear = model.fc + del model.fc + return model diff --git a/optim/nadam.py b/optim/nadam.py new file mode 100644 index 00000000..56a57b1c --- /dev/null +++ b/optim/nadam.py @@ -0,0 +1,85 @@ +import torch +from torch.optim import Optimizer + + +class Nadam(Optimizer): + """Implements Nadam algorithm (a variant of Adam based on Nesterov momentum). + + It has been proposed in `Incorporating Nesterov Momentum into Adam`__. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 2e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + schedule_decay (float, optional): momentum schedule decay (default: 4e-3) + + __ http://cs229.stanford.edu/proj2015/054_report.pdf + __ http://www.cs.toronto.edu/~fritz/absps/momentum.pdf + """ + + def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=0, schedule_decay=4e-3): + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, schedule_decay=schedule_decay) + super(Nadam, self).__init__(params, defaults) + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + state['m_schedule'] = 1. + state['exp_avg'] = grad.new().resize_as_(grad).zero_() + state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_() + + # Warming momentum schedule + m_schedule = state['m_schedule'] + schedule_decay = group['schedule_decay'] + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + eps = group['eps'] + state['step'] += 1 + t = state['step'] + + if group['weight_decay'] != 0: + grad = grad.add(group['weight_decay'], p.data) + + momentum_cache_t = beta1 * \ + (1. - 0.5 * (0.96 ** (t * schedule_decay))) + momentum_cache_t_1 = beta1 * \ + (1. - 0.5 * (0.96 ** ((t + 1) * schedule_decay))) + m_schedule_new = m_schedule * momentum_cache_t + m_schedule_next = m_schedule * momentum_cache_t * momentum_cache_t_1 + state['m_schedule'] = m_schedule_new + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(1. - beta1, grad) + exp_avg_sq.mul_(beta2).addcmul_(1. - beta2, grad, grad) + exp_avg_sq_prime = exp_avg_sq / (1. - beta2 ** t) + denom = exp_avg_sq_prime.sqrt_().add_(eps) + + p.data.addcdiv_(-group['lr'] * (1. - momentum_cache_t) / (1. - m_schedule_new), grad, denom) + p.data.addcdiv_(-group['lr'] * momentum_cache_t_1 / (1. - m_schedule_next), exp_avg, denom) + + return loss diff --git a/train.py b/train.py new file mode 100644 index 00000000..991eb253 --- /dev/null +++ b/train.py @@ -0,0 +1,407 @@ +import argparse +import csv +import os +import time +from collections import OrderedDict +from datetime import datetime + +from dataset import Dataset +from models import model_factory, get_transforms_eval, get_transforms_train +from utils import * +from optim import nadam + +import torch +import torch.nn +import torch.nn.functional as F +import torch.optim as optim +import torch.utils.data as data +import torchvision.utils + +torch.backends.cudnn.benchmark = True + +parser = argparse.ArgumentParser(description='Training') +parser.add_argument('data', metavar='DIR', + help='path to dataset') +parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL', + help='Name of model to train (default: "countception"') +parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', + help='Optimizer (default: "sgd"') +parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', + help='Optimizer Epsilon (default: 1e-8)') +parser.add_argument('--gp', default='avg', type=str, metavar='POOL', + help='Type of global pool, "avg", "max", "avgmax", "avgmaxc" (default: "avg")') +parser.add_argument('--tta', type=int, default=0, metavar='N', + help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)') +parser.add_argument('--pretrained', action='store_true', default=False, + help='Start with pretrained version of specified network (if avail)') +parser.add_argument('--img-size', type=int, default=224, metavar='N', + help='Image patch size (default: 224)') +parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N', + help='input batch size for training (default: 32)') +parser.add_argument('-s', '--initial-batch-size', type=int, default=0, metavar='N', + help='initial input batch size for training (default: 0)') +parser.add_argument('--epochs', type=int, default=200, metavar='N', + help='number of epochs to train (default: 2)') +parser.add_argument('--start-epoch', default=None, type=int, metavar='N', + help='manual epoch number (useful on restarts)') +parser.add_argument('--decay-epochs', type=int, default=30, metavar='N', + help='epoch interval to decay LR') +parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', + help='LR decay rate (default: 0.1)') +parser.add_argument('--drop', type=float, default=0.0, metavar='DROP', + help='Dropout rate (default: 0.1)') +parser.add_argument('--lr', type=float, default=0.01, metavar='LR', + help='learning rate (default: 0.01)') +parser.add_argument('--momentum', type=float, default=0.9, metavar='M', + help='SGD momentum (default: 0.9)') +parser.add_argument('--weight-decay', type=float, default=0.0005, metavar='M', + help='weight decay (default: 0.0001)') +parser.add_argument('--seed', type=int, default=42, metavar='S', + help='random seed (default: 42)') +parser.add_argument('--log-interval', type=int, default=50, metavar='N', + help='how many batches to wait before logging training status') +parser.add_argument('--recovery-interval', type=int, default=1000, metavar='N', + help='how many batches to wait before writing recovery checkpoint') +parser.add_argument('-j', '--workers', type=int, default=2, metavar='N', + help='how many training processes to use (default: 1)') +parser.add_argument('--num-gpu', type=int, default=1, + help='Number of GPUS to use') +parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH', + help='path to init checkpoint (default: none)') +parser.add_argument('--resume', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') +parser.add_argument('--save-images', action='store_true', default=False, + help='save images of input bathes every log interval for debugging') +parser.add_argument('--output', default='', type=str, metavar='PATH', + help='path to output folder (default: none, current dir)') + + +def main(): + args = parser.parse_args() + + if args.output: + output_base = args.output + else: + output_base = './output' + exp_name = '-'.join([ + datetime.now().strftime("%Y%m%d-%H%M%S"), + args.model, + str(args.img_size)]) + output_dir = get_outdir(output_base, 'train', exp_name) + + batch_size = args.batch_size + num_epochs = args.epochs + torch.manual_seed(args.seed) + + model = model_factory.create_model( + args.model, + pretrained=args.pretrained, + num_classes=1000, + drop_rate=args.drop, + global_pool=args.gp, + checkpoint_path=args.initial_checkpoint) + + if args.initial_batch_size: + batch_size = adjust_batch_size( + epoch=0, initial_bs=args.initial_batch_size, target_bs=args.batch_size) + print('Setting batch-size to %d' % batch_size) + + dataset_train = Dataset( + os.path.join(args.data, 'train'), + transform=get_transforms_train(args.model)) + + loader_train = data.DataLoader( + dataset_train, + batch_size=batch_size, + pin_memory=True, + shuffle=True, + num_workers=args.workers + ) + + dataset_eval = Dataset( + os.path.join(args.data, 'validation'), + transform=get_transforms_eval(args.model)) + + loader_eval = data.DataLoader( + dataset_eval, + batch_size=4 * args.batch_size, + pin_memory=True, + shuffle=False, + num_workers=args.workers + ) + + train_loss_fn = validate_loss_fn = torch.nn.CrossEntropyLoss() + train_loss_fn = train_loss_fn.cuda() + validate_loss_fn = validate_loss_fn.cuda() + + if args.opt.lower() == 'sgd': + optimizer = optim.SGD( + model.parameters(), lr=args.lr, + momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) + elif args.opt.lower() == 'adam': + optimizer = optim.Adam( + model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps) + elif args.opt.lower() == 'nadam': + optimizer = nadam.Nadam( + model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps) + elif args.opt.lower() == 'adadelta': + optimizer = optim.Adadelta( + model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps) + elif args.opt.lower() == 'rmsprop': + optimizer = optim.RMSprop( + model.parameters(), lr=args.lr, alpha=0.9, eps=args.opt_eps, + momentum=args.momentum, weight_decay=args.weight_decay) + else: + assert False and "Invalid optimizer" + exit(1) + + if not args.decay_epochs: + lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=8) + else: + lr_scheduler = None + + # optionally resume from a checkpoint + start_epoch = 0 if args.start_epoch is None else args.start_epoch + if args.resume: + if os.path.isfile(args.resume): + print("=> loading checkpoint '{}'".format(args.resume)) + checkpoint = torch.load(args.resume) + if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: + new_state_dict = OrderedDict() + for k, v in checkpoint['state_dict'].items(): + if k.startswith('module'): + name = k[7:] # remove `module.` + else: + name = k + new_state_dict[name] = v + model.load_state_dict(new_state_dict) + if 'optimizer' in checkpoint: + optimizer.load_state_dict(checkpoint['optimizer']) + print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) + start_epoch = checkpoint['epoch'] if args.start_epoch is None else args.start_epoch + else: + model.load_state_dict(checkpoint) + else: + print("=> no checkpoint found at '{}'".format(args.resume)) + return False + + saver = CheckpointSaver(checkpoint_dir=output_dir) + + if args.num_gpu > 1: + model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() + else: + model.cuda() + + best_loss = None + try: + for epoch in range(start_epoch, num_epochs): + if args.decay_epochs: + adjust_learning_rate( + optimizer, epoch, initial_lr=args.lr, + decay_rate=args.decay_rate, decay_epochs=args.decay_epochs) + + if args.initial_batch_size: + next_batch_size = adjust_batch_size( + epoch, initial_bs=args.initial_batch_size, target_bs=args.batch_size) + if next_batch_size > batch_size: + print("Changing batch size from %d to %d" % (batch_size, next_batch_size)) + batch_size = next_batch_size + loader_train = data.DataLoader( + dataset_train, + batch_size=batch_size, + pin_memory=True, + shuffle=True, + # sampler=sampler, + num_workers=args.workers) + + train_metrics = train_epoch( + epoch, model, loader_train, optimizer, train_loss_fn, args, + saver=saver, output_dir=output_dir) + + # save a recovery in case validation blows up + saver.save_recovery({ + 'epoch': epoch + 1, + 'arch': args.model, + 'state_dict': model.state_dict(), + 'optimizer': optimizer.state_dict(), + 'loss': train_loss_fn.state_dict(), + 'args': args, + 'gp': args.gp, + }, + epoch=epoch + 1, + batch_idx=0) + + step = epoch * len(loader_train) + eval_metrics = validate( + step, model, loader_eval, validate_loss_fn, args, + output_dir=output_dir) + + if lr_scheduler is not None: + lr_scheduler.step(eval_metrics['eval_loss']) + + rowd = OrderedDict(epoch=epoch) + rowd.update(train_metrics) + rowd.update(eval_metrics) + with open(os.path.join(output_dir, 'summary.csv'), mode='a') as cf: + dw = csv.DictWriter(cf, fieldnames=rowd.keys()) + if best_loss is None: # first iteration (epoch == 1 can't be used) + dw.writeheader() + dw.writerow(rowd) + + # save proper checkpoint with eval metric + best_loss = saver.save_checkpoint({ + 'epoch': epoch + 1, + 'arch': args.model, + 'state_dict': model.state_dict(), + 'optimizer': optimizer.state_dict(), + 'args': args, + 'gp': args.gp, + }, + epoch=epoch + 1, + metric=eval_metrics['eval_loss']) + + except KeyboardInterrupt: + pass + print('*** Best loss: {0} (epoch {1})'.format(best_loss[1], best_loss[0])) + + +def train_epoch( + epoch, model, loader, optimizer, loss_fn, args, + saver=None, output_dir=''): + + epoch_step = (epoch - 1) * len(loader) + batch_time_m = AverageMeter() + data_time_m = AverageMeter() + losses_m = AverageMeter() + + model.train() + + end = time.time() + for batch_idx, (input, target) in enumerate(loader): + step = epoch_step + batch_idx + data_time_m.update(time.time() - end) + + input = input.cuda() + if isinstance(target, list): + target = [t.cuda() for t in target] + else: + target = target.cuda() + + output = model(input) + + loss = loss_fn(output, target) + losses_m.update(loss.item(), input.size(0)) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + batch_time_m.update(time.time() - end) + if batch_idx % args.log_interval == 0: + print('Train: {} [{}/{} ({:.0f}%)] ' + 'Loss: {loss.val:.6f} ({loss.avg:.4f}) ' + 'Time: {batch_time.val:.3f}s, {rate:.3f}/s ' + '({batch_time.avg:.3f}s, {rate_avg:.3f}/s) ' + 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format( + epoch, + batch_idx * len(input), len(loader.sampler), + 100. * batch_idx / len(loader), + loss=losses_m, + batch_time=batch_time_m, + rate=input.size(0) / batch_time_m.val, + rate_avg=input.size(0) / batch_time_m.avg, + data_time=data_time_m)) + + if args.save_images: + torchvision.utils.save_image( + input, + os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx), + padding=0, + normalize=True) + + if saver is not None and batch_idx % args.recovery_interval == 0: + saver.save_recovery({ + 'epoch': epoch, + 'arch': args.model, + 'state_dict': model.state_dict(), + 'optimizer': optimizer.state_dict(), + 'args': args, + 'gp': args.gp, + }, + epoch=epoch, + batch_idx=batch_idx) + + end = time.time() + + return OrderedDict([('train_loss', losses_m.avg)]) + + +def validate(step, model, loader, loss_fn, args, output_dir=''): + batch_time_m = AverageMeter() + losses_m = AverageMeter() + prec1_m = AverageMeter() + prec5_m = AverageMeter() + + model.eval() + + end = time.time() + with torch.no_grad(): + for batch_idx, (input, target) in enumerate(loader): + input = input.cuda() + if isinstance(target, list): + target = target[0].cuda() + else: + target = target.cuda() + + output = model(input) + + if isinstance(output, list): + output = output[0] + + # augmentation reduction + reduce_factor = loader.dataset.get_aug_factor() + if reduce_factor > 1: + output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2) + target = target[0:target.size(0):reduce_factor] + + # calc loss + loss = loss_fn(output, target) + losses_m.update(loss.item(), input.size(0)) + + # metrics + prec1, prec5 = accuracy(output, target, topk=(1, 3)) + prec1_m.update(prec1.item(), output.size(0)) + prec5_m.update(prec5.item(), output.size(0)) + + batch_time_m.update(time.time() - end) + end = time.time() + if batch_idx % args.log_interval == 0: + print('Test: [{0}/{1}]\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) ' + 'Loss {loss.val:.4f} ({loss.avg:.4f}) ' + 'Prec@1 {top1.val:.4f} ({top1.avg:.4f}) ' + 'Prec@5 {top5.val:.4f} ({top5.avg:.4f})'.format( + batch_idx, len(loader), + batch_time=batch_time_m, loss=losses_m, + top1=prec1_m, top5=prec5_m)) + + metrics = OrderedDict([('eval_loss', losses_m.avg), ('eval_prec1', prec1_m.avg)]) + + return metrics + + +def adjust_learning_rate(optimizer, epoch, initial_lr, decay_rate=0.1, decay_epochs=30): + """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" + lr = initial_lr * (decay_rate ** (epoch // decay_epochs)) + print('Setting LR to', lr) + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + +def adjust_batch_size(epoch, initial_bs, target_bs, decay_epochs=1): + batch_size = min(target_bs, initial_bs * (2 ** (epoch // decay_epochs))) + return batch_size + + +if __name__ == '__main__': + main() diff --git a/utils.py b/utils.py new file mode 100644 index 00000000..ad5f7780 --- /dev/null +++ b/utils.py @@ -0,0 +1,139 @@ +import torch +import numbers +import math +import numpy as np +import os +import shutil +import glob + + +class CheckpointSaver: + def __init__( + self, + checkpoint_prefix='checkpoint', + recovery_prefix='recovery', + checkpoint_dir='', + recovery_dir='', + max_history=10): + + self.checkpoint_files = [] + self.best_metric = None + self.worst_metric = None + self.max_history = max_history + assert self.max_history >= 1 + self.curr_recovery_file = '' + self.last_recovery_file = '' + self.checkpoint_dir = checkpoint_dir + self.recovery_dir = recovery_dir + self.save_prefix = checkpoint_prefix + self.recovery_prefix = recovery_prefix + self.extension = '.pth.tar' + + def save_checkpoint(self, state, epoch, metric=None): + worst_metric = self.checkpoint_files[-1] if self.checkpoint_files else None + if len(self.checkpoint_files) < self.max_history or metric < worst_metric[1]: + if len(self.checkpoint_files) >= self.max_history: + self._cleanup_checkpoints(1) + + filename = '-'.join([self.save_prefix, str(epoch)]) + self.extension + save_path = os.path.join(self.checkpoint_dir, filename) + if metric is not None: + state['metric'] = metric + torch.save(state, save_path) + self.checkpoint_files.append((save_path, metric)) + self.checkpoint_files = sorted(self.checkpoint_files, key=lambda x: x[1]) + + print("Current checkpoints:") + for c in self.checkpoint_files: + print(c) + + if metric is not None and (self.best_metric is None or metric < self.best_metric[1]): + self.best_metric = (epoch, metric) + shutil.copyfile(save_path, os.path.join(self.checkpoint_dir, 'model_best' + self.extension)) + return None, None if self.best_metric is None else self.best_metric + + def _cleanup_checkpoints(self, trim=0): + trim = min(len(self.checkpoint_files), trim) + delete_index = self.max_history - trim + if delete_index <= 0 or len(self.checkpoint_files) <= delete_index: + return + to_delete = self.checkpoint_files[delete_index:] + for d in to_delete: + try: + print('Cleaning checkpoint: ', d) + os.remove(d[0]) + except Exception as e: + print('Exception (%s) while deleting checkpoint' % str(e)) + self.checkpoint_files = self.checkpoint_files[:delete_index] + + def save_recovery(self, state, epoch, batch_idx): + filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension + save_path = os.path.join(self.recovery_dir, filename) + torch.save(state, save_path) + if os.path.exists(self.last_recovery_file): + try: + print('Cleaning recovery', self.last_recovery_file) + os.remove(self.last_recovery_file) + except Exception as e: + print("Exception (%s) while removing %s" % (str(e), self.last_recovery_file)) + self.last_recovery_file = self.curr_recovery_file + self.curr_recovery_file = save_path + + def find_recovery(self): + recovery_path = os.path.join(self.recovery_dir, self.recovery_prefix) + files = glob.glob(recovery_path + '*' + self.extension) + files = sorted(files) + if len(files): + return files[0] + else: + return '' + + +class AverageMeter: + """Computes and stores the average and current value""" + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def get_outdir(path, *paths, inc=False): + outdir = os.path.join(path, *paths) + if not os.path.exists(outdir): + os.makedirs(outdir) + elif inc: + count = 1 + outdir_inc = outdir + '-' + str(count) + while os.path.exists(outdir_inc): + count = count + 1 + outdir_inc = outdir + '-' + str(count) + assert count < 100 + outdir = outdir_inc + os.makedirs(outdir) + return outdir diff --git a/validate.py b/validate.py new file mode 100644 index 00000000..0d4ce999 --- /dev/null +++ b/validate.py @@ -0,0 +1,174 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import os +import time +import torch +import torch.backends.cudnn as cudnn +import torch.nn as nn +import torch.nn.parallel +import torch.utils.data as data + + +from models import model_factory +from dataset import Dataset + + +parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation') +parser.add_argument('data', metavar='DIR', + help='path to dataset') +parser.add_argument('--model', '-m', metavar='MODEL', default='dpn92', + help='model architecture (default: dpn92)') +parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', + help='number of data loading workers (default: 2)') +parser.add_argument('-b', '--batch-size', default=256, type=int, + metavar='N', help='mini-batch size (default: 256)') +parser.add_argument('--img-size', default=224, type=int, + metavar='N', help='Input image dimension') +parser.add_argument('--print-freq', '-p', default=10, type=int, + metavar='N', help='print frequency (default: 10)') +parser.add_argument('--restore-checkpoint', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') +parser.add_argument('--pretrained', dest='pretrained', action='store_true', + help='use pre-trained model') +parser.add_argument('--multi-gpu', dest='multi_gpu', action='store_true', + help='use multiple-gpus') +parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true', + help='disable test time pool for DPN models') + + +def main(): + args = parser.parse_args() + + test_time_pool = False + if 'dpn' in args.model and args.img_size > 224 and not args.no_test_pool: + test_time_pool = True + + # create model + num_classes = 1000 + model = model_factory.create_model( + args.model, + num_classes=num_classes, + pretrained=args.pretrained, + test_time_pool=test_time_pool) + + print('Model %s created, param count: %d' % + (args.model, sum([m.numel() for m in model.parameters()]))) + + print(model) + + # optionally resume from a checkpoint + if args.restore_checkpoint and os.path.isfile(args.restore_checkpoint): + print("=> loading checkpoint '{}'".format(args.restore_checkpoint)) + checkpoint = torch.load(args.restore_checkpoint) + if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: + model.load_state_dict(checkpoint['state_dict']) + else: + model.load_state_dict(checkpoint) + print("=> loaded checkpoint '{}'".format(args.restore_checkpoint)) + elif not args.pretrained: + print("=> no checkpoint found at '{}'".format(args.restore_checkpoint)) + exit(1) + + if args.multi_gpu: + model = torch.nn.DataParallel(model).cuda() + else: + model = model.cuda() + + # define loss function (criterion) and optimizer + criterion = nn.CrossEntropyLoss().cuda() + + cudnn.benchmark = True + + transforms = model_factory.get_transforms_eval( + args.model, + args.img_size) + + dataset = Dataset( + args.data, + transforms) + + loader = data.DataLoader( + dataset, + batch_size=args.batch_size, shuffle=False, + num_workers=args.workers, pin_memory=True) + + batch_time = AverageMeter() + losses = AverageMeter() + top1 = AverageMeter() + top5 = AverageMeter() + + # switch to evaluate mode + model.eval() + end = time.time() + with torch.no_grad(): + for i, (input, target) in enumerate(loader): + target = target.cuda() + input = input.cuda() + + # compute output + output = model(input) + loss = criterion(output, target) + + # measure accuracy and record loss + prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) + losses.update(loss.item(), input.size(0)) + top1.update(prec1.item(), input.size(0)) + top5.update(prec5.item(), input.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + print('Test: [{0}/{1}]\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' + 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' + 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( + i, len(loader), batch_time=batch_time, loss=losses, + top1=top1, top5=top5)) + + print(' * Prec@1 {top1.avg:.3f} ({top1a:.3f}) Prec@5 {top5.avg:.3f} ({top5a:.3f})'.format( + top1=top1, top1a=100-top1.avg, top5=top5, top5a=100.-top5.avg)) + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +if __name__ == '__main__': + main()