Merge pull request #880 from rwightman/fixes_bce_regnet

A collection of fixes, model experiments, etc
pull/910/head
Ross Wightman 3 years ago committed by GitHub
commit cd638d50a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -41,6 +41,10 @@ def checkpoint_metric(checkpoint_path):
metric = None
if 'metric' in checkpoint:
metric = checkpoint['metric']
elif 'metrics' in checkpoint and 'metric_name' in checkpoint:
metrics = checkpoint['metrics']
print(metrics)
metric = metrics[checkpoint['metric_name']]
return metric

@ -3,8 +3,11 @@
Prefetcher and Fast Collate inspired by NVIDIA APEX example at
https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#diff-cf86c282ff7fba81fad27a559379d5bf
Hacked together by / Copyright 2020 Ross Wightman
Hacked together by / Copyright 2021 Ross Wightman
"""
import random
from functools import partial
from typing import Callable
import torch.utils.data
import numpy as np
@ -125,6 +128,22 @@ class PrefetchLoader:
self.loader.collate_fn.mixup_enabled = x
def _worker_init(worker_id, worker_seeding='all'):
worker_info = torch.utils.data.get_worker_info()
assert worker_info.id == worker_id
if isinstance(worker_seeding, Callable):
seed = worker_seeding(worker_info)
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed % (2 ** 32 - 1))
else:
assert worker_seeding in ('all', 'part')
# random / torch seed already called in dataloader iter class w/ worker_info.seed
# to reproduce some old results (same seed + hparam combo), partial seeding is required (skip numpy re-seed)
if worker_seeding == 'all':
np.random.seed(worker_info.seed % (2 ** 32 - 1))
def create_loader(
dataset,
input_size,
@ -156,6 +175,7 @@ def create_loader(
tf_preprocessing=False,
use_multi_epochs_loader=False,
persistent_workers=True,
worker_seeding='all',
):
re_num_splits = 0
if re_split:
@ -202,7 +222,6 @@ def create_loader(
collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate
loader_class = torch.utils.data.DataLoader
if use_multi_epochs_loader:
loader_class = MultiEpochsDataLoader
@ -214,7 +233,9 @@ def create_loader(
collate_fn=collate_fn,
pin_memory=pin_memory,
drop_last=is_training,
persistent_workers=persistent_workers)
worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding),
persistent_workers=persistent_workers
)
try:
loader = loader_class(dataset, **loader_args)
except TypeError as e:

@ -1,4 +1,4 @@
from .asymmetric_loss import AsymmetricLossMultiLabel, AsymmetricLossSingleLabel
from .binary_cross_entropy import DenseBinaryCrossEntropy
from .binary_cross_entropy import BinaryCrossEntropy
from .cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from .jsd import JsdCrossEntropy

@ -1,23 +1,47 @@
""" Binary Cross Entropy w/ a few extras
Hacked together by / Copyright 2021 Ross Wightman
"""
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
class DenseBinaryCrossEntropy(nn.Module):
""" BCE using one-hot from dense targets w/ label smoothing
class BinaryCrossEntropy(nn.Module):
""" BCE with optional one-hot from dense targets, label smoothing, thresholding
NOTE for experiments comparing CE to BCE /w label smoothing, may remove
"""
def __init__(self, smoothing=0.1):
super(DenseBinaryCrossEntropy, self).__init__()
def __init__(
self, smoothing=0.1, target_threshold: Optional[float] = None, weight: Optional[torch.Tensor] = None,
reduction: str = 'mean', pos_weight: Optional[torch.Tensor] = None):
super(BinaryCrossEntropy, self).__init__()
assert 0. <= smoothing < 1.0
self.smoothing = smoothing
self.bce = nn.BCEWithLogitsLoss()
self.target_threshold = target_threshold
self.reduction = reduction
self.register_buffer('weight', weight)
self.register_buffer('pos_weight', pos_weight)
def forward(self, x, target):
num_classes = x.shape[-1]
off_value = self.smoothing / num_classes
on_value = 1. - self.smoothing + off_value
target = target.long().view(-1, 1)
target = torch.full(
(target.size()[0], num_classes), off_value, device=x.device, dtype=x.dtype).scatter_(1, target, on_value)
return self.bce(x, target)
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
assert x.shape[0] == target.shape[0]
if target.shape != x.shape:
# NOTE currently assume smoothing or other label softening is applied upstream if targets are already sparse
num_classes = x.shape[-1]
# FIXME should off/on be different for smoothing w/ BCE? Other impl out there differ
off_value = self.smoothing / num_classes
on_value = 1. - self.smoothing + off_value
target = target.long().view(-1, 1)
target = torch.full(
(target.size()[0], num_classes),
off_value,
device=x.device, dtype=x.dtype).scatter_(1, target, on_value)
if self.target_threshold is not None:
# Make target 0, or 1 if threshold set
target = target.gt(self.target_threshold).to(dtype=target.dtype)
return F.binary_cross_entropy_with_logits(
x, target,
self.weight,
pos_weight=self.pos_weight,
reduction=self.reduction)

@ -1,23 +1,23 @@
""" Cross Entropy w/ smoothing or soft targets
Hacked together by / Copyright 2021 Ross Wightman
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class LabelSmoothingCrossEntropy(nn.Module):
"""
NLL loss with label smoothing.
""" NLL loss with label smoothing.
"""
def __init__(self, smoothing=0.1):
"""
Constructor for the LabelSmoothing module.
:param smoothing: label smoothing factor
"""
super(LabelSmoothingCrossEntropy, self).__init__()
assert smoothing < 1.0
self.smoothing = smoothing
self.confidence = 1. - smoothing
def forward(self, x, target):
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
logprobs = F.log_softmax(x, dim=-1)
nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
nll_loss = nll_loss.squeeze(1)
@ -31,6 +31,6 @@ class SoftTargetCrossEntropy(nn.Module):
def __init__(self):
super(SoftTargetCrossEntropy, self).__init__()
def forward(self, x, target):
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1)
return loss.mean()

@ -3,7 +3,7 @@
A flexible network w/ dataclass based config for stacking NN blocks including
self-attention (or similar) layers.
Currently used to implement experimential variants of:
Currently used to implement experimental variants of:
* Bottleneck Transformers
* Lambda ResNets
* HaloNets
@ -23,7 +23,7 @@ __all__ = []
def _cfg(url='', **kwargs):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bicubic',
'crop_pct': 0.95, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.conv1.conv', 'classifier': 'head.fc',
'fixed_input_size': False, 'min_input_size': (3, 224, 224),
@ -34,7 +34,7 @@ def _cfg(url='', **kwargs):
default_cfgs = {
# GPU-Efficient (ResNet) weights
'botnet26t_256': _cfg(
url='',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/botnet26t_c1_256-167a0e9f.pth',
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
'botnet50ts_256': _cfg(
url='',
@ -46,19 +46,26 @@ default_cfgs = {
'halonet_h1': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
'halonet26t': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet26t_256-9b4bf0b3.pth',
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
'sehalonet33ts': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/sehalonet33ts_256-87e053f9.pth',
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
'halonet50ts': _cfg(
url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet50ts_256_ra3-f07eab9f.pth',
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
'eca_halonext26ts': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_halonext26ts_256-1e55880b.pth',
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
'lambda_resnet26t': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet26t_256-b040fce6.pth',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet26t_a2h_256-25ded63d.pth',
min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)),
'lambda_resnet50ts': _cfg(
url='',
min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)),
'lambda_resnet26rpt_256': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet26rpt_a2h_256-482adad8.pth',
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
}
@ -198,6 +205,33 @@ model_cfgs = dict(
self_attn_layer='lambda',
self_attn_kwargs=dict(r=9)
),
lambda_resnet50ts=ByoModelCfg(
blocks=(
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=4, d=4, c=512, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), d=6, c=1024, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), d=3, c=2048, s=2, gs=0, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool='maxpool',
act_layer='silu',
self_attn_layer='lambda',
self_attn_kwargs=dict(r=9)
),
lambda_resnet26rpt_256=ByoModelCfg(
blocks=(
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool='maxpool',
self_attn_layer='lambda',
self_attn_kwargs=dict(r=None)
),
)
@ -275,6 +309,21 @@ def eca_halonext26ts(pretrained=False, **kwargs):
@register_model
def lambda_resnet26t(pretrained=False, **kwargs):
""" Lambda-ResNet-26T. Lambda layers in last two stages.
""" Lambda-ResNet-26-T. Lambda layers w/ conv pos in last two stages.
"""
return _create_byoanet('lambda_resnet26t', pretrained=pretrained, **kwargs)
@register_model
def lambda_resnet50ts(pretrained=False, **kwargs):
""" Lambda-ResNet-50-TS. SiLU act. Lambda layers w/ conv pos in last two stages.
"""
return _create_byoanet('lambda_resnet50ts', pretrained=pretrained, **kwargs)
@register_model
def lambda_resnet26rpt_256(pretrained=False, **kwargs):
""" Lambda-ResNet-26-R-T. Lambda layers w/ rel pos embed in last two stages.
"""
kwargs.setdefault('img_size', 256)
return _create_byoanet('lambda_resnet26rpt_256', pretrained=pretrained, **kwargs)

@ -51,6 +51,16 @@ def _cfg(url='', **kwargs):
}
def _cfgr(url='', **kwargs):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8),
'crop_pct': 0.9, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.conv1.conv', 'classifier': 'head.fc',
**kwargs
}
default_cfgs = {
# GPU-Efficient (ResNet) weights
'gernet_s': _cfg(
@ -92,51 +102,50 @@ default_cfgs = {
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet51q_ra2-d47dcc76.pth',
first_conv='stem.conv1', input_size=(3, 256, 256), pool_size=(8, 8),
test_input_size=(3, 288, 288), crop_pct=1.0),
'resnet61q': _cfg(
'resnet61q': _cfgr(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet61q_ra2-6afc536c.pth',
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8),
test_input_size=(3, 288, 288), crop_pct=1.0, interpolation='bicubic'),
'resnext26ts': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnext26ts_256_ra2-8bbd9106.pth',
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
'gcresnext26ts': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnext26ts_256-e414378b.pth',
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
'seresnext26ts': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/seresnext26ts_256-6f0d74a3.pth',
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
'eca_resnext26ts': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_resnext26ts_256-5a1d030f.pth',
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
'bat_resnext26ts': _cfg(
test_input_size=(3, 288, 288), crop_pct=1.0),
'resnext26ts': _cfgr(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnext26ts_256_ra2-8bbd9106.pth'),
'gcresnext26ts': _cfgr(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnext26ts_256-e414378b.pth'),
'seresnext26ts': _cfgr(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/seresnext26ts_256-6f0d74a3.pth'),
'eca_resnext26ts': _cfgr(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_resnext26ts_256-5a1d030f.pth'),
'bat_resnext26ts': _cfgr(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/bat_resnext26ts_256-fa6fd595.pth',
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic',
min_input_size=(3, 256, 256)),
'resnet32ts': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnet32ts_256-aacf5250.pth',
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
'resnet33ts': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnet33ts_256-e91b09a4.pth',
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
'gcresnet33ts': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnet33ts_256-0e0cd345.pth',
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
'seresnet33ts': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/seresnet33ts_256-f8ad44d9.pth',
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
'eca_resnet33ts': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_resnet33ts_256-8f98face.pth',
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
'gcresnet50t': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnet50t_256-96374d1c.pth',
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
'gcresnext50ts': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnext50ts_256-3e0f515e.pth',
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
'resnet32ts': _cfgr(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnet32ts_256-aacf5250.pth'),
'resnet33ts': _cfgr(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnet33ts_256-e91b09a4.pth'),
'gcresnet33ts': _cfgr(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnet33ts_256-0e0cd345.pth'),
'seresnet33ts': _cfgr(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/seresnet33ts_256-f8ad44d9.pth'),
'eca_resnet33ts': _cfgr(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_resnet33ts_256-8f98face.pth'),
'gcresnet50t': _cfgr(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnet50t_256-96374d1c.pth'),
'gcresnext50ts': _cfgr(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnext50ts_256-3e0f515e.pth'),
# experimental models, likely to change ot be removed
'regnetz_b': _cfgr(
url='',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 224, 224), pool_size=(7, 7), first_conv='stem.conv'),
'regnetz_c': _cfgr(
url='',
imean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), first_conv='stem.conv'),
'regnetz_d': _cfgr(
url='',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
}
@ -489,6 +498,57 @@ model_cfgs = dict(
act_layer='silu',
attn_layer='gca',
),
# experimental models, closer to a RegNetZ than a ResNet. Similar to EfficientNets but w/ groups instead of DW
regnetz_b=ByoModelCfg(
blocks=(
ByoBlockCfg(type='bottle', d=2, c=48, s=2, gs=16, br=3),
ByoBlockCfg(type='bottle', d=6, c=96, s=2, gs=16, br=3),
ByoBlockCfg(type='bottle', d=12, c=192, s=2, gs=16, br=3),
ByoBlockCfg(type='bottle', d=2, c=288, s=2, gs=16, br=3),
),
stem_chs=32,
stem_pool='',
downsample='',
num_features=1536,
act_layer='silu',
attn_layer='se',
attn_kwargs=dict(rd_ratio=0.25),
block_kwargs=dict(bottle_in=True, linear_out=True),
),
regnetz_c=ByoModelCfg(
blocks=(
ByoBlockCfg(type='bottle', d=2, c=48, s=2, gs=16, br=4),
ByoBlockCfg(type='bottle', d=6, c=96, s=2, gs=16, br=4),
ByoBlockCfg(type='bottle', d=12, c=192, s=2, gs=16, br=4),
ByoBlockCfg(type='bottle', d=2, c=288, s=2, gs=16, br=4),
),
stem_chs=32,
stem_pool='',
downsample='',
num_features=1536,
act_layer='silu',
attn_layer='se',
attn_kwargs=dict(rd_ratio=0.25),
block_kwargs=dict(bottle_in=True, linear_out=True),
),
regnetz_d=ByoModelCfg(
blocks=(
ByoBlockCfg(type='bottle', d=3, c=64, s=1, gs=32, br=4),
ByoBlockCfg(type='bottle', d=6, c=128, s=2, gs=32, br=4),
ByoBlockCfg(type='bottle', d=12, c=256, s=2, gs=32, br=4),
ByoBlockCfg(type='bottle', d=3, c=384, s=2, gs=32, br=4),
),
stem_chs=64,
stem_type='tiered',
stem_pool='',
downsample='',
num_features=1792,
act_layer='silu',
attn_layer='se',
attn_kwargs=dict(rd_ratio=0.25),
block_kwargs=dict(bottle_in=True, linear_out=True),
),
)
@ -678,6 +738,27 @@ def gcresnext50ts(pretrained=False, **kwargs):
return _create_byobnet('gcresnext50ts', pretrained=pretrained, **kwargs)
@register_model
def regnetz_b(pretrained=False, **kwargs):
"""
"""
return _create_byobnet('regnetz_b', pretrained=pretrained, **kwargs)
@register_model
def regnetz_c(pretrained=False, **kwargs):
"""
"""
return _create_byobnet('regnetz_c', pretrained=pretrained, **kwargs)
@register_model
def regnetz_d(pretrained=False, **kwargs):
"""
"""
return _create_byobnet('regnetz_d', pretrained=pretrained, **kwargs)
def expand_blocks_cfg(stage_blocks_cfg: Union[ByoBlockCfg, Sequence[ByoBlockCfg]]) -> List[ByoBlockCfg]:
if not isinstance(stage_blocks_cfg, Sequence):
stage_blocks_cfg = (stage_blocks_cfg,)
@ -722,11 +803,17 @@ class DownsampleAvg(nn.Module):
return self.conv(self.pool(x))
def create_downsample(downsample_type, layers: LayerFn, **kwargs):
if downsample_type == 'avg':
return DownsampleAvg(**kwargs)
def create_shortcut(downsample_type, layers: LayerFn, in_chs, out_chs, stride, dilation, **kwargs):
assert downsample_type in ('avg', 'conv1x1', '')
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
if not downsample_type:
return None # no shortcut
elif downsample_type == 'avg':
return DownsampleAvg(in_chs, out_chs, stride=stride, dilation=dilation[0], **kwargs)
else:
return layers.conv_norm_act(in_chs, out_chs, kernel_size=1, stride=stride, dilation=dilation[0], **kwargs)
else:
return layers.conv_norm_act(kwargs.pop('in_chs'), kwargs.pop('out_chs'), kernel_size=1, **kwargs)
return nn.Identity() # identity shortcut
class BasicBlock(nn.Module):
@ -742,12 +829,9 @@ class BasicBlock(nn.Module):
mid_chs = make_divisible(out_chs * bottle_ratio)
groups = num_groups(group_size, mid_chs)
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
self.shortcut = create_downsample(
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0],
apply_act=False, layers=layers)
else:
self.shortcut = nn.Identity()
self.shortcut = create_shortcut(
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation,
apply_act=False, layers=layers)
self.conv1_kxk = layers.conv_norm_act(in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0])
self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs)
@ -758,23 +842,21 @@ class BasicBlock(nn.Module):
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last: bool = False):
if zero_init_last:
if zero_init_last and self.shortcut is not None:
nn.init.zeros_(self.conv2_kxk.bn.weight)
for attn in (self.attn, self.attn_last):
if hasattr(attn, 'reset_parameters'):
attn.reset_parameters()
def forward(self, x):
shortcut = self.shortcut(x)
# residual path
shortcut = x
x = self.conv1_kxk(x)
x = self.conv2_kxk(x)
x = self.attn(x)
x = self.drop_path(x)
x = self.act(x + shortcut)
return x
if self.shortcut is not None:
x = x + self.shortcut(shortcut)
return self.act(x)
class BottleneckBlock(nn.Module):
@ -782,19 +864,16 @@ class BottleneckBlock(nn.Module):
"""
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None,
downsample='avg', attn_last=False, linear_out=False, extra_conv=False, layers: LayerFn = None,
drop_block=None, drop_path_rate=0.):
downsample='avg', attn_last=False, linear_out=False, extra_conv=False, bottle_in=False,
layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
super(BottleneckBlock, self).__init__()
layers = layers or LayerFn()
mid_chs = make_divisible(out_chs * bottle_ratio)
mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio)
groups = num_groups(group_size, mid_chs)
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
self.shortcut = create_downsample(
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0],
apply_act=False, layers=layers)
else:
self.shortcut = nn.Identity()
self.shortcut = create_shortcut(
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation,
apply_act=False, layers=layers)
self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1)
self.conv2_kxk = layers.conv_norm_act(
@ -812,15 +891,14 @@ class BottleneckBlock(nn.Module):
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last: bool = False):
if zero_init_last:
if zero_init_last and self.shortcut is not None:
nn.init.zeros_(self.conv3_1x1.bn.weight)
for attn in (self.attn, self.attn_last):
if hasattr(attn, 'reset_parameters'):
attn.reset_parameters()
def forward(self, x):
shortcut = self.shortcut(x)
shortcut = x
x = self.conv1_1x1(x)
x = self.conv2_kxk(x)
x = self.conv2b_kxk(x)
@ -828,9 +906,9 @@ class BottleneckBlock(nn.Module):
x = self.conv3_1x1(x)
x = self.attn_last(x)
x = self.drop_path(x)
x = self.act(x + shortcut)
return x
if self.shortcut is not None:
x = x + self.shortcut(shortcut)
return self.act(x)
class DarkBlock(nn.Module):
@ -852,12 +930,9 @@ class DarkBlock(nn.Module):
mid_chs = make_divisible(out_chs * bottle_ratio)
groups = num_groups(group_size, mid_chs)
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
self.shortcut = create_downsample(
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0],
apply_act=False, layers=layers)
else:
self.shortcut = nn.Identity()
self.shortcut = create_shortcut(
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation,
apply_act=False, layers=layers)
self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1)
self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs)
@ -869,22 +944,22 @@ class DarkBlock(nn.Module):
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last: bool = False):
if zero_init_last:
if zero_init_last and self.shortcut is not None:
nn.init.zeros_(self.conv2_kxk.bn.weight)
for attn in (self.attn, self.attn_last):
if hasattr(attn, 'reset_parameters'):
attn.reset_parameters()
def forward(self, x):
shortcut = self.shortcut(x)
shortcut = x
x = self.conv1_1x1(x)
x = self.attn(x)
x = self.conv2_kxk(x)
x = self.attn_last(x)
x = self.drop_path(x)
x = self.act(x + shortcut)
return x
if self.shortcut is not None:
x = x + self.shortcut(shortcut)
return self.act(x)
class EdgeBlock(nn.Module):
@ -905,12 +980,9 @@ class EdgeBlock(nn.Module):
mid_chs = make_divisible(out_chs * bottle_ratio)
groups = num_groups(group_size, mid_chs)
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
self.shortcut = create_downsample(
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0],
apply_act=False, layers=layers)
else:
self.shortcut = nn.Identity()
self.shortcut = create_shortcut(
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation,
apply_act=False, layers=layers)
self.conv1_kxk = layers.conv_norm_act(
in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0],
@ -922,22 +994,22 @@ class EdgeBlock(nn.Module):
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last: bool = False):
if zero_init_last:
if zero_init_last and self.shortcut is not None:
nn.init.zeros_(self.conv2_1x1.bn.weight)
for attn in (self.attn, self.attn_last):
if hasattr(attn, 'reset_parameters'):
attn.reset_parameters()
def forward(self, x):
shortcut = self.shortcut(x)
shortcut = x
x = self.conv1_kxk(x)
x = self.attn(x)
x = self.conv2_1x1(x)
x = self.attn_last(x)
x = self.drop_path(x)
x = self.act(x + shortcut)
return x
if self.shortcut is not None:
x = x + self.shortcut(shortcut)
return self.act(x)
class RepVggBlock(nn.Module):
@ -982,8 +1054,7 @@ class RepVggBlock(nn.Module):
x = self.drop_path(x) # not in the paper / official impl, experimental
x = x + identity
x = self.attn(x) # no attn in the paper / official impl, experimental
x = self.act(x)
return x
return self.act(x)
class SelfAttnBlock(nn.Module):
@ -991,19 +1062,16 @@ class SelfAttnBlock(nn.Module):
"""
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None,
downsample='avg', extra_conv=False, linear_out=False, post_attn_na=True, feat_size=None,
layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
downsample='avg', extra_conv=False, linear_out=False, bottle_in=False, post_attn_na=True,
feat_size=None, layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
super(SelfAttnBlock, self).__init__()
assert layers is not None
mid_chs = make_divisible(out_chs * bottle_ratio)
mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio)
groups = num_groups(group_size, mid_chs)
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
self.shortcut = create_downsample(
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0],
apply_act=False, layers=layers)
else:
self.shortcut = nn.Identity()
self.shortcut = create_shortcut(
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation,
apply_act=False, layers=layers)
self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1)
if extra_conv:
@ -1022,7 +1090,7 @@ class SelfAttnBlock(nn.Module):
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last: bool = False):
if zero_init_last:
if zero_init_last and self.shortcut is not None:
nn.init.zeros_(self.conv3_1x1.bn.weight)
if hasattr(self.self_attn, 'reset_parameters'):
self.self_attn.reset_parameters()

@ -277,7 +277,6 @@ class EdgeResidual(nn.Module):
mid_chs = make_divisible(force_in_chs * exp_ratio)
else:
mid_chs = make_divisible(in_chs * exp_ratio)
has_se = se_layer is not None and se_ratio > 0.
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
self.drop_path_rate = drop_path_rate

@ -24,13 +24,20 @@ _logger = logging.getLogger(__name__)
def load_state_dict(checkpoint_path, use_ema=False):
if checkpoint_path and os.path.isfile(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location='cpu')
state_dict_key = 'state_dict'
state_dict_key = ''
if isinstance(checkpoint, dict):
if use_ema and 'state_dict_ema' in checkpoint:
if use_ema and checkpoint.get('state_dict_ema', None) is not None:
state_dict_key = 'state_dict_ema'
if state_dict_key and state_dict_key in checkpoint:
elif use_ema and checkpoint.get('model_ema', None) is not None:
state_dict_key = 'model_ema'
elif 'state_dict' in checkpoint:
state_dict_key = 'state_dict'
elif 'model' in checkpoint:
state_dict_key = 'model'
if state_dict_key:
state_dict = checkpoint[state_dict_key]
new_state_dict = OrderedDict()
for k, v in checkpoint[state_dict_key].items():
for k, v in state_dict.items():
# strip `module.` prefix
name = k[7:] if k.startswith('module') else k
new_state_dict[name] = v

@ -118,12 +118,12 @@ class BottleneckAttn(nn.Module):
x = x.reshape(B, -1, self.dim_head, H * W).transpose(-1, -2)
q, k, v = torch.split(x, self.num_heads, dim=1)
attn_logits = (q @ k.transpose(-1, -2)) * self.scale
attn_logits = attn_logits + self.pos_embed(q) # B, num_heads, H * W, H * W
attn = (q @ k.transpose(-1, -2)) * self.scale
attn = attn + self.pos_embed(q) # B, num_heads, H * W, H * W
attn = attn.softmax(dim=-1)
attn_out = attn_logits.softmax(dim=-1)
attn_out = (attn_out @ v).transpose(1, 2).reshape(B, self.dim_out, H, W) # B, dim_out, H, W
attn_out = self.pool(attn_out)
return attn_out
out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out, H, W) # B, dim_out, H, W
out = self.pool(out)
return out

@ -106,22 +106,23 @@ class HaloAttn(nn.Module):
assert dim_out % num_heads == 0
self.stride = stride
self.num_heads = num_heads
self.dim_head = dim_head or dim // num_heads
self.dim_qk = num_heads * self.dim_head
self.dim_v = dim_out
self.dim_head_qk = dim_head or dim_out // num_heads
self.dim_head_v = dim_out // self.num_heads
self.dim_out_qk = num_heads * self.dim_head_qk
self.dim_out_v = num_heads * self.dim_head_v
self.block_size = block_size
self.halo_size = halo_size
self.win_size = block_size + halo_size * 2 # neighbourhood window size
self.scale = self.dim_head ** -0.5
self.scale = self.dim_head_qk ** -0.5
# FIXME not clear if this stride behaviour is what the paper intended
# Also, the paper mentions using a 3D conv for dealing with the blocking/gather, and leaving
# data in unfolded block form. I haven't wrapped my head around how that'd look.
self.q = nn.Conv2d(dim, self.dim_qk, 1, stride=self.stride, bias=qkv_bias)
self.kv = nn.Conv2d(dim, self.dim_qk + self.dim_v, 1, bias=qkv_bias)
self.q = nn.Conv2d(dim, self.dim_out_qk, 1, stride=self.stride, bias=qkv_bias)
self.kv = nn.Conv2d(dim, self.dim_out_qk + self.dim_out_v, 1, bias=qkv_bias)
self.pos_embed = PosEmbedRel(
block_size=block_size // self.stride, win_size=self.win_size, dim_head=self.dim_head, scale=self.scale)
block_size=block_size // self.stride, win_size=self.win_size, dim_head=self.dim_head_qk, scale=self.scale)
self.reset_parameters()
@ -143,37 +144,42 @@ class HaloAttn(nn.Module):
q = self.q(x)
# unfold
q = q.reshape(-1, self.dim_head, num_h_blocks, bs_stride, num_w_blocks, bs_stride).permute(0, 1, 3, 5, 2, 4)
q = q.reshape(-1, self.dim_head_qk, num_h_blocks, bs_stride, num_w_blocks, bs_stride).permute(0, 1, 3, 5, 2, 4)
# B, num_heads * dim_head * block_size ** 2, num_blocks
q = q.reshape(B * self.num_heads, self.dim_head, -1, num_blocks).transpose(1, 3)
q = q.reshape(B * self.num_heads, self.dim_head_qk, -1, num_blocks).transpose(1, 3)
# B * num_heads, num_blocks, block_size ** 2, dim_head
kv = self.kv(x)
# generate overlapping windows for kv
kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size])
kv = kv.unfold(2, self.win_size, self.block_size).unfold(3, self.win_size, self.block_size).reshape(
B * self.num_heads, self.dim_head + (self.dim_v // self.num_heads), num_blocks, -1).permute(0, 2, 3, 1)
# NOTE these two alternatives are equivalent, but above is the best balance of performance and clarity
# if self.stride_tricks:
# kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]).contiguous()
# kv = kv.as_strided((
# B, self.dim_qk + self.dim_v, self.win_size, self.win_size, num_h_blocks, num_w_blocks),
# stride=(kv.stride(0), kv.stride(1), kv.shape[-1], 1, self.block_size * kv.shape[-1], self.block_size))
# else:
# kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size)
# kv = kv.reshape(
# B * self.num_heads, self.dim_head + (self.dim_v // self.num_heads), -1, num_blocks).transpose(1, 3)
k, v = torch.split(kv, [self.dim_head, self.dim_v // self.num_heads], dim=-1)
# B * num_heads, num_blocks, block_size ** 2, dim_head or dim_v // num_heads
attn_logits = (q @ k.transpose(-1, -2)) * self.scale # FIXME should usual attn scale be applied?
attn_logits = attn_logits + self.pos_embed(q) # B * num_heads, block_size ** 2, win_size ** 2
attn_out = attn_logits.softmax(dim=-1)
attn_out = (attn_out @ v).transpose(1, 3) # B * num_heads, dim_v // num_heads, block_size ** 2, num_blocks
B * self.num_heads, self.dim_head_qk + self.dim_head_v, num_blocks, -1).permute(0, 2, 3, 1)
k, v = torch.split(kv, [self.dim_head_qk, self.dim_head_v], dim=-1)
# B * num_heads, num_blocks, win_size ** 2, dim_head_qk or dim_head_v
attn = (q @ k.transpose(-1, -2)) * self.scale
attn = attn + self.pos_embed(q) # B * num_heads, num_blocks, block_size ** 2, win_size ** 2
attn = attn.softmax(dim=-1)
out = (attn @ v).transpose(1, 3) # B * num_heads, dim_head_v, block_size ** 2, num_blocks
# fold
attn_out = attn_out.reshape(-1, bs_stride, bs_stride, num_h_blocks, num_w_blocks)
attn_out = attn_out.permute(0, 3, 1, 4, 2).contiguous().view(B, self.dim_v, H // self.stride, W // self.stride)
out = out.reshape(-1, bs_stride, bs_stride, num_h_blocks, num_w_blocks)
out = out.permute(0, 3, 1, 4, 2).contiguous().view(B, self.dim_out_v, H // self.stride, W // self.stride)
# B, dim_out, H // stride, W // stride
return attn_out
return out
""" Two alternatives for overlapping windows.
`.unfold().unfold()` is same speed as stride tricks with similar clarity as F.unfold()
if self.stride_tricks:
kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]).contiguous()
kv = kv.as_strided((
B, self.dim_out_qk + self.dim_out_v, self.win_size, self.win_size, num_h_blocks, num_w_blocks),
stride=(kv.stride(0), kv.stride(1), kv.shape[-1], 1, self.block_size * kv.shape[-1], self.block_size))
else:
kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size)
kv = kv.reshape(
B * self.num_heads, self.dim_head_qk + self.dim_head_v, -1, num_blocks).transpose(1, 3)
"""

@ -24,18 +24,30 @@ import torch
from torch import nn
import torch.nn.functional as F
from .helpers import to_2tuple
from .weight_init import trunc_normal_
def rel_pos_indices(size):
size = to_2tuple(size)
pos = torch.stack(torch.meshgrid(torch.arange(size[0]), torch.arange(size[1]))).flatten(1)
rel_pos = pos[:, None, :] - pos[:, :, None]
rel_pos[0] += size[0] - 1
rel_pos[1] += size[1] - 1
return rel_pos # 2, H * W, H * W
class LambdaLayer(nn.Module):
"""Lambda Layer w/ lambda conv position embedding
"""Lambda Layer
Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention`
- https://arxiv.org/abs/2102.08602
NOTE: intra-depth parameter 'u' is fixed at 1. It did not appear worth the complexity to add.
"""
def __init__(
self,
dim, dim_out=None, stride=1, num_heads=4, dim_head=16, r=7, qkv_bias=False):
dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=16, r=7, qkv_bias=False):
super().__init__()
self.dim = dim
self.dim_out = dim_out or dim
@ -43,7 +55,6 @@ class LambdaLayer(nn.Module):
self.num_heads = num_heads
assert self.dim_out % num_heads == 0, ' should be divided by num_heads'
self.dim_v = self.dim_out // num_heads # value depth 'v'
self.r = r # relative position neighbourhood (lambda conv kernel size)
self.qkv = nn.Conv2d(
dim,
@ -52,8 +63,19 @@ class LambdaLayer(nn.Module):
self.norm_q = nn.BatchNorm2d(num_heads * dim_head)
self.norm_v = nn.BatchNorm2d(self.dim_v)
# NOTE currently only supporting the local lambda convolutions for positional
self.conv_lambda = nn.Conv3d(1, dim_head, (r, r, 1), padding=(r // 2, r // 2, 0))
if r is not None:
# local lambda convolution for pos
self.conv_lambda = nn.Conv3d(1, dim_head, (r, r, 1), padding=(r // 2, r // 2, 0))
self.pos_emb = None
self.rel_pos_indices = None
else:
# relative pos embedding
assert feat_size is not None
feat_size = to_2tuple(feat_size)
rel_size = [2 * s - 1 for s in feat_size]
self.conv_lambda = None
self.pos_emb = nn.Parameter(torch.zeros(rel_size[0], rel_size[1], self.dim_k))
self.register_buffer('rel_pos_indices', rel_pos_indices(feat_size), persistent=False)
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
@ -61,12 +83,14 @@ class LambdaLayer(nn.Module):
def reset_parameters(self):
trunc_normal_(self.qkv.weight, std=self.dim ** -0.5)
trunc_normal_(self.conv_lambda.weight, std=self.dim_k ** -0.5)
if self.conv_lambda is not None:
trunc_normal_(self.conv_lambda.weight, std=self.dim_k ** -0.5)
if self.pos_emb is not None:
trunc_normal_(self.pos_emb, std=.02)
def forward(self, x):
B, C, H, W = x.shape
M = H * W
qkv = self.qkv(x)
q, k, v = torch.split(qkv, [
self.num_heads * self.dim_k, self.dim_k, self.dim_v], dim=1)
@ -77,10 +101,15 @@ class LambdaLayer(nn.Module):
content_lam = k @ v # B, K, V
content_out = q @ content_lam.unsqueeze(1) # B, num_heads, M, V
position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v)) # B, H, W, V, K
position_lam = position_lam.reshape(B, 1, self.dim_k, H * W, self.dim_v).transpose(2, 3) # B, 1, M, K, V
if self.pos_emb is None:
position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v)) # B, H, W, V, K
position_lam = position_lam.reshape(B, 1, self.dim_k, H * W, self.dim_v).transpose(2, 3) # B, 1, M, K, V
else:
# FIXME relative pos embedding path not fully verified
pos_emb = self.pos_emb[self.rel_pos_indices[0], self.rel_pos_indices[1]].expand(B, -1, -1, -1)
position_lam = (pos_emb.transpose(-1, -2) @ v.unsqueeze(1)).unsqueeze(1) # B, 1, M, K, V
position_out = (q.unsqueeze(-2) @ position_lam).squeeze(-2) # B, num_heads, M, V
out = (content_out + position_out).transpose(3, 1).reshape(B, C, H, W) # B, C (num_heads * V), H, W
out = (content_out + position_out).transpose(-1, -2).reshape(B, C, H, W) # B, C (num_heads * V), H, W
out = self.pool(out)
return out

@ -41,7 +41,7 @@ class StdConv2d(nn.Conv2d):
def forward(self, x):
weight = F.batch_norm(
self.weight.view(1, self.out_channels, -1), None, None,
self.weight.reshape(1, self.out_channels, -1), None, None,
training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
return x
@ -67,7 +67,7 @@ class StdConv2dSame(nn.Conv2d):
if self.same_pad:
x = pad_same(x, self.kernel_size, self.stride, self.dilation)
weight = F.batch_norm(
self.weight.view(1, self.out_channels, -1), None, None,
self.weight.reshape(1, self.out_channels, -1), None, None,
training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
return x
@ -96,7 +96,7 @@ class ScaledStdConv2d(nn.Conv2d):
def forward(self, x):
weight = F.batch_norm(
self.weight.view(1, self.out_channels, -1), None, None,
self.weight.reshape(1, self.out_channels, -1), None, None,
weight=(self.gain * self.scale).view(-1),
training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
@ -127,7 +127,7 @@ class ScaledStdConv2dSame(nn.Conv2d):
if self.same_pad:
x = pad_same(x, self.kernel_size, self.stride, self.dilation)
weight = F.batch_norm(
self.weight.view(1, self.out_channels, -1), None, None,
self.weight.reshape(1, self.out_channels, -1), None, None,
weight=(self.gain * self.scale).view(-1),
training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)

@ -53,11 +53,11 @@ default_cfgs = {
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnet26t_256_ra2-6f6fa748.pth',
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8)),
'resnet50': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50_ram-a26f946b.pth',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1_0-00ca2c6a.pth',
interpolation='bicubic'),
'resnet50d': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50d_ra2-464e36ba.pth',
interpolation='bicubic', first_conv='conv1.0'),
interpolation='bicubic', first_conv='conv1.0', crop_pct=0.95),
'resnet50t': _cfg(
url='',
interpolation='bicubic', first_conv='conv1.0'),

@ -105,7 +105,8 @@ default_cfgs = {
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, interpolation='bicubic'),
'resnetv2_50': _cfg(
interpolation='bicubic'),
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnetv2_50_a1_h-000cdf49.pth',
interpolation='bicubic', crop_pct=0.95),
'resnetv2_50d': _cfg(
interpolation='bicubic', first_conv='stem.conv1'),
'resnetv2_50t': _cfg(
@ -344,7 +345,7 @@ class ResNetV2(nn.Module):
num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
width_factor=1, stem_chs=64, stem_type='', avg_down=False, preact=True,
act_layer=nn.ReLU, conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32),
drop_rate=0., drop_path_rate=0., zero_init_last=True):
drop_rate=0., drop_path_rate=0., zero_init_last=False):
super().__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate

@ -1 +1 @@
__version__ = '0.4.13'
__version__ = '0.5.0'

@ -190,6 +190,8 @@ parser.add_argument('--jsd-loss', action='store_true', default=False,
help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
parser.add_argument('--bce-loss', action='store_true', default=False,
help='Enable BCE loss w/ Mixup/CutMix use.')
parser.add_argument('--bce-target-thresh', type=float, default=None,
help='Threshold for binarizing softened BCE targets (default: None, disabled)')
parser.add_argument('--reprob', type=float, default=0., metavar='PCT',
help='Random erase prob (default: 0.)')
parser.add_argument('--remode', type=str, default='pixel',
@ -250,6 +252,8 @@ parser.add_argument('--model-ema-decay', type=float, default=0.9998,
# Misc
parser.add_argument('--seed', type=int, default=42, metavar='S',
help='random seed (default: 42)')
parser.add_argument('--worker-seeding', type=str, default='all',
help='worker seed mode (default: all)')
parser.add_argument('--log-interval', type=int, default=50, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',
@ -266,6 +270,8 @@ parser.add_argument('--apex-amp', action='store_true', default=False,
help='Use NVIDIA Apex AMP mixed precision')
parser.add_argument('--native-amp', action='store_true', default=False,
help='Use Native Torch AMP mixed precision')
parser.add_argument('--no-ddp-bb', action='store_true', default=False,
help='Force broadcast buffers for native DDP to off.')
parser.add_argument('--channels-last', action='store_true', default=False,
help='Use channels_last memory layout')
parser.add_argument('--pin-mem', action='store_true', default=False,
@ -459,7 +465,7 @@ def main():
else:
if args.local_rank == 0:
_logger.info("Using native Torch DistributedDataParallel.")
model = NativeDDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1
model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.no_ddp_bb)
# NOTE: EMA model does not need to be wrapped by DDP
# setup learning rate schedule and starting epoch
@ -533,7 +539,8 @@ def main():
distributed=args.distributed,
collate_fn=collate_fn,
pin_memory=args.pin_mem,
use_multi_epochs_loader=args.use_multi_epochs_loader
use_multi_epochs_loader=args.use_multi_epochs_loader,
worker_seeding=args.worker_seeding,
)
loader_eval = create_loader(
@ -558,12 +565,12 @@ def main():
elif mixup_active:
# smoothing is handled with mixup target transform which outputs sparse, soft targets
if args.bce_loss:
train_loss_fn = nn.BCEWithLogitsLoss()
train_loss_fn = BinaryCrossEntropy(target_threshold=args.bce_target_thresh)
else:
train_loss_fn = SoftTargetCrossEntropy()
elif args.smoothing:
if args.bce_loss:
train_loss_fn = DenseBinaryCrossEntropy(smoothing=args.smoothing)
train_loss_fn = BinaryCrossEntropy(smoothing=args.smoothing, target_threshold=args.bce_target_thresh)
else:
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
else:

@ -296,6 +296,11 @@ def main():
model_names = list_models(args.model)
model_cfgs = [(n, '') for n in model_names]
if not model_cfgs and os.path.isfile(args.model):
with open(args.model) as f:
model_names = [line.rstrip() for line in f]
model_cfgs = [(n, None) for n in model_names if n]
if len(model_cfgs):
results_file = args.results_file or './results-all.csv'
_logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names)))

Loading…
Cancel
Save