From 79640fcc1f72241431b534420f2f7c9868157e93 Mon Sep 17 00:00:00 2001 From: Norman Mu Date: Mon, 19 Apr 2021 14:19:31 -0700 Subject: [PATCH] Enable uniform augmentation magnitude sampling and set AugMix default --- timm/data/auto_augment.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/timm/data/auto_augment.py b/timm/data/auto_augment.py index cbf5464d..7cbd2dee 100644 --- a/timm/data/auto_augment.py +++ b/timm/data/auto_augment.py @@ -332,14 +332,18 @@ class AugmentOp: # in the usually fixed policy and sample magnitude from a normal distribution # with mean `magnitude` and std-dev of `magnitude_std`. # NOTE This is my own hack, being tested, not in papers or reference impls. + # If magnitude_std is inf, we sample magnitude from a uniform distribution self.magnitude_std = self.hparams.get('magnitude_std', 0) def __call__(self, img): if self.prob < 1.0 and random.random() > self.prob: return img magnitude = self.magnitude - if self.magnitude_std and self.magnitude_std > 0: - magnitude = random.gauss(magnitude, self.magnitude_std) + if self.magnitude_std: + if self.magnitude_std == float('inf'): + magnitude = random.uniform(0, magnitude) + elif self.magnitude_std > 0: + magnitude = random.gauss(magnitude, self.magnitude_std) magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple() return self.aug_fn(img, *level_args, **self.kwargs) @@ -790,6 +794,7 @@ def augment_and_mix_transform(config_str, hparams): depth = -1 alpha = 1. blended = False + hparams['magnitude_std'] = float('inf') config = config_str.split('-') assert config[0] == 'augmix' config = config[1:]