import torch
import torch.nn as nn
class BaseNetwork(nn.Module):
def __init__(self):
super(BaseNetwork, self).__init__()
def print_network(self):
if isinstance(self, list):
self = self[0]
num_params = 0
for param in self.parameters():
num_params += param.numel()
print('Network [%s] was created. Total number of parameters: %.1f million. '
'To see the architecture, do print(network).' % (type(self).__name__, num_params / 1000000))
def init_weights(self, init_type='normal', gain=0.02):
initialize network's weights
init_type: normal | xavier | kaiming | orthogonal
def init_func(m):
classname = m.__class__.__name__
if classname.find('InstanceNorm2d') != -1:
if hasattr(m, 'weight') and m.weight is not None:
nn.init.constant_(, 1.0)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(, 0.0)
elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
nn.init.normal_(, 0.0, gain)
elif init_type == 'xavier':
nn.init.xavier_normal_(, gain=gain)
elif init_type == 'xavier_uniform':
nn.init.xavier_uniform_(, gain=1.0)
elif init_type == 'kaiming':
nn.init.kaiming_normal_(, a=0, mode='fan_in')
elif init_type == 'orthogonal':
nn.init.orthogonal_(, gain=gain)
elif init_type == 'none': # uses pytorch's default init method
raise NotImplementedError(
'initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(, 0.0)
# propagate to children
for m in self.children():
if hasattr(m, 'init_weights'):
m.init_weights(init_type, gain)