From aa4354f4668ee7cf349f24de60e3a23dee43a6ef Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 19 Jun 2019 17:19:37 -0700 Subject: [PATCH] Big re-org, working towards making pip/module as 'timm' --- inference.py | 9 ++- optim/adam_tf.py | 0 requirements.txt | 2 + setup.py | 55 +++++++++++++++++++ timm/__init__.py | 2 + {data => timm/data}/__init__.py | 0 {data => timm/data}/config.py | 2 +- {data => timm/data}/constants.py | 0 {data => timm/data}/dataset.py | 0 {data => timm/data}/distributed_sampler.py | 0 {data => timm/data}/loader.py | 8 +-- {data => timm/data}/mixup.py | 0 {data => timm/data}/random_erasing.py | 2 - {data => timm/data}/tf_preprocessing.py | 0 {data => timm/data}/transforms.py | 4 +- {loss => timm/loss}/__init__.py | 0 {loss => timm/loss}/cross_entropy.py | 0 {models => timm/models}/__init__.py | 0 .../models}/adaptive_avgmax_pool.py | 0 {models => timm/models}/conv2d_same.py | 0 {models => timm/models}/densenet.py | 9 +-- {models => timm/models}/dpn.py | 8 +-- {models => timm/models}/gen_efficientnet.py | 8 +-- {models => timm/models}/gluon_resnet.py | 7 +-- {models => timm/models}/helpers.py | 0 .../models}/inception_resnet_v2.py | 9 +-- {models => timm/models}/inception_v3.py | 4 +- {models => timm/models}/inception_v4.py | 9 +-- {models => timm/models}/median_pool.py | 0 {models => timm/models}/model_factory.py | 28 +++++----- {models => timm/models}/pnasnet.py | 4 +- {models => timm/models}/resnet.py | 7 +-- {models => timm/models}/senet.py | 6 +- {models => timm/models}/test_time_pool.py | 2 +- {models => timm/models}/xception.py | 4 +- {optim => timm/optim}/__init__.py | 0 {optim => timm/optim}/nadam.py | 0 {optim => timm/optim}/optim_factory.py | 2 +- {optim => timm/optim}/rmsprop_tf.py | 0 {scheduler => timm/scheduler}/__init__.py | 0 {scheduler => timm/scheduler}/cosine_lr.py | 0 {scheduler => timm/scheduler}/plateau_lr.py | 0 {scheduler => timm/scheduler}/scheduler.py | 0 .../scheduler}/scheduler_factory.py | 7 +-- {scheduler => timm/scheduler}/step_lr.py | 0 {scheduler => timm/scheduler}/tanh_lr.py | 0 utils.py => timm/utils.py | 0 timm/version.py | 1 + train.py | 13 ++--- validate.py | 6 +- 50 files changed, 132 insertions(+), 86 deletions(-) delete mode 100644 optim/adam_tf.py create mode 100644 requirements.txt create mode 100644 setup.py create mode 100644 timm/__init__.py rename {data => timm/data}/__init__.py (100%) rename {data => timm/data}/config.py (99%) rename {data => timm/data}/constants.py (100%) rename {data => timm/data}/dataset.py (100%) rename {data => timm/data}/distributed_sampler.py (100%) rename {data => timm/data}/loader.py (95%) rename {data => timm/data}/mixup.py (100%) rename {data => timm/data}/random_erasing.py (98%) rename {data => timm/data}/tf_preprocessing.py (100%) rename {data => timm/data}/transforms.py (98%) rename {loss => timm/loss}/__init__.py (100%) rename {loss => timm/loss}/cross_entropy.py (100%) rename {models => timm/models}/__init__.py (100%) rename {models => timm/models}/adaptive_avgmax_pool.py (100%) rename {models => timm/models}/conv2d_same.py (100%) rename {models => timm/models}/densenet.py (97%) rename {models => timm/models}/dpn.py (98%) rename {models => timm/models}/gen_efficientnet.py (99%) rename {models => timm/models}/gluon_resnet.py (99%) rename {models => timm/models}/helpers.py (100%) rename {models => timm/models}/inception_resnet_v2.py (98%) rename {models => timm/models}/inception_v3.py (96%) rename {models => timm/models}/inception_v4.py (97%) rename {models => timm/models}/median_pool.py (100%) rename {models => timm/models}/model_factory.py (65%) rename {models => timm/models}/pnasnet.py (99%) rename {models => timm/models}/resnet.py (98%) rename {models => timm/models}/senet.py (99%) rename {models => timm/models}/test_time_pool.py (96%) rename {models => timm/models}/xception.py (98%) rename {optim => timm/optim}/__init__.py (100%) rename {optim => timm/optim}/nadam.py (100%) rename {optim => timm/optim}/optim_factory.py (97%) rename {optim => timm/optim}/rmsprop_tf.py (100%) rename {scheduler => timm/scheduler}/__init__.py (100%) rename {scheduler => timm/scheduler}/cosine_lr.py (100%) rename {scheduler => timm/scheduler}/plateau_lr.py (100%) rename {scheduler => timm/scheduler}/scheduler.py (100%) rename {scheduler => timm/scheduler}/scheduler_factory.py (86%) rename {scheduler => timm/scheduler}/step_lr.py (100%) rename {scheduler => timm/scheduler}/tanh_lr.py (100%) rename utils.py => timm/utils.py (100%) create mode 100644 timm/version.py diff --git a/inference.py b/inference.py index d6c7e48a..5aeb258f 100644 --- a/inference.py +++ b/inference.py @@ -11,9 +11,9 @@ import argparse import numpy as np import torch -from models import create_model, apply_test_time_pool -from data import Dataset, create_loader, resolve_data_config -from utils import AverageMeter +from timm.models import create_model, apply_test_time_pool +from timm.data import Dataset, create_loader, resolve_data_config +from timm.utils import AverageMeter torch.backends.cudnn.benchmark = True @@ -55,6 +55,9 @@ parser.add_argument('--topk', default=5, type=int, def main(): args = parser.parse_args() + # might as well try to do something useful... + args.pretrained = args.pretrained or not args.checkpoint + # create model model = create_model( args.model, diff --git a/optim/adam_tf.py b/optim/adam_tf.py deleted file mode 100644 index e69de29b..00000000 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000..b021ba18 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +torch~=1.0 +torchvision \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..01c07ccd --- /dev/null +++ b/setup.py @@ -0,0 +1,55 @@ +""" Setup +""" +from setuptools import setup, find_packages +from codecs import open +from os import path + +here = path.abspath(path.dirname(__file__)) + +# Get the long description from the README file +with open(path.join(here, 'README.md'), encoding='utf-8') as f: + long_description = f.read() + +exec(open('timm/version.py').read()) +setup( + name='timm', + version=__version__, + description='(Unofficial) PyTorch Image Models', + long_description=long_description, + url='https://github.com/rwightman/pytorch-image-models', + author='Ross Wightman', + author_email='hello@rwightman.com', + classifiers=[ # Optional + # How mature is this project? Common values are + # 3 - Alpha + # 4 - Beta + # 5 - Production/Stable + 'Development Status :: 3 - Alpha', + 'Intended Audience :: Developers', + 'Topic :: Software Development :: Build Tools', + 'License :: OSI Approved :: Apache License', + 'Programming Language :: Python :: 3.6', + ], + + # Note that this is a string of words separated by whitespace, not a list. + keywords='pytorch pretrained models efficientnet mobilenetv3 mnasnet', + + # You can just specify package directories manually here if your project is + # simple. Or you can use find_packages(). + # + # Alternatively, if you just want to distribute a single Python file, use + # the `py_modules` argument instead as follows, which will expect a file + # called `my_module.py` to exist: + # + # py_modules=["my_module"], + # + packages=find_packages(exclude=['convert']), + + # This field lists other packages that your project depends on to run. + # Any package you put here will be installed by pip when your project is + # installed, so they must be valid existing projects. + # + # For an analysis of "install_requires" vs pip's requirements files see: + # https://packaging.python.org/en/latest/requirements.html + install_requires=['torch', 'torchvision'], +) diff --git a/timm/__init__.py b/timm/__init__.py new file mode 100644 index 00000000..325a273a --- /dev/null +++ b/timm/__init__.py @@ -0,0 +1,2 @@ +from .version import __version__ +from .models import create_model diff --git a/data/__init__.py b/timm/data/__init__.py similarity index 100% rename from data/__init__.py rename to timm/data/__init__.py diff --git a/data/config.py b/timm/data/config.py similarity index 99% rename from data/config.py rename to timm/data/config.py index 5b15dba8..29d6f9e3 100644 --- a/data/config.py +++ b/timm/data/config.py @@ -1,4 +1,4 @@ -from data.constants import * +from .constants import * def resolve_data_config(model, args, default_cfg={}, verbose=True): diff --git a/data/constants.py b/timm/data/constants.py similarity index 100% rename from data/constants.py rename to timm/data/constants.py diff --git a/data/dataset.py b/timm/data/dataset.py similarity index 100% rename from data/dataset.py rename to timm/data/dataset.py diff --git a/data/distributed_sampler.py b/timm/data/distributed_sampler.py similarity index 100% rename from data/distributed_sampler.py rename to timm/data/distributed_sampler.py diff --git a/data/loader.py b/timm/data/loader.py similarity index 95% rename from data/loader.py rename to timm/data/loader.py index f6710a3b..777eb878 100644 --- a/data/loader.py +++ b/timm/data/loader.py @@ -1,7 +1,7 @@ import torch.utils.data -from data.transforms import * -from data.distributed_sampler import OrderedDistributedSampler -from data.mixup import FastCollateMixup +from .transforms import * +from .distributed_sampler import OrderedDistributedSampler +from .mixup import FastCollateMixup def fast_collate(batch): @@ -101,7 +101,7 @@ def create_loader( img_size = input_size if tf_preprocessing and use_prefetcher: - from data.tf_preprocessing import TfPreprocessTransform + from timm.data.tf_preprocessing import TfPreprocessTransform transform = TfPreprocessTransform(is_training=is_training, size=img_size) else: if is_training: diff --git a/data/mixup.py b/timm/data/mixup.py similarity index 100% rename from data/mixup.py rename to timm/data/mixup.py diff --git a/data/random_erasing.py b/timm/data/random_erasing.py similarity index 98% rename from data/random_erasing.py rename to timm/data/random_erasing.py index 43f5f57e..c16725ae 100644 --- a/data/random_erasing.py +++ b/timm/data/random_erasing.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import - import random import math import torch diff --git a/data/tf_preprocessing.py b/timm/data/tf_preprocessing.py similarity index 100% rename from data/tf_preprocessing.py rename to timm/data/tf_preprocessing.py diff --git a/data/transforms.py b/timm/data/transforms.py similarity index 98% rename from data/transforms.py rename to timm/data/transforms.py index e777fbca..bee505a2 100644 --- a/data/transforms.py +++ b/timm/data/transforms.py @@ -7,8 +7,8 @@ import math import random import numpy as np -from data import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from data.random_erasing import RandomErasing +from .constants import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .random_erasing import RandomErasing class ToNumpy: diff --git a/loss/__init__.py b/timm/loss/__init__.py similarity index 100% rename from loss/__init__.py rename to timm/loss/__init__.py diff --git a/loss/cross_entropy.py b/timm/loss/cross_entropy.py similarity index 100% rename from loss/cross_entropy.py rename to timm/loss/cross_entropy.py diff --git a/models/__init__.py b/timm/models/__init__.py similarity index 100% rename from models/__init__.py rename to timm/models/__init__.py diff --git a/models/adaptive_avgmax_pool.py b/timm/models/adaptive_avgmax_pool.py similarity index 100% rename from models/adaptive_avgmax_pool.py rename to timm/models/adaptive_avgmax_pool.py diff --git a/models/conv2d_same.py b/timm/models/conv2d_same.py similarity index 100% rename from models/conv2d_same.py rename to timm/models/conv2d_same.py diff --git a/models/densenet.py b/timm/models/densenet.py similarity index 97% rename from models/densenet.py rename to timm/models/densenet.py index 2e8e160a..5f9aeb35 100644 --- a/models/densenet.py +++ b/timm/models/densenet.py @@ -2,14 +2,11 @@ This file is a copy of https://github.com/pytorch/vision 'densenet.py' (BSD-3-Clause) with fixed kwargs passthrough and addition of dynamic global avg/max pool. """ -import torch -import torch.nn as nn -import torch.nn.functional as F from collections import OrderedDict -from models.helpers import load_pretrained -from models.adaptive_avgmax_pool import * -from data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import load_pretrained +from .adaptive_avgmax_pool import * +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD import re _models = ['densenet121', 'densenet169', 'densenet201', 'densenet161'] diff --git a/models/dpn.py b/timm/models/dpn.py similarity index 98% rename from models/dpn.py rename to timm/models/dpn.py index ea766411..04aa9a5c 100644 --- a/models/dpn.py +++ b/timm/models/dpn.py @@ -9,15 +9,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os import torch import torch.nn as nn -import torch.nn.functional as F from collections import OrderedDict -from models.helpers import load_pretrained -from models.adaptive_avgmax_pool import select_adaptive_pool2d -from data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD +from .helpers import load_pretrained +from .adaptive_avgmax_pool import select_adaptive_pool2d +from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD _models = ['dpn68', 'dpn68b', 'dpn92', 'dpn98', 'dpn131', 'dpn107'] __all__ = ['DPN'] + _models diff --git a/models/gen_efficientnet.py b/timm/models/gen_efficientnet.py similarity index 99% rename from models/gen_efficientnet.py rename to timm/models/gen_efficientnet.py index d3bb0f30..3d26d4f6 100644 --- a/models/gen_efficientnet.py +++ b/timm/models/gen_efficientnet.py @@ -22,10 +22,10 @@ from copy import deepcopy import torch import torch.nn as nn import torch.nn.functional as F -from models.helpers import load_pretrained -from models.adaptive_avgmax_pool import SelectAdaptivePool2d -from models.conv2d_same import sconv2d -from data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import load_pretrained +from .adaptive_avgmax_pool import SelectAdaptivePool2d +from .conv2d_same import sconv2d +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD _models = [ 'mnasnet_050', 'mnasnet_075', 'mnasnet_100', 'mnasnet_b1', 'mnasnet_140', 'semnasnet_050', 'semnasnet_075', diff --git a/models/gluon_resnet.py b/timm/models/gluon_resnet.py similarity index 99% rename from models/gluon_resnet.py rename to timm/models/gluon_resnet.py index 3232aa25..a8877bd9 100644 --- a/models/gluon_resnet.py +++ b/timm/models/gluon_resnet.py @@ -3,13 +3,12 @@ This file evolved from https://github.com/pytorch/vision 'resnet.py' with (SE)-R and ports of Gluon variations (https://github.com/dmlc/gluon-cv/blob/master/gluoncv/model_zoo/resnet.py) by Ross Wightman """ -import torch import torch.nn as nn import torch.nn.functional as F import math -from models.helpers import load_pretrained -from models.adaptive_avgmax_pool import SelectAdaptivePool2d -from data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import load_pretrained +from .adaptive_avgmax_pool import SelectAdaptivePool2d +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD _models = [ 'gluon_resnet18_v1b', 'gluon_resnet34_v1b', 'gluon_resnet50_v1b', 'gluon_resnet101_v1b', 'gluon_resnet152_v1b', diff --git a/models/helpers.py b/timm/models/helpers.py similarity index 100% rename from models/helpers.py rename to timm/models/helpers.py diff --git a/models/inception_resnet_v2.py b/timm/models/inception_resnet_v2.py similarity index 98% rename from models/inception_resnet_v2.py rename to timm/models/inception_resnet_v2.py index dad1396d..05452b4b 100644 --- a/models/inception_resnet_v2.py +++ b/timm/models/inception_resnet_v2.py @@ -2,12 +2,9 @@ Sourced from https://github.com/Cadene/tensorflow-model-zoo.torch (MIT License) which is based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License) """ -import torch -import torch.nn as nn -import torch.nn.functional as F -from models.helpers import load_pretrained -from models.adaptive_avgmax_pool import * -from data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from .helpers import load_pretrained +from .adaptive_avgmax_pool import * +from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD _models = ['inception_resnet_v2'] __all__ = ['InceptionResnetV2'] + _models diff --git a/models/inception_v3.py b/timm/models/inception_v3.py similarity index 96% rename from models/inception_v3.py rename to timm/models/inception_v3.py index 70ccb37f..d9e48ffc 100644 --- a/models/inception_v3.py +++ b/timm/models/inception_v3.py @@ -1,6 +1,6 @@ from torchvision.models import Inception3 -from models.helpers import load_pretrained -from data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from .helpers import load_pretrained +from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD _models = ['inception_v3', 'tf_inception_v3', 'adv_inception_v3', 'gluon_inception_v3'] __all__ = _models diff --git a/models/inception_v4.py b/timm/models/inception_v4.py similarity index 97% rename from models/inception_v4.py rename to timm/models/inception_v4.py index e251d0c4..2c47c45a 100644 --- a/models/inception_v4.py +++ b/timm/models/inception_v4.py @@ -2,12 +2,9 @@ Sourced from https://github.com/Cadene/tensorflow-model-zoo.torch (MIT License) which is based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License) """ -import torch -import torch.nn as nn -import torch.nn.functional as F -from models.helpers import load_pretrained -from models.adaptive_avgmax_pool import * -from data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from .helpers import load_pretrained +from .adaptive_avgmax_pool import * +from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD _models = ['inception_v4'] __all__ = ['InceptionV4'] + _models diff --git a/models/median_pool.py b/timm/models/median_pool.py similarity index 100% rename from models/median_pool.py rename to timm/models/median_pool.py diff --git a/models/model_factory.py b/timm/models/model_factory.py similarity index 65% rename from models/model_factory.py rename to timm/models/model_factory.py index 09a8bb95..7ffb423e 100644 --- a/models/model_factory.py +++ b/timm/models/model_factory.py @@ -1,21 +1,21 @@ -from models.inception_v4 import * -from models.inception_resnet_v2 import * -from models.densenet import * -from models.resnet import * -from models.dpn import * -from models.senet import * -from models.xception import * -from models.pnasnet import * -from models.gen_efficientnet import * -from models.inception_v3 import * -from models.gluon_resnet import * +from .inception_v4 import * +from .inception_resnet_v2 import * +from .densenet import * +from .resnet import * +from .dpn import * +from .senet import * +from .xception import * +from .pnasnet import * +from .gen_efficientnet import * +from .inception_v3 import * +from .gluon_resnet import * -from models.helpers import load_checkpoint +from .helpers import load_checkpoint def create_model( - model_name='resnet50', - pretrained=None, + model_name, + pretrained=False, num_classes=1000, in_chans=3, checkpoint_path='', diff --git a/models/pnasnet.py b/timm/models/pnasnet.py similarity index 99% rename from models/pnasnet.py rename to timm/models/pnasnet.py index af348b5f..3102e3f8 100644 --- a/models/pnasnet.py +++ b/timm/models/pnasnet.py @@ -12,8 +12,8 @@ import torch import torch.nn as nn import torch.nn.functional as F -from models.helpers import load_pretrained -from models.adaptive_avgmax_pool import SelectAdaptivePool2d +from .helpers import load_pretrained +from .adaptive_avgmax_pool import SelectAdaptivePool2d _models = ['pnasnet5large'] __all__ = ['PNASNet5Large'] + _models diff --git a/models/resnet.py b/timm/models/resnet.py similarity index 98% rename from models/resnet.py rename to timm/models/resnet.py index 6abf0fb0..54b9efd3 100644 --- a/models/resnet.py +++ b/timm/models/resnet.py @@ -4,13 +4,12 @@ additional dropout and dynamic global avg/max pool. ResNext additions added by Ross Wightman """ -import torch import torch.nn as nn import torch.nn.functional as F import math -from models.helpers import load_pretrained -from models.adaptive_avgmax_pool import SelectAdaptivePool2d -from data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import load_pretrained +from .adaptive_avgmax_pool import SelectAdaptivePool2d +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD _models = ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x4d', 'resnext101_64x4d', 'resnext152_32x4d'] diff --git a/models/senet.py b/timm/models/senet.py similarity index 99% rename from models/senet.py rename to timm/models/senet.py index 5cfd78f9..76690c10 100644 --- a/models/senet.py +++ b/timm/models/senet.py @@ -15,9 +15,9 @@ import math import torch.nn as nn import torch.nn.functional as F -from models.helpers import load_pretrained -from models.adaptive_avgmax_pool import SelectAdaptivePool2d -from data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import load_pretrained +from .adaptive_avgmax_pool import SelectAdaptivePool2d +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD _models = ['seresnet18', 'seresnet34', 'seresnet50', 'seresnet101', 'seresnet152', 'senet154', 'seresnext26_32x4d', 'seresnext50_32x4d', 'seresnext101_32x4d'] diff --git a/models/test_time_pool.py b/timm/models/test_time_pool.py similarity index 96% rename from models/test_time_pool.py rename to timm/models/test_time_pool.py index b24d82a3..ec36380b 100644 --- a/models/test_time_pool.py +++ b/timm/models/test_time_pool.py @@ -1,6 +1,6 @@ from torch import nn import torch.nn.functional as F -from models.adaptive_avgmax_pool import adaptive_avgmax_pool2d +from .adaptive_avgmax_pool import adaptive_avgmax_pool2d class TestTimePoolHead(nn.Module): diff --git a/models/xception.py b/timm/models/xception.py similarity index 98% rename from models/xception.py rename to timm/models/xception.py index 96389b29..44094db7 100644 --- a/models/xception.py +++ b/timm/models/xception.py @@ -27,8 +27,8 @@ import torch import torch.nn as nn import torch.nn.functional as F -from models.helpers import load_pretrained -from models.adaptive_avgmax_pool import select_adaptive_pool2d +from .helpers import load_pretrained +from .adaptive_avgmax_pool import select_adaptive_pool2d _models = ['xception'] __all__ = ['Xception'] + _models diff --git a/optim/__init__.py b/timm/optim/__init__.py similarity index 100% rename from optim/__init__.py rename to timm/optim/__init__.py diff --git a/optim/nadam.py b/timm/optim/nadam.py similarity index 100% rename from optim/nadam.py rename to timm/optim/nadam.py diff --git a/optim/optim_factory.py b/timm/optim/optim_factory.py similarity index 97% rename from optim/optim_factory.py rename to timm/optim/optim_factory.py index efd246a5..7fe3e1e4 100644 --- a/optim/optim_factory.py +++ b/timm/optim/optim_factory.py @@ -1,5 +1,5 @@ from torch import optim as optim -from optim import Nadam, RMSpropTF +from timm.optim import Nadam, RMSpropTF def add_weight_decay(model, weight_decay=1e-5, skip_list=()): diff --git a/optim/rmsprop_tf.py b/timm/optim/rmsprop_tf.py similarity index 100% rename from optim/rmsprop_tf.py rename to timm/optim/rmsprop_tf.py diff --git a/scheduler/__init__.py b/timm/scheduler/__init__.py similarity index 100% rename from scheduler/__init__.py rename to timm/scheduler/__init__.py diff --git a/scheduler/cosine_lr.py b/timm/scheduler/cosine_lr.py similarity index 100% rename from scheduler/cosine_lr.py rename to timm/scheduler/cosine_lr.py diff --git a/scheduler/plateau_lr.py b/timm/scheduler/plateau_lr.py similarity index 100% rename from scheduler/plateau_lr.py rename to timm/scheduler/plateau_lr.py diff --git a/scheduler/scheduler.py b/timm/scheduler/scheduler.py similarity index 100% rename from scheduler/scheduler.py rename to timm/scheduler/scheduler.py diff --git a/scheduler/scheduler_factory.py b/timm/scheduler/scheduler_factory.py similarity index 86% rename from scheduler/scheduler_factory.py rename to timm/scheduler/scheduler_factory.py index 55c4927d..e64c34d1 100644 --- a/scheduler/scheduler_factory.py +++ b/timm/scheduler/scheduler_factory.py @@ -1,7 +1,6 @@ -from scheduler.cosine_lr import CosineLRScheduler -from scheduler.plateau_lr import PlateauLRScheduler -from scheduler.tanh_lr import TanhLRScheduler -from scheduler.step_lr import StepLRScheduler +from .cosine_lr import CosineLRScheduler +from .tanh_lr import TanhLRScheduler +from .step_lr import StepLRScheduler def create_scheduler(args, optimizer): diff --git a/scheduler/step_lr.py b/timm/scheduler/step_lr.py similarity index 100% rename from scheduler/step_lr.py rename to timm/scheduler/step_lr.py diff --git a/scheduler/tanh_lr.py b/timm/scheduler/tanh_lr.py similarity index 100% rename from scheduler/tanh_lr.py rename to timm/scheduler/tanh_lr.py diff --git a/utils.py b/timm/utils.py similarity index 100% rename from utils.py rename to timm/utils.py diff --git a/timm/version.py b/timm/version.py new file mode 100644 index 00000000..df9144c5 --- /dev/null +++ b/timm/version.py @@ -0,0 +1 @@ +__version__ = '0.1.1' diff --git a/train.py b/train.py index 3811f7b4..b2b033c8 100644 --- a/train.py +++ b/train.py @@ -11,16 +11,15 @@ try: except ImportError: has_apex = False -from data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_target -from models import create_model, resume_checkpoint, load_checkpoint -from utils import * -from loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy -from optim import create_optimizer -from scheduler import create_scheduler +from timm.data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_target +from timm.models import create_model, resume_checkpoint +from timm.utils import * +from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy +from timm.optim import create_optimizer +from timm.scheduler import create_scheduler import torch import torch.nn as nn -import torch.distributed as dist import torchvision.utils torch.backends.cudnn.benchmark = True diff --git a/validate.py b/validate.py index 5a09f0cd..453ea514 100644 --- a/validate.py +++ b/validate.py @@ -12,9 +12,9 @@ import torch.nn as nn import torch.nn.parallel from collections import OrderedDict -from models import create_model, apply_test_time_pool, load_checkpoint -from data import Dataset, create_loader, resolve_data_config -from utils import accuracy, AverageMeter, natural_key +from timm.models import create_model, apply_test_time_pool, load_checkpoint +from timm.data import Dataset, create_loader, resolve_data_config +from timm.utils import accuracy, AverageMeter, natural_key torch.backends.cudnn.benchmark = True