From 011a11d9879f975cd53b65632a192fef478ba2e0 Mon Sep 17 00:00:00 2001 From: talrid Date: Sun, 12 Apr 2020 18:44:12 +0300 Subject: [PATCH 1/2] TResNet models --- timm/models/__init__.py | 1 + timm/models/layers/__init__.py | 2 + timm/models/layers/anti_aliasing.py | 61 ++++++ timm/models/layers/space_to_depth.py | 53 +++++ timm/models/tresnet.py | 292 +++++++++++++++++++++++++++ 5 files changed, 409 insertions(+) create mode 100644 timm/models/layers/anti_aliasing.py create mode 100644 timm/models/layers/space_to_depth.py create mode 100644 timm/models/tresnet.py diff --git a/timm/models/__init__.py b/timm/models/__init__.py index cc4d470e..b073eb3a 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -17,6 +17,7 @@ from .res2net import * from .dla import * from .hrnet import * from .sknet import * +from .tresnet import * from .registry import * from .factory import create_model diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index 3dec7498..568a9e11 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -16,3 +16,5 @@ from .adaptive_avgmax_pool import \ from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path from .test_time_pool import TestTimePoolHead, apply_test_time_pool from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model +from .anti_aliasing import AntiAliasDownsampleLayer +from .space_to_depth import SpaceToDepthModule \ No newline at end of file diff --git a/timm/models/layers/anti_aliasing.py b/timm/models/layers/anti_aliasing.py new file mode 100644 index 00000000..a1f7535a --- /dev/null +++ b/timm/models/layers/anti_aliasing.py @@ -0,0 +1,61 @@ +import torch +import torch.nn.parallel +import torch.nn as nn +import torch.nn.functional as F + + +class AntiAliasDownsampleLayer(nn.Module): + def __init__(self, remove_aa_jit: bool = False, filt_size: int = 3, stride: int = 2, + channels: int = 0): + super(AntiAliasDownsampleLayer, self).__init__() + if not remove_aa_jit: + self.op = DownsampleJIT(filt_size, stride, channels) + else: + self.op = Downsample(filt_size, stride, channels) + + def forward(self, x): + return self.op(x) + + +@torch.jit.script +class DownsampleJIT(object): + def __init__(self, filt_size: int = 3, stride: int = 2, channels: int = 0): + self.stride = stride + self.filt_size = filt_size + self.channels = channels + + assert self.filt_size == 3 + assert stride == 2 + a = torch.tensor([1., 2., 1.]) + + filt = (a[:, None] * a[None, :]).clone().detach() + filt = filt / torch.sum(filt) + self.filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1)).cuda().half() + + def __call__(self, input: torch.Tensor): + if input.dtype != self.filt.dtype: + self.filt = self.filt.float() + input_pad = F.pad(input, (1, 1, 1, 1), 'reflect') + return F.conv2d(input_pad, self.filt, stride=2, padding=0, groups=input.shape[1]) + + +class Downsample(nn.Module): + def __init__(self, filt_size=3, stride=2, channels=None): + super(Downsample, self).__init__() + self.filt_size = filt_size + self.stride = stride + self.channels = channels + + + assert self.filt_size == 3 + a = torch.tensor([1., 2., 1.]) + + filt = (a[:, None] * a[None, :]) + filt = filt / torch.sum(filt) + + # self.filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1)) + self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1))) + + def forward(self, input): + input_pad = F.pad(input, (1, 1, 1, 1), 'reflect') + return F.conv2d(input_pad, self.filt, stride=self.stride, padding=0, groups=input.shape[1]) \ No newline at end of file diff --git a/timm/models/layers/space_to_depth.py b/timm/models/layers/space_to_depth.py new file mode 100644 index 00000000..70bf7db9 --- /dev/null +++ b/timm/models/layers/space_to_depth.py @@ -0,0 +1,53 @@ +import torch +import torch.nn as nn + + +class SpaceToDepth(nn.Module): + def __init__(self, block_size=4): + super().__init__() + assert block_size == 4 + self.bs = block_size + + def forward(self, x): + N, C, H, W = x.size() + x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs) + x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) + x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs) + return x + + +@torch.jit.script +class SpaceToDepthJit(object): + def __call__(self, x: torch.Tensor): + # assuming hard-coded that block_size==4 for acceleration + N, C, H, W = x.size() + x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs) + x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) + x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs) + return x + + +class SpaceToDepthModule(nn.Module): + def __init__(self, remove_model_jit=False): + super().__init__() + if not remove_model_jit: + self.op = SpaceToDepthJit() + else: + self.op = SpaceToDepth() + + def forward(self, x): + return self.op(x) + + +class DepthToSpace(nn.Module): + + def __init__(self, block_size): + super().__init__() + self.bs = block_size + + def forward(self, x): + N, C, H, W = x.size() + x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W) + x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs) + x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs) + return x \ No newline at end of file diff --git a/timm/models/tresnet.py b/timm/models/tresnet.py new file mode 100644 index 00000000..dc01e0fb --- /dev/null +++ b/timm/models/tresnet.py @@ -0,0 +1,292 @@ +""" +TResNet: High Performance GPU-Dedicated Architecture +https://arxiv.org/pdf/2003.13630.pdf + +Original model: https://github.com/mrT23/TResNet + +""" +from functools import partial +import torch +import torch.nn as nn +from collections import OrderedDict +from .layers import SpaceToDepthModule, AntiAliasDownsampleLayer +from .registry import register_model +from .helpers import load_pretrained + +try: + from inplace_abn import InPlaceABN + has_iabn = True +except ImportError: + has_iabn = False + +__all__ = ['tresnet_m', 'tresnet_l', 'tresnet_xl'] + + +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': (0, 0, 0), 'std': (1, 1, 1), + 'first_conv': 'layer0.conv1', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + 'tresnet_m': + _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/tresnet/tresnet_m_80_8.pth'), + 'tresnet_l': + _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/tresnet/tresnet_l_81_5.pth'), + 'tresnet_xl': + _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/tresnet/tresnet_xl_82_0.pth') +} + + +class FastGlobalAvgPool2d(nn.Module): + def __init__(self, flatten=False): + super(FastGlobalAvgPool2d, self).__init__() + self.flatten = flatten + + def forward(self, x): + if self.flatten: + in_size = x.size() + return x.view((in_size[0], in_size[1], -1)).mean(dim=2) + else: + return x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1) + + +class FastSEModule(nn.Module): + + def __init__(self, channels, reduction_channels, inplace=True): + super(FastSEModule, self).__init__() + self.avg_pool = FastGlobalAvgPool2d() + self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, padding=0, bias=True) + self.relu = nn.ReLU(inplace=inplace) + self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, padding=0, bias=True) + self.activation = nn.Sigmoid() + + def forward(self, x): + x_se = self.avg_pool(x) + x_se2 = self.fc1(x_se) + x_se2 = self.relu(x_se2) + x_se = self.fc2(x_se2) + x_se = self.activation(x_se) + return x * x_se + + +def IABN2Float(module: nn.Module) -> nn.Module: + "If `module` is IABN don't use half precision." + if isinstance(module, InPlaceABN): + module.float() + for child in module.children(): IABN2Float(child) + return module + + +def conv2d_ABN(ni, nf, stride, activation="leaky_relu", kernel_size=3, activation_param=1e-2, groups=1): + return nn.Sequential( + nn.Conv2d(ni, nf, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, groups=groups, + bias=False), + InPlaceABN(num_features=nf, activation=activation, activation_param=activation_param) + ) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True, anti_alias_layer=None): + super(BasicBlock, self).__init__() + if stride == 1: + self.conv1 = conv2d_ABN(inplanes, planes, stride=1, activation_param=1e-3) + else: + if anti_alias_layer is None: + self.conv1 = conv2d_ABN(inplanes, planes, stride=2, activation_param=1e-3) + else: + self.conv1 = nn.Sequential(conv2d_ABN(inplanes, planes, stride=1, activation_param=1e-3), + anti_alias_layer(channels=planes, filt_size=3, stride=2)) + + self.conv2 = conv2d_ABN(planes, planes, stride=1, activation="identity") + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + reduce_layer_planes = max(planes * self.expansion // 4, 64) + self.se = FastSEModule(planes * self.expansion, reduce_layer_planes) if use_se else None + + def forward(self, x): + if self.downsample is not None: + residual = self.downsample(x) + else: + residual = x + + out = self.conv1(x) + out = self.conv2(out) + + if self.se is not None: out = self.se(out) + + out += residual + + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True, anti_alias_layer=None): + super(Bottleneck, self).__init__() + self.conv1 = conv2d_ABN(inplanes, planes, kernel_size=1, stride=1, activation="leaky_relu", + activation_param=1e-3) + if stride == 1: + self.conv2 = conv2d_ABN(planes, planes, kernel_size=3, stride=1, activation="leaky_relu", + activation_param=1e-3) + else: + if anti_alias_layer is None: + self.conv2 = conv2d_ABN(planes, planes, kernel_size=3, stride=2, activation="leaky_relu", + activation_param=1e-3) + else: + self.conv2 = nn.Sequential(conv2d_ABN(planes, planes, kernel_size=3, stride=1, + activation="leaky_relu", activation_param=1e-3), + anti_alias_layer(channels=planes, filt_size=3, stride=2)) + + self.conv3 = conv2d_ABN(planes, planes * self.expansion, kernel_size=1, stride=1, + activation="identity") + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + reduce_layer_planes = max(planes * self.expansion // 8, 64) + self.se = FastSEModule(planes, reduce_layer_planes) if use_se else None + + def forward(self, x): + if self.downsample is not None: + residual = self.downsample(x) + else: + residual = x + + out = self.conv1(x) + out = self.conv2(out) + if self.se is not None: out = self.se(out) + + out = self.conv3(out) + out = out + residual # no inplace + out = self.relu(out) + + return out + + +class TResNet(nn.Module): + def __init__(self, layers, in_chans=3, num_classes=1000, width_factor=1.0, remove_aa_jit=False): + if not has_iabn: + raise " For TResNet models, please install InplaceABN: 'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.11' " + + super(TResNet, self).__init__() + + # JIT layers + space_to_depth = SpaceToDepthModule() + anti_alias_layer = partial(AntiAliasDownsampleLayer, remove_aa_jit=remove_aa_jit) + global_pool_layer = FastGlobalAvgPool2d(flatten=True) + + # TResnet stages + self.inplanes = int(64 * width_factor) + self.planes = int(64 * width_factor) + conv1 = conv2d_ABN(in_chans * 16, self.planes, stride=1, kernel_size=3) + layer1 = self._make_layer(BasicBlock, self.planes, layers[0], stride=1, use_se=True, + anti_alias_layer=anti_alias_layer) # 56x56 + layer2 = self._make_layer(BasicBlock, self.planes * 2, layers[1], stride=2, use_se=True, + anti_alias_layer=anti_alias_layer) # 28x28 + layer3 = self._make_layer(Bottleneck, self.planes * 4, layers[2], stride=2, use_se=True, + anti_alias_layer=anti_alias_layer) # 14x14 + layer4 = self._make_layer(Bottleneck, self.planes * 8, layers[3], stride=2, use_se=False, + anti_alias_layer=anti_alias_layer) # 7x7 + + # body + self.body = nn.Sequential(OrderedDict([ + ('SpaceToDepth', space_to_depth), + ('conv1', conv1), + ('layer1', layer1), + ('layer2', layer2), + ('layer3', layer3), + ('layer4', layer4)])) + + # head + self.embeddings = [] + self.global_pool = nn.Sequential(OrderedDict([('global_pool_layer', global_pool_layer)])) + self.num_features = (self.planes * 8) * Bottleneck.expansion + fc = nn.Linear(self.num_features, num_classes) + self.head = nn.Sequential(OrderedDict([('fc', fc)])) + + # model initilization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') + elif isinstance(m, nn.BatchNorm2d) or isinstance(m, InPlaceABN): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # residual connections special initialization + for m in self.modules(): + if isinstance(m, BasicBlock): + m.conv2[1].weight = nn.Parameter(torch.zeros_like(m.conv2[1].weight)) # BN to zero + if isinstance(m, Bottleneck): + m.conv3[1].weight = nn.Parameter(torch.zeros_like(m.conv3[1].weight)) # BN to zero + if isinstance(m, nn.Linear): m.weight.data.normal_(0, 0.01) + + def _make_layer(self, block, planes, blocks, stride=1, use_se=True, anti_alias_layer=None): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + layers = [] + if stride == 2: + # avg pooling before 1x1 conv + layers.append(nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True, count_include_pad=False)) + layers += [conv2d_ABN(self.inplanes, planes * block.expansion, kernel_size=1, stride=1, + activation="identity")] + downsample = nn.Sequential(*layers) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, use_se=use_se, + anti_alias_layer=anti_alias_layer)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): layers.append( + block(self.inplanes, planes, use_se=use_se, anti_alias_layer=anti_alias_layer)) + return nn.Sequential(*layers) + + def forward(self, x): + x = self.body(x) + self.embeddings = self.global_pool(x) + logits = self.head(self.embeddings) + return logits + + +def filter_fn(input): + return input['model'] + + +@register_model +def tresnet_m(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + default_cfg = default_cfgs['tresnet_m'] + model = TResNet(layers=[3, 4, 11, 3], num_classes=num_classes, in_chans=in_chans) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=filter_fn) + return model + + +@register_model +def tresnet_l(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + default_cfg = default_cfgs['tresnet_l'] + model = TResNet(layers=[4, 5, 18, 3], num_classes=num_classes, in_chans=in_chans, width_factor=1.2) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=filter_fn) + return model + + +@register_model +def tresnet_xl(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + default_cfg = default_cfgs['tresnet_xl'] + model = TResNet(layers=[4, 5, 24, 3], num_classes=num_classes, in_chans=in_chans, width_factor=1.3) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=filter_fn) + return model From 27fadaa922e7a6432ab89915594bbec313caa670 Mon Sep 17 00:00:00 2001 From: talrid Date: Fri, 16 Oct 2020 17:12:28 +0300 Subject: [PATCH 2/2] asymmetric_loss --- timm/loss/__init__.py | 3 +- timm/loss/asymmetric_loss.py | 97 ++++++++++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+), 1 deletion(-) create mode 100644 timm/loss/asymmetric_loss.py diff --git a/timm/loss/__init__.py b/timm/loss/__init__.py index b781472f..28a686ce 100644 --- a/timm/loss/__init__.py +++ b/timm/loss/__init__.py @@ -1,2 +1,3 @@ from .cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy -from .jsd import JsdCrossEntropy \ No newline at end of file +from .jsd import JsdCrossEntropy +from .asymmetric_loss import AsymmetricLossMultiLabel, AsymmetricLossSingleLabel \ No newline at end of file diff --git a/timm/loss/asymmetric_loss.py b/timm/loss/asymmetric_loss.py new file mode 100644 index 00000000..96a97788 --- /dev/null +++ b/timm/loss/asymmetric_loss.py @@ -0,0 +1,97 @@ +import torch +import torch.nn as nn + + +class AsymmetricLossMultiLabel(nn.Module): + def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False): + super(AsymmetricLossMultiLabel, self).__init__() + + self.gamma_neg = gamma_neg + self.gamma_pos = gamma_pos + self.clip = clip + self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss + self.eps = eps + + def forward(self, x, y): + """" + Parameters + ---------- + x: input logits + y: targets (multi-label binarized vector) + """ + + # Calculating Probabilities + x_sigmoid = torch.sigmoid(x) + xs_pos = x_sigmoid + xs_neg = 1 - x_sigmoid + + # Asymmetric Clipping + if self.clip is not None and self.clip > 0: + xs_neg = (xs_neg + self.clip).clamp(max=1) + + # Basic CE calculation + los_pos = y * torch.log(xs_pos.clamp(min=self.eps)) + los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps)) + loss = los_pos + los_neg + + # Asymmetric Focusing + if self.gamma_neg > 0 or self.gamma_pos > 0: + if self.disable_torch_grad_focal_loss: + torch._C.set_grad_enabled(False) + pt0 = xs_pos * y + pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p + pt = pt0 + pt1 + one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y) + one_sided_w = torch.pow(1 - pt, one_sided_gamma) + if self.disable_torch_grad_focal_loss: + torch._C.set_grad_enabled(True) + loss *= one_sided_w + + return -loss.sum() + + +class AsymmetricLossSingleLabel(nn.Module): + def __init__(self, gamma_pos=1, gamma_neg=4, eps: float = 0.1, reduction='mean'): + super(AsymmetricLossSingleLabel, self).__init__() + + self.eps = eps + self.logsoftmax = nn.LogSoftmax(dim=-1) + self.targets_classes = [] # prevent gpu repeated memory allocation + self.gamma_pos = gamma_pos + self.gamma_neg = gamma_neg + self.reduction = reduction + + def forward(self, inputs, target, reduction=None): + """" + Parameters + ---------- + x: input logits + y: targets (1-hot vector) + """ + + num_classes = inputs.size()[-1] + log_preds = self.logsoftmax(inputs) + self.targets_classes = torch.zeros_like(inputs).scatter_(1, target.long().unsqueeze(1), 1) + + # ASL weights + targets = self.targets_classes + anti_targets = 1 - targets + xs_pos = torch.exp(log_preds) + xs_neg = 1 - xs_pos + xs_pos = xs_pos * targets + xs_neg = xs_neg * anti_targets + asymmetric_w = torch.pow(1 - xs_pos - xs_neg, + self.gamma_pos * targets + self.gamma_neg * anti_targets) + log_preds = log_preds * asymmetric_w + + if self.eps > 0: # label smoothing + self.targets_classes.mul_(1 - self.eps).add_(self.eps / num_classes) + + # loss calculation + loss = - self.targets_classes.mul(log_preds) + + loss = loss.sum(dim=-1) + if self.reduction == 'mean': + loss = loss.mean() + + return loss