diff --git a/timm/data/auto_augment.py b/timm/data/auto_augment.py index 8d7b36f9..e355eef5 100644 --- a/timm/data/auto_augment.py +++ b/timm/data/auto_augment.py @@ -1,7 +1,19 @@ -""" AutoAugment and RandAugment -Implementation adapted from: +""" AutoAugment, RandAugment, and AugMix for PyTorch + +This code implements the searched ImageNet policies with various tweaks and improvements and +does not include any of the search code. + +AA and RA Implementation adapted from: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py -Papers: https://arxiv.org/abs/1805.09501, https://arxiv.org/abs/1906.11172, and https://arxiv.org/abs/1909.13719 + +AugMix adapted from: + https://github.com/google-research/augmix + +Papers: + AutoAugment: Learning Augmentation Policies from Data - https://arxiv.org/abs/1805.09501 + Learning Data Augmentation Strategies for Object Detection - https://arxiv.org/abs/1906.11172 + RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719 + AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - https://arxiv.org/abs/1912.02781 Hacked together by Ross Wightman """ @@ -691,12 +703,17 @@ def augmix_ops(magnitude=10, hparams=None, transforms=None): class AugMixAugment: + """ AugMix Transform + Adapted and improved from impl here: https://github.com/google-research/augmix/blob/master/imagenet.py + From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - + https://arxiv.org/abs/1912.02781 + """ def __init__(self, ops, alpha=1., width=3, depth=-1, blended=False): self.ops = ops self.alpha = alpha self.width = width self.depth = depth - self.blended = blended + self.blended = blended # blended mode is faster but not well tested def _calc_blended_weights(self, ws, m): ws = ws * m diff --git a/timm/data/transforms_factory.py b/timm/data/transforms_factory.py index faf55b70..767dd157 100644 --- a/timm/data/transforms_factory.py +++ b/timm/data/transforms_factory.py @@ -1,3 +1,6 @@ +""" Transforms Factory +Factory methods for building image transforms for use with TIMM (PyTorch Image Models) +""" import math import torch @@ -24,7 +27,13 @@ def transforms_imagenet_train( re_num_splits=0, separate=False, ): - + """ + If separate==True, the transforms are returned as a tuple of 3 separate transforms + for use in a mixing dataset that passes + * all data through the first (primary) transform, called the 'clean' data + * a portion of the data through the secondary transform + * normalizes and converts the branches above with the third, final transform + """ primary_tfl = [ RandomResizedCropAndInterpolation( img_size, scale=scale, interpolation=interpolation), diff --git a/timm/loss/jsd.py b/timm/loss/jsd.py index ad6ca1e5..0f8eb696 100644 --- a/timm/loss/jsd.py +++ b/timm/loss/jsd.py @@ -8,6 +8,11 @@ from .cross_entropy import LabelSmoothingCrossEntropy class JsdCrossEntropy(nn.Module): """ Jensen-Shannon Divergence + Cross-Entropy Loss + Based on impl here: https://github.com/google-research/augmix/blob/master/imagenet.py + From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - + https://arxiv.org/abs/1912.02781 + + Hacked together by Ross Wightman """ def __init__(self, num_splits=3, alpha=12, smoothing=0.1): super().__init__() diff --git a/timm/models/split_batchnorm.py b/timm/models/split_batchnorm.py index 327c35ba..ad01cfeb 100644 --- a/timm/models/split_batchnorm.py +++ b/timm/models/split_batchnorm.py @@ -1,6 +1,18 @@ +""" Split BatchNorm + +A PyTorch BatchNorm layer that splits input batch into N equal parts and passes each through +a separate BN layer. The first split is passed through the parent BN layers with weight/bias +keys the same as the original BN. All other splits pass through BN sub-layers under the '.aux_bn' +namespace. + +This allows easily removing the auxiliary BN layers after training to efficiently +achieve the 'Auxiliary BatchNorm' as described in the AdvProp Paper, section 4.2, +'Disentangled Learning via An Auxiliary BN' + +Hacked together by Ross Wightman +""" import torch import torch.nn as nn -import torch.nn.functional as F class SplitBatchNorm2d(torch.nn.BatchNorm2d): diff --git a/train.py b/train.py index bb6db08d..4db2d7c3 100644 --- a/train.py +++ b/train.py @@ -237,8 +237,9 @@ def main(): data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) num_aug_splits = 0 - if args.aug_splits: - num_aug_splits = max(args.aug_splits, 2) # split of 1 makes no sense + if args.aug_splits > 0: + assert args.aug_splits > 1, 'A split of 1 makes no sense' + num_aug_splits = args.aug_splits if args.split_bn: assert num_aug_splits > 1 or args.resplit