fix for non-distributed GPU configurations

pull/5/head
Deniz Ugur 3 years ago
parent 4180346273
commit a96e18a5fe

@ -79,10 +79,16 @@ class Trainer():
def save(self, ):
if self.args.global_rank == 0:
print(f'\nsaving {self.iteration} model to {self.args.save_dir} ...')
torch.save(self.netG.module.state_dict(),
os.path.join(self.args.save_dir, f'G{str(self.iteration).zfill(7)}.pt'))
torch.save(self.netD.module.state_dict(),
os.path.join(self.args.save_dir, f'D{str(self.iteration).zfill(7)}.pt'))
if self.args.distributed:
torch.save(self.netG.module.state_dict(),
os.path.join(self.args.save_dir, f'G{str(self.iteration).zfill(7)}.pt'))
torch.save(self.netD.module.state_dict(),
os.path.join(self.args.save_dir, f'D{str(self.iteration).zfill(7)}.pt'))
else:
torch.save(self.netG.state_dict(),
os.path.join(self.args.save_dir, f'G{str(self.iteration).zfill(7)}.pt'))
torch.save(self.netD.state_dict(),
os.path.join(self.args.save_dir, f'D{str(self.iteration).zfill(7)}.pt'))
torch.save(
{'optimG': self.optimG.state_dict(), 'optimD': self.optimD.state_dict()},
os.path.join(self.args.save_dir, f'O{str(self.iteration).zfill(7)}.pt'))

Loading…
Cancel
Save