Commit Graph

34 Commits (b995ca3c314598b2b91ca1755232652c6df07335)

Author SHA1 Message Date
Fredo Guan 81ca323751
Davit update formatting and fix grad checkpointing (#7)
1 year ago
Ross Wightman b1b024dfed Scheduler update, add v2 factory method, support scheduling on updates instead of just epochs. Add LR to summary csv. Add lr_base scaling calculations to train script. Fix #1168
2 years ago
Ross Wightman 87939e6fab Refactor device handling in scripts, distributed init to be less 'cuda' centric. More device args passed through where needed.
2 years ago
Ross Wightman 0dbd9352ce Add bulk_runner script and updates to benchmark.py and validate.py for better error handling in bulk runs (used for benchmark and validation result runs). Improved batch size decay stepping on retry...
2 years ago
Ross Wightman 324a4e58b6 disable nvfuser for jit te/legacy modes (for PT 1.12+)
2 years ago
Ross Wightman 2f2b22d8c7 Disable nvfuser fma / opt level overrides per #1244
2 years ago
jjsjann123 f88c606fcf fixing channels_last on cond_conv2d; update nvfuser debug env variable
2 years ago
Ross Wightman f0f9eccda8 Add --fuser arg to train/validate/benchmark scripts to select jit fuser type
2 years ago
Ross Wightman 57992509f9 Fix some formatting in utils/model.py
3 years ago
Ross Wightman e5da481073 Small post-merge tweak for freeze/unfreeze, add to __init__ for utils
3 years ago
Alexander Soare 431e60c83f Add acknowledgements for freeze_batch_norm inspiration
3 years ago
Alexander Soare 65c3d78b96 Freeze unfreeze functionality finalized. Tests added
3 years ago
Alexander Soare 0cb8ea432c wip
3 years ago
Ross Wightman d667351eac Tweak accuracy topk safety. Fix #807
3 years ago
Yohann Lereclus 35c9740826 Fix accuracy when topk > num_classes
3 years ago
Ross Wightman e685618f45
Merge pull request #550 from amaarora/wandb
3 years ago
Ross Wightman 7c97e66f7c Remove commented code, add more consistent seed fn
3 years ago
Aman Arora 5772c55c57 Make wandb optional
3 years ago
Aman Arora f54897cc0b make wandb not required but rather optional as huggingface_hub
3 years ago
Aman Arora 3f028ebc0f import wandb in summary.py
3 years ago
Aman Arora 624c9b6949 log to wandb only if using using wandb
3 years ago
Aman Arora 6b18061773 Add GIST to docstring for quick access
3 years ago
Aman Arora 92b1db9a79 update docstrings and add check on and
3 years ago
Aman Arora b85be24054 update to work with fnmatch
3 years ago
Aman Arora 20626e8387 Add to extract stats for SPP
3 years ago
Ross Wightman 4f49b94311 Initial AGC impl. Still testing.
3 years ago
Ross Wightman 4203efa36d Fix #387 so that checkpoint saver works with max history of 1. Add checkpoint-hist arg to train.py.
3 years ago
Ross Wightman 4ca52d73d8 Add separate set and update method to ModelEmaV2
4 years ago
Ross Wightman 27bbc70d71 Add back old ModelEma and rename new one to ModelEmaV2 to avoid compat breaks in dependant code. Shuffle train script, add a few comments, remove DataParallel support, support experimental torchscript training.
4 years ago
Ross Wightman 9214ca0716 Simplifying EMA...
4 years ago
Ross Wightman 4a3df7842a Fix topn metric view regression on PyTorch 1.7
4 years ago
Ross Wightman 80078c47bb Add Adafactor and Adahessian optimizers, cleanup optimizer arg passing, add gradient clipping support.
4 years ago
Ross Wightman fcb6258877 Add missing leaky_relu layer factory defn, update Apex/Native loss scaler interfaces to support unscaled grad clipping. Bump ver to 0.2.2 for pending release.
4 years ago
Ross Wightman 532e3b417d Reorg of utils into separate modules
4 years ago