You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

154 lines
6.3 KiB

import os
import importlib
from tqdm import tqdm
from glob import glob
import torch
import torch.optim as optim
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter
from torch.nn.parallel import DistributedDataParallel as DDP
from data import create_loader
from loss import loss as loss_module
from .common import timer, reduce_loss_dict
class Trainer():
def __init__(self, args):
self.args = args
self.iteration = 0
# setup data set and data loader
self.dataloader = create_loader(args)
# set up losses and metrics
self.rec_loss_func = {
key: getattr(loss_module, key)() for key, val in args.rec_loss.items()}
self.adv_loss = getattr(loss_module, args.gan_type)()
# Image generator input: [rgb(3) + mask(1)], discriminator input: [rgb(3)]
net = importlib.import_module('model.'+args.model)
self.netG = net.InpaintGenerator(args).cuda()
self.optimG = torch.optim.Adam(
self.netG.parameters(), lr=args.lrg, betas=(args.beta1, args.beta2))
self.netD = net.Discriminator().cuda()
self.optimD = torch.optim.Adam(
self.netD.parameters(), lr=args.lrd, betas=(args.beta1, args.beta2))
self.load()
if args.distributed:
self.netG = DDP(self.netG, device_ids= [args.local_rank], output_device=[args.local_rank])
self.netD = DDP(self.netD, device_ids= [args.local_rank], output_device=[args.local_rank])
if args.tensorboard:
self.writer = SummaryWriter(os.path.join(args.save_dir, 'log'))
def load(self):
try:
gpath = sorted(list(glob(os.path.join(self.args.save_dir, 'G*.pt'))))[-1]
self.netG.load_state_dict(torch.load(gpath, map_location='cuda'))
self.iteration = int(os.path.basename(gpath)[1:-3])
if self.args.global_rank == 0:
print(f'[**] Loading generator network from {gpath}')
except:
pass
try:
dpath = sorted(list(glob(os.path.join(self.args.save_dir, 'D*.pt'))))[-1]
self.netD.load_state_dict(torch.load(dpath, map_location='cuda'))
if self.args.global_rank == 0:
print(f'[**] Loading discriminator network from {dpath}')
except:
pass
try:
opath = sorted(list(glob(os.path.join(self.args.save_dir, 'O*.pt'))))[-1]
data = torch.load(opath, map_location='cuda')
self.optimG.load_state_dict(data['optimG'])
self.optimD.load_state_dict(data['optimD'])
if self.args.global_rank == 0:
print(f'[**] Loading optimizer from {opath}')
except:
pass
def save(self, ):
if self.args.global_rank == 0:
print(f'\nsaving {self.iteration} model to {self.args.save_dir} ...')
torch.save(self.netG.module.state_dict(),
os.path.join(self.args.save_dir, f'G{str(self.iteration).zfill(7)}.pt'))
torch.save(self.netD.module.state_dict(),
os.path.join(self.args.save_dir, f'D{str(self.iteration).zfill(7)}.pt'))
torch.save(
{'optimG': self.optimG.state_dict(), 'optimD': self.optimD.state_dict()},
os.path.join(self.args.save_dir, f'O{str(self.iteration).zfill(7)}.pt'))
def train(self):
pbar = range(self.iteration, self.args.iterations)
if self.args.global_rank == 0:
pbar = tqdm(range(self.args.iterations), initial=self.iteration, dynamic_ncols=True, smoothing=0.01)
timer_data, timer_model = timer(), timer()
for idx in pbar:
self.iteration += 1
images, masks, filename = next(self.dataloader)
images, masks = images.cuda(), masks.cuda()
images_masked = (images * (1 - masks).float()) + masks
if self.args.global_rank == 0:
timer_data.hold()
timer_model.tic()
# in: [rgb(3) + edge(1)]
pred_img = self.netG(images_masked, masks)
comp_img = (1 - masks) * images + masks * pred_img
# reconstruction losses
losses = {}
for name, weight in self.args.rec_loss.items():
losses[name] = weight * self.rec_loss_func[name](pred_img, images)
# adversarial loss
dis_loss, gen_loss = self.adv_loss(self.netD, comp_img, images, masks)
losses[f"advg"] = gen_loss * self.args.adv_weight
# backforward
self.optimG.zero_grad()
self.optimD.zero_grad()
sum(losses.values()).backward()
losses[f"advd"] = dis_loss
dis_loss.backward()
self.optimG.step()
self.optimD.step()
if self.args.global_rank == 0:
timer_model.hold()
timer_data.tic()
# logs
scalar_reduced = reduce_loss_dict(losses, self.args.world_size)
if self.args.global_rank == 0 and (self.iteration % self.args.print_every == 0):
pbar.update(self.args.print_every)
description = f'mt:{timer_model.release():.1f}s, dt:{timer_data.release():.1f}s, '
for key, val in losses.items():
description += f'{key}:{val.item():.3f}, '
if self.args.tensorboard:
self.writer.add_scalar(key, val.item(), self.iteration)
pbar.set_description((description))
if self.args.tensorboard:
self.writer.add_image('mask', make_grid(masks), self.iteration)
self.writer.add_image('orig', make_grid((images+1.0)/2.0), self.iteration)
self.writer.add_image('pred', make_grid((pred_img+1.0)/2.0), self.iteration)
self.writer.add_image('comp', make_grid((comp_img+1.0)/2.0), self.iteration)
if self.args.global_rank == 0 and (self.iteration % self.args.save_every) == 0:
self.save()