Commit Graph

124 Commits (f2e14685a8729f009ebf4d6128ceb4bcd5b95d22)

Author SHA1 Message Date
Ross Wightman f2e14685a8 Add force-cpu flag for train/validate, fix CPU fallback for device init, remove old force cpu flag for EMA model weights
3 years ago
Ross Wightman 0d82876132 Add comment for reference re PyTorch XLA 'race' issue
3 years ago
Ross Wightman 40457e5691 Transforms, augmentation work for bits, add RandomErasing support for XLA (pushing into transforms), revamp of transform/preproc config, etc ongoing...
3 years ago
Ross Wightman c3db5f5801 Worker hack for TFDS eval, add TPU env var setting.
3 years ago
Ross Wightman f411724de4 Fix checkpoint delete issue. Add README about bits and initial Pytorch XLA usage on TPU-VM. Add some FIXMEs and fold train_cfg into train_state by default.
3 years ago
Ross Wightman 91ab0b6ce5 Add proper TrainState checkpoint save/load. Some reorg/refactoring and other cleanup. More to go...
3 years ago
Ross Wightman 5b9c69e80a Add basic training resume based on legacy code
3 years ago
Ross Wightman cbd4ee737f Fix model init for XLA, remove some prints.
3 years ago
Ross Wightman 6d90fcf282 Fix distribute_bn and model_ema
3 years ago
Ross Wightman aa92d7b1c5 Major timm.bits update. Updater and DeviceEnv now dataclasses, after_step closure used, metrics base impl w/ distributed reduce, many tweaks/fixes.
3 years ago
Ross Wightman 76de984a5f Fix some bugs with XLA support, logger, add hacky xla dist launch script since torch.dist.launch doesn't work
3 years ago
Ross Wightman 12d9a6d4d2 First timm.bits commit, add initial abstractions, WIP updates to train, val... some of it working
3 years ago
Ross Wightman e685618f45
Merge pull request #550 from amaarora/wandb
3 years ago
Ross Wightman 7c97e66f7c Remove commented code, add more consistent seed fn
3 years ago
Aman Arora 5772c55c57 Make wandb optional
3 years ago
Aman Arora f54897cc0b make wandb not required but rather optional as huggingface_hub
3 years ago
Aman Arora f13f7508a9 Keep changes to minimal and use args.experiment as wandb project name if it exists
3 years ago
Aman Arora f8bb13f640 Default project name to None
3 years ago
Aman Arora 3f028ebc0f import wandb in summary.py
3 years ago
Aman Arora a9e5d9e5ad log loss as before
3 years ago
Aman Arora 624c9b6949 log to wandb only if using using wandb
3 years ago
Aman Arora 00c8e0b8bd Make use of wandb configurable
3 years ago
Aman Arora 8e6fb861e4 Add wandb support
3 years ago
Ross Wightman 37c71a5609 Some further create_optimizer_v2 tweaks, remove some redudnant code, add back safe model str. Benchmark step times per batch.
3 years ago
Ross Wightman 288682796f Update benchmark script to add precision arg. Fix some downstream (DeiT) compat issues with latest changes. Bump version to 0.4.7
3 years ago
Ross Wightman a5310a3451 Merge remote-tracking branch 'origin/benchmark-fixes-vit_hybrids' into pit_and_vit_update
3 years ago
Ross Wightman e2e3290fbf Add '--experiment' to train args for fixed exp name if desired, 'train' not added to output folder if specified.
3 years ago
Ross Wightman d584e7f617 Support for huggingface hub via create_model and default_cfgs.
3 years ago
Ross Wightman 2db2d87ff7 Add epoch-repeats arg to multiply the number of dataset passes per epoch. Currently for iterable datasets (read TFDS wrapper) only.
3 years ago
Ross Wightman 0e16d4e9fb Add benchmark.py script, and update optimizer factory to be more friendly to use outside of argparse interface.
3 years ago
Ross Wightman 01653db104 Missed clip-mode arg for repo train script
3 years ago
Ross Wightman 4f49b94311 Initial AGC impl. Still testing.
3 years ago
Ross Wightman d8e69206be
Merge pull request #419 from rwightman/byob_vgg_models
3 years ago
Ross Wightman 0356e773f5 Default to native PyTorch AMP instead of APEX amp. Too many APEX issues cropping up lately.
3 years ago
Csaba Kertesz 5114c214fc Change the Python interpreter to Python 3.x in the scripts
3 years ago
Ross Wightman 4203efa36d Fix #387 so that checkpoint saver works with max history of 1. Add checkpoint-hist arg to train.py.
3 years ago
Ross Wightman 38d8f67570 Fix potential issue with change to num_classes arg in train/validate.py defaulting to None (rely on model def / default_cfg)
3 years ago
Ross Wightman 5d4c3d0af3 Add enhanced ParserImageInTar that can read images from tars within tars, folders with multiple tars, etc. Additional comment cleanup.
3 years ago
Ross Wightman 9d5d4b8df6 Fix silly train.py typo during dataset work
3 years ago
Ross Wightman 855d6cc217 More dataset work including factories and a tensorflow datasets (TFDS) wrapper
3 years ago
Ross Wightman fd9061dbf7 Remove debug print from train.py
3 years ago
Ross Wightman 59ec7e6a53 Merge branch 'master' into imagenet21k_datasets_more
3 years ago
Csaba Kertesz e42b140ade Add --input-size option to scripts to specify full input dimensions from command-line
3 years ago
Ross Wightman 231d04e91a ResNetV2 pre-act and non-preact model, w/ BiT pretrained weights and support for ViT R50 model. Tweaks for in21k num_classes passing. More to do... tests failing.
3 years ago
Ross Wightman de6046e213 Initial commit for dataset / parser reorg to support additional datasets / types
3 years ago
Ross Wightman 2ed8f24715 A few more changes for 0.3.2 maint release. Linear layer change for mobilenetv3 and inception_v3, support no bias for linear wrapper.
4 years ago
Ross Wightman 460eba7f24 Work around casting issue with combination of native torch AMP and torchscript for Linear layers
4 years ago
Ross Wightman 27bbc70d71 Add back old ModelEma and rename new one to ModelEmaV2 to avoid compat breaks in dependant code. Shuffle train script, add a few comments, remove DataParallel support, support experimental torchscript training.
4 years ago
Ross Wightman 9214ca0716 Simplifying EMA...
4 years ago
Ross Wightman 80078c47bb Add Adafactor and Adahessian optimizers, cleanup optimizer arg passing, add gradient clipping support.
4 years ago