You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

114 lines
3.4 KiB

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
from .common import BaseNetwork
class InpaintGenerator(BaseNetwork):
def __init__(self, args): # 1046
super(InpaintGenerator, self).__init__()
self.encoder = nn.Sequential(
nn.ReflectionPad2d(3),
nn.Conv2d(4, 64, 7),
nn.ReLU(True),
nn.Conv2d(64, 128, 4, stride=2, padding=1),
nn.ReLU(True),
nn.Conv2d(128, 256, 4, stride=2, padding=1),
nn.ReLU(True)
)
self.middle = nn.Sequential(*[AOTBlock(256, args.rates) for _ in range(args.block_num)])
self.decoder = nn.Sequential(
UpConv(256, 128),
nn.ReLU(True),
UpConv(128, 64),
nn.ReLU(True),
nn.Conv2d(64, 3, 3, stride=1, padding=1)
)
self.init_weights()
def forward(self, x, mask):
x = torch.cat([x, mask], dim=1)
x = self.encoder(x)
x = self.middle(x)
x = self.decoder(x)
x = torch.tanh(x)
return x
class UpConv(nn.Module):
def __init__(self, inc, outc, scale=2):
super(UpConv, self).__init__()
self.scale = scale
self.conv = nn.Conv2d(inc, outc, 3, stride=1, padding=1)
def forward(self, x):
return self.conv(F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True))
class AOTBlock(nn.Module):
def __init__(self, dim, rates):
super(AOTBlock, self).__init__()
self.rates = rates
for i, rate in enumerate(rates):
self.__setattr__(
'block{}'.format(str(i).zfill(2)),
nn.Sequential(
nn.ReflectionPad2d(rate),
nn.Conv2d(dim, dim//4, 3, padding=0, dilation=rate),
nn.ReLU(True)))
self.fuse = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(dim, dim, 3, padding=0, dilation=1))
self.gate = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(dim, dim, 3, padding=0, dilation=1))
def forward(self, x):
out = [self.__getattr__(f'block{str(i).zfill(2)}')(x) for i in range(len(self.rates))]
out = torch.cat(out, 1)
out = self.fuse(out)
mask = my_layer_norm(self.gate(x))
mask = torch.sigmoid(mask)
return x * (1 - mask) + out * mask
def my_layer_norm(feat):
mean = feat.mean((2, 3), keepdim=True)
std = feat.std((2, 3), keepdim=True) + 1e-9
feat = 2 * (feat - mean) / std - 1
feat = 5 * feat
return feat
# ----- discriminator -----
class Discriminator(BaseNetwork):
def __init__(self, ):
super(Discriminator, self).__init__()
inc = 3
self.conv = nn.Sequential(
spectral_norm(nn.Conv2d(inc, 64, 4, stride=2, padding=1, bias=False)),
nn.LeakyReLU(0.2, inplace=True),
spectral_norm(nn.Conv2d(64, 128, 4, stride=2, padding=1, bias=False)),
nn.LeakyReLU(0.2, inplace=True),
spectral_norm(nn.Conv2d(128, 256, 4, stride=2, padding=1, bias=False)),
nn.LeakyReLU(0.2, inplace=True),
spectral_norm(nn.Conv2d(256, 512, 4, stride=1, padding=1, bias=False)),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(512, 1, 4, stride=1, padding=1)
)
self.init_weights()
def forward(self, x):
feat = self.conv(x)
return feat