You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
154 lines
6.3 KiB
154 lines
6.3 KiB
import os
|
|
import importlib
|
|
from tqdm import tqdm
|
|
from glob import glob
|
|
|
|
import torch
|
|
import torch.optim as optim
|
|
from torchvision.utils import make_grid
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
|
|
from data import create_loader
|
|
from loss import loss as loss_module
|
|
from .common import timer, reduce_loss_dict
|
|
|
|
|
|
class Trainer():
|
|
def __init__(self, args):
|
|
self.args = args
|
|
self.iteration = 0
|
|
|
|
# setup data set and data loader
|
|
self.dataloader = create_loader(args)
|
|
|
|
# set up losses and metrics
|
|
self.rec_loss_func = {
|
|
key: getattr(loss_module, key)() for key, val in args.rec_loss.items()}
|
|
self.adv_loss = getattr(loss_module, args.gan_type)()
|
|
|
|
# Image generator input: [rgb(3) + mask(1)], discriminator input: [rgb(3)]
|
|
net = importlib.import_module('model.'+args.model)
|
|
|
|
self.netG = net.InpaintGenerator(args).cuda()
|
|
self.optimG = torch.optim.Adam(
|
|
self.netG.parameters(), lr=args.lrg, betas=(args.beta1, args.beta2))
|
|
|
|
self.netD = net.Discriminator().cuda()
|
|
self.optimD = torch.optim.Adam(
|
|
self.netD.parameters(), lr=args.lrd, betas=(args.beta1, args.beta2))
|
|
|
|
self.load()
|
|
if args.distributed:
|
|
self.netG = DDP(self.netG, device_ids= [args.local_rank], output_device=[args.local_rank])
|
|
self.netD = DDP(self.netD, device_ids= [args.local_rank], output_device=[args.local_rank])
|
|
|
|
if args.tensorboard:
|
|
self.writer = SummaryWriter(os.path.join(args.save_dir, 'log'))
|
|
|
|
|
|
def load(self):
|
|
try:
|
|
gpath = sorted(list(glob(os.path.join(self.args.save_dir, 'G*.pt'))))[-1]
|
|
self.netG.load_state_dict(torch.load(gpath, map_location='cuda'))
|
|
self.iteration = int(os.path.basename(gpath)[1:-3])
|
|
if self.args.global_rank == 0:
|
|
print(f'[**] Loading generator network from {gpath}')
|
|
except:
|
|
pass
|
|
|
|
try:
|
|
dpath = sorted(list(glob(os.path.join(self.args.save_dir, 'D*.pt'))))[-1]
|
|
self.netD.load_state_dict(torch.load(dpath, map_location='cuda'))
|
|
if self.args.global_rank == 0:
|
|
print(f'[**] Loading discriminator network from {dpath}')
|
|
except:
|
|
pass
|
|
|
|
try:
|
|
opath = sorted(list(glob(os.path.join(self.args.save_dir, 'O*.pt'))))[-1]
|
|
data = torch.load(opath, map_location='cuda')
|
|
self.optimG.load_state_dict(data['optimG'])
|
|
self.optimD.load_state_dict(data['optimD'])
|
|
if self.args.global_rank == 0:
|
|
print(f'[**] Loading optimizer from {opath}')
|
|
except:
|
|
pass
|
|
|
|
|
|
def save(self, ):
|
|
if self.args.global_rank == 0:
|
|
print(f'\nsaving {self.iteration} model to {self.args.save_dir} ...')
|
|
torch.save(self.netG.module.state_dict(),
|
|
os.path.join(self.args.save_dir, f'G{str(self.iteration).zfill(7)}.pt'))
|
|
torch.save(self.netD.module.state_dict(),
|
|
os.path.join(self.args.save_dir, f'D{str(self.iteration).zfill(7)}.pt'))
|
|
torch.save(
|
|
{'optimG': self.optimG.state_dict(), 'optimD': self.optimD.state_dict()},
|
|
os.path.join(self.args.save_dir, f'O{str(self.iteration).zfill(7)}.pt'))
|
|
|
|
|
|
def train(self):
|
|
pbar = range(self.iteration, self.args.iterations)
|
|
if self.args.global_rank == 0:
|
|
pbar = tqdm(range(self.args.iterations), initial=self.iteration, dynamic_ncols=True, smoothing=0.01)
|
|
timer_data, timer_model = timer(), timer()
|
|
|
|
for idx in pbar:
|
|
self.iteration += 1
|
|
images, masks, filename = next(self.dataloader)
|
|
images, masks = images.cuda(), masks.cuda()
|
|
images_masked = (images * (1 - masks).float()) + masks
|
|
|
|
if self.args.global_rank == 0:
|
|
timer_data.hold()
|
|
timer_model.tic()
|
|
|
|
# in: [rgb(3) + edge(1)]
|
|
pred_img = self.netG(images_masked, masks)
|
|
comp_img = (1 - masks) * images + masks * pred_img
|
|
|
|
# reconstruction losses
|
|
losses = {}
|
|
for name, weight in self.args.rec_loss.items():
|
|
losses[name] = weight * self.rec_loss_func[name](pred_img, images)
|
|
|
|
# adversarial loss
|
|
dis_loss, gen_loss = self.adv_loss(self.netD, comp_img, images, masks)
|
|
losses[f"advg"] = gen_loss * self.args.adv_weight
|
|
|
|
# backforward
|
|
self.optimG.zero_grad()
|
|
self.optimD.zero_grad()
|
|
sum(losses.values()).backward()
|
|
losses[f"advd"] = dis_loss
|
|
dis_loss.backward()
|
|
self.optimG.step()
|
|
self.optimD.step()
|
|
|
|
if self.args.global_rank == 0:
|
|
timer_model.hold()
|
|
timer_data.tic()
|
|
|
|
# logs
|
|
scalar_reduced = reduce_loss_dict(losses, self.args.world_size)
|
|
if self.args.global_rank == 0 and (self.iteration % self.args.print_every == 0):
|
|
pbar.update(self.args.print_every)
|
|
description = f'mt:{timer_model.release():.1f}s, dt:{timer_data.release():.1f}s, '
|
|
for key, val in losses.items():
|
|
description += f'{key}:{val.item():.3f}, '
|
|
if self.args.tensorboard:
|
|
self.writer.add_scalar(key, val.item(), self.iteration)
|
|
pbar.set_description((description))
|
|
if self.args.tensorboard:
|
|
self.writer.add_image('mask', make_grid(masks), self.iteration)
|
|
self.writer.add_image('orig', make_grid((images+1.0)/2.0), self.iteration)
|
|
self.writer.add_image('pred', make_grid((pred_img+1.0)/2.0), self.iteration)
|
|
self.writer.add_image('comp', make_grid((comp_img+1.0)/2.0), self.iteration)
|
|
|
|
|
|
if self.args.global_rank == 0 and (self.iteration % self.args.save_every) == 0:
|
|
self.save()
|
|
|
|
|