diff --git a/timm/data/auto_augment.py b/timm/data/auto_augment.py index a148f771..04c0b60a 100644 --- a/timm/data/auto_augment.py +++ b/timm/data/auto_augment.py @@ -251,18 +251,22 @@ class AutoAugmentOp: self.level_fn = level_to_arg(hparams)[name] self.prob = prob self.magnitude = magnitude + # If std deviation of magnitude is > 0, we introduce some randomness + # in the usually fixed policy and sample magnitude from normal dist + # with mean magnitude and std-dev of magnitude_std. + # NOTE This is being tested as it's not in paper or reference impl. + self.magnitude_std = 0.5 # FIXME add arg/hparam self.kwargs = { 'fillcolor': hparams['img_mean'] if 'img_mean' in hparams else _FILL, 'resample': hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION } - self.rand_magnitude = True def __call__(self, img): if self.prob < random.random(): return img magnitude = self.magnitude - if self.rand_magnitude: - magnitude = random.normalvariate(magnitude, 0.5) + if self.magnitude_std and self.magnitude_std > 0: + magnitude = random.gauss(magnitude, self.magnitude_std) magnitude = min(_MAX_LEVEL, max(0, magnitude)) level_args = self.level_fn(magnitude) return self.aug_fn(img, *level_args, **self.kwargs)