From ed41d3263764e6bf3aa8d34ab17a9317bdd9c8ea Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 28 Oct 2021 17:33:36 -0700 Subject: [PATCH] Add repr to auto_augment and random_erasing impl --- timm/data/auto_augment.py | 32 ++++++++++++++++++++++++++++++++ timm/data/random_erasing.py | 14 ++++++++++---- 2 files changed, 42 insertions(+), 4 deletions(-) diff --git a/timm/data/auto_augment.py b/timm/data/auto_augment.py index 7d80d702..8907e504 100644 --- a/timm/data/auto_augment.py +++ b/timm/data/auto_augment.py @@ -316,6 +316,7 @@ class AugmentOp: def __init__(self, name, prob=0.5, magnitude=10, hparams=None): hparams = hparams or _HPARAMS_DEFAULT + self.name = name self.aug_fn = NAME_TO_OP[name] self.level_fn = LEVEL_TO_ARG[name] self.prob = prob @@ -351,6 +352,14 @@ class AugmentOp: level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple() return self.aug_fn(img, *level_args, **self.kwargs) + def __repr__(self): + fs = self.__class__.__name__ + f'(name={self.name}, p={self.prob}' + fs += f', m={self.magnitude}, mstd={self.magnitude_std}' + if self.magnitude_max is not None: + fs += f', mmax={self.magnitude_max}' + fs += ')' + return fs + def auto_augment_policy_v0(hparams): # ImageNet v0 policy from TPU EfficientNet impl, cannot find a paper reference. @@ -510,6 +519,15 @@ class AutoAugment: img = op(img) return img + def __repr__(self): + fs = self.__class__.__name__ + f'(policy=' + for p in self.policy: + fs += '\n\t[' + fs += ', '.join([str(op) for op in p]) + fs += ']' + fs += ')' + return fs + def auto_augment_transform(config_str, hparams): """ @@ -634,6 +652,13 @@ class RandAugment: img = op(img) return img + def __repr__(self): + fs = self.__class__.__name__ + f'(n={self.num_layers}, ops=' + for op in self.ops: + fs += f'\n\t{op}' + fs += ')' + return fs + def rand_augment_transform(config_str, hparams): """ @@ -782,6 +807,13 @@ class AugMixAugment: mixed = self._apply_basic(img, mixing_weights, m) return mixed + def __repr__(self): + fs = self.__class__.__name__ + f'(alpha={self.alpha}, width={self.width}, depth={self.depth}, ops=' + for op in self.ops: + fs += f'\n\t{op}' + fs += ')' + return fs + def augment_and_mix_transform(config_str, hparams): """ Create AugMix PyTorch transform diff --git a/timm/data/random_erasing.py b/timm/data/random_erasing.py index 78967d10..2fa63153 100644 --- a/timm/data/random_erasing.py +++ b/timm/data/random_erasing.py @@ -54,15 +54,15 @@ class RandomErasing: self.min_count = min_count self.max_count = max_count or min_count self.num_splits = num_splits - mode = mode.lower() + self.mode = mode.lower() self.rand_color = False self.per_pixel = False - if mode == 'rand': + if self.mode == 'rand': self.rand_color = True # per block random normal - elif mode == 'pixel': + elif self.mode == 'pixel': self.per_pixel = True # per pixel random normal else: - assert not mode or mode == 'const' + assert not self.mode or self.mode == 'const' self.device = device def _erase(self, img, chan, img_h, img_w, dtype): @@ -95,3 +95,9 @@ class RandomErasing: for i in range(batch_start, batch_size): self._erase(input[i], chan, img_h, img_w, input.dtype) return input + + def __repr__(self): + # NOTE simplified state for repr + fs = self.__class__.__name__ + f'(p={self.probability}, mode={self.mode}' + fs += f', count=({self.min_count}, {self.max_count}))' + return fs