From 4d505e0785db1800889d19aa2afa93f79ba93f61 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 10 Aug 2019 13:52:01 -0700 Subject: [PATCH 1/2] Add working Gluon Xception-65 model. Some cleanup still needed. --- convert/convert_from_mxnet.py | 2 +- timm/models/__init__.py | 1 + timm/models/gluon_xception.py | 438 ++++++++++++++++++++++++++++++++++ 3 files changed, 440 insertions(+), 1 deletion(-) create mode 100644 timm/models/gluon_xception.py diff --git a/convert/convert_from_mxnet.py b/convert/convert_from_mxnet.py index ba404a1e..f2c64c25 100644 --- a/convert/convert_from_mxnet.py +++ b/convert/convert_from_mxnet.py @@ -5,7 +5,7 @@ import os import mxnet as mx import gluoncv import torch -from models.model_factory import create_model +from timm import create_model parser = argparse.ArgumentParser(description='Convert from MXNet') parser.add_argument('--model', default='all', type=str, metavar='MODEL', diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 1a864f87..f0825077 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -10,6 +10,7 @@ from .pnasnet import * from .gen_efficientnet import * from .inception_v3 import * from .gluon_resnet import * +from .gluon_xception import * from .registry import * from .factory import create_model diff --git a/timm/models/gluon_xception.py b/timm/models/gluon_xception.py new file mode 100644 index 00000000..6fc33d8f --- /dev/null +++ b/timm/models/gluon_xception.py @@ -0,0 +1,438 @@ +"""Pytorch impl of Gluon Xception +This is a port of the Gluon Xception code and weights, itself ported from a PyTorch DeepLab impl. + +Gluon model: (https://gluon-cv.mxnet.io/_modules/gluoncv/model_zoo/xception.html) +Original PyTorch DeepLab impl: https://github.com/jfzhang95/pytorch-deeplab-xception + +Hacked together by Ross Wightman +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +from collections import OrderedDict + +from .registry import register_model +from .helpers import load_pretrained +from .adaptive_avgmax_pool import select_adaptive_pool2d +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD + +__all__ = ['Xception65', 'Xception71'] + +default_cfgs = { + 'gluon_xception65': { + 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_xception-7015a15c.pth', + 'input_size': (3, 299, 299), + 'crop_pct': 0.875, + 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, + 'std': IMAGENET_DEFAULT_STD, + 'num_classes': 1000, + 'first_conv': 'conv1', + 'classifier': 'fc' + # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 + }, + 'gluon_xception71': { + 'url': '', + 'input_size': (3, 299, 299), + 'crop_pct': 0.875, + 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, + 'std': IMAGENET_DEFAULT_STD, + 'num_classes': 1000, + 'first_conv': 'conv1', + 'classifier': 'fc' + # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 + } +} + + +""" PADDING NOTES +The original PyTorch and Gluon impl of these models dutifully reproduced the +aligned padding added to Tensorflow models for Deeplab. This padding was compensating +for Tensorflow 'SAME' padding. PyTorch symmetric padding behaves the way we'd want it to. + +So, I'm phasing out the 'fixed_padding' ported from TF and replacing with normal +PyTorch padding, some asserts to validate the equivalence for any scenario we'd +care about before removing altogether. +""" +_USE_FIXED_PAD = False + + +def _pytorch_padding(kernel_size, stride=1, dilation=1, **_): + if _USE_FIXED_PAD: + return 0 # FIXME remove once verified + else: + padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 + + # FIXME remove once verified + fp = _fixed_padding(kernel_size, dilation) + assert all(padding == p for p in fp) + + return padding + + +def _fixed_padding(kernel_size, dilation): + kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) + pad_total = kernel_size_effective - 1 + pad_beg = pad_total // 2 + pad_end = pad_total - pad_beg + return [pad_beg, pad_end, pad_beg, pad_end] + + +class SeparableConv2d(nn.Module): + def __init__(self, inplanes, planes, kernel_size=3, stride=1, + dilation=1, bias=False, norm_layer=None, norm_kwargs=None): + super(SeparableConv2d, self).__init__() + norm_kwargs = norm_kwargs if norm_kwargs is not None else {} + self.kernel_size = kernel_size + self.dilation = dilation + + padding = _fixed_padding(self.kernel_size, self.dilation) + if _USE_FIXED_PAD and any(p > 0 for p in padding): + self.fixed_padding = nn.ZeroPad2d(padding) + else: + self.fixed_padding = None + + # depthwise convolution + self.conv_dw = nn.Conv2d( + inplanes, inplanes, kernel_size, stride=stride, + padding=_pytorch_padding(kernel_size, stride, dilation), dilation=dilation, groups=inplanes, bias=bias) + self.bn = norm_layer(num_features=inplanes, **norm_kwargs) + # pointwise convolution + self.conv_pw = nn.Conv2d(inplanes, planes, kernel_size=1, bias=bias) + + def forward(self, x): + if self.fixed_padding is not None: + # FIXME remove once verified + x = self.fixed_padding(x) + x = self.conv_dw(x) + x = self.bn(x) + x = self.conv_pw(x) + return x + + +class Block(nn.Module): + def __init__(self, inplanes, planes, num_reps, stride=1, dilation=1, norm_layer=None, + norm_kwargs=None, start_with_relu=True, grow_first=True, is_last=False): + super(Block, self).__init__() + norm_kwargs = norm_kwargs if norm_kwargs is not None else {} + if planes != inplanes or stride != 1: + self.skip = nn.Sequential() + self.skip.add_module('conv1', nn.Conv2d( + inplanes, planes, 1, stride=stride, bias=False)), + self.skip.add_module('bn1', norm_layer(num_features=planes, **norm_kwargs)) + else: + self.skip = None + + rep = OrderedDict() + l = 1 + filters = inplanes + if grow_first: + if start_with_relu: + rep['act%d' % l] = nn.ReLU(inplace=False) # NOTE: silent failure if inplace=True here + rep['conv%d' % l] = SeparableConv2d( + inplanes, planes, 3, 1, dilation, norm_layer=norm_layer, norm_kwargs=norm_kwargs) + rep['bn%d' % l] = norm_layer(num_features=planes, **norm_kwargs) + filters = planes + l += 1 + + for _ in range(num_reps - 1): + if grow_first or start_with_relu: + # FIXME being conservative with inplace here, think it's fine to leave True? + rep['act%d' % l] = nn.ReLU(inplace=grow_first or not start_with_relu) + rep['conv%d' % l] = SeparableConv2d( + filters, filters, 3, 1, dilation, norm_layer=norm_layer, norm_kwargs=norm_kwargs) + rep['bn%d' % l] = norm_layer(num_features=filters, **norm_kwargs) + l += 1 + + if not grow_first: + rep['act%d' % l] = nn.ReLU(inplace=True) + rep['conv%d' % l] = SeparableConv2d( + inplanes, planes, 3, 1, dilation, norm_layer=norm_layer, norm_kwargs=norm_kwargs) + rep['bn%d' % l] = norm_layer(num_features=planes, **norm_kwargs) + l += 1 + + if stride != 1: + rep['act%d' % l] = nn.ReLU(inplace=True) + rep['conv%d' % l] = SeparableConv2d( + planes, planes, 3, stride, norm_layer=norm_layer, norm_kwargs=norm_kwargs) + rep['bn%d' % l] = norm_layer(num_features=planes, **norm_kwargs) + l += 1 + elif is_last: + rep['act%d' % l] = nn.ReLU(inplace=True) + rep['conv%d' % l] = SeparableConv2d( + planes, planes, 3, 1, dilation, norm_layer=norm_layer, norm_kwargs=norm_kwargs) + rep['bn%d' % l] = norm_layer(num_features=planes, **norm_kwargs) + l += 1 + self.rep = nn.Sequential(rep) + + def forward(self, x): + skip = x + if self.skip is not None: + skip = self.skip(skip) + x = self.rep(x) + skip + return x + + +class Xception65(nn.Module): + """Modified Aligned Xception + """ + + def __init__(self, num_classes=1000, in_chans=3, output_stride=32, norm_layer=nn.BatchNorm2d, + norm_kwargs=None, drop_rate=0., global_pool='avg'): + super(Xception65, self).__init__() + self.drop_rate = drop_rate + norm_kwargs = norm_kwargs if norm_kwargs is not None else {} + if output_stride == 32: + entry_block3_stride = 2 + exit_block20_stride = 2 + middle_block_dilation = 1 + exit_block_dilations = (1, 1) + elif output_stride == 16: + entry_block3_stride = 2 + exit_block20_stride = 1 + middle_block_dilation = 1 + exit_block_dilations = (1, 2) + elif output_stride == 8: + entry_block3_stride = 1 + exit_block20_stride = 1 + middle_block_dilation = 2 + exit_block_dilations = (2, 4) + else: + raise NotImplementedError + + # Entry flow + self.conv1 = nn.Conv2d(in_chans, 32, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = norm_layer(num_features=32, **norm_kwargs) + self.relu = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = norm_layer(num_features=64) + + self.block1 = Block( + 64, 128, num_reps=2, stride=2, + norm_layer=norm_layer, norm_kwargs=norm_kwargs, start_with_relu=False) + self.block2 = Block( + 128, 256, num_reps=2, stride=2, + norm_layer=norm_layer, norm_kwargs=norm_kwargs, start_with_relu=False, grow_first=True) + self.block3 = Block( + 256, 728, num_reps=2, stride=entry_block3_stride, + norm_layer=norm_layer, norm_kwargs=norm_kwargs, start_with_relu=True, grow_first=True, is_last=True) + + # Middle flow + self.mid = nn.Sequential(OrderedDict([('block%d' % i, Block( + 728, 728, num_reps=3, stride=1, dilation=middle_block_dilation, + norm_layer=norm_layer, norm_kwargs=norm_kwargs, start_with_relu=True, grow_first=True)) + for i in range(4, 20)])) + + # Exit flow + self.block20 = Block( + 728, 1024, num_reps=2, stride=exit_block20_stride, dilation=exit_block_dilations[0], + norm_layer=norm_layer, norm_kwargs=norm_kwargs, start_with_relu=True, grow_first=False, is_last=True) + + self.conv3 = SeparableConv2d( + 1024, 1536, 3, stride=1, dilation=exit_block_dilations[1], + norm_layer=norm_layer, norm_kwargs=norm_kwargs) + self.bn3 = norm_layer(num_features=1536, **norm_kwargs) + + self.conv4 = SeparableConv2d( + 1536, 1536, 3, stride=1, dilation=exit_block_dilations[1], + norm_layer=norm_layer, norm_kwargs=norm_kwargs) + self.bn4 = norm_layer(num_features=1536, **norm_kwargs) + + self.conv5 = SeparableConv2d( + 1536, 2048, 3, stride=1, dilation=exit_block_dilations[1], + norm_layer=norm_layer, norm_kwargs=norm_kwargs) + self.bn5 = norm_layer(num_features=2048, **norm_kwargs) + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Linear(in_features=2048, out_features=num_classes) + + def forward(self, x): + # Entry flow + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + + x = self.block1(x) + # add relu here + x = self.relu(x) + # c1 = x + x = self.block2(x) + # c2 = x + x = self.block3(x) + + # Middle flow + x = self.mid(x) + # c3 = x + + # Exit flow + x = self.block20(x) + x = self.relu(x) + x = self.conv3(x) + x = self.bn3(x) + x = self.relu(x) + + x = self.conv4(x) + x = self.bn4(x) + x = self.relu(x) + + x = self.conv5(x) + x = self.bn5(x) + x = self.relu(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.fc(x) + return x + + +class Xception71(nn.Module): + """Modified Aligned Xception + """ + + def __init__(self, num_classes=1000, in_chans=3, output_stride=32, norm_layer=nn.BatchNorm2d, + norm_kwargs=None, drop_rate=0., global_pool='avg'): + super(Xception71, self).__init__() + self.drop_rate = drop_rate + norm_kwargs = norm_kwargs if norm_kwargs is not None else {} + if output_stride == 32: + entry_block3_stride = 2 + exit_block20_stride = 2 + middle_block_dilation = 1 + exit_block_dilations = (1, 1) + elif output_stride == 16: + entry_block3_stride = 2 + exit_block20_stride = 1 + middle_block_dilation = 1 + exit_block_dilations = (1, 2) + elif output_stride == 8: + entry_block3_stride = 1 + exit_block20_stride = 1 + middle_block_dilation = 2 + exit_block_dilations = (2, 4) + else: + raise NotImplementedError + + # Entry flow + self.conv1 = nn.Conv2d(in_chans, 32, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = norm_layer(num_features=32, **norm_kwargs) + self.relu = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = norm_layer(num_features=64) + + self.block1 = Block( + 64, 128, num_reps=2, stride=2, norm_layer=norm_layer, + norm_kwargs=norm_kwargs, start_with_relu=False) + self.block2 = nn.Sequential(*[ + Block( + 128, 256, num_reps=2, stride=1, norm_layer=norm_layer, + norm_kwargs=norm_kwargs, start_with_relu=False, grow_first=True), + Block( + 256, 256, num_reps=2, stride=2, norm_layer=norm_layer, + norm_kwargs=norm_kwargs, start_with_relu=False, grow_first=True), + Block( + 256, 728, num_reps=2, stride=2, norm_layer=norm_layer, + norm_kwargs=norm_kwargs, start_with_relu=False, grow_first=True)]) + self.block3 = Block( + 728, 728, num_reps=2, stride=entry_block3_stride, norm_layer=norm_layer, + norm_kwargs=norm_kwargs, start_with_relu=True, grow_first=True, is_last=True) + + # Middle flow + self.mid = nn.Sequential(OrderedDict([('block%d' % i, Block( + 728, 728, num_reps=3, stride=1, dilation=middle_block_dilation, + norm_layer=norm_layer, norm_kwargs=norm_kwargs, start_with_relu=True, grow_first=True)) + for i in range(4, 20)])) + + # Exit flow + self.block20 = Block( + 728, 1024, num_reps=2, stride=exit_block20_stride, dilation=exit_block_dilations[0], + norm_layer=norm_layer, norm_kwargs=norm_kwargs, start_with_relu=True, grow_first=False, is_last=True) + + self.conv3 = SeparableConv2d( + 1024, 1536, 3, stride=1, dilation=exit_block_dilations[1], + norm_layer=norm_layer, norm_kwargs=norm_kwargs) + self.bn3 = norm_layer(num_features=1536, **norm_kwargs) + + self.conv4 = SeparableConv2d( + 1536, 1536, 3, stride=1, dilation=exit_block_dilations[1], + norm_layer=norm_layer, norm_kwargs=norm_kwargs) + self.bn4 = norm_layer(num_features=1536, **norm_kwargs) + + self.conv5 = SeparableConv2d( + 1536, 2048, 3, stride=1, dilation=exit_block_dilations[1], + norm_layer=norm_layer, norm_kwargs=norm_kwargs) + self.bn5 = norm_layer(num_features=2048, **norm_kwargs) + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Linear(in_features=2048, out_features=num_classes) + + def forward(self, x): + # Entry flow + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + + x = self.block1(x) + # add relu here + x = self.relu(x) + # low_level_feat = x + x = self.block2(x) + # c2 = x + x = self.block3(x) + + # Middle flow + x = self.mid(x) + # c3 = x + + # Exit flow + x = self.block20(x) + x = self.relu(x) + x = self.conv3(x) + x = self.bn3(x) + x = self.relu(x) + + x = self.conv4(x) + x = self.bn4(x) + x = self.relu(x) + + x = self.conv5(x) + x = self.bn5(x) + x = self.relu(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + x = self.fc(x) + return x + + +@register_model +def gluon_xception65(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + default_cfg = default_cfgs['gluon_xception65'] + model = Xception65(num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +@register_model +def gluon_xception71(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + default_cfg = default_cfgs['gluon_xception71'] + model = Xception71(num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + From 3b4868f6dc617e542acd5767049d06a747e5863e Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 11 Aug 2019 23:06:23 -0700 Subject: [PATCH 2/2] A few more additions to Gluon Xception models to match interface of others. --- README.md | 1 + timm/models/gluon_xception.py | 80 +++++++++++++++++++++++++++-------- 2 files changed, 63 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index a648ce19..bcb6b843 100644 --- a/README.md +++ b/README.md @@ -120,6 +120,7 @@ I've leveraged the training scripts in this repository to train a few of the mod | gluon_resnet152_v1c | 79.916 (20.084) | 94.842 (5.158) | 60.21 | bicubic | 224 | | | gluon_seresnext50_32x4d | 79.912 (20.088) | 94.818 (5.182) | 27.56 | bicubic | 224 | | | gluon_resnet152_v1b | 79.692 (20.308) | 94.738 (5.262) | 60.19 | bicubic | 224 | | +| gluon_xception65 | 79.604 (20.396) | 94.748 (5.252) | 39.92 | bicubic | 299 | | | gluon_resnet101_v1c | 79.544 (20.456) | 94.586 (5.414) | 44.57 | bicubic | 224 | | | gluon_resnext50_32x4d | 79.356 (20.644) | 94.424 (5.576) | 25.03 | bicubic | 224 | | | gluon_resnet101_v1b | 79.304 (20.696) | 94.524 (5.476) | 44.55 | bicubic | 224 | | diff --git a/timm/models/gluon_xception.py b/timm/models/gluon_xception.py index 6fc33d8f..9393e5ba 100644 --- a/timm/models/gluon_xception.py +++ b/timm/models/gluon_xception.py @@ -23,6 +23,7 @@ default_cfgs = { 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_xception-7015a15c.pth', 'input_size': (3, 299, 299), 'crop_pct': 0.875, + 'pool_size': (10, 10), 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, @@ -35,6 +36,7 @@ default_cfgs = { 'url': '', 'input_size': (3, 299, 299), 'crop_pct': 0.875, + 'pool_size': (10, 10), 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, @@ -181,7 +183,9 @@ class Xception65(nn.Module): def __init__(self, num_classes=1000, in_chans=3, output_stride=32, norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_rate=0., global_pool='avg'): super(Xception65, self).__init__() + self.num_classes = num_classes self.drop_rate = drop_rate + self.global_pool = global_pool norm_kwargs = norm_kwargs if norm_kwargs is not None else {} if output_stride == 32: entry_block3_stride = 2 @@ -240,14 +244,26 @@ class Xception65(nn.Module): norm_layer=norm_layer, norm_kwargs=norm_kwargs) self.bn4 = norm_layer(num_features=1536, **norm_kwargs) + self.num_features = 2048 self.conv5 = SeparableConv2d( - 1536, 2048, 3, stride=1, dilation=exit_block_dilations[1], + 1536, self.num_features, 3, stride=1, dilation=exit_block_dilations[1], norm_layer=norm_layer, norm_kwargs=norm_kwargs) - self.bn5 = norm_layer(num_features=2048, **norm_kwargs) - self.avgpool = nn.AdaptiveAvgPool2d(1) - self.fc = nn.Linear(in_features=2048, out_features=num_classes) + self.bn5 = norm_layer(num_features=self.num_features, **norm_kwargs) + self.fc = nn.Linear(in_features=self.num_features, out_features=num_classes) + + def get_classifier(self): + return self.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool = global_pool + del self.fc + if num_classes: + self.fc = nn.Linear(self.num_features, num_classes) + else: + self.fc = None - def forward(self, x): + def forward_features(self, x, pool=True): # Entry flow x = self.conv1(x) x = self.bn1(x) @@ -284,10 +300,15 @@ class Xception65(nn.Module): x = self.bn5(x) x = self.relu(x) - x = self.avgpool(x) - x = x.view(x.size(0), -1) - if self.drop_rate > 0.: - x = F.dropout(x, p=self.drop_rate, training=self.training) + if pool: + x = select_adaptive_pool2d(x, pool_type=self.global_pool) + x = x.view(x.size(0), -1) + return x + + def forward(self, x): + x = self.forward_features(x) + if self.drop_rate: + F.dropout(x, self.drop_rate, training=self.training) x = self.fc(x) return x @@ -299,7 +320,9 @@ class Xception71(nn.Module): def __init__(self, num_classes=1000, in_chans=3, output_stride=32, norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_rate=0., global_pool='avg'): super(Xception71, self).__init__() + self.num_classes = num_classes self.drop_rate = drop_rate + self.global_pool = global_pool norm_kwargs = norm_kwargs if norm_kwargs is not None else {} if output_stride == 32: entry_block3_stride = 2 @@ -365,14 +388,26 @@ class Xception71(nn.Module): norm_layer=norm_layer, norm_kwargs=norm_kwargs) self.bn4 = norm_layer(num_features=1536, **norm_kwargs) + self.num_features = 2048 self.conv5 = SeparableConv2d( - 1536, 2048, 3, stride=1, dilation=exit_block_dilations[1], + 1536, self.num_features, 3, stride=1, dilation=exit_block_dilations[1], norm_layer=norm_layer, norm_kwargs=norm_kwargs) - self.bn5 = norm_layer(num_features=2048, **norm_kwargs) - self.avgpool = nn.AdaptiveAvgPool2d(1) - self.fc = nn.Linear(in_features=2048, out_features=num_classes) + self.bn5 = norm_layer(num_features=self.num_features, **norm_kwargs) + self.fc = nn.Linear(in_features=self.num_features, out_features=num_classes) + + def get_classifier(self): + return self.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool = global_pool + del self.fc + if num_classes: + self.fc = nn.Linear(self.num_features, num_classes) + else: + self.fc = None - def forward(self, x): + def forward_features(self, x, pool=True): # Entry flow x = self.conv1(x) x = self.bn1(x) @@ -409,16 +444,23 @@ class Xception71(nn.Module): x = self.bn5(x) x = self.relu(x) - x = self.avgpool(x) - x = x.view(x.size(0), -1) - if self.drop_rate > 0.: - x = F.dropout(x, p=self.drop_rate, training=self.training) + if pool: + x = select_adaptive_pool2d(x, pool_type=self.global_pool) + x = x.view(x.size(0), -1) + return x + + def forward(self, x): + x = self.forward_features(x) + if self.drop_rate: + F.dropout(x, self.drop_rate, training=self.training) x = self.fc(x) return x @register_model def gluon_xception65(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """ Modified Aligned Xception-65 + """ default_cfg = default_cfgs['gluon_xception65'] model = Xception65(num_classes=num_classes, in_chans=in_chans, **kwargs) model.default_cfg = default_cfg @@ -429,6 +471,8 @@ def gluon_xception65(pretrained=False, num_classes=1000, in_chans=3, **kwargs): @register_model def gluon_xception71(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """ Modified Aligned Xception-71 + """ default_cfg = default_cfgs['gluon_xception71'] model = Xception71(num_classes=num_classes, in_chans=in_chans, **kwargs) model.default_cfg = default_cfg