Cleanup tranforms, add custom schedulers, tweak senet34 model

pull/1/head
Ross Wightman 5 years ago
parent c57717d325
commit cf0c280e1b

@ -9,6 +9,7 @@ import re
import torch
from PIL import Image
IMG_EXTENSIONS = ['.png', '.jpg', '.jpeg']
@ -53,7 +54,7 @@ class Dataset(data.Dataset):
def __init__(
self,
root,
transform=None):
transform):
imgs, _, _ = find_images_and_targets(root)
if len(imgs) == 0:
@ -66,8 +67,7 @@ class Dataset(data.Dataset):
def __getitem__(self, index):
path, target = self.imgs[index]
img = Image.open(path).convert('RGB')
if self.transform is not None:
img = self.transform(img)
img = self.transform(img)
if target is None:
target = torch.zeros(1).long()
return img, target
@ -75,9 +75,6 @@ class Dataset(data.Dataset):
def __len__(self):
return len(self.imgs)
def set_transform(self, transform):
self.transform = transform
def filenames(self, indices=[], basename=False):
if indices:
if basename:

@ -1 +1,2 @@
from .model_factory import create_model, get_transforms_eval, get_transforms_train
from .model_factory import create_model
from .transforms import transforms_imagenet_eval, transforms_imagenet_train

@ -129,66 +129,3 @@ def load_checkpoint(model, checkpoint_path):
else:
print("Error: No checkpoint found at %s." % checkpoint_path)
class LeNormalize(object):
"""Normalize to -1..1 in Google Inception style
"""
def __call__(self, tensor):
for t in tensor:
t.sub_(0.5).mul_(2.0)
return tensor
DEFAULT_CROP_PCT = 0.875
def get_transforms_train(model_name, img_size=224):
if 'dpn' in model_name:
normalize = transforms.Normalize(
mean=[124 / 255, 117 / 255, 104 / 255],
std=[1 / (.0167 * 255)] * 3)
elif 'inception' in model_name:
normalize = LeNormalize()
else:
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
return transforms.Compose([
transforms.RandomResizedCrop(img_size, scale=(0.3, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.3, 0.3, 0.3),
transforms.ToTensor(),
normalize])
def get_transforms_eval(model_name, img_size=224, crop_pct=None):
crop_pct = crop_pct or DEFAULT_CROP_PCT
if 'dpn' in model_name:
if crop_pct is None:
# Use default 87.5% crop for model's native img_size
# but use 100% crop for larger than native as it
# improves test time results across all models.
if img_size == 224:
scale_size = int(math.floor(img_size / DEFAULT_CROP_PCT))
else:
scale_size = img_size
else:
scale_size = int(math.floor(img_size / crop_pct))
normalize = transforms.Normalize(
mean=[124 / 255, 117 / 255, 104 / 255],
std=[1 / (.0167 * 255)] * 3)
elif 'inception' in model_name:
scale_size = int(math.floor(img_size / crop_pct))
normalize = LeNormalize()
else:
scale_size = int(math.floor(img_size / crop_pct))
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
return transforms.Compose([
transforms.Resize(scale_size, Image.BICUBIC),
transforms.CenterCrop(img_size),
transforms.ToTensor(),
normalize])

@ -441,7 +441,7 @@ def senet154(num_classes=1000, pretrained='imagenet'):
def se_resnet18(num_classes=1000, pretrained='imagenet'):
model = SENet(SEResNetBottleneck, [2, 2, 2, 2], groups=1, reduction=16,
model = SENet(SEResNetBlock, [2, 2, 2, 2], groups=1, reduction=16,
dropout_p=None, inplanes=64, input_3x3=False,
downsample_kernel_size=1, downsample_padding=0,
num_classes=num_classes)

@ -0,0 +1,73 @@
import torch
from torchvision import transforms
from PIL import Image
import math
DEFAULT_CROP_PCT = 0.875
IMAGENET_DPN_MEAN = [124 / 255, 117 / 255, 104 / 255]
IMAGENET_DPN_STD = [1 / (.0167 * 255)] * 3
IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406]
IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225]
class LeNormalize(object):
"""Normalize to -1..1 in Google Inception style
"""
def __call__(self, tensor):
for t in tensor:
t.sub_(0.5).mul_(2.0)
return tensor
def transforms_imagenet_train(model_name, img_size=224, scale=(0.08, 1.0), color_jitter=(0.3, 0.3, 0.3)):
if 'dpn' in model_name:
normalize = transforms.Normalize(
mean=IMAGENET_DPN_MEAN,
std=IMAGENET_DPN_STD)
elif 'inception' in model_name:
normalize = LeNormalize()
else:
normalize = transforms.Normalize(
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD)
return transforms.Compose([
transforms.RandomResizedCrop(img_size, scale=scale),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(*color_jitter),
transforms.ToTensor(),
normalize])
def transforms_imagenet_eval(model_name, img_size=224, crop_pct=None):
crop_pct = crop_pct or DEFAULT_CROP_PCT
if 'dpn' in model_name:
if crop_pct is None:
# Use default 87.5% crop for model's native img_size
# but use 100% crop for larger than native as it
# improves test time results across all models.
if img_size == 224:
scale_size = int(math.floor(img_size / DEFAULT_CROP_PCT))
else:
scale_size = img_size
else:
scale_size = int(math.floor(img_size / crop_pct))
normalize = transforms.Normalize(
mean=IMAGENET_DPN_MEAN,
std=IMAGENET_DPN_STD)
elif 'inception' in model_name:
scale_size = int(math.floor(img_size / crop_pct))
normalize = LeNormalize()
else:
scale_size = int(math.floor(img_size / crop_pct))
normalize = transforms.Normalize(
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD)
return transforms.Compose([
transforms.Resize(scale_size, Image.BICUBIC),
transforms.CenterCrop(img_size),
transforms.ToTensor(),
normalize])

@ -0,0 +1,3 @@
from .cosine_lr import CosineLRScheduler
from .plateau_lr import PlateauLRScheduler
from .step_lr import StepLRScheduler

@ -0,0 +1,72 @@
import logging
import math
import numpy as np
import torch
from .scheduler import Scheduler
logger = logging.getLogger(__name__)
class CosineLRScheduler(Scheduler):
"""
Cosine annealing with restarts.
This is described in the paper https://arxiv.org/abs/1608.03983.
"""
def __init__(self,
optimizer: torch.optim.Optimizer,
t_initial: int,
t_mul: float = 1.,
lr_min: float = 0.,
decay_rate: float = 1.,
warmup_updates=0,
warmup_lr_init=0,
initialize=True) -> None:
super().__init__(optimizer, param_group_field="lr", initialize=initialize)
assert t_initial > 0
assert lr_min >= 0
if t_initial == 1 and t_mul == 1 and decay_rate == 1:
logger.warning("Cosine annealing scheduler will have no effect on the learning "
"rate since t_initial = t_mul = eta_mul = 1.")
self.t_initial = t_initial
self.t_mul = t_mul
self.lr_min = lr_min
self.decay_rate = decay_rate
self.warmup_updates = warmup_updates
self.warmup_lr_init = warmup_lr_init
if self.warmup_updates:
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_updates for v in self.base_values]
else:
self.warmup_steps = [1 for _ in self.base_values]
if self.warmup_lr_init:
super().update_groups(self.warmup_lr_init)
def get_epoch_values(self, epoch: int):
# this scheduler doesn't update on epoch
return None
def get_update_values(self, num_updates: int):
if num_updates < self.warmup_updates:
lrs = [self.warmup_lr_init + num_updates * s for s in self.warmup_steps]
else:
curr_updates = num_updates - self.warmup_updates
if self.t_mul != 1:
i = math.floor(math.log(1 - curr_updates / self.t_initial * (1 - self.t_mul), self.t_mul))
t_i = self.t_mul ** i * self.t_initial
t_curr = curr_updates - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial
else:
i = curr_updates // self.t_initial
t_i = self.t_initial
t_curr = curr_updates - (self.t_initial * i)
gamma = self.decay_rate ** i
lr_min = self.lr_min * gamma
lr_max_values = [v * gamma for v in self.base_values]
lrs = [
lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * t_curr / t_i)) for lr_max in lr_max_values
]
return lrs

@ -0,0 +1,68 @@
import torch
from .scheduler import Scheduler
class PlateauLRScheduler(Scheduler):
"""Decay the LR by a factor every time the validation loss plateaus."""
def __init__(self,
optimizer,
factor=0.1,
patience=10,
verbose=False,
threshold=1e-4,
cooldown_epochs=0,
warmup_updates=0,
warmup_lr_init=0,
lr_min=0,
):
super().__init__(optimizer, 'lr', initialize=False)
self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer.optimizer,
patience=patience,
factor=factor,
verbose=verbose,
threshold=threshold,
cooldown=cooldown_epochs,
min_lr=lr_min
)
self.warmup_updates = warmup_updates
self.warmup_lr_init = warmup_lr_init
if self.warmup_updates:
self.warmup_active = warmup_updates > 0 # this state updates with num_updates
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_updates for v in self.base_values]
super().update_groups(self.warmup_lr_init)
else:
self.warmup_steps = [1 for _ in self.base_values]
def state_dict(self):
return {
'best': self.lr_scheduler.best,
'last_epoch': self.lr_scheduler.last_epoch,
}
def load_state_dict(self, state_dict):
self.lr_scheduler.best = state_dict['best']
if 'last_epoch' in state_dict:
self.lr_scheduler.last_epoch = state_dict['last_epoch']
# override the base class step fn completely
def step(self, epoch, val_loss=None):
"""Update the learning rate at the end of the given epoch."""
if val_loss is not None and not self.warmup_active:
self.lr_scheduler.step(val_loss, epoch)
else:
self.lr_scheduler.last_epoch = epoch
def get_update_values(self, num_updates: int):
if num_updates < self.warmup_updates:
lrs = [self.warmup_lr_init + num_updates * s for s in self.warmup_steps]
else:
self.warmup_active = False # warmup cancelled by first update past warmup_update count
lrs = None # no change on update after warmup stage
return lrs

@ -0,0 +1,73 @@
from typing import Dict, Any
import torch
class Scheduler:
""" Parameter Scheduler Base Class
A scheduler base class that can be used to schedule any optimizer parameter groups.
Unlike the builtin PyTorch schedulers, this is intended to be consistently called
* At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value
* At the END of each optimizer update, after incrementing the update count, to calculate next update's value
The schedulers built on this should try to remain as stateless as possible (for simplicity).
This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch'
and -1 values for special behaviour. All epoch and update counts must be tracked in the training
code and explicitly passed in to the schedulers on the corresponding step or step_update call.
Based on ideas from:
* https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler
* https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers
"""
def __init__(self,
optimizer: torch.optim.Optimizer,
param_group_field: str,
initialize: bool = True) -> None:
self.optimizer = optimizer
self.param_group_field = param_group_field
self._initial_param_group_field = f"initial_{param_group_field}"
if initialize:
for i, group in enumerate(self.optimizer.param_groups):
if param_group_field not in group:
raise KeyError(f"{param_group_field} missing from param_groups[{i}]")
group.setdefault(self._initial_param_group_field, group[param_group_field])
else:
for i, group in enumerate(self.optimizer.param_groups):
if self._initial_param_group_field not in group:
raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]")
self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups]
self.metric = None # any point to having this for all?
self.update_groups(self.base_values)
def state_dict(self) -> Dict[str, Any]:
return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.__dict__.update(state_dict)
def get_epoch_values(self, epoch: int):
return None
def get_update_values(self, num_updates: int):
return None
def step(self, epoch: int, metric: float = None) -> None:
self.metric = metric
values = self.get_epoch_values(epoch)
if values is not None:
self.update_groups(values)
def step_update(self, num_updates: int, metric: float = None):
self.metric = metric
values = self.get_update_values(num_updates)
if values is not None:
self.update_groups(values)
def update_groups(self, values):
if not isinstance(values, (list, tuple)):
values = [values] * len(self.optimizer.param_groups)
for param_group, value in zip(self.optimizer.param_groups, values):
param_group[self.param_group_field] = value

@ -0,0 +1,48 @@
import math
import torch
from .scheduler import Scheduler
class StepLRScheduler(Scheduler):
"""
"""
def __init__(self,
optimizer: torch.optim.Optimizer,
decay_epochs: int,
decay_rate: float = 1.,
warmup_updates=0,
warmup_lr_init=0,
initialize=True) -> None:
super().__init__(optimizer, param_group_field="lr", initialize=initialize)
self.decay_epochs = decay_epochs
self.decay_rate = decay_rate
self.warmup_updates = warmup_updates
self.warmup_lr_init = warmup_lr_init
if self.warmup_updates:
self.warmup_active = warmup_updates > 0 # this state updates with num_updates
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_updates for v in self.base_values]
super().update_groups(self.warmup_lr_init)
else:
self.warmup_steps = [1 for _ in self.base_values]
def get_epoch_values(self, epoch: int):
if not self.warmup_active:
lrs = [v * (self.decay_rate ** ((epoch + 1) // self.decay_epochs))
for v in self.base_values]
else:
lrs = None # no epoch updates while warming up
return lrs
def get_update_values(self, num_updates: int):
if num_updates < self.warmup_updates:
lrs = [self.warmup_lr_init + num_updates * s for s in self.warmup_steps]
else:
self.warmup_active = False # warmup cancelled by first update past warmup_update count
lrs = None # no change on update afte warmup stage
return lrs

@ -6,9 +6,10 @@ from collections import OrderedDict
from datetime import datetime
from dataset import Dataset
from models import model_factory, get_transforms_eval, get_transforms_train
from models import model_factory, transforms_imagenet_eval, transforms_imagenet_train
from utils import *
from optim import nadam
import scheduler
import torch
import torch.nn
@ -48,6 +49,8 @@ parser.add_argument('--decay-epochs', type=int, default=30, metavar='N',
help='epoch interval to decay LR')
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
help='LR decay rate (default: 0.1)')
parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER',
help='LR scheduler (default: "step"')
parser.add_argument('--drop', type=float, default=0.0, metavar='DROP',
help='Dropout rate (default: 0.1)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
@ -93,22 +96,9 @@ def main():
num_epochs = args.epochs
torch.manual_seed(args.seed)
model = model_factory.create_model(
args.model,
pretrained=args.pretrained,
num_classes=1000,
drop_rate=args.drop,
global_pool=args.gp,
checkpoint_path=args.initial_checkpoint)
if args.initial_batch_size:
batch_size = adjust_batch_size(
epoch=0, initial_bs=args.initial_batch_size, target_bs=args.batch_size)
print('Setting batch-size to %d' % batch_size)
dataset_train = Dataset(
os.path.join(args.data, 'train'),
transform=get_transforms_train(args.model))
transform=transforms_imagenet_train(args.model))
loader_train = data.DataLoader(
dataset_train,
@ -119,7 +109,7 @@ def main():
dataset_eval = Dataset(
os.path.join(args.data, 'validation'),
transform=get_transforms_eval(args.model))
transform=transforms_imagenet_eval(args.model))
loader_eval = data.DataLoader(
dataset_eval,
@ -128,38 +118,17 @@ def main():
num_workers=args.workers
)
train_loss_fn = validate_loss_fn = torch.nn.CrossEntropyLoss()
train_loss_fn = train_loss_fn.cuda()
validate_loss_fn = validate_loss_fn.cuda()
if args.opt.lower() == 'sgd':
optimizer = optim.SGD(
model.parameters(), lr=args.lr,
momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
elif args.opt.lower() == 'adam':
optimizer = optim.Adam(
model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
elif args.opt.lower() == 'nadam':
optimizer = nadam.Nadam(
model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
elif args.opt.lower() == 'adadelta':
optimizer = optim.Adadelta(
model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
elif args.opt.lower() == 'rmsprop':
optimizer = optim.RMSprop(
model.parameters(), lr=args.lr, alpha=0.9, eps=args.opt_eps,
momentum=args.momentum, weight_decay=args.weight_decay)
else:
assert False and "Invalid optimizer"
exit(1)
if not args.decay_epochs:
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=8)
else:
lr_scheduler = None
model = model_factory.create_model(
args.model,
pretrained=args.pretrained,
num_classes=1000,
drop_rate=args.drop,
global_pool=args.gp,
checkpoint_path=args.initial_checkpoint)
# optionally resume from a checkpoint
start_epoch = 0 if args.start_epoch is None else args.start_epoch
optimizer_state = None
if args.resume:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
@ -174,7 +143,7 @@ def main():
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
if 'optimizer' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
optimizer_state = checkpoint['optimizer']
print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
start_epoch = checkpoint['epoch'] if args.start_epoch is None else args.start_epoch
else:
@ -183,55 +152,73 @@ def main():
print("=> no checkpoint found at '{}'".format(args.resume))
return False
saver = CheckpointSaver(checkpoint_dir=output_dir)
if args.num_gpu > 1:
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
else:
model.cuda()
train_loss_fn = validate_loss_fn = torch.nn.CrossEntropyLoss()
train_loss_fn = train_loss_fn.cuda()
validate_loss_fn = validate_loss_fn.cuda()
if args.opt.lower() == 'sgd':
optimizer = optim.SGD(
model.parameters(), lr=args.lr,
momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
elif args.opt.lower() == 'adam':
optimizer = optim.Adam(
model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
elif args.opt.lower() == 'nadam':
optimizer = nadam.Nadam(
model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
elif args.opt.lower() == 'adadelta':
optimizer = optim.Adadelta(
model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
elif args.opt.lower() == 'rmsprop':
optimizer = optim.RMSprop(
model.parameters(), lr=args.lr, alpha=0.9, eps=args.opt_eps,
momentum=args.momentum, weight_decay=args.weight_decay)
else:
assert False and "Invalid optimizer"
exit(1)
if optimizer_state is not None:
optimizer.load_state_dict(optimizer_state)
if args.sched == 'cosine':
lr_scheduler = scheduler.CosineLRScheduler(
optimizer,
t_initial=13 * len(loader_train),
t_mul=2.0,
lr_min=0,
decay_rate=0.5,
warmup_lr_init=1e-4,
warmup_updates=len(loader_train)
)
else:
lr_scheduler = scheduler.StepLRScheduler(
optimizer,
decay_epochs=args.decay_epochs,
decay_rate=args.decay_rate,
)
saver = CheckpointSaver(checkpoint_dir=output_dir)
best_loss = None
try:
for epoch in range(start_epoch, num_epochs):
if args.decay_epochs:
adjust_learning_rate(
optimizer, epoch, initial_lr=args.lr,
decay_rate=args.decay_rate, decay_epochs=args.decay_epochs)
if args.initial_batch_size:
next_batch_size = adjust_batch_size(
epoch, initial_bs=args.initial_batch_size, target_bs=args.batch_size)
if next_batch_size > batch_size:
print("Changing batch size from %d to %d" % (batch_size, next_batch_size))
batch_size = next_batch_size
loader_train = data.DataLoader(
dataset_train,
batch_size=batch_size,
pin_memory=True,
shuffle=True,
# sampler=sampler,
num_workers=args.workers)
train_metrics = train_epoch(
epoch, model, loader_train, optimizer, train_loss_fn, args,
saver=saver, output_dir=output_dir)
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir)
step = epoch * len(loader_train)
eval_metrics = validate(
step, model, loader_eval, validate_loss_fn, args,
output_dir=output_dir)
model, loader_eval, validate_loss_fn, args)
if lr_scheduler is not None:
lr_scheduler.step(eval_metrics['eval_loss'])
lr_scheduler.step(epoch, eval_metrics['eval_loss'])
rowd = OrderedDict(epoch=epoch)
rowd.update(train_metrics)
rowd.update(eval_metrics)
with open(os.path.join(output_dir, 'summary.csv'), mode='a') as cf:
dw = csv.DictWriter(cf, fieldnames=rowd.keys())
if best_loss is None: # first iteration (epoch == 1 can't be used)
dw.writeheader()
dw.writerow(rowd)
update_summary(
epoch, train_metrics, eval_metrics, output_dir, write_header=best_loss is None)
# save proper checkpoint with eval metric
best_loss = saver.save_checkpoint({
@ -252,9 +239,8 @@ def main():
def train_epoch(
epoch, model, loader, optimizer, loss_fn, args,
saver=None, output_dir=''):
lr_scheduler=None, saver=None, output_dir=''):
epoch_step = (epoch - 1) * len(loader)
batch_time_m = AverageMeter()
data_time_m = AverageMeter()
losses_m = AverageMeter()
@ -263,9 +249,9 @@ def train_epoch(
end = time.time()
last_idx = len(loader) - 1
num_updates = epoch * len(loader)
for batch_idx, (input, target) in enumerate(loader):
last_batch = batch_idx == last_idx
step = epoch_step + batch_idx
data_time_m.update(time.time() - end)
input = input.cuda()
@ -283,20 +269,27 @@ def train_epoch(
loss.backward()
optimizer.step()
num_updates += 1
batch_time_m.update(time.time() - end)
if last_batch or batch_idx % args.log_interval == 0:
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
lr = sum(lrl) / len(lrl)
print('Train: {} [{}/{} ({:.0f}%)] '
'Loss: {loss.val:.6f} ({loss.avg:.4f}) '
'Time: {batch_time.val:.3f}s, {rate:.3f}/s '
'({batch_time.avg:.3f}s, {rate_avg:.3f}/s) '
'LR: {lr:.4f} '
'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
epoch,
batch_idx * len(input), len(loader.sampler),
batch_idx, len(loader),
100. * batch_idx / last_idx,
loss=losses_m,
batch_time=batch_time_m,
rate=input.size(0) / batch_time_m.val,
rate_avg=input.size(0) / batch_time_m.avg,
lr=lr,
data_time=data_time_m))
if args.save_images:
@ -319,12 +312,15 @@ def train_epoch(
epoch=save_epoch,
batch_idx=batch_idx)
if lr_scheduler is not None:
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
end = time.time()
return OrderedDict([('train_loss', losses_m.avg)])
def validate(step, model, loader, loss_fn, args, output_dir=''):
def validate(model, loader, loss_fn, args):
batch_time_m = AverageMeter()
losses_m = AverageMeter()
prec1_m = AverageMeter()
@ -345,7 +341,6 @@ def validate(step, model, loader, loss_fn, args, output_dir=''):
target = target.cuda()
output = model(input)
if isinstance(output, (tuple, list)):
output = output[0]
@ -381,17 +376,15 @@ def validate(step, model, loader, loss_fn, args, output_dir=''):
return metrics
def adjust_learning_rate(optimizer, epoch, initial_lr, decay_rate=0.1, decay_epochs=30):
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
lr = initial_lr * (decay_rate ** (epoch // decay_epochs))
print('Setting LR to', lr)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def adjust_batch_size(epoch, initial_bs, target_bs, decay_epochs=1):
batch_size = min(target_bs, initial_bs * (2 ** (epoch // decay_epochs)))
return batch_size
def update_summary(epoch, train_metrics, eval_metrics, output_dir, write_header=False):
rowd = OrderedDict(epoch=epoch)
rowd.update(train_metrics)
rowd.update(eval_metrics)
with open(os.path.join(output_dir, 'summary.csv'), mode='a') as cf:
dw = csv.DictWriter(cf, fieldnames=rowd.keys())
if write_header: # first iteration (epoch == 1 can't be used)
dw.writeheader()
dw.writerow(rowd)
if __name__ == '__main__':

Loading…
Cancel
Save