Update AugMix, JSD, etc comments and references

pull/74/head
Ross Wightman 5 years ago
parent 833066b540
commit 3eb4a96eda

@ -1,7 +1,19 @@
""" AutoAugment and RandAugment """ AutoAugment, RandAugment, and AugMix for PyTorch
Implementation adapted from:
This code implements the searched ImageNet policies with various tweaks and improvements and
does not include any of the search code.
AA and RA Implementation adapted from:
https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
Papers: https://arxiv.org/abs/1805.09501, https://arxiv.org/abs/1906.11172, and https://arxiv.org/abs/1909.13719
AugMix adapted from:
https://github.com/google-research/augmix
Papers:
AutoAugment: Learning Augmentation Policies from Data - https://arxiv.org/abs/1805.09501
Learning Data Augmentation Strategies for Object Detection - https://arxiv.org/abs/1906.11172
RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719
AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - https://arxiv.org/abs/1912.02781
Hacked together by Ross Wightman Hacked together by Ross Wightman
""" """
@ -691,12 +703,17 @@ def augmix_ops(magnitude=10, hparams=None, transforms=None):
class AugMixAugment: class AugMixAugment:
""" AugMix Transform
Adapted and improved from impl here: https://github.com/google-research/augmix/blob/master/imagenet.py
From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty -
https://arxiv.org/abs/1912.02781
"""
def __init__(self, ops, alpha=1., width=3, depth=-1, blended=False): def __init__(self, ops, alpha=1., width=3, depth=-1, blended=False):
self.ops = ops self.ops = ops
self.alpha = alpha self.alpha = alpha
self.width = width self.width = width
self.depth = depth self.depth = depth
self.blended = blended self.blended = blended # blended mode is faster but not well tested
def _calc_blended_weights(self, ws, m): def _calc_blended_weights(self, ws, m):
ws = ws * m ws = ws * m

@ -1,3 +1,6 @@
""" Transforms Factory
Factory methods for building image transforms for use with TIMM (PyTorch Image Models)
"""
import math import math
import torch import torch
@ -24,7 +27,13 @@ def transforms_imagenet_train(
re_num_splits=0, re_num_splits=0,
separate=False, separate=False,
): ):
"""
If separate==True, the transforms are returned as a tuple of 3 separate transforms
for use in a mixing dataset that passes
* all data through the first (primary) transform, called the 'clean' data
* a portion of the data through the secondary transform
* normalizes and converts the branches above with the third, final transform
"""
primary_tfl = [ primary_tfl = [
RandomResizedCropAndInterpolation( RandomResizedCropAndInterpolation(
img_size, scale=scale, interpolation=interpolation), img_size, scale=scale, interpolation=interpolation),

@ -8,6 +8,11 @@ from .cross_entropy import LabelSmoothingCrossEntropy
class JsdCrossEntropy(nn.Module): class JsdCrossEntropy(nn.Module):
""" Jensen-Shannon Divergence + Cross-Entropy Loss """ Jensen-Shannon Divergence + Cross-Entropy Loss
Based on impl here: https://github.com/google-research/augmix/blob/master/imagenet.py
From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty -
https://arxiv.org/abs/1912.02781
Hacked together by Ross Wightman
""" """
def __init__(self, num_splits=3, alpha=12, smoothing=0.1): def __init__(self, num_splits=3, alpha=12, smoothing=0.1):
super().__init__() super().__init__()

@ -1,6 +1,18 @@
""" Split BatchNorm
A PyTorch BatchNorm layer that splits input batch into N equal parts and passes each through
a separate BN layer. The first split is passed through the parent BN layers with weight/bias
keys the same as the original BN. All other splits pass through BN sub-layers under the '.aux_bn'
namespace.
This allows easily removing the auxiliary BN layers after training to efficiently
achieve the 'Auxiliary BatchNorm' as described in the AdvProp Paper, section 4.2,
'Disentangled Learning via An Auxiliary BN'
Hacked together by Ross Wightman
"""
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
class SplitBatchNorm2d(torch.nn.BatchNorm2d): class SplitBatchNorm2d(torch.nn.BatchNorm2d):

@ -237,8 +237,9 @@ def main():
data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
num_aug_splits = 0 num_aug_splits = 0
if args.aug_splits: if args.aug_splits > 0:
num_aug_splits = max(args.aug_splits, 2) # split of 1 makes no sense assert args.aug_splits > 1, 'A split of 1 makes no sense'
num_aug_splits = args.aug_splits
if args.split_bn: if args.split_bn:
assert num_aug_splits > 1 or args.resplit assert num_aug_splits > 1 or args.resplit

Loading…
Cancel
Save