Merge remote-tracking branch 'origin/main' into multi-weight

pull/1520/head
Ross Wightman 1 year ago
commit c59d88339b

@ -15,10 +15,10 @@ jobs:
name: Run tests on ${{ matrix.os }} with Python ${{ matrix.python }}
strategy:
matrix:
os: [ubuntu-latest, macOS-latest]
python: ['3.9']
torch: ['1.10.0']
torchvision: ['0.11.1']
os: [ubuntu-latest]
python: ['3.10']
torch: ['1.13.0']
torchvision: ['0.14.0']
runs-on: ${{ matrix.os }}
steps:
@ -34,6 +34,9 @@ jobs:
- name: Install torch on mac
if: startsWith(matrix.os, 'macOS')
run: pip install --no-cache-dir torch==${{ matrix.torch }} torchvision==${{ matrix.torchvision }}
- name: Install torch on Windows
if: startsWith(matrix.os, 'windows')
run: pip install --no-cache-dir torch==${{ matrix.torch }} torchvision==${{ matrix.torchvision }}
- name: Install torch on ubuntu
if: startsWith(matrix.os, 'ubuntu')
run: |
@ -42,11 +45,18 @@ jobs:
sudo apt install -y google-perftools
- name: Install requirements
run: |
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
pip install --no-cache-dir git+https://github.com/mapillary/inplace_abn.git@v1.1.0
- name: Run tests
pip install -r requirements.txt
pip install --no-cache-dir git+https://github.com/mapillary/inplace_abn.git
- name: Run tests on Windows
if: startsWith(matrix.os, 'windows')
env:
PYTHONDONTWRITEBYTECODE: 1
run: |
pytest -vv tests
- name: Run tests on Linux / Mac
if: ${{ !startsWith(matrix.os, 'windows') }}
env:
LD_PRELOAD: /usr/lib/x86_64-linux-gnu/libtcmalloc.so.4
PYTHONDONTWRITEBYTECODE: 1
run: |
export PYTHONDONTWRITEBYTECODE=1
pytest -vv --forked --durations=0 ./tests
pytest -vv --forked --durations=0 tests

@ -4,7 +4,7 @@ Hacked together by / Copyright 2021, Ross Wightman
"""
import os
from torchvision.datasets import CIFAR100, CIFAR10, MNIST, QMNIST, KMNIST, FashionMNIST, ImageNet, ImageFolder
from torchvision.datasets import CIFAR100, CIFAR10, MNIST, KMNIST, FashionMNIST, ImageFolder
try:
from torchvision.datasets import Places365
has_places365 = True
@ -15,6 +15,16 @@ try:
has_inaturalist = True
except ImportError:
has_inaturalist = False
try:
from torchvision.datasets import QMNIST
has_qmnist = True
except ImportError:
has_qmnist = False
try:
from torchvision.datasets import ImageNet
has_imagenet = True
except ImportError:
has_imagenet = False
from .dataset import IterableImageDataset, ImageDataset
@ -22,7 +32,6 @@ _TORCH_BASIC_DS = dict(
cifar10=CIFAR10,
cifar100=CIFAR100,
mnist=MNIST,
qmist=QMNIST,
kmnist=KMNIST,
fashion_mnist=FashionMNIST,
)
@ -122,7 +131,12 @@ def create_dataset(
elif split in _EVAL_SYNONYM:
split = 'val'
ds = Places365(split=split, **torch_kwargs)
elif name == 'qmnist':
assert has_qmnist, 'Please update to a newer PyTorch and torchvision for QMNIST dataset.'
use_train = split in _TRAIN_SYNONYM
ds = QMNIST(train=use_train, **torch_kwargs)
elif name == 'imagenet':
assert has_imagenet, 'Please update to a newer PyTorch and torchvision for ImageNet dataset.'
if split in _EVAL_SYNONYM:
split = 'val'
ds = ImageNet(split=split, **torch_kwargs)

@ -3,10 +3,12 @@ import logging
import os
from functools import partial
from pathlib import Path
from typing import Union
from tempfile import TemporaryDirectory
from typing import Optional, Union
import torch
from torch.hub import HASH_REGEX, download_url_to_file, urlparse
try:
from torch.hub import get_dir
except ImportError:
@ -15,7 +17,10 @@ except ImportError:
from timm import __version__
try:
from huggingface_hub import HfApi, HfFolder, Repository, hf_hub_download, hf_hub_url
from huggingface_hub import (create_repo, get_hf_file_metadata,
hf_hub_download, hf_hub_url,
repo_type_and_id_from_hf_id, upload_folder)
from huggingface_hub.utils import EntryNotFoundError
hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__)
_has_hf_hub = True
except ImportError:
@ -121,53 +126,45 @@ def save_for_hf(model, save_directory, model_config=None):
def push_to_hf_hub(
model,
local_dir,
repo_namespace_or_url=None,
commit_message='Add model',
use_auth_token=True,
git_email=None,
git_user=None,
revision=None,
model_config=None,
repo_id: str,
commit_message: str ='Add model',
token: Optional[str] = None,
revision: Optional[str] = None,
private: bool = False,
create_pr: bool = False,
model_config: Optional[dict] = None,
):
if repo_namespace_or_url:
repo_owner, repo_name = repo_namespace_or_url.rstrip('/').split('/')[-2:]
else:
if isinstance(use_auth_token, str):
token = use_auth_token
else:
token = HfFolder.get_token()
if token is None:
raise ValueError(
"You must login to the Hugging Face hub on this computer by typing `transformers-cli login` and "
"entering your credentials to use `use_auth_token=True`. Alternatively, you can pass your own "
"token as the `use_auth_token` argument."
)
repo_owner = HfApi().whoami(token)['name']
repo_name = Path(local_dir).name
repo_url = f'https://huggingface.co/{repo_owner}/{repo_name}'
repo = Repository(
local_dir,
clone_from=repo_url,
use_auth_token=use_auth_token,
git_user=git_user,
git_email=git_email,
revision=revision,
)
# Prepare a default model card that includes the necessary tags to enable inference.
readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {repo_name}'
with repo.commit(commit_message):
# Create repo if doesn't exist yet
repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
# Infer complete repo_id from repo_url
# Can be different from the input `repo_id` if repo_owner was implicit
_, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url)
repo_id = f"{repo_owner}/{repo_name}"
# Check if README file already exist in repo
try:
get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision))
has_readme = True
except EntryNotFoundError:
has_readme = False
# Dump model and push to Hub
with TemporaryDirectory() as tmpdir:
# Save model weights and config.
save_for_hf(model, repo.local_dir, model_config=model_config)
save_for_hf(model, tmpdir, model_config=model_config)
# Save a model card if it doesn't exist.
readme_path = Path(repo.local_dir) / 'README.md'
if not readme_path.exists():
# Add readme if does not exist
if not has_readme:
readme_path = Path(tmpdir) / "README.md"
readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {repo_id}'
readme_path.write_text(readme_text)
return repo.git_remote_url()
# Upload model and return
return upload_folder(
repo_id=repo_id,
folder_path=tmpdir,
revision=revision,
create_pr=create_pr,
commit_message=commit_message,
)

@ -50,7 +50,7 @@ class LayerNorm(nn.LayerNorm):
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self._fast_norm:
x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
else:
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
return x
@ -65,7 +65,7 @@ class LayerNorm2d(nn.LayerNorm):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.permute(0, 2, 3, 1)
if self._fast_norm:
x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
else:
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
x = x.permute(0, 3, 1, 2)

@ -104,7 +104,7 @@ group.add_argument('--class-map', default='', type=str, metavar='FILENAME',
# Model parameters
group = parser.add_argument_group('Model parameters')
group.add_argument('--model', default='resnet50', type=str, metavar='MODEL',
help='Name of model to train (default: "resnet50"')
help='Name of model to train (default: "resnet50")')
group.add_argument('--pretrained', action='store_true', default=False,
help='Start with pretrained version of specified network (if avail)')
group.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
@ -157,7 +157,7 @@ scripting_group.add_argument('--dynamo', default=False, action='store_true',
# Optimizer parameters
group = parser.add_argument_group('Optimizer parameters')
group.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "sgd"')
help='Optimizer (default: "sgd")')
group.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: None, use opt default)')
group.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
@ -222,7 +222,7 @@ group.add_argument('--warmup-prefix', action='store_true', default=False,
group.add_argument('--cooldown-epochs', type=int, default=0, metavar='N',
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
group.add_argument('--patience-epochs', type=int, default=10, metavar='N',
help='patience epochs for Plateau LR scheduler (default: 10')
help='patience epochs for Plateau LR scheduler (default: 10)')
group.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
help='LR decay rate (default: 0.1)')
@ -419,7 +419,7 @@ def main():
in_chans = 3
if args.in_chans is not None:
in_chans = args.in_chanes
in_chans = args.in_chans
elif args.input_size is not None:
in_chans = args.input_size[0]

Loading…
Cancel
Save