You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/hfdocs/source/quickstart.mdx

86 lines
2.2 KiB

# Quickstart
This quickstart is intended for developers who are ready to dive into the code and see an example of how to integrate `timm` into their model training workflow.
First, you'll need to install `timm`. For more information on installation, see [Installation](installation).
```bash
pip install timm
```
## Load a Pretrained Model
Pretrained models can be loaded using `timm.create_model`
```py
>>> import timm
>>> m = timm.create_model('mobilenetv3_large_100', pretrained=True)
>>> m.eval()
```
## Fine-Tune a Pretrained Model
You can finetune any of the pre-trained models just by changing the classifier (the last layer).
```py
>>> model = timm.create_model('mobilenetv3_large_100', pretrained=True, num_classes=NUM_FINETUNE_CLASSES)
```
To fine-tune on your own dataset, you have to write a PyTorch training loop or adapt `timm`'s [training script](training_script) to use your dataset.
## Use a Pretrained Model for Feature Extraction
Without modifying the network, one can call model.forward_features(input) on any model instead of the usual model(input). This will bypass the head classifier and global pooling for networks.
For a more in depth guide to using `timm` for feature extraction, see [Feature Extraction](feature_extraction).
```py
>>> import timm
>>> import torch
>>> x = torch.randn(1, 3, 224, 224)
>>> model = timm.create_model('mobilenetv3_large_100', pretrained=True)
>>> features = model.forward_features(x)
>>> print(features.shape)
torch.Size([1, 960, 7, 7])
```
## List Models with Pretrained Weights
You can list all models with pretrained weights using `timm.list_models`.
```py
>>> import timm
>>> from pprint import pprint
>>> model_names = timm.list_models(pretrained=True)
>>> pprint(model_names)
[
'adv_inception_v3',
'cspdarknet53',
'cspresnext50',
'densenet121',
'densenet161',
'densenet169',
'densenet201',
'densenetblur121d',
'dla34',
'dla46_c',
]
```
You can also list models with a specific pattern in their name.
```py
>>> import timm
>>> from pprint import pprint
>>> model_names = timm.list_models('*resne*t*')
>>> pprint(model_names)
[
'cspresnet50',
'cspresnet50d',
'cspresnet50w',
'cspresnext50',
...
]
```