|
|
|
@ -79,10 +79,16 @@ class Trainer():
|
|
|
|
|
def save(self, ):
|
|
|
|
|
if self.args.global_rank == 0:
|
|
|
|
|
print(f'\nsaving {self.iteration} model to {self.args.save_dir} ...')
|
|
|
|
|
torch.save(self.netG.module.state_dict(),
|
|
|
|
|
os.path.join(self.args.save_dir, f'G{str(self.iteration).zfill(7)}.pt'))
|
|
|
|
|
torch.save(self.netD.module.state_dict(),
|
|
|
|
|
os.path.join(self.args.save_dir, f'D{str(self.iteration).zfill(7)}.pt'))
|
|
|
|
|
if self.args.distributed:
|
|
|
|
|
torch.save(self.netG.module.state_dict(),
|
|
|
|
|
os.path.join(self.args.save_dir, f'G{str(self.iteration).zfill(7)}.pt'))
|
|
|
|
|
torch.save(self.netD.module.state_dict(),
|
|
|
|
|
os.path.join(self.args.save_dir, f'D{str(self.iteration).zfill(7)}.pt'))
|
|
|
|
|
else:
|
|
|
|
|
torch.save(self.netG.state_dict(),
|
|
|
|
|
os.path.join(self.args.save_dir, f'G{str(self.iteration).zfill(7)}.pt'))
|
|
|
|
|
torch.save(self.netD.state_dict(),
|
|
|
|
|
os.path.join(self.args.save_dir, f'D{str(self.iteration).zfill(7)}.pt'))
|
|
|
|
|
torch.save(
|
|
|
|
|
{'optimG': self.optimG.state_dict(), 'optimD': self.optimD.state_dict()},
|
|
|
|
|
os.path.join(self.args.save_dir, f'O{str(self.iteration).zfill(7)}.pt'))
|
|
|
|
|