You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
114 lines
3.4 KiB
114 lines
3.4 KiB
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.nn.utils import spectral_norm
|
|
|
|
from .common import BaseNetwork
|
|
|
|
|
|
class InpaintGenerator(BaseNetwork):
|
|
def __init__(self, args): # 1046
|
|
super(InpaintGenerator, self).__init__()
|
|
|
|
self.encoder = nn.Sequential(
|
|
nn.ReflectionPad2d(3),
|
|
nn.Conv2d(4, 64, 7),
|
|
nn.ReLU(True),
|
|
nn.Conv2d(64, 128, 4, stride=2, padding=1),
|
|
nn.ReLU(True),
|
|
nn.Conv2d(128, 256, 4, stride=2, padding=1),
|
|
nn.ReLU(True)
|
|
)
|
|
|
|
self.middle = nn.Sequential(*[AOTBlock(256, args.rates) for _ in range(args.block_num)])
|
|
|
|
self.decoder = nn.Sequential(
|
|
UpConv(256, 128),
|
|
nn.ReLU(True),
|
|
UpConv(128, 64),
|
|
nn.ReLU(True),
|
|
nn.Conv2d(64, 3, 3, stride=1, padding=1)
|
|
)
|
|
|
|
self.init_weights()
|
|
|
|
def forward(self, x, mask):
|
|
x = torch.cat([x, mask], dim=1)
|
|
x = self.encoder(x)
|
|
x = self.middle(x)
|
|
x = self.decoder(x)
|
|
x = torch.tanh(x)
|
|
return x
|
|
|
|
|
|
class UpConv(nn.Module):
|
|
def __init__(self, inc, outc, scale=2):
|
|
super(UpConv, self).__init__()
|
|
self.scale = scale
|
|
self.conv = nn.Conv2d(inc, outc, 3, stride=1, padding=1)
|
|
|
|
def forward(self, x):
|
|
return self.conv(F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True))
|
|
|
|
|
|
class AOTBlock(nn.Module):
|
|
def __init__(self, dim, rates):
|
|
super(AOTBlock, self).__init__()
|
|
self.rates = rates
|
|
for i, rate in enumerate(rates):
|
|
self.__setattr__(
|
|
'block{}'.format(str(i).zfill(2)),
|
|
nn.Sequential(
|
|
nn.ReflectionPad2d(rate),
|
|
nn.Conv2d(dim, dim//4, 3, padding=0, dilation=rate),
|
|
nn.ReLU(True)))
|
|
self.fuse = nn.Sequential(
|
|
nn.ReflectionPad2d(1),
|
|
nn.Conv2d(dim, dim, 3, padding=0, dilation=1))
|
|
self.gate = nn.Sequential(
|
|
nn.ReflectionPad2d(1),
|
|
nn.Conv2d(dim, dim, 3, padding=0, dilation=1))
|
|
|
|
def forward(self, x):
|
|
out = [self.__getattr__(f'block{str(i).zfill(2)}')(x) for i in range(len(self.rates))]
|
|
out = torch.cat(out, 1)
|
|
out = self.fuse(out)
|
|
mask = my_layer_norm(self.gate(x))
|
|
mask = torch.sigmoid(mask)
|
|
return x * (1 - mask) + out * mask
|
|
|
|
|
|
def my_layer_norm(feat):
|
|
mean = feat.mean((2, 3), keepdim=True)
|
|
std = feat.std((2, 3), keepdim=True) + 1e-9
|
|
feat = 2 * (feat - mean) / std - 1
|
|
feat = 5 * feat
|
|
return feat
|
|
|
|
|
|
|
|
|
|
# ----- discriminator -----
|
|
class Discriminator(BaseNetwork):
|
|
def __init__(self, ):
|
|
super(Discriminator, self).__init__()
|
|
inc = 3
|
|
self.conv = nn.Sequential(
|
|
spectral_norm(nn.Conv2d(inc, 64, 4, stride=2, padding=1, bias=False)),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
spectral_norm(nn.Conv2d(64, 128, 4, stride=2, padding=1, bias=False)),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
spectral_norm(nn.Conv2d(128, 256, 4, stride=2, padding=1, bias=False)),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
spectral_norm(nn.Conv2d(256, 512, 4, stride=1, padding=1, bias=False)),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
nn.Conv2d(512, 1, 4, stride=1, padding=1)
|
|
)
|
|
|
|
self.init_weights()
|
|
|
|
def forward(self, x):
|
|
feat = self.conv(x)
|
|
return feat
|
|
|