Fix augmix variable name scope overlap, default non-blended mode

pull/74/head
Ross Wightman 5 years ago
parent 3afc2a4dc0
commit 3cc0f91e23

@ -691,7 +691,7 @@ class AugMixAugment:
self.alpha = alpha
self.width = width
self.depth = depth
self.blended = True
self.blended = False
def _calc_blended_weights(self, ws, m):
ws = ws * m
@ -703,13 +703,13 @@ class AugMixAugment:
rws.append(alpha)
return np.array(rws[::-1], dtype=np.float32)
def _apply_blended(self, img, ws, m):
def _apply_blended(self, img, mixing_weights, m):
# This is my first crack and implementing a slightly faster mixed augmentation. Instead
# of accumulating the mix for each chain in a Numpy array and then blending with original,
# it recomputes the blending coefficients and applies one PIL image blend per chain.
# TODO I've verified the results are in the right ballpark but they differ by more than rounding.
img_orig = img.copy()
ws = self._calc_blended_weights(ws, m)
ws = self._calc_blended_weights(mixing_weights, m)
for w in ws:
depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
ops = np.random.choice(self.ops, depth, replace=True)
@ -719,21 +719,19 @@ class AugMixAugment:
img = Image.blend(img, img_aug, w)
return img
def _apply_basic(self, img, ws, m):
def _apply_basic(self, img, mixing_weights, m):
# This is a literal adaptation of the paper/official implementation without normalizations and
# PIL <-> Numpy conversions between every op. It is still quite CPU compute heavy compared to the
# typical augmentation transforms, could use a GPU / Kornia implementation.
w, h = img.size
c = len(img.getbands())
mixed = np.zeros((w, h, c), dtype=np.float32)
for w in ws:
img_shape = img.size[0], img.size[1], len(img.getbands())
mixed = np.zeros(img_shape, dtype=np.float32)
for mw in mixing_weights:
depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
ops = np.random.choice(self.ops, depth, replace=True)
img_aug = img # no ops are in-place, deep copy not necessary
for op in ops:
img_aug = op(img_aug)
img_aug = np.asarray(img_aug, dtype=np.float32)
mixed += w * img_aug
mixed += mw * np.asarray(img_aug, dtype=np.float32)
np.clip(mixed, 0, 255., out=mixed)
mixed = Image.fromarray(mixed.astype(np.uint8))
return Image.blend(img, mixed, m)

Loading…
Cancel
Save