You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

154 lines
6.3 KiB

3 years ago
import os
import importlib
from tqdm import tqdm
from glob import glob
import torch
import torch.optim as optim
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter
from torch.nn.parallel import DistributedDataParallel as DDP
from data import create_loader
from loss import loss as loss_module
from .common import timer, reduce_loss_dict
class Trainer():
def __init__(self, args):
self.args = args
self.iteration = 0
# setup data set and data loader
self.dataloader = create_loader(args)
# set up losses and metrics
self.rec_loss_func = {
key: getattr(loss_module, key)() for key, val in args.rec_loss.items()}
self.adv_loss = getattr(loss_module, args.gan_type)()
# Image generator input: [rgb(3) + mask(1)], discriminator input: [rgb(3)]
net = importlib.import_module('model.'+args.model)
self.netG = net.InpaintGenerator(args).cuda()
self.optimG = torch.optim.Adam(
self.netG.parameters(), lr=args.lrg, betas=(args.beta1, args.beta2))
self.netD = net.Discriminator().cuda()
self.optimD = torch.optim.Adam(
self.netD.parameters(), lr=args.lrd, betas=(args.beta1, args.beta2))
self.load()
if args.distributed:
self.netG = DDP(self.netG, device_ids= [args.local_rank], output_device=[args.local_rank])
self.netD = DDP(self.netD, device_ids= [args.local_rank], output_device=[args.local_rank])
if args.tensorboard:
self.writer = SummaryWriter(os.path.join(args.save_dir, 'log'))
def load(self):
try:
gpath = sorted(list(glob(os.path.join(self.args.save_dir, 'G*.pt'))))[-1]
self.netG.load_state_dict(torch.load(gpath, map_location='cuda'))
self.iteration = int(os.path.basename(gpath)[1:-3])
if self.args.global_rank == 0:
print(f'[**] Loading generator network from {gpath}')
except:
pass
try:
dpath = sorted(list(glob(os.path.join(self.args.save_dir, 'D*.pt'))))[-1]
self.netD.load_state_dict(torch.load(dpath, map_location='cuda'))
if self.args.global_rank == 0:
print(f'[**] Loading discriminator network from {dpath}')
except:
pass
try:
opath = sorted(list(glob(os.path.join(self.args.save_dir, 'O*.pt'))))[-1]
data = torch.load(opath, map_location='cuda')
self.optimG.load_state_dict(data['optimG'])
self.optimD.load_state_dict(data['optimD'])
if self.args.global_rank == 0:
print(f'[**] Loading optimizer from {opath}')
except:
pass
def save(self, ):
if self.args.global_rank == 0:
print(f'\nsaving {self.iteration} model to {self.args.save_dir} ...')
torch.save(self.netG.module.state_dict(),
os.path.join(self.args.save_dir, f'G{str(self.iteration).zfill(7)}.pt'))
torch.save(self.netD.module.state_dict(),
os.path.join(self.args.save_dir, f'D{str(self.iteration).zfill(7)}.pt'))
torch.save(
{'optimG': self.optimG.state_dict(), 'optimD': self.optimD.state_dict()},
os.path.join(self.args.save_dir, f'O{str(self.iteration).zfill(7)}.pt'))
def train(self):
pbar = range(self.iteration, self.args.iterations)
if self.args.global_rank == 0:
pbar = tqdm(range(self.args.iterations), initial=self.iteration, dynamic_ncols=True, smoothing=0.01)
timer_data, timer_model = timer(), timer()
for idx in pbar:
self.iteration += 1
images, masks, filename = next(self.dataloader)
images, masks = images.cuda(), masks.cuda()
images_masked = (images * (1 - masks).float()) + masks
if self.args.global_rank == 0:
timer_data.hold()
timer_model.tic()
# in: [rgb(3) + edge(1)]
pred_img = self.netG(images_masked, masks)
comp_img = (1 - masks) * images + masks * pred_img
# reconstruction losses
losses = {}
for name, weight in self.args.rec_loss.items():
losses[name] = weight * self.rec_loss_func[name](pred_img, images)
# adversarial loss
dis_loss, gen_loss = self.adv_loss(self.netD, comp_img, images, masks)
losses[f"advg"] = gen_loss * self.args.adv_weight
# backforward
self.optimG.zero_grad()
self.optimD.zero_grad()
sum(losses.values()).backward()
losses[f"advd"] = dis_loss
dis_loss.backward()
self.optimG.step()
self.optimD.step()
if self.args.global_rank == 0:
timer_model.hold()
timer_data.tic()
# logs
scalar_reduced = reduce_loss_dict(losses, self.args.world_size)
if self.args.global_rank == 0 and (self.iteration % self.args.print_every == 0):
pbar.update(self.args.print_every)
description = f'mt:{timer_model.release():.1f}s, dt:{timer_data.release():.1f}s, '
for key, val in losses.items():
description += f'{key}:{val.item():.3f}, '
if self.args.tensorboard:
self.writer.add_scalar(key, val.item(), self.iteration)
pbar.set_description((description))
if self.args.tensorboard:
self.writer.add_image('mask', make_grid(masks), self.iteration)
self.writer.add_image('orig', make_grid((images+1.0)/2.0), self.iteration)
self.writer.add_image('pred', make_grid((pred_img+1.0)/2.0), self.iteration)
self.writer.add_image('comp', make_grid((comp_img+1.0)/2.0), self.iteration)
if self.args.global_rank == 0 and (self.iteration % self.args.save_every) == 0:
self.save()