|
|
@ -332,13 +332,17 @@ class AugmentOp:
|
|
|
|
# in the usually fixed policy and sample magnitude from a normal distribution
|
|
|
|
# in the usually fixed policy and sample magnitude from a normal distribution
|
|
|
|
# with mean `magnitude` and std-dev of `magnitude_std`.
|
|
|
|
# with mean `magnitude` and std-dev of `magnitude_std`.
|
|
|
|
# NOTE This is my own hack, being tested, not in papers or reference impls.
|
|
|
|
# NOTE This is my own hack, being tested, not in papers or reference impls.
|
|
|
|
|
|
|
|
# If magnitude_std is inf, we sample magnitude from a uniform distribution
|
|
|
|
self.magnitude_std = self.hparams.get('magnitude_std', 0)
|
|
|
|
self.magnitude_std = self.hparams.get('magnitude_std', 0)
|
|
|
|
|
|
|
|
|
|
|
|
def __call__(self, img):
|
|
|
|
def __call__(self, img):
|
|
|
|
if self.prob < 1.0 and random.random() > self.prob:
|
|
|
|
if self.prob < 1.0 and random.random() > self.prob:
|
|
|
|
return img
|
|
|
|
return img
|
|
|
|
magnitude = self.magnitude
|
|
|
|
magnitude = self.magnitude
|
|
|
|
if self.magnitude_std and self.magnitude_std > 0:
|
|
|
|
if self.magnitude_std:
|
|
|
|
|
|
|
|
if self.magnitude_std == float('inf'):
|
|
|
|
|
|
|
|
magnitude = random.uniform(0, magnitude)
|
|
|
|
|
|
|
|
elif self.magnitude_std > 0:
|
|
|
|
magnitude = random.gauss(magnitude, self.magnitude_std)
|
|
|
|
magnitude = random.gauss(magnitude, self.magnitude_std)
|
|
|
|
magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range
|
|
|
|
magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range
|
|
|
|
level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple()
|
|
|
|
level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple()
|
|
|
@ -790,6 +794,7 @@ def augment_and_mix_transform(config_str, hparams):
|
|
|
|
depth = -1
|
|
|
|
depth = -1
|
|
|
|
alpha = 1.
|
|
|
|
alpha = 1.
|
|
|
|
blended = False
|
|
|
|
blended = False
|
|
|
|
|
|
|
|
hparams['magnitude_std'] = float('inf')
|
|
|
|
config = config_str.split('-')
|
|
|
|
config = config_str.split('-')
|
|
|
|
assert config[0] == 'augmix'
|
|
|
|
assert config[0] == 'augmix'
|
|
|
|
config = config[1:]
|
|
|
|
config = config[1:]
|
|
|
|