Add repr to auto_augment and random_erasing impl

more_datasets
Ross Wightman 3 years ago
parent 135a48d024
commit ed41d32637

@ -316,6 +316,7 @@ class AugmentOp:
def __init__(self, name, prob=0.5, magnitude=10, hparams=None): def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
hparams = hparams or _HPARAMS_DEFAULT hparams = hparams or _HPARAMS_DEFAULT
self.name = name
self.aug_fn = NAME_TO_OP[name] self.aug_fn = NAME_TO_OP[name]
self.level_fn = LEVEL_TO_ARG[name] self.level_fn = LEVEL_TO_ARG[name]
self.prob = prob self.prob = prob
@ -351,6 +352,14 @@ class AugmentOp:
level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple() level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple()
return self.aug_fn(img, *level_args, **self.kwargs) return self.aug_fn(img, *level_args, **self.kwargs)
def __repr__(self):
fs = self.__class__.__name__ + f'(name={self.name}, p={self.prob}'
fs += f', m={self.magnitude}, mstd={self.magnitude_std}'
if self.magnitude_max is not None:
fs += f', mmax={self.magnitude_max}'
fs += ')'
return fs
def auto_augment_policy_v0(hparams): def auto_augment_policy_v0(hparams):
# ImageNet v0 policy from TPU EfficientNet impl, cannot find a paper reference. # ImageNet v0 policy from TPU EfficientNet impl, cannot find a paper reference.
@ -510,6 +519,15 @@ class AutoAugment:
img = op(img) img = op(img)
return img return img
def __repr__(self):
fs = self.__class__.__name__ + f'(policy='
for p in self.policy:
fs += '\n\t['
fs += ', '.join([str(op) for op in p])
fs += ']'
fs += ')'
return fs
def auto_augment_transform(config_str, hparams): def auto_augment_transform(config_str, hparams):
""" """
@ -634,6 +652,13 @@ class RandAugment:
img = op(img) img = op(img)
return img return img
def __repr__(self):
fs = self.__class__.__name__ + f'(n={self.num_layers}, ops='
for op in self.ops:
fs += f'\n\t{op}'
fs += ')'
return fs
def rand_augment_transform(config_str, hparams): def rand_augment_transform(config_str, hparams):
""" """
@ -782,6 +807,13 @@ class AugMixAugment:
mixed = self._apply_basic(img, mixing_weights, m) mixed = self._apply_basic(img, mixing_weights, m)
return mixed return mixed
def __repr__(self):
fs = self.__class__.__name__ + f'(alpha={self.alpha}, width={self.width}, depth={self.depth}, ops='
for op in self.ops:
fs += f'\n\t{op}'
fs += ')'
return fs
def augment_and_mix_transform(config_str, hparams): def augment_and_mix_transform(config_str, hparams):
""" Create AugMix PyTorch transform """ Create AugMix PyTorch transform

@ -54,15 +54,15 @@ class RandomErasing:
self.min_count = min_count self.min_count = min_count
self.max_count = max_count or min_count self.max_count = max_count or min_count
self.num_splits = num_splits self.num_splits = num_splits
mode = mode.lower() self.mode = mode.lower()
self.rand_color = False self.rand_color = False
self.per_pixel = False self.per_pixel = False
if mode == 'rand': if self.mode == 'rand':
self.rand_color = True # per block random normal self.rand_color = True # per block random normal
elif mode == 'pixel': elif self.mode == 'pixel':
self.per_pixel = True # per pixel random normal self.per_pixel = True # per pixel random normal
else: else:
assert not mode or mode == 'const' assert not self.mode or self.mode == 'const'
self.device = device self.device = device
def _erase(self, img, chan, img_h, img_w, dtype): def _erase(self, img, chan, img_h, img_w, dtype):
@ -95,3 +95,9 @@ class RandomErasing:
for i in range(batch_start, batch_size): for i in range(batch_start, batch_size):
self._erase(input[i], chan, img_h, img_w, input.dtype) self._erase(input[i], chan, img_h, img_w, input.dtype)
return input return input
def __repr__(self):
# NOTE simplified state for repr
fs = self.__class__.__name__ + f'(p={self.probability}, mode={self.mode}'
fs += f', count=({self.min_count}, {self.max_count}))'
return fs

Loading…
Cancel
Save