Enable uniform augmentation magnitude sampling and set AugMix default

pull/571/head
Norman Mu 4 years ago
parent c1cf9712fc
commit 79640fcc1f

@ -332,13 +332,17 @@ class AugmentOp:
# in the usually fixed policy and sample magnitude from a normal distribution # in the usually fixed policy and sample magnitude from a normal distribution
# with mean `magnitude` and std-dev of `magnitude_std`. # with mean `magnitude` and std-dev of `magnitude_std`.
# NOTE This is my own hack, being tested, not in papers or reference impls. # NOTE This is my own hack, being tested, not in papers or reference impls.
# If magnitude_std is inf, we sample magnitude from a uniform distribution
self.magnitude_std = self.hparams.get('magnitude_std', 0) self.magnitude_std = self.hparams.get('magnitude_std', 0)
def __call__(self, img): def __call__(self, img):
if self.prob < 1.0 and random.random() > self.prob: if self.prob < 1.0 and random.random() > self.prob:
return img return img
magnitude = self.magnitude magnitude = self.magnitude
if self.magnitude_std and self.magnitude_std > 0: if self.magnitude_std:
if self.magnitude_std == float('inf'):
magnitude = random.uniform(0, magnitude)
elif self.magnitude_std > 0:
magnitude = random.gauss(magnitude, self.magnitude_std) magnitude = random.gauss(magnitude, self.magnitude_std)
magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range
level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple() level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple()
@ -790,6 +794,7 @@ def augment_and_mix_transform(config_str, hparams):
depth = -1 depth = -1
alpha = 1. alpha = 1.
blended = False blended = False
hparams['magnitude_std'] = float('inf')
config = config_str.split('-') config = config_str.split('-')
assert config[0] == 'augmix' assert config[0] == 'augmix'
config = config[1:] config = config[1:]

Loading…
Cancel
Save