Tweaking tanh scheduler, senet weight init (for BN), transform defaults

pull/1/head
Ross Wightman 5 years ago
parent 48360625f2
commit b5255960d9

@ -104,6 +104,18 @@ pretrained_config = {
}
def _weight_init(m, n='', ll=''):
print(m, n, ll)
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
if ll and n == ll:
nn.init.constant_(m.weight, 0.)
else:
nn.init.constant_(m.weight, 1.)
nn.init.constant_(m.bias, 0.)
class SEModule(nn.Module):
def __init__(self, channels, reduction):
@ -116,6 +128,9 @@ class SEModule(nn.Module):
channels // reduction, channels, kernel_size=1, padding=0)
self.sigmoid = nn.Sigmoid()
for m in self.modules():
_weight_init(m)
def forward(self, x):
module_input = x
x = self.avg_pool(x)
@ -176,6 +191,9 @@ class SEBottleneck(Bottleneck):
self.downsample = downsample
self.stride = stride
for n, m in self.named_modules():
_weight_init(m, n, ll='bn3')
class SEResNetBottleneck(Bottleneck):
"""
@ -201,6 +219,9 @@ class SEResNetBottleneck(Bottleneck):
self.downsample = downsample
self.stride = stride
for n, m in self.named_modules():
_weight_init(m, n, ll='bn3')
class SEResNeXtBottleneck(Bottleneck):
"""
@ -225,6 +246,9 @@ class SEResNeXtBottleneck(Bottleneck):
self.downsample = downsample
self.stride = stride
for n, m in self.named_modules():
_weight_init(m, n, ll='bn3')
class SEResNetBlock(nn.Module):
expansion = 1
@ -242,6 +266,9 @@ class SEResNetBlock(nn.Module):
self.downsample = downsample
self.stride = stride
for n, m in self.named_modules():
_weight_init(m, n, ll='bn2')
def forward(self, x):
residual = x
@ -378,6 +405,12 @@ class SENet(nn.Module):
self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None
self.last_linear = nn.Linear(512 * block.expansion, num_classes)
for n, m in self.named_children():
if n == 'layer0':
m.apply(_weight_init)
else:
_weight_init(m)
def _make_layer(self, block, planes, blocks, groups, reduction, stride=1,
downsample_kernel_size=1, downsample_padding=0):
downsample = None

@ -21,7 +21,7 @@ class LeNormalize(object):
return tensor
def transforms_imagenet_train(model_name, img_size=224, scale=(0.08, 1.0), color_jitter=(0.3, 0.3, 0.3)):
def transforms_imagenet_train(model_name, img_size=224, scale=(0.1, 1.0), color_jitter=(0.333, 0.333, 0.333)):
if 'dpn' in model_name:
normalize = transforms.Normalize(
mean=IMAGENET_DPN_MEAN,

@ -23,14 +23,20 @@ class TanhLRScheduler(Scheduler):
t_mul: float = 1.,
lr_min: float = 0.,
decay_rate: float = 1.,
warmup_updates=0,
warmup_t=0,
warmup_lr_init=0,
warmup_prefix=False,
cycle_limit=0,
t_in_epochs=False,
initialize=True) -> None:
super().__init__(optimizer, param_group_field="lr", initialize=initialize)
assert t_initial > 0
assert lr_min >= 0
assert lb < ub
assert cycle_limit >= 0
assert warmup_t >= 0
assert warmup_lr_init >= 0
self.lb = lb
self.ub = ub
self.t_initial = t_initial
@ -38,33 +44,33 @@ class TanhLRScheduler(Scheduler):
self.lr_min = lr_min
self.decay_rate = decay_rate
self.cycle_limit = cycle_limit
self.warmup_updates = warmup_updates
self.warmup_t = warmup_t
self.warmup_lr_init = warmup_lr_init
if self.warmup_updates:
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_updates for v in self.base_values]
self.warmup_prefix = warmup_prefix
self.t_in_epochs = t_in_epochs
if self.warmup_t:
t_v = self.base_values if self.warmup_prefix else self._get_lr(self.warmup_t)
print(t_v)
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in t_v]
super().update_groups(self.warmup_lr_init)
else:
self.warmup_steps = [1 for _ in self.base_values]
if self.warmup_lr_init:
super().update_groups(self.warmup_lr_init)
def get_epoch_values(self, epoch: int):
# this scheduler doesn't update on epoch
return None
def get_update_values(self, num_updates: int):
if num_updates < self.warmup_updates:
lrs = [self.warmup_lr_init + num_updates * s for s in self.warmup_steps]
def _get_lr(self, t):
if t < self.warmup_t:
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
else:
curr_updates = num_updates - self.warmup_updates
if self.warmup_prefix:
t = t - self.warmup_t
if self.t_mul != 1:
i = math.floor(math.log(1 - curr_updates / self.t_initial * (1 - self.t_mul), self.t_mul))
i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul))
t_i = self.t_mul ** i * self.t_initial
t_curr = curr_updates - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial
t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial
else:
i = curr_updates // self.t_initial
i = t // self.t_initial
t_i = self.t_initial
t_curr = curr_updates - (self.t_initial * i)
t_curr = t - (self.t_initial * i)
if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit):
gamma = self.decay_rate ** i
@ -78,5 +84,16 @@ class TanhLRScheduler(Scheduler):
]
else:
lrs = [self.lr_min * (self.decay_rate ** self.cycle_limit) for _ in self.base_values]
return lrs
def get_epoch_values(self, epoch: int):
if self.t_in_epochs:
return self._get_lr(epoch)
else:
return None
def get_update_values(self, num_updates: int):
if not self.t_in_epochs:
return self._get_lr(num_updates)
else:
return None

@ -162,7 +162,7 @@ def main():
if args.opt.lower() == 'sgd':
optimizer = optim.SGD(
model.parameters(), lr=args.lr,
momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
momentum=args.momentum, weight_decay=args.weight_decay, nesterov=False)
elif args.opt.lower() == 'adam':
optimizer = optim.Adam(
model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
@ -183,15 +183,27 @@ def main():
if optimizer_state is not None:
optimizer.load_state_dict(optimizer_state)
updates_per_epoch = len(loader_train)
if args.sched == 'cosine':
lr_scheduler = scheduler.CosineLRScheduler(
optimizer,
t_initial=13 * len(loader_train),
t_mul=2.0,
t_initial=100 * updates_per_epoch,
t_mul=1.0,
lr_min=0,
decay_rate=0.5,
warmup_lr_init=1e-4,
warmup_updates=len(loader_train)
warmup_updates=1 * updates_per_epoch
)
elif args.sched == 'tanh':
lr_scheduler = scheduler.TanhLRScheduler(
optimizer,
t_initial=80 * updates_per_epoch,
t_mul=1.0,
lr_min=1e-5,
decay_rate=0.5,
warmup_lr_init=.001,
warmup_t=5 * updates_per_epoch,
cycle_limit=1
)
else:
lr_scheduler = scheduler.StepLRScheduler(
@ -354,7 +366,7 @@ def validate(model, loader, loss_fn, args):
losses_m.update(loss.item(), input.size(0))
# metrics
prec1, prec5 = accuracy(output, target, topk=(1, 3))
prec1, prec5 = accuracy(output, target, topk=(1, 5))
prec1_m.update(prec1.item(), output.size(0))
prec5_m.update(prec5.item(), output.size(0))
@ -375,16 +387,5 @@ def validate(model, loader, loss_fn, args):
return metrics
def update_summary(epoch, train_metrics, eval_metrics, output_dir, write_header=False):
rowd = OrderedDict(epoch=epoch)
rowd.update(train_metrics)
rowd.update(eval_metrics)
with open(os.path.join(output_dir, 'summary.csv'), mode='a') as cf:
dw = csv.DictWriter(cf, fieldnames=rowd.keys())
if write_header: # first iteration (epoch == 1 can't be used)
dw.writeheader()
dw.writerow(rowd)
if __name__ == '__main__':
main()

Loading…
Cancel
Save