parent
871f4c1b0c
commit
aa4354f466
@ -0,0 +1,2 @@
|
|||||||
|
torch~=1.0
|
||||||
|
torchvision
|
@ -0,0 +1,55 @@
|
|||||||
|
""" Setup
|
||||||
|
"""
|
||||||
|
from setuptools import setup, find_packages
|
||||||
|
from codecs import open
|
||||||
|
from os import path
|
||||||
|
|
||||||
|
here = path.abspath(path.dirname(__file__))
|
||||||
|
|
||||||
|
# Get the long description from the README file
|
||||||
|
with open(path.join(here, 'README.md'), encoding='utf-8') as f:
|
||||||
|
long_description = f.read()
|
||||||
|
|
||||||
|
exec(open('timm/version.py').read())
|
||||||
|
setup(
|
||||||
|
name='timm',
|
||||||
|
version=__version__,
|
||||||
|
description='(Unofficial) PyTorch Image Models',
|
||||||
|
long_description=long_description,
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models',
|
||||||
|
author='Ross Wightman',
|
||||||
|
author_email='hello@rwightman.com',
|
||||||
|
classifiers=[ # Optional
|
||||||
|
# How mature is this project? Common values are
|
||||||
|
# 3 - Alpha
|
||||||
|
# 4 - Beta
|
||||||
|
# 5 - Production/Stable
|
||||||
|
'Development Status :: 3 - Alpha',
|
||||||
|
'Intended Audience :: Developers',
|
||||||
|
'Topic :: Software Development :: Build Tools',
|
||||||
|
'License :: OSI Approved :: Apache License',
|
||||||
|
'Programming Language :: Python :: 3.6',
|
||||||
|
],
|
||||||
|
|
||||||
|
# Note that this is a string of words separated by whitespace, not a list.
|
||||||
|
keywords='pytorch pretrained models efficientnet mobilenetv3 mnasnet',
|
||||||
|
|
||||||
|
# You can just specify package directories manually here if your project is
|
||||||
|
# simple. Or you can use find_packages().
|
||||||
|
#
|
||||||
|
# Alternatively, if you just want to distribute a single Python file, use
|
||||||
|
# the `py_modules` argument instead as follows, which will expect a file
|
||||||
|
# called `my_module.py` to exist:
|
||||||
|
#
|
||||||
|
# py_modules=["my_module"],
|
||||||
|
#
|
||||||
|
packages=find_packages(exclude=['convert']),
|
||||||
|
|
||||||
|
# This field lists other packages that your project depends on to run.
|
||||||
|
# Any package you put here will be installed by pip when your project is
|
||||||
|
# installed, so they must be valid existing projects.
|
||||||
|
#
|
||||||
|
# For an analysis of "install_requires" vs pip's requirements files see:
|
||||||
|
# https://packaging.python.org/en/latest/requirements.html
|
||||||
|
install_requires=['torch', 'torchvision'],
|
||||||
|
)
|
@ -0,0 +1,2 @@
|
|||||||
|
from .version import __version__
|
||||||
|
from .models import create_model
|
@ -1,4 +1,4 @@
|
|||||||
from data.constants import *
|
from .constants import *
|
||||||
|
|
||||||
|
|
||||||
def resolve_data_config(model, args, default_cfg={}, verbose=True):
|
def resolve_data_config(model, args, default_cfg={}, verbose=True):
|
@ -1,5 +1,3 @@
|
|||||||
from __future__ import absolute_import
|
|
||||||
|
|
||||||
import random
|
import random
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
@ -1,6 +1,6 @@
|
|||||||
from torchvision.models import Inception3
|
from torchvision.models import Inception3
|
||||||
from models.helpers import load_pretrained
|
from .helpers import load_pretrained
|
||||||
from data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||||
|
|
||||||
_models = ['inception_v3', 'tf_inception_v3', 'adv_inception_v3', 'gluon_inception_v3']
|
_models = ['inception_v3', 'tf_inception_v3', 'adv_inception_v3', 'gluon_inception_v3']
|
||||||
__all__ = _models
|
__all__ = _models
|
@ -1,21 +1,21 @@
|
|||||||
from models.inception_v4 import *
|
from .inception_v4 import *
|
||||||
from models.inception_resnet_v2 import *
|
from .inception_resnet_v2 import *
|
||||||
from models.densenet import *
|
from .densenet import *
|
||||||
from models.resnet import *
|
from .resnet import *
|
||||||
from models.dpn import *
|
from .dpn import *
|
||||||
from models.senet import *
|
from .senet import *
|
||||||
from models.xception import *
|
from .xception import *
|
||||||
from models.pnasnet import *
|
from .pnasnet import *
|
||||||
from models.gen_efficientnet import *
|
from .gen_efficientnet import *
|
||||||
from models.inception_v3 import *
|
from .inception_v3 import *
|
||||||
from models.gluon_resnet import *
|
from .gluon_resnet import *
|
||||||
|
|
||||||
from models.helpers import load_checkpoint
|
from .helpers import load_checkpoint
|
||||||
|
|
||||||
|
|
||||||
def create_model(
|
def create_model(
|
||||||
model_name='resnet50',
|
model_name,
|
||||||
pretrained=None,
|
pretrained=False,
|
||||||
num_classes=1000,
|
num_classes=1000,
|
||||||
in_chans=3,
|
in_chans=3,
|
||||||
checkpoint_path='',
|
checkpoint_path='',
|
@ -1,6 +1,6 @@
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from models.adaptive_avgmax_pool import adaptive_avgmax_pool2d
|
from .adaptive_avgmax_pool import adaptive_avgmax_pool2d
|
||||||
|
|
||||||
|
|
||||||
class TestTimePoolHead(nn.Module):
|
class TestTimePoolHead(nn.Module):
|
@ -1,5 +1,5 @@
|
|||||||
from torch import optim as optim
|
from torch import optim as optim
|
||||||
from optim import Nadam, RMSpropTF
|
from timm.optim import Nadam, RMSpropTF
|
||||||
|
|
||||||
|
|
||||||
def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
|
def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
|
@ -1,7 +1,6 @@
|
|||||||
from scheduler.cosine_lr import CosineLRScheduler
|
from .cosine_lr import CosineLRScheduler
|
||||||
from scheduler.plateau_lr import PlateauLRScheduler
|
from .tanh_lr import TanhLRScheduler
|
||||||
from scheduler.tanh_lr import TanhLRScheduler
|
from .step_lr import StepLRScheduler
|
||||||
from scheduler.step_lr import StepLRScheduler
|
|
||||||
|
|
||||||
|
|
||||||
def create_scheduler(args, optimizer):
|
def create_scheduler(args, optimizer):
|
@ -0,0 +1 @@
|
|||||||
|
__version__ = '0.1.1'
|
Loading…
Reference in new issue