You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/models/vision-transformer/index.html

1692 lines
45 KiB

<!doctype html>
<html lang="en" class="no-js">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width,initial-scale=1">
<meta name="description" content="Pretained Image Recognition Models">
<link rel="icon" href="../../assets/images/favicon.png">
<meta name="generator" content="mkdocs-1.1.2, mkdocs-material-7.0.6">
<title>Vision Transformer (ViT) - Pytorch Image Models</title>
<link rel="stylesheet" href="../../assets/stylesheets/main.2c0c5eaf.min.css">
<link rel="stylesheet" href="../../assets/stylesheets/palette.7fa14f5b.min.css">
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
<link rel="stylesheet" href="https://fonts.googleapis.com/css?family=Roboto:300,400,400i,700%7CRoboto+Mono&display=fallback">
<style>:root{--md-text-font-family:"Roboto";--md-code-font-family:"Roboto Mono"}</style>
</head>
<body dir="ltr" data-md-color-scheme="" data-md-color-primary="none" data-md-color-accent="none">
<input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer" autocomplete="off">
<input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search" autocomplete="off">
<label class="md-overlay" for="__drawer"></label>
<div data-md-component="skip">
<a href="#vision-transformer-vit" class="md-skip">
Skip to content
</a>
</div>
<div data-md-component="announce">
</div>
<header class="md-header" data-md-component="header">
<nav class="md-header__inner md-grid" aria-label="Header">
<a href="../.." title="Pytorch Image Models" class="md-header__button md-logo" aria-label="Pytorch Image Models" data-md-component="logo">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M12 8a3 3 0 0 0 3-3 3 3 0 0 0-3-3 3 3 0 0 0-3 3 3 3 0 0 0 3 3m0 3.54C9.64 9.35 6.5 8 3 8v11c3.5 0 6.64 1.35 9 3.54 2.36-2.19 5.5-3.54 9-3.54V8c-3.5 0-6.64 1.35-9 3.54z"/></svg>
</a>
<label class="md-header__button md-icon" for="__drawer">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M3 6h18v2H3V6m0 5h18v2H3v-2m0 5h18v2H3v-2z"/></svg>
</label>
<div class="md-header__title" data-md-component="header-title">
<div class="md-header__ellipsis">
<div class="md-header__topic">
<span class="md-ellipsis">
Pytorch Image Models
</span>
</div>
<div class="md-header__topic" data-md-component="header-topic">
<span class="md-ellipsis">
Vision Transformer (ViT)
</span>
</div>
</div>
</div>
<div class="md-header__options">
</div>
<label class="md-header__button md-icon" for="__search">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M9.5 3A6.5 6.5 0 0 1 16 9.5c0 1.61-.59 3.09-1.56 4.23l.27.27h.79l5 5-1.5 1.5-5-5v-.79l-.27-.27A6.516 6.516 0 0 1 9.5 16 6.5 6.5 0 0 1 3 9.5 6.5 6.5 0 0 1 9.5 3m0 2C7 5 5 7 5 9.5S7 14 9.5 14 14 12 14 9.5 12 5 9.5 5z"/></svg>
</label>
<div class="md-search" data-md-component="search" role="dialog">
<label class="md-search__overlay" for="__search"></label>
<div class="md-search__inner" role="search">
<form class="md-search__form" name="search">
<input type="text" class="md-search__input" name="query" aria-label="Search" placeholder="Search" autocapitalize="off" autocorrect="off" autocomplete="off" spellcheck="false" data-md-component="search-query" data-md-state="active" required>
<label class="md-search__icon md-icon" for="__search">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M9.5 3A6.5 6.5 0 0 1 16 9.5c0 1.61-.59 3.09-1.56 4.23l.27.27h.79l5 5-1.5 1.5-5-5v-.79l-.27-.27A6.516 6.516 0 0 1 9.5 16 6.5 6.5 0 0 1 3 9.5 6.5 6.5 0 0 1 9.5 3m0 2C7 5 5 7 5 9.5S7 14 9.5 14 14 12 14 9.5 12 5 9.5 5z"/></svg>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M20 11v2H8l5.5 5.5-1.42 1.42L4.16 12l7.92-7.92L13.5 5.5 8 11h12z"/></svg>
</label>
<button type="reset" class="md-search__icon md-icon" aria-label="Clear" tabindex="-1">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M19 6.41L17.59 5 12 10.59 6.41 5 5 6.41 10.59 12 5 17.59 6.41 19 12 13.41 17.59 19 19 17.59 13.41 12 19 6.41z"/></svg>
</button>
</form>
<div class="md-search__output">
<div class="md-search__scrollwrap" data-md-scrollfix>
<div class="md-search-result" data-md-component="search-result">
<div class="md-search-result__meta">
Initializing search
</div>
<ol class="md-search-result__list"></ol>
</div>
</div>
</div>
</div>
</div>
<div class="md-header__source">
<a href="https://github.com/rwightman/pytorch-image-models/" title="Go to repository" class="md-source" data-md-component="source">
<div class="md-source__icon md-icon">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 448 512"><path d="M439.55 236.05L244 40.45a28.87 28.87 0 0 0-40.81 0l-40.66 40.63 51.52 51.52c27.06-9.14 52.68 16.77 43.39 43.68l49.66 49.66c34.23-11.8 61.18 31 35.47 56.69-26.49 26.49-70.21-2.87-56-37.34L240.22 199v121.85c25.3 12.54 22.26 41.85 9.08 55a34.34 34.34 0 0 1-48.55 0c-17.57-17.6-11.07-46.91 11.25-56v-123c-20.8-8.51-24.6-30.74-18.64-45L142.57 101 8.45 235.14a28.86 28.86 0 0 0 0 40.81l195.61 195.6a28.86 28.86 0 0 0 40.8 0l194.69-194.69a28.86 28.86 0 0 0 0-40.81z"/></svg>
</div>
<div class="md-source__repository">
rwightman/pytorch-image-models
</div>
</a>
</div>
</nav>
</header>
<div class="md-container" data-md-component="container">
<main class="md-main" data-md-component="main">
<div class="md-main__inner md-grid">
<div class="md-sidebar md-sidebar--primary" data-md-component="sidebar" data-md-type="navigation" >
<div class="md-sidebar__scrollwrap">
<div class="md-sidebar__inner">
<nav class="md-nav md-nav--primary" aria-label="Navigation" data-md-level="0">
<label class="md-nav__title" for="__drawer">
<a href="../.." title="Pytorch Image Models" class="md-nav__button md-logo" aria-label="Pytorch Image Models" data-md-component="logo">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M12 8a3 3 0 0 0 3-3 3 3 0 0 0-3-3 3 3 0 0 0-3 3 3 3 0 0 0 3 3m0 3.54C9.64 9.35 6.5 8 3 8v11c3.5 0 6.64 1.35 9 3.54 2.36-2.19 5.5-3.54 9-3.54V8c-3.5 0-6.64 1.35-9 3.54z"/></svg>
</a>
Pytorch Image Models
</label>
<div class="md-nav__source">
<a href="https://github.com/rwightman/pytorch-image-models/" title="Go to repository" class="md-source" data-md-component="source">
<div class="md-source__icon md-icon">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 448 512"><path d="M439.55 236.05L244 40.45a28.87 28.87 0 0 0-40.81 0l-40.66 40.63 51.52 51.52c27.06-9.14 52.68 16.77 43.39 43.68l49.66 49.66c34.23-11.8 61.18 31 35.47 56.69-26.49 26.49-70.21-2.87-56-37.34L240.22 199v121.85c25.3 12.54 22.26 41.85 9.08 55a34.34 34.34 0 0 1-48.55 0c-17.57-17.6-11.07-46.91 11.25-56v-123c-20.8-8.51-24.6-30.74-18.64-45L142.57 101 8.45 235.14a28.86 28.86 0 0 0 0 40.81l195.61 195.6a28.86 28.86 0 0 0 40.8 0l194.69-194.69a28.86 28.86 0 0 0 0-40.81z"/></svg>
</div>
<div class="md-source__repository">
rwightman/pytorch-image-models
</div>
</a>
</div>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../.." class="md-nav__link">
Getting Started
</a>
</li>
<li class="md-nav__item">
<a href="../" class="md-nav__link">
Model Architectures
</a>
</li>
<li class="md-nav__item">
<a href="../../results/" class="md-nav__link">
Results
</a>
</li>
<li class="md-nav__item">
<a href="../../scripts/" class="md-nav__link">
Scripts
</a>
</li>
<li class="md-nav__item">
<a href="../../training_hparam_examples/" class="md-nav__link">
Training Examples
</a>
</li>
<li class="md-nav__item">
<a href="../../feature_extraction/" class="md-nav__link">
Feature Extraction
</a>
</li>
<li class="md-nav__item">
<a href="../../changes/" class="md-nav__link">
Recent Changes
</a>
</li>
<li class="md-nav__item">
<a href="../../archived_changes/" class="md-nav__link">
Archived Changes
</a>
</li>
<li class="md-nav__item md-nav__item--active md-nav__item--nested">
<input class="md-nav__toggle md-toggle" data-md-toggle="__nav_9" type="checkbox" id="__nav_9" checked>
<label class="md-nav__link" for="__nav_9">
Models
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" aria-label="Models" data-md-level="1">
<label class="md-nav__title" for="__nav_9">
<span class="md-nav__icon md-icon"></span>
Models
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../adversarial-inception-v3/" class="md-nav__link">
Adversarial Inception v3
</a>
</li>
<li class="md-nav__item">
<a href="../advprop/" class="md-nav__link">
AdvProp (EfficientNet)
</a>
</li>
<li class="md-nav__item">
<a href="../big-transfer/" class="md-nav__link">
Big Transfer (BiT)
</a>
</li>
<li class="md-nav__item">
<a href="../csp-darknet/" class="md-nav__link">
CSP-DarkNet
</a>
</li>
<li class="md-nav__item">
<a href="../csp-resnet/" class="md-nav__link">
CSP-ResNet
</a>
</li>
<li class="md-nav__item">
<a href="../csp-resnext/" class="md-nav__link">
CSP-ResNeXt
</a>
</li>
<li class="md-nav__item">
<a href="../densenet/" class="md-nav__link">
DenseNet
</a>
</li>
<li class="md-nav__item">
<a href="../dla/" class="md-nav__link">
Deep Layer Aggregation
</a>
</li>
<li class="md-nav__item">
<a href="../dpn/" class="md-nav__link">
Dual Path Network (DPN)
</a>
</li>
<li class="md-nav__item">
<a href="../ecaresnet/" class="md-nav__link">
ECA-ResNet
</a>
</li>
<li class="md-nav__item">
<a href="../efficientnet-pruned/" class="md-nav__link">
EfficientNet (Knapsack Pruned)
</a>
</li>
<li class="md-nav__item">
<a href="../efficientnet/" class="md-nav__link">
EfficientNet
</a>
</li>
<li class="md-nav__item">
<a href="../ensemble-adversarial/" class="md-nav__link">
Ensemble Adversarial Inception ResNet v2
</a>
</li>
<li class="md-nav__item">
<a href="../ese-vovnet/" class="md-nav__link">
ESE-VoVNet
</a>
</li>
<li class="md-nav__item">
<a href="../fbnet/" class="md-nav__link">
FBNet
</a>
</li>
<li class="md-nav__item">
<a href="../gloun-inception-v3/" class="md-nav__link">
(Gluon) Inception v3
</a>
</li>
<li class="md-nav__item">
<a href="../gloun-resnet/" class="md-nav__link">
(Gluon) ResNet
</a>
</li>
<li class="md-nav__item">
<a href="../gloun-resnext/" class="md-nav__link">
(Gluon) ResNeXt
</a>
</li>
<li class="md-nav__item">
<a href="../gloun-senet/" class="md-nav__link">
(Gluon) SENet
</a>
</li>
<li class="md-nav__item">
<a href="../gloun-seresnext/" class="md-nav__link">
(Gluon) SE-ResNeXt
</a>
</li>
<li class="md-nav__item">
<a href="../gloun-xception/" class="md-nav__link">
(Gluon) Xception
</a>
</li>
<li class="md-nav__item">
<a href="../hrnet/" class="md-nav__link">
HRNet
</a>
</li>
<li class="md-nav__item">
<a href="../ig-resnext/" class="md-nav__link">
Instagram ResNeXt WSL
</a>
</li>
<li class="md-nav__item">
<a href="../inception-resnet-v2/" class="md-nav__link">
Inception ResNet v2
</a>
</li>
<li class="md-nav__item">
<a href="../inception-v3/" class="md-nav__link">
Inception v3
</a>
</li>
<li class="md-nav__item">
<a href="../inception-v4/" class="md-nav__link">
Inception v4
</a>
</li>
<li class="md-nav__item">
<a href="../legacy-se-resnet/" class="md-nav__link">
(Legacy) SE-ResNet
</a>
</li>
<li class="md-nav__item">
<a href="../legacy-se-resnext/" class="md-nav__link">
(Legacy) SE-ResNeXt
</a>
</li>
<li class="md-nav__item">
<a href="../legacy-senet/" class="md-nav__link">
(Legacy) SENet
</a>
</li>
<li class="md-nav__item">
<a href="../mixnet/" class="md-nav__link">
MixNet
</a>
</li>
<li class="md-nav__item">
<a href="../mnasnet/" class="md-nav__link">
MnasNet
</a>
</li>
<li class="md-nav__item">
<a href="../mobilenet-v2/" class="md-nav__link">
MobileNet v2
</a>
</li>
<li class="md-nav__item">
<a href="../mobilenet-v3/" class="md-nav__link">
MobileNet v3
</a>
</li>
<li class="md-nav__item">
<a href="../nasnet/" class="md-nav__link">
NASNet
</a>
</li>
<li class="md-nav__item">
<a href="../noisy-student/" class="md-nav__link">
Noisy Student (EfficientNet)
</a>
</li>
<li class="md-nav__item">
<a href="../pnasnet/" class="md-nav__link">
PNASNet
</a>
</li>
<li class="md-nav__item">
<a href="../regnetx/" class="md-nav__link">
RegNetX
</a>
</li>
<li class="md-nav__item">
<a href="../regnety/" class="md-nav__link">
RegNetY
</a>
</li>
<li class="md-nav__item">
<a href="../res2net/" class="md-nav__link">
Res2Net
</a>
</li>
<li class="md-nav__item">
<a href="../res2next/" class="md-nav__link">
Res2NeXt
</a>
</li>
<li class="md-nav__item">
<a href="../resnest/" class="md-nav__link">
ResNeSt
</a>
</li>
<li class="md-nav__item">
<a href="../resnet-d/" class="md-nav__link">
ResNet-D
</a>
</li>
<li class="md-nav__item">
<a href="../resnet/" class="md-nav__link">
ResNet
</a>
</li>
<li class="md-nav__item">
<a href="../resnext/" class="md-nav__link">
ResNeXt
</a>
</li>
<li class="md-nav__item">
<a href="../rexnet/" class="md-nav__link">
RexNet
</a>
</li>
<li class="md-nav__item">
<a href="../se-resnet/" class="md-nav__link">
SE-ResNet
</a>
</li>
<li class="md-nav__item">
<a href="../selecsls/" class="md-nav__link">
SelecSLS
</a>
</li>
<li class="md-nav__item">
<a href="../seresnext/" class="md-nav__link">
SE-ResNeXt
</a>
</li>
<li class="md-nav__item">
<a href="../skresnet/" class="md-nav__link">
SK-ResNet
</a>
</li>
<li class="md-nav__item">
<a href="../skresnext/" class="md-nav__link">
SK-ResNeXt
</a>
</li>
<li class="md-nav__item">
<a href="../spnasnet/" class="md-nav__link">
SPNASNet
</a>
</li>
<li class="md-nav__item">
<a href="../ssl-resnet/" class="md-nav__link">
SSL ResNet
</a>
</li>
<li class="md-nav__item">
<a href="../ssl-resnext/" class="md-nav__link">
SSL ResNeXT
</a>
</li>
<li class="md-nav__item">
<a href="../swsl-resnet/" class="md-nav__link">
SWSL ResNet
</a>
</li>
<li class="md-nav__item">
<a href="../swsl-resnext/" class="md-nav__link">
SWSL ResNeXt
</a>
</li>
<li class="md-nav__item">
<a href="../tf-efficientnet-condconv/" class="md-nav__link">
(Tensorflow) EfficientNet CondConv
</a>
</li>
<li class="md-nav__item">
<a href="../tf-efficientnet-lite/" class="md-nav__link">
(Tensorflow) EfficientNet Lite
</a>
</li>
<li class="md-nav__item">
<a href="../tf-efficientnet/" class="md-nav__link">
(Tensorflow) EfficientNet
</a>
</li>
<li class="md-nav__item">
<a href="../tf-inception-v3/" class="md-nav__link">
(Tensorflow) Inception v3
</a>
</li>
<li class="md-nav__item">
<a href="../tf-mixnet/" class="md-nav__link">
(Tensorflow) MixNet
</a>
</li>
<li class="md-nav__item">
<a href="../tf-mobilenet-v3/" class="md-nav__link">
(Tensorflow) MobileNet v3
</a>
</li>
<li class="md-nav__item">
<a href="../tresnet/" class="md-nav__link">
TResNet
</a>
</li>
<li class="md-nav__item md-nav__item--active">
<input class="md-nav__toggle md-toggle" data-md-toggle="toc" type="checkbox" id="__toc">
<label class="md-nav__link md-nav__link--active" for="__toc">
Vision Transformer (ViT)
<span class="md-nav__icon md-icon"></span>
</label>
<a href="./" class="md-nav__link md-nav__link--active">
Vision Transformer (ViT)
</a>
<nav class="md-nav md-nav--secondary" aria-label="Table of contents">
<label class="md-nav__title" for="__toc">
<span class="md-nav__icon md-icon"></span>
Table of contents
</label>
<ul class="md-nav__list" data-md-component="toc" data-md-scrollfix>
<li class="md-nav__item">
<a href="#how-do-i-use-this-model-on-an-image" class="md-nav__link">
How do I use this model on an image?
</a>
</li>
<li class="md-nav__item">
<a href="#how-do-i-finetune-this-model" class="md-nav__link">
How do I finetune this model?
</a>
</li>
<li class="md-nav__item">
<a href="#how-do-i-train-this-model" class="md-nav__link">
How do I train this model?
</a>
</li>
<li class="md-nav__item">
<a href="#citation" class="md-nav__link">
Citation
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item">
<a href="../wide-resnet/" class="md-nav__link">
Wide ResNet
</a>
</li>
<li class="md-nav__item">
<a href="../xception/" class="md-nav__link">
Xception
</a>
</li>
</ul>
</nav>
</li>
</ul>
</nav>
</div>
</div>
</div>
<div class="md-sidebar md-sidebar--secondary" data-md-component="sidebar" data-md-type="toc" >
<div class="md-sidebar__scrollwrap">
<div class="md-sidebar__inner">
<nav class="md-nav md-nav--secondary" aria-label="Table of contents">
<label class="md-nav__title" for="__toc">
<span class="md-nav__icon md-icon"></span>
Table of contents
</label>
<ul class="md-nav__list" data-md-component="toc" data-md-scrollfix>
<li class="md-nav__item">
<a href="#how-do-i-use-this-model-on-an-image" class="md-nav__link">
How do I use this model on an image?
</a>
</li>
<li class="md-nav__item">
<a href="#how-do-i-finetune-this-model" class="md-nav__link">
How do I finetune this model?
</a>
</li>
<li class="md-nav__item">
<a href="#how-do-i-train-this-model" class="md-nav__link">
How do I train this model?
</a>
</li>
<li class="md-nav__item">
<a href="#citation" class="md-nav__link">
Citation
</a>
</li>
</ul>
</nav>
</div>
</div>
</div>
<div class="md-content" data-md-component="content">
<article class="md-content__inner md-typeset">
<a href="https://github.com/rwightman/pytorch-image-models/edit/master/docs/models/vision-transformer.md" title="Edit this page" class="md-content__button md-icon">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M20.71 7.04c.39-.39.39-1.04 0-1.41l-2.34-2.34c-.37-.39-1.02-.39-1.41 0l-1.84 1.83 3.75 3.75M3 17.25V21h3.75L17.81 9.93l-3.75-3.75L3 17.25z"/></svg>
</a>
<h1 id="vision-transformer-vit">Vision Transformer (ViT)</h1>
<p>The <strong>Vision Transformer</strong> is a model for image classification that employs a Transformer-like architecture over patches of the image. This includes the use of <a href="https://paperswithcode.com/method/multi-head-attention">Multi-Head Attention</a>, <a href="https://paperswithcode.com/method/scaled">Scaled Dot-Product Attention</a> and other architectural features seen in the <a href="https://paperswithcode.com/method/transformer">Transformer</a> architecture traditionally used for NLP.</p>
<h2 id="how-do-i-use-this-model-on-an-image">How do I use this model on an image?</h2>
<p>To load a pretrained model:</p>
<div class="highlight"><pre><span></span><code><span class="kn">import</span> <span class="nn">timm</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">timm</span><span class="o">.</span><span class="n">create_model</span><span class="p">(</span><span class="s1">&#39;vit_base_patch16_224&#39;</span><span class="p">,</span> <span class="n">pretrained</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="n">model</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span>
</code></pre></div>
<p>To load and preprocess the image:
<div class="highlight"><pre><span></span><code><span class="kn">import</span> <span class="nn">urllib</span>
<span class="kn">from</span> <span class="nn">PIL</span> <span class="kn">import</span> <span class="n">Image</span>
<span class="kn">from</span> <span class="nn">timm.data</span> <span class="kn">import</span> <span class="n">resolve_data_config</span>
<span class="kn">from</span> <span class="nn">timm.data.transforms_factory</span> <span class="kn">import</span> <span class="n">create_transform</span>
<span class="n">config</span> <span class="o">=</span> <span class="n">resolve_data_config</span><span class="p">({},</span> <span class="n">model</span><span class="o">=</span><span class="n">model</span><span class="p">)</span>
<span class="n">transform</span> <span class="o">=</span> <span class="n">create_transform</span><span class="p">(</span><span class="o">**</span><span class="n">config</span><span class="p">)</span>
<span class="n">url</span><span class="p">,</span> <span class="n">filename</span> <span class="o">=</span> <span class="p">(</span><span class="s2">&quot;https://github.com/pytorch/hub/raw/master/images/dog.jpg&quot;</span><span class="p">,</span> <span class="s2">&quot;dog.jpg&quot;</span><span class="p">)</span>
<span class="n">urllib</span><span class="o">.</span><span class="n">request</span><span class="o">.</span><span class="n">urlretrieve</span><span class="p">(</span><span class="n">url</span><span class="p">,</span> <span class="n">filename</span><span class="p">)</span>
<span class="n">img</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">filename</span><span class="p">)</span><span class="o">.</span><span class="n">convert</span><span class="p">(</span><span class="s1">&#39;RGB&#39;</span><span class="p">)</span>
<span class="n">tensor</span> <span class="o">=</span> <span class="n">transform</span><span class="p">(</span><span class="n">img</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="c1"># transform and add batch dimension</span>
</code></pre></div></p>
<p>To get the model predictions:
<div class="highlight"><pre><span></span><code><span class="kn">import</span> <span class="nn">torch</span>
<span class="k">with</span> <span class="n">torch</span><span class="o">.</span><span class="n">no_grad</span><span class="p">():</span>
<span class="n">out</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">tensor</span><span class="p">)</span>
<span class="n">probabilities</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">functional</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">out</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">probabilities</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span>
<span class="c1"># prints: torch.Size([1000])</span>
</code></pre></div></p>
<p>To get the top-5 predictions class names:
<div class="highlight"><pre><span></span><code><span class="c1"># Get imagenet class mappings</span>
<span class="n">url</span><span class="p">,</span> <span class="n">filename</span> <span class="o">=</span> <span class="p">(</span><span class="s2">&quot;https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt&quot;</span><span class="p">,</span> <span class="s2">&quot;imagenet_classes.txt&quot;</span><span class="p">)</span>
<span class="n">urllib</span><span class="o">.</span><span class="n">request</span><span class="o">.</span><span class="n">urlretrieve</span><span class="p">(</span><span class="n">url</span><span class="p">,</span> <span class="n">filename</span><span class="p">)</span>
<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="s2">&quot;imagenet_classes.txt&quot;</span><span class="p">,</span> <span class="s2">&quot;r&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
<span class="n">categories</span> <span class="o">=</span> <span class="p">[</span><span class="n">s</span><span class="o">.</span><span class="n">strip</span><span class="p">()</span> <span class="k">for</span> <span class="n">s</span> <span class="ow">in</span> <span class="n">f</span><span class="o">.</span><span class="n">readlines</span><span class="p">()]</span>
<span class="c1"># Print top categories per image</span>
<span class="n">top5_prob</span><span class="p">,</span> <span class="n">top5_catid</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">topk</span><span class="p">(</span><span class="n">probabilities</span><span class="p">,</span> <span class="mi">5</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">top5_prob</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">0</span><span class="p">)):</span>
<span class="nb">print</span><span class="p">(</span><span class="n">categories</span><span class="p">[</span><span class="n">top5_catid</span><span class="p">[</span><span class="n">i</span><span class="p">]],</span> <span class="n">top5_prob</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>
<span class="c1"># prints class names and probabilities like:</span>
<span class="c1"># [(&#39;Samoyed&#39;, 0.6425196528434753), (&#39;Pomeranian&#39;, 0.04062102362513542), (&#39;keeshond&#39;, 0.03186424449086189), (&#39;white wolf&#39;, 0.01739676296710968), (&#39;Eskimo dog&#39;, 0.011717947199940681)]</span>
</code></pre></div></p>
<p>Replace the model name with the variant you want to use, e.g. <code>vit_base_patch16_224</code>. You can find the IDs in the model summaries at the top of this page.</p>
<p>To extract image features with this model, follow the <a href="https://rwightman.github.io/pytorch-image-models/feature_extraction/">timm feature extraction examples</a>, just change the name of the model you want to use.</p>
<h2 id="how-do-i-finetune-this-model">How do I finetune this model?</h2>
<p>You can finetune any of the pre-trained models just by changing the classifier (the last layer).
<div class="highlight"><pre><span></span><code><span class="n">model</span> <span class="o">=</span> <span class="n">timm</span><span class="o">.</span><span class="n">create_model</span><span class="p">(</span><span class="s1">&#39;vit_base_patch16_224&#39;</span><span class="p">,</span> <span class="n">pretrained</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">num_classes</span><span class="o">=</span><span class="n">NUM_FINETUNE_CLASSES</span><span class="p">)</span>
</code></pre></div>
To finetune on your own dataset, you have to write a training loop or adapt <a href="https://github.com/rwightman/pytorch-image-models/blob/master/train.py">timm's training
script</a> to use your dataset.</p>
<h2 id="how-do-i-train-this-model">How do I train this model?</h2>
<p>You can follow the <a href="https://rwightman.github.io/pytorch-image-models/scripts/">timm recipe scripts</a> for training a new model afresh.</p>
<h2 id="citation">Citation</h2>
<div class="highlight"><pre><span></span><code><span class="nc">@misc</span><span class="p">{</span><span class="nl">dosovitskiy2020image</span><span class="p">,</span>
<span class="na">title</span><span class="p">=</span><span class="s">{An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale}</span><span class="p">,</span>
<span class="na">author</span><span class="p">=</span><span class="s">{Alexey Dosovitskiy and Lucas Beyer and Alexander Kolesnikov and Dirk Weissenborn and Xiaohua Zhai and Thomas Unterthiner and Mostafa Dehghani and Matthias Minderer and Georg Heigold and Sylvain Gelly and Jakob Uszkoreit and Neil Houlsby}</span><span class="p">,</span>
<span class="na">year</span><span class="p">=</span><span class="s">{2020}</span><span class="p">,</span>
<span class="na">eprint</span><span class="p">=</span><span class="s">{2010.11929}</span><span class="p">,</span>
<span class="na">archivePrefix</span><span class="p">=</span><span class="s">{arXiv}</span><span class="p">,</span>
<span class="na">primaryClass</span><span class="p">=</span><span class="s">{cs.CV}</span>
<span class="p">}</span>
</code></pre></div>
<!--
Type: model-index
Collections:
- Name: Vision Transformer
Paper:
Title: 'An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale'
URL: https://paperswithcode.com/paper/an-image-is-worth-16x16-words-transformers-1
Models:
- Name: vit_base_patch16_224
In Collection: Vision Transformer
Metadata:
FLOPs: 67394605056
Parameters: 86570000
File Size: 346292833
Architecture:
- Attention Dropout
- Convolution
- Dense Connections
- Dropout
- GELU
- Layer Normalization
- Multi-Head Attention
- Scaled Dot-Product Attention
- Tanh Activation
Tasks:
- Image Classification
Training Techniques:
- Cosine Annealing
- Gradient Clipping
- SGD with Momentum
Training Data:
- ImageNet
- JFT-300M
Training Resources: TPUv3
ID: vit_base_patch16_224
LR: 0.0008
Epochs: 90
Dropout: 0.0
Crop Pct: '0.9'
Batch Size: 4096
Image Size: '224'
Warmup Steps: 10000
Weight Decay: 0.03
Interpolation: bicubic
Code: https://github.com/rwightman/pytorch-image-models/blob/5f9aff395c224492e9e44248b15f44b5cc095d9c/timm/models/vision_transformer.py#L503
Weights: https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth
Results:
- Task: Image Classification
Dataset: ImageNet
Metrics:
Top 1 Accuracy: 81.78%
Top 5 Accuracy: 96.13%
- Name: vit_base_patch16_384
In Collection: Vision Transformer
Metadata:
FLOPs: 49348245504
Parameters: 86860000
File Size: 347460194
Architecture:
- Attention Dropout
- Convolution
- Dense Connections
- Dropout
- GELU
- Layer Normalization
- Multi-Head Attention
- Scaled Dot-Product Attention
- Tanh Activation
Tasks:
- Image Classification
Training Techniques:
- Cosine Annealing
- Gradient Clipping
- SGD with Momentum
Training Data:
- ImageNet
- JFT-300M
Training Resources: TPUv3
ID: vit_base_patch16_384
Crop Pct: '1.0'
Momentum: 0.9
Batch Size: 512
Image Size: '384'
Weight Decay: 0.0
Interpolation: bicubic
Code: https://github.com/rwightman/pytorch-image-models/blob/5f9aff395c224492e9e44248b15f44b5cc095d9c/timm/models/vision_transformer.py#L522
Weights: https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth
Results:
- Task: Image Classification
Dataset: ImageNet
Metrics:
Top 1 Accuracy: 84.2%
Top 5 Accuracy: 97.22%
- Name: vit_base_patch32_384
In Collection: Vision Transformer
Metadata:
FLOPs: 12656142336
Parameters: 88300000
File Size: 353210979
Architecture:
- Attention Dropout
- Convolution
- Dense Connections
- Dropout
- GELU
- Layer Normalization
- Multi-Head Attention
- Scaled Dot-Product Attention
- Tanh Activation
Tasks:
- Image Classification
Training Techniques:
- Cosine Annealing
- Gradient Clipping
- SGD with Momentum
Training Data:
- ImageNet
- JFT-300M
Training Resources: TPUv3
ID: vit_base_patch32_384
Crop Pct: '1.0'
Momentum: 0.9
Batch Size: 512
Image Size: '384'
Weight Decay: 0.0
Interpolation: bicubic
Code: https://github.com/rwightman/pytorch-image-models/blob/5f9aff395c224492e9e44248b15f44b5cc095d9c/timm/models/vision_transformer.py#L532
Weights: https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p32_384-830016f5.pth
Results:
- Task: Image Classification
Dataset: ImageNet
Metrics:
Top 1 Accuracy: 81.66%
Top 5 Accuracy: 96.13%
- Name: vit_base_resnet50_384
In Collection: Vision Transformer
Metadata:
FLOPs: 49461491712
Parameters: 98950000
File Size: 395854632
Architecture:
- Attention Dropout
- Convolution
- Dense Connections
- Dropout
- GELU
- Layer Normalization
- Multi-Head Attention
- Scaled Dot-Product Attention
- Tanh Activation
Tasks:
- Image Classification
Training Techniques:
- Cosine Annealing
- Gradient Clipping
- SGD with Momentum
Training Data:
- ImageNet
- JFT-300M
Training Resources: TPUv3
ID: vit_base_resnet50_384
Crop Pct: '1.0'
Momentum: 0.9
Batch Size: 512
Image Size: '384'
Weight Decay: 0.0
Interpolation: bicubic
Code: https://github.com/rwightman/pytorch-image-models/blob/5f9aff395c224492e9e44248b15f44b5cc095d9c/timm/models/vision_transformer.py#L653
Weights: https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth
Results:
- Task: Image Classification
Dataset: ImageNet
Metrics:
Top 1 Accuracy: 84.99%
Top 5 Accuracy: 97.3%
- Name: vit_large_patch16_224
In Collection: Vision Transformer
Metadata:
FLOPs: 119294746624
Parameters: 304330000
File Size: 1217350532
Architecture:
- Attention Dropout
- Convolution
- Dense Connections
- Dropout
- GELU
- Layer Normalization
- Multi-Head Attention
- Scaled Dot-Product Attention
- Tanh Activation
Tasks:
- Image Classification
Training Techniques:
- Cosine Annealing
- Gradient Clipping
- SGD with Momentum
Training Data:
- ImageNet
- JFT-300M
Training Resources: TPUv3
ID: vit_large_patch16_224
Crop Pct: '0.9'
Momentum: 0.9
Batch Size: 512
Image Size: '224'
Weight Decay: 0.0
Interpolation: bicubic
Code: https://github.com/rwightman/pytorch-image-models/blob/5f9aff395c224492e9e44248b15f44b5cc095d9c/timm/models/vision_transformer.py#L542
Weights: https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth
Results:
- Task: Image Classification
Dataset: ImageNet
Metrics:
Top 1 Accuracy: 83.06%
Top 5 Accuracy: 96.44%
- Name: vit_large_patch16_384
In Collection: Vision Transformer
Metadata:
FLOPs: 174702764032
Parameters: 304720000
File Size: 1218907013
Architecture:
- Attention Dropout
- Convolution
- Dense Connections
- Dropout
- GELU
- Layer Normalization
- Multi-Head Attention
- Scaled Dot-Product Attention
- Tanh Activation
Tasks:
- Image Classification
Training Techniques:
- Cosine Annealing
- Gradient Clipping
- SGD with Momentum
Training Data:
- ImageNet
- JFT-300M
Training Resources: TPUv3
ID: vit_large_patch16_384
Crop Pct: '1.0'
Momentum: 0.9
Batch Size: 512
Image Size: '384'
Weight Decay: 0.0
Interpolation: bicubic
Code: https://github.com/rwightman/pytorch-image-models/blob/5f9aff395c224492e9e44248b15f44b5cc095d9c/timm/models/vision_transformer.py#L561
Weights: https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth
Results:
- Task: Image Classification
Dataset: ImageNet
Metrics:
Top 1 Accuracy: 85.17%
Top 5 Accuracy: 97.36%
- Name: vit_small_patch16_224
In Collection: Vision Transformer
Metadata:
FLOPs: 28236450816
Parameters: 48750000
File Size: 195031454
Architecture:
- Attention Dropout
- Convolution
- Dense Connections
- Dropout
- GELU
- Layer Normalization
- Multi-Head Attention
- Scaled Dot-Product Attention
- Tanh Activation
Tasks:
- Image Classification
Training Techniques:
- Cosine Annealing
- Gradient Clipping
- SGD with Momentum
Training Data:
- ImageNet
- JFT-300M
Training Resources: TPUv3
ID: vit_small_patch16_224
Crop Pct: '0.9'
Image Size: '224'
Interpolation: bicubic
Code: https://github.com/rwightman/pytorch-image-models/blob/5f9aff395c224492e9e44248b15f44b5cc095d9c/timm/models/vision_transformer.py#L490
Weights: https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth
Results:
- Task: Image Classification
Dataset: ImageNet
Metrics:
Top 1 Accuracy: 77.85%
Top 5 Accuracy: 93.42%
-->
</article>
</div>
</div>
</main>
<footer class="md-footer">
<nav class="md-footer__inner md-grid" aria-label="Footer">
<a href="../tresnet/" class="md-footer__link md-footer__link--prev" rel="prev">
<div class="md-footer__button md-icon">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M20 11v2H8l5.5 5.5-1.42 1.42L4.16 12l7.92-7.92L13.5 5.5 8 11h12z"/></svg>
</div>
<div class="md-footer__title">
<div class="md-ellipsis">
<span class="md-footer__direction">
Previous
</span>
TResNet
</div>
</div>
</a>
<a href="../wide-resnet/" class="md-footer__link md-footer__link--next" rel="next">
<div class="md-footer__title">
<div class="md-ellipsis">
<span class="md-footer__direction">
Next
</span>
Wide ResNet
</div>
</div>
<div class="md-footer__button md-icon">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M4 11v2h12l-5.5 5.5 1.42 1.42L19.84 12l-7.92-7.92L10.5 5.5 16 11H4z"/></svg>
</div>
</a>
</nav>
<div class="md-footer-meta md-typeset">
<div class="md-footer-meta__inner md-grid">
<div class="md-footer-copyright">
Made with
<a href="https://squidfunk.github.io/mkdocs-material/" target="_blank" rel="noopener">
Material for MkDocs
</a>
</div>
</div>
</div>
</footer>
</div>
<div class="md-dialog" data-md-component="dialog">
<div class="md-dialog__inner md-typeset"></div>
</div>
<script id="__config" type="application/json">{"base": "../..", "features": [], "translations": {"clipboard.copy": "Copy to clipboard", "clipboard.copied": "Copied to clipboard", "search.config.lang": "en", "search.config.pipeline": "trimmer, stopWordFilter", "search.config.separator": "[\\s\\-]+", "search.placeholder": "Search", "search.result.placeholder": "Type to start searching", "search.result.none": "No matching documents", "search.result.one": "1 matching document", "search.result.other": "# matching documents", "search.result.more.one": "1 more on this page", "search.result.more.other": "# more on this page", "search.result.term.missing": "Missing"}, "search": "../../assets/javascripts/workers/search.fb4a9340.min.js", "version": null}</script>
<script src="../../assets/javascripts/bundle.a1c7c35e.min.js"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-MML-AM_CHTML"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/tablesort/5.2.1/tablesort.min.js"></script>
<script src="../../javascripts/tables.js"></script>
</body>
</html>