Merge pull request #571 from normster/augmix-fix

Enable uniform augmentation magnitude sampling and set AugMix default
pull/581/head
Ross Wightman 3 years ago committed by GitHub
commit 9a1bd358c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -332,14 +332,18 @@ class AugmentOp:
# in the usually fixed policy and sample magnitude from a normal distribution
# with mean `magnitude` and std-dev of `magnitude_std`.
# NOTE This is my own hack, being tested, not in papers or reference impls.
# If magnitude_std is inf, we sample magnitude from a uniform distribution
self.magnitude_std = self.hparams.get('magnitude_std', 0)
def __call__(self, img):
if self.prob < 1.0 and random.random() > self.prob:
return img
magnitude = self.magnitude
if self.magnitude_std and self.magnitude_std > 0:
magnitude = random.gauss(magnitude, self.magnitude_std)
if self.magnitude_std:
if self.magnitude_std == float('inf'):
magnitude = random.uniform(0, magnitude)
elif self.magnitude_std > 0:
magnitude = random.gauss(magnitude, self.magnitude_std)
magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range
level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple()
return self.aug_fn(img, *level_args, **self.kwargs)
@ -790,6 +794,7 @@ def augment_and_mix_transform(config_str, hparams):
depth = -1
alpha = 1.
blended = False
hparams['magnitude_std'] = float('inf')
config = config_str.split('-')
assert config[0] == 'augmix'
config = config[1:]

Loading…
Cancel
Save