Ross Wightman
e861b74cf8
Pass through --model-kwargs (and --opt-kwargs for train) from command line through to model __init__. Update some models to improve arg overlay. Cleanup along the way.
2 years ago
Ross Wightman
d5e7d6b27e
Merge remote-tracking branch 'origin/main' into refactor-imports
2 years ago
Lorenzo Baraldi
3d6bc42aa1
Put validation loss under amp_autocast
...
Secured the loss evaluation under the amp, avoiding function to operate on float16
2 years ago
Ross Wightman
927f031293
Major module / path restructure, timm.models.layers -> timm.layers, add _ prefix to all non model modules in timm.models
2 years ago
Ross Wightman
dbe7531aa3
Update scripts to support torch.compile(). Make --results_file arg more consistent across benchmark/validate/inference. Fix #1570
2 years ago
Ross Wightman
9da7e3a799
Add crop_mode for pretraind config / image transforms. Add support for dynamo compilation to benchmark/train/validate
2 years ago
Ross Wightman
4714a4910e
Merge pull request #1525 from TianyiFranklinWang/main
...
✏️ fix typo
2 years ago
klae01
ddd6361904
Update train.py
...
fix typo args.in_chanes
2 years ago
NPU-Franklin
9152b10478
✏️ fix typo
2 years ago
hova88
29baf32327
fix typo : miss back quote
2 years ago
Simon Schrodi
aceb79e002
Fix typo
2 years ago
Ross Wightman
285771972e
Change --amp flags, no more --apex-amp and --native-amp, add --amp-impl to select apex, and --amp-dtype to allow bfloat16 AMP dtype
2 years ago
Ross Wightman
b1b024dfed
Scheduler update, add v2 factory method, support scheduling on updates instead of just epochs. Add LR to summary csv. Add lr_base scaling calculations to train script. Fix #1168
2 years ago
Ross Wightman
b8c8550841
Data improvements. Improve train support for in_chans != 3. Add wds dataset support from bits_and_tpu branch w/ fixes and tweaks. TFDS tweaks.
2 years ago
Ross Wightman
87939e6fab
Refactor device handling in scripts, distributed init to be less 'cuda' centric. More device args passed through where needed.
2 years ago
Ross Wightman
ff6a919cf5
Add --fast-norm arg to benchmark.py, train.py, validate.py
2 years ago
Xiao Wang
11060f84c5
make train.py compatible with torchrun
2 years ago
Ross Wightman
a29fba307d
disable dist_bn when sync_bn active
3 years ago
Ross Wightman
879df47c0a
Support BatchNormAct2d for sync-bn use. Fix #1254
3 years ago
Ross Wightman
037e5e6c09
Fix #1309 , move wandb init after distributed init, only init on rank == 0 process
3 years ago
Jakub Kaczmarzyk
9e12530433
use utils namespace instead of function/classnames
...
This fixes buggy behavior introduced by
https://github.com/rwightman/pytorch-image-models/pull/1266 .
Related to https://github.com/rwightman/pytorch-image-models/pull/1273 .
3 years ago
Xiao Wang
ca991c1fa5
add --aot-autograd
3 years ago
Ross Wightman
fd360ac951
Merge pull request #1266 from kaczmarj/enh/no-star-imports
...
ENH: replace star imports with imported names in train.py
3 years ago
Jakub Kaczmarzyk
ce5578bc3a
replace star imports with imported names
3 years ago
Jakub Kaczmarzyk
dcad288fd6
use argparse groups to group arguments
3 years ago
Jakub Kaczmarzyk
e1e4c9bbae
rm whitespace
3 years ago
han
a16171335b
fix: change milestones to decay-milestones
...
- change argparser option `milestone` to `decay-milestone`
3 years ago
han
57a988df30
fix: multistep lr decay epoch bugs
...
- add milestones arguments
- change decay_epochs to milestones variable
3 years ago
Ross Wightman
b049a5c5c6
Merge remote-tracking branch 'origin/master' into norm_norm_norm
3 years ago
Ross Wightman
04db5833eb
Merge pull request #986 from hankyul2/master
...
fix: typo of argment parser desc in train.py
3 years ago
Ross Wightman
0557c8257d
Fix bug introduced in non layer_decay weight_decay application. Remove debug print, fix arg desc.
3 years ago
Ross Wightman
372ad5fa0d
Significant model refactor and additions:
...
* All models updated with revised foward_features / forward_head interface
* Vision transformer and MLP based models consistently output sequence from forward_features (pooling or token selection considered part of 'head')
* WIP param grouping interface to allow consistent grouping of parameters for layer-wise decay across all model types
* Add gradient checkpointing support to a significant % of models, especially popular architectures
* Formatting and interface consistency improvements across models
* layer-wise LR decay impl part of optimizer factory w/ scale support in scheduler
* Poolformer and Volo architectures added
3 years ago
Ross Wightman
95cfc9b3e8
Merge remote-tracking branch 'origin/master' into norm_norm_norm
3 years ago
Ross Wightman
abc9ba2544
Transitioning default_cfg -> pretrained_cfg. Improving handling of pretrained_cfg source (HF-Hub, files, timm config, etc). Checkpoint handling tweaks.
3 years ago
Ross Wightman
f0f9eccda8
Add --fuser arg to train/validate/benchmark scripts to select jit fuser type
3 years ago
Ross Wightman
5ccf682a8f
Remove deprecated bn-tf train arg and create_model handler. Add evos/evob models back into fx test filter until norm_norm_norm branch merged.
3 years ago
han
ab5ae32f75
fix: typo of argment parser desc in train.py
...
- Remove duplicated `of`
3 years ago
Ross Wightman
ba65dfe2c6
Dataset work
...
* support some torchvision datasets
* improvements to TFDS wrapper for subsplit handling (fix #942 ), shuffle seed
* add class-map support to train (fix #957 )
3 years ago
Ross Wightman
cd638d50a5
Merge pull request #880 from rwightman/fixes_bce_regnet
...
A collection of fixes, model experiments, etc
3 years ago
Ross Wightman
d9abfa48df
Make broadcast_buffers disable its own flag for now (needs more testing on interaction with dist_bn)
3 years ago
Ross Wightman
80075b0b8a
Add worker_seeding arg to allow selecting old vs updated data loader worker seed for (old) experiment repeatability
3 years ago
Shoufa Chen
908563d060
fix `use_amp`
...
Fix https://github.com/rwightman/pytorch-image-models/issues/881
3 years ago
Ross Wightman
0387e6057e
Update binary cross ent impl to use thresholding as an option (convert soft targets from mixup/cutmix to 0, 1)
3 years ago
Ross Wightman
0639d9a591
Fix updated validation_batch_size fallback
3 years ago
Ross Wightman
5db057dca0
Fix misnamed arg, tweak other train script args for better defaults.
3 years ago
Ross Wightman
fb94350896
Update training script and loader factory to allow use of scheduler updates, repeat augment, and bce loss
3 years ago
SamuelGabriel
7c19c35d9f
Global instead of local rank.
4 years ago
Ross Wightman
e15e68d881
Fix #566 , summary.csv writing to pwd on local_rank != 0. Tweak benchmark mem handling to see if it reduces likelihood of 'bad' exceptions on OOM.
4 years ago
Ross Wightman
e685618f45
Merge pull request #550 from amaarora/wandb
...
Wandb Support
4 years ago
Ross Wightman
7c97e66f7c
Remove commented code, add more consistent seed fn
4 years ago