Merge pull request #571 from normster/augmix-fix

Enable uniform augmentation magnitude sampling and set AugMix default
pull/581/head
Ross Wightman 3 years ago committed by GitHub
commit 9a1bd358c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -332,14 +332,18 @@ class AugmentOp:
# in the usually fixed policy and sample magnitude from a normal distribution # in the usually fixed policy and sample magnitude from a normal distribution
# with mean `magnitude` and std-dev of `magnitude_std`. # with mean `magnitude` and std-dev of `magnitude_std`.
# NOTE This is my own hack, being tested, not in papers or reference impls. # NOTE This is my own hack, being tested, not in papers or reference impls.
# If magnitude_std is inf, we sample magnitude from a uniform distribution
self.magnitude_std = self.hparams.get('magnitude_std', 0) self.magnitude_std = self.hparams.get('magnitude_std', 0)
def __call__(self, img): def __call__(self, img):
if self.prob < 1.0 and random.random() > self.prob: if self.prob < 1.0 and random.random() > self.prob:
return img return img
magnitude = self.magnitude magnitude = self.magnitude
if self.magnitude_std and self.magnitude_std > 0: if self.magnitude_std:
magnitude = random.gauss(magnitude, self.magnitude_std) if self.magnitude_std == float('inf'):
magnitude = random.uniform(0, magnitude)
elif self.magnitude_std > 0:
magnitude = random.gauss(magnitude, self.magnitude_std)
magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range
level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple() level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple()
return self.aug_fn(img, *level_args, **self.kwargs) return self.aug_fn(img, *level_args, **self.kwargs)
@ -790,6 +794,7 @@ def augment_and_mix_transform(config_str, hparams):
depth = -1 depth = -1
alpha = 1. alpha = 1.
blended = False blended = False
hparams['magnitude_std'] = float('inf')
config = config_str.split('-') config = config_str.split('-')
assert config[0] == 'augmix' assert config[0] == 'augmix'
config = config[1:] config = config[1:]

Loading…
Cancel
Save