You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
35 lines
1.3 KiB
35 lines
1.3 KiB
5 years ago
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
|
||
|
from .cross_entropy import LabelSmoothingCrossEntropy
|
||
|
|
||
|
|
||
|
class JsdCrossEntropy(nn.Module):
|
||
|
""" Jenson-Shannon Divergence + Cross-Entropy Loss
|
||
|
|
||
|
"""
|
||
|
def __init__(self, num_splits=3, alpha=12, smoothing=0.1):
|
||
|
super().__init__()
|
||
|
self.num_splits = num_splits
|
||
|
self.alpha = alpha
|
||
|
if smoothing is not None and smoothing > 0:
|
||
|
self.cross_entropy_loss = LabelSmoothingCrossEntropy(smoothing)
|
||
|
else:
|
||
|
self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
|
||
|
|
||
|
def __call__(self, output, target):
|
||
|
split_size = output.shape[0] // self.num_splits
|
||
|
assert split_size * self.num_splits == output.shape[0]
|
||
|
logits_split = torch.split(output, split_size)
|
||
|
|
||
|
# Cross-entropy is only computed on clean images
|
||
|
loss = self.cross_entropy_loss(logits_split[0], target[:split_size])
|
||
|
probs = [F.softmax(logits, dim=1) for logits in logits_split]
|
||
|
|
||
|
# Clamp mixture distribution to avoid exploding KL divergence
|
||
|
logp_mixture = torch.clamp(torch.stack(probs).mean(axis=0), 1e-7, 1).log()
|
||
|
loss += self.alpha * sum([F.kl_div(
|
||
|
logp_mixture, p_split, reduction='batchmean') for p_split in probs]) / len(probs)
|
||
|
return loss
|