diff --git a/timm/data/transforms.py b/timm/data/transforms.py index d4f67bf9..33911638 100644 --- a/timm/data/transforms.py +++ b/timm/data/transforms.py @@ -179,8 +179,12 @@ def transforms_imagenet_train( transforms.RandomHorizontalFlip() ] if auto_augment: + if isinstance(img_size, tuple): + img_size_min = min(img_size) + else: + img_size_min = img_size aa_params = dict( - translate_const=img_size[-1] // 2 - 1, + translate_const=int(img_size_min * 0.45), img_mean=tuple([min(255, round(255 * x)) for x in mean]), ) if interpolation and interpolation != 'random':