You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

105 lines
3.2 KiB

import torch
import torch.nn as nn
import torch.nn.functional as F
from .common import VGG19, gaussian_blur
class L1():
def __init__(self,):
self.calc = torch.nn.L1Loss()
def __call__(self, x, y):
return self.calc(x, y)
class Perceptual(nn.Module):
def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0]):
super(Perceptual, self).__init__()
self.vgg = VGG19().cuda()
self.criterion = torch.nn.L1Loss()
self.weights = weights
def __call__(self, x, y):
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
content_loss = 0.0
prefix = [1, 2, 3, 4, 5]
for i in range(5):
content_loss += self.weights[i] * self.criterion(
x_vgg[f'relu{prefix[i]}_1'], y_vgg[f'relu{prefix[i]}_1'])
return content_loss
class Style(nn.Module):
def __init__(self):
super(Style, self).__init__()
self.vgg = VGG19().cuda()
self.criterion = torch.nn.L1Loss()
def compute_gram(self, x):
b, c, h, w = x.size()
f = x.view(b, c, w * h)
f_T = f.transpose(1, 2)
G = f.bmm(f_T) / (h * w * c)
return G
def __call__(self, x, y):
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
style_loss = 0.0
prefix = [2, 3, 4, 5]
posfix = [2, 4, 4, 2]
for pre, pos in list(zip(prefix, posfix)):
style_loss += self.criterion(
self.compute_gram(x_vgg[f'relu{pre}_{pos}']), self.compute_gram(y_vgg[f'relu{pre}_{pos}']))
return style_loss
class nsgan():
def __init__(self, ):
self.loss_fn = torch.nn.Softplus()
def __call__(self, netD, fake, real):
fake_detach = fake.detach()
d_fake = netD(fake_detach)
d_real = netD(real)
dis_loss = self.loss_fn(-d_real).mean() + self.loss_fn(d_fake).mean()
g_fake = netD(fake)
gen_loss = self.loss_fn(-g_fake).mean()
return dis_loss, gen_loss
class smgan():
def __init__(self, ksize=71):
self.ksize = ksize
self.loss_fn = nn.MSELoss()
def __call__(self, netD, fake, real, masks):
fake_detach = fake.detach()
g_fake = netD(fake)
d_fake = netD(fake_detach)
d_real = netD(real)
_, _, h, w = g_fake.size()
b, c, ht, wt = masks.size()
# Handle inconsistent size between outputs and masks
if h != ht or w != wt:
g_fake = F.interpolate(g_fake, size=(ht, wt), mode='bilinear', align_corners=True)
d_fake = F.interpolate(d_fake, size=(ht, wt), mode='bilinear', align_corners=True)
d_real = F.interpolate(d_real, size=(ht, wt), mode='bilinear', align_corners=True)
d_fake_label = gaussian_blur(masks, (self.ksize, self.ksize), (10, 10)).detach().cuda()
d_real_label = torch.zeros_like(d_real).cuda()
g_fake_label = torch.ones_like(g_fake).cuda()
dis_loss = self.loss_fn(d_fake, d_fake_label) + self.loss_fn(d_real, d_real_label)
gen_loss = self.loss_fn(g_fake, g_fake_label) * masks / torch.mean(masks)
return dis_loss.mean(), gen_loss.mean()