Working on an implementation of AugMix with JensenShannonDivergence loss that's compatible with my AutoAugment and RandAugment impl
parent
ff8688ca3d
commit
232ab7fb12
@ -0,0 +1,164 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
|
||||
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT
|
||||
from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform
|
||||
from timm.data.transforms import _pil_interp, RandomResizedCropAndInterpolation, ToNumpy, ToTensor
|
||||
from timm.data.random_erasing import RandomErasing
|
||||
|
||||
|
||||
def transforms_imagenet_train(
|
||||
img_size=224,
|
||||
scale=(0.08, 1.0),
|
||||
color_jitter=0.4,
|
||||
auto_augment=None,
|
||||
interpolation='random',
|
||||
random_erasing=0.4,
|
||||
random_erasing_mode='const',
|
||||
use_prefetcher=False,
|
||||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD,
|
||||
separate=False,
|
||||
):
|
||||
|
||||
primary_tfl = [
|
||||
RandomResizedCropAndInterpolation(
|
||||
img_size, scale=scale, interpolation=interpolation),
|
||||
transforms.RandomHorizontalFlip()
|
||||
]
|
||||
|
||||
secondary_tfl = []
|
||||
if auto_augment:
|
||||
assert isinstance(auto_augment, str)
|
||||
if isinstance(img_size, tuple):
|
||||
img_size_min = min(img_size)
|
||||
else:
|
||||
img_size_min = img_size
|
||||
aa_params = dict(
|
||||
translate_const=int(img_size_min * 0.45),
|
||||
img_mean=tuple([min(255, round(255 * x)) for x in mean]),
|
||||
)
|
||||
if interpolation and interpolation != 'random':
|
||||
aa_params['interpolation'] = _pil_interp(interpolation)
|
||||
if auto_augment.startswith('rand'):
|
||||
secondary_tfl += [rand_augment_transform(auto_augment, aa_params)]
|
||||
elif auto_augment.startswith('augmix'):
|
||||
aa_params['translate_pct'] = 0.3
|
||||
secondary_tfl += [augment_and_mix_transform(auto_augment, aa_params)]
|
||||
else:
|
||||
secondary_tfl += [auto_augment_transform(auto_augment, aa_params)]
|
||||
elif color_jitter is not None:
|
||||
# color jitter is enabled when not using AA
|
||||
if isinstance(color_jitter, (list, tuple)):
|
||||
# color jitter should be a 3-tuple/list if spec brightness/contrast/saturation
|
||||
# or 4 if also augmenting hue
|
||||
assert len(color_jitter) in (3, 4)
|
||||
else:
|
||||
# if it's a scalar, duplicate for brightness, contrast, and saturation, no hue
|
||||
color_jitter = (float(color_jitter),) * 3
|
||||
secondary_tfl += [transforms.ColorJitter(*color_jitter)]
|
||||
|
||||
final_tfl = []
|
||||
if use_prefetcher:
|
||||
# prefetcher and collate will handle tensor conversion and norm
|
||||
final_tfl += [ToNumpy()]
|
||||
else:
|
||||
final_tfl += [
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=torch.tensor(mean),
|
||||
std=torch.tensor(std))
|
||||
]
|
||||
if random_erasing > 0.:
|
||||
final_tfl.append(RandomErasing(random_erasing, mode=random_erasing_mode, device='cpu'))
|
||||
|
||||
if separate:
|
||||
return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl)
|
||||
else:
|
||||
return transforms.Compose(primary_tfl + secondary_tfl + final_tfl)
|
||||
|
||||
|
||||
def transforms_imagenet_eval(
|
||||
img_size=224,
|
||||
crop_pct=None,
|
||||
interpolation='bilinear',
|
||||
use_prefetcher=False,
|
||||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD):
|
||||
crop_pct = crop_pct or DEFAULT_CROP_PCT
|
||||
|
||||
if isinstance(img_size, tuple):
|
||||
assert len(img_size) == 2
|
||||
if img_size[-1] == img_size[-2]:
|
||||
# fall-back to older behaviour so Resize scales to shortest edge if target is square
|
||||
scale_size = int(math.floor(img_size[0] / crop_pct))
|
||||
else:
|
||||
scale_size = tuple([int(x / crop_pct) for x in img_size])
|
||||
else:
|
||||
scale_size = int(math.floor(img_size / crop_pct))
|
||||
|
||||
tfl = [
|
||||
transforms.Resize(scale_size, _pil_interp(interpolation)),
|
||||
transforms.CenterCrop(img_size),
|
||||
]
|
||||
if use_prefetcher:
|
||||
# prefetcher and collate will handle tensor conversion and norm
|
||||
tfl += [ToNumpy()]
|
||||
else:
|
||||
tfl += [
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=torch.tensor(mean),
|
||||
std=torch.tensor(std))
|
||||
]
|
||||
|
||||
return transforms.Compose(tfl)
|
||||
|
||||
|
||||
def create_transform(
|
||||
input_size,
|
||||
is_training=False,
|
||||
use_prefetcher=False,
|
||||
color_jitter=0.4,
|
||||
auto_augment=None,
|
||||
interpolation='bilinear',
|
||||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD,
|
||||
crop_pct=None,
|
||||
tf_preprocessing=False,
|
||||
separate=False):
|
||||
|
||||
if isinstance(input_size, tuple):
|
||||
img_size = input_size[-2:]
|
||||
else:
|
||||
img_size = input_size
|
||||
|
||||
if tf_preprocessing and use_prefetcher:
|
||||
assert not separate, "Separate transforms not supported for TF preprocessing"
|
||||
from timm.data.tf_preprocessing import TfPreprocessTransform
|
||||
transform = TfPreprocessTransform(
|
||||
is_training=is_training, size=img_size, interpolation=interpolation)
|
||||
else:
|
||||
if is_training:
|
||||
transform = transforms_imagenet_train(
|
||||
img_size,
|
||||
color_jitter=color_jitter,
|
||||
auto_augment=auto_augment,
|
||||
interpolation=interpolation,
|
||||
use_prefetcher=use_prefetcher,
|
||||
mean=mean,
|
||||
std=std,
|
||||
separate=separate)
|
||||
else:
|
||||
assert not separate, "Separate transforms not supported for validation preprocessing"
|
||||
transform = transforms_imagenet_eval(
|
||||
img_size,
|
||||
interpolation=interpolation,
|
||||
use_prefetcher=use_prefetcher,
|
||||
mean=mean,
|
||||
std=std,
|
||||
crop_pct=crop_pct)
|
||||
|
||||
return transform
|
@ -1 +1,2 @@
|
||||
from .cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
|
||||
from .jsd import JsdCrossEntropy
|
@ -0,0 +1,34 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .cross_entropy import LabelSmoothingCrossEntropy
|
||||
|
||||
|
||||
class JsdCrossEntropy(nn.Module):
|
||||
""" Jenson-Shannon Divergence + Cross-Entropy Loss
|
||||
|
||||
"""
|
||||
def __init__(self, num_splits=3, alpha=12, smoothing=0.1):
|
||||
super().__init__()
|
||||
self.num_splits = num_splits
|
||||
self.alpha = alpha
|
||||
if smoothing is not None and smoothing > 0:
|
||||
self.cross_entropy_loss = LabelSmoothingCrossEntropy(smoothing)
|
||||
else:
|
||||
self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
|
||||
|
||||
def __call__(self, output, target):
|
||||
split_size = output.shape[0] // self.num_splits
|
||||
assert split_size * self.num_splits == output.shape[0]
|
||||
logits_split = torch.split(output, split_size)
|
||||
|
||||
# Cross-entropy is only computed on clean images
|
||||
loss = self.cross_entropy_loss(logits_split[0], target[:split_size])
|
||||
probs = [F.softmax(logits, dim=1) for logits in logits_split]
|
||||
|
||||
# Clamp mixture distribution to avoid exploding KL divergence
|
||||
logp_mixture = torch.clamp(torch.stack(probs).mean(axis=0), 1e-7, 1).log()
|
||||
loss += self.alpha * sum([F.kl_div(
|
||||
logp_mixture, p_split, reduction='batchmean') for p_split in probs]) / len(probs)
|
||||
return loss
|
Loading…
Reference in new issue