You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
86 lines
2.2 KiB
86 lines
2.2 KiB
# Quickstart
|
|
|
|
This quickstart is intended for developers who are ready to dive into the code and see an example of how to integrate `timm` into their model training workflow.
|
|
|
|
First, you'll need to install `timm`. For more information on installation, see [Installation](installation).
|
|
|
|
```bash
|
|
pip install timm
|
|
```
|
|
|
|
## Load a Pretrained Model
|
|
|
|
Pretrained models can be loaded using `timm.create_model`
|
|
|
|
```py
|
|
>>> import timm
|
|
|
|
>>> m = timm.create_model('mobilenetv3_large_100', pretrained=True)
|
|
>>> m.eval()
|
|
```
|
|
|
|
## Fine-Tune a Pretrained Model
|
|
|
|
You can finetune any of the pre-trained models just by changing the classifier (the last layer).
|
|
|
|
```py
|
|
>>> model = timm.create_model('mobilenetv3_large_100', pretrained=True, num_classes=NUM_FINETUNE_CLASSES)
|
|
```
|
|
|
|
To fine-tune on your own dataset, you have to write a PyTorch training loop or adapt `timm`'s [training script](training_script) to use your dataset.
|
|
|
|
## Use a Pretrained Model for Feature Extraction
|
|
|
|
Without modifying the network, one can call model.forward_features(input) on any model instead of the usual model(input). This will bypass the head classifier and global pooling for networks.
|
|
|
|
For a more in depth guide to using `timm` for feature extraction, see [Feature Extraction](feature_extraction).
|
|
|
|
```py
|
|
>>> import timm
|
|
>>> import torch
|
|
>>> x = torch.randn(1, 3, 224, 224)
|
|
>>> model = timm.create_model('mobilenetv3_large_100', pretrained=True)
|
|
>>> features = model.forward_features(x)
|
|
>>> print(features.shape)
|
|
torch.Size([1, 960, 7, 7])
|
|
```
|
|
|
|
## List Models with Pretrained Weights
|
|
|
|
You can list all models with pretrained weights using `timm.list_models`.
|
|
|
|
```py
|
|
>>> import timm
|
|
>>> from pprint import pprint
|
|
>>> model_names = timm.list_models(pretrained=True)
|
|
>>> pprint(model_names)
|
|
[
|
|
'adv_inception_v3',
|
|
'cspdarknet53',
|
|
'cspresnext50',
|
|
'densenet121',
|
|
'densenet161',
|
|
'densenet169',
|
|
'densenet201',
|
|
'densenetblur121d',
|
|
'dla34',
|
|
'dla46_c',
|
|
]
|
|
```
|
|
|
|
You can also list models with a specific pattern in their name.
|
|
|
|
```py
|
|
>>> import timm
|
|
>>> from pprint import pprint
|
|
>>> model_names = timm.list_models('*resne*t*')
|
|
>>> pprint(model_names)
|
|
[
|
|
'cspresnet50',
|
|
'cspresnet50d',
|
|
'cspresnet50w',
|
|
'cspresnext50',
|
|
...
|
|
]
|
|
```
|