diff --git a/timm/data/auto_augment.py b/timm/data/auto_augment.py index 864ca6e0..ec2602b3 100644 --- a/timm/data/auto_augment.py +++ b/timm/data/auto_augment.py @@ -691,7 +691,7 @@ class AugMixAugment: self.alpha = alpha self.width = width self.depth = depth - self.blended = True + self.blended = False def _calc_blended_weights(self, ws, m): ws = ws * m @@ -703,13 +703,13 @@ class AugMixAugment: rws.append(alpha) return np.array(rws[::-1], dtype=np.float32) - def _apply_blended(self, img, ws, m): + def _apply_blended(self, img, mixing_weights, m): # This is my first crack and implementing a slightly faster mixed augmentation. Instead # of accumulating the mix for each chain in a Numpy array and then blending with original, # it recomputes the blending coefficients and applies one PIL image blend per chain. # TODO I've verified the results are in the right ballpark but they differ by more than rounding. img_orig = img.copy() - ws = self._calc_blended_weights(ws, m) + ws = self._calc_blended_weights(mixing_weights, m) for w in ws: depth = self.depth if self.depth > 0 else np.random.randint(1, 4) ops = np.random.choice(self.ops, depth, replace=True) @@ -719,21 +719,19 @@ class AugMixAugment: img = Image.blend(img, img_aug, w) return img - def _apply_basic(self, img, ws, m): + def _apply_basic(self, img, mixing_weights, m): # This is a literal adaptation of the paper/official implementation without normalizations and # PIL <-> Numpy conversions between every op. It is still quite CPU compute heavy compared to the # typical augmentation transforms, could use a GPU / Kornia implementation. - w, h = img.size - c = len(img.getbands()) - mixed = np.zeros((w, h, c), dtype=np.float32) - for w in ws: + img_shape = img.size[0], img.size[1], len(img.getbands()) + mixed = np.zeros(img_shape, dtype=np.float32) + for mw in mixing_weights: depth = self.depth if self.depth > 0 else np.random.randint(1, 4) ops = np.random.choice(self.ops, depth, replace=True) img_aug = img # no ops are in-place, deep copy not necessary for op in ops: img_aug = op(img_aug) - img_aug = np.asarray(img_aug, dtype=np.float32) - mixed += w * img_aug + mixed += mw * np.asarray(img_aug, dtype=np.float32) np.clip(mixed, 0, 255., out=mixed) mixed = Image.fromarray(mixed.astype(np.uint8)) return Image.blend(img, mixed, m)