<!doctype html>
< html lang = "en" class = "no-js" >
< head >
< meta charset = "utf-8" >
< meta name = "viewport" content = "width=device-width,initial-scale=1" >
< meta name = "description" content = "Pretained Image Recognition Models" >
< link rel = "icon" href = "../assets/images/favicon.png" >
< meta name = "generator" content = "mkdocs-1.1.2, mkdocs-material-7.0.6" >
< title > Feature Extraction - Pytorch Image Models< / title >
< link rel = "stylesheet" href = "../assets/stylesheets/main.2c0c5eaf.min.css" >
< link rel = "stylesheet" href = "../assets/stylesheets/palette.7fa14f5b.min.css" >
< link rel = "preconnect" href = "https://fonts.gstatic.com" crossorigin >
< link rel = "stylesheet" href = "https://fonts.googleapis.com/css?family=Roboto:300,400,400i,700%7CRoboto+Mono&display=fallback" >
< style > : root { --md-text-font-family : "Roboto" ; --md-code-font-family : "Roboto Mono" } < / style >
< / head >
< body dir = "ltr" data-md-color-scheme = "" data-md-color-primary = "none" data-md-color-accent = "none" >
< input class = "md-toggle" data-md-toggle = "drawer" type = "checkbox" id = "__drawer" autocomplete = "off" >
< input class = "md-toggle" data-md-toggle = "search" type = "checkbox" id = "__search" autocomplete = "off" >
< label class = "md-overlay" for = "__drawer" > < / label >
< div data-md-component = "skip" >
< a href = "#feature-extraction" class = "md-skip" >
Skip to content
< / a >
< / div >
< div data-md-component = "announce" >
< / div >
< header class = "md-header" data-md-component = "header" >
< nav class = "md-header__inner md-grid" aria-label = "Header" >
< a href = ".." title = "Pytorch Image Models" class = "md-header__button md-logo" aria-label = "Pytorch Image Models" data-md-component = "logo" >
< svg xmlns = "http://www.w3.org/2000/svg" viewBox = "0 0 24 24" > < path d = "M12 8a3 3 0 0 0 3-3 3 3 0 0 0-3-3 3 3 0 0 0-3 3 3 3 0 0 0 3 3m0 3.54C9.64 9.35 6.5 8 3 8v11c3.5 0 6.64 1.35 9 3.54 2.36-2.19 5.5-3.54 9-3.54V8c-3.5 0-6.64 1.35-9 3.54z" / > < / svg >
< / a >
< label class = "md-header__button md-icon" for = "__drawer" >
< svg xmlns = "http://www.w3.org/2000/svg" viewBox = "0 0 24 24" > < path d = "M3 6h18v2H3V6m0 5h18v2H3v-2m0 5h18v2H3v-2z" / > < / svg >
< / label >
< div class = "md-header__title" data-md-component = "header-title" >
< div class = "md-header__ellipsis" >
< div class = "md-header__topic" >
< span class = "md-ellipsis" >
Pytorch Image Models
< / span >
< / div >
< div class = "md-header__topic" data-md-component = "header-topic" >
< span class = "md-ellipsis" >
Feature Extraction
< / span >
< / div >
< / div >
< / div >
< div class = "md-header__options" >
< / div >
< label class = "md-header__button md-icon" for = "__search" >
< svg xmlns = "http://www.w3.org/2000/svg" viewBox = "0 0 24 24" > < path d = "M9.5 3A6.5 6.5 0 0 1 16 9.5c0 1.61-.59 3.09-1.56 4.23l.27.27h.79l5 5-1.5 1.5-5-5v-.79l-.27-.27A6.516 6.516 0 0 1 9.5 16 6.5 6.5 0 0 1 3 9.5 6.5 6.5 0 0 1 9.5 3m0 2C7 5 5 7 5 9.5S7 14 9.5 14 14 12 14 9.5 12 5 9.5 5z" / > < / svg >
< / label >
< div class = "md-search" data-md-component = "search" role = "dialog" >
< label class = "md-search__overlay" for = "__search" > < / label >
< div class = "md-search__inner" role = "search" >
< form class = "md-search__form" name = "search" >
< input type = "text" class = "md-search__input" name = "query" aria-label = "Search" placeholder = "Search" autocapitalize = "off" autocorrect = "off" autocomplete = "off" spellcheck = "false" data-md-component = "search-query" data-md-state = "active" required >
< label class = "md-search__icon md-icon" for = "__search" >
< svg xmlns = "http://www.w3.org/2000/svg" viewBox = "0 0 24 24" > < path d = "M9.5 3A6.5 6.5 0 0 1 16 9.5c0 1.61-.59 3.09-1.56 4.23l.27.27h.79l5 5-1.5 1.5-5-5v-.79l-.27-.27A6.516 6.516 0 0 1 9.5 16 6.5 6.5 0 0 1 3 9.5 6.5 6.5 0 0 1 9.5 3m0 2C7 5 5 7 5 9.5S7 14 9.5 14 14 12 14 9.5 12 5 9.5 5z" / > < / svg >
< svg xmlns = "http://www.w3.org/2000/svg" viewBox = "0 0 24 24" > < path d = "M20 11v2H8l5.5 5.5-1.42 1.42L4.16 12l7.92-7.92L13.5 5.5 8 11h12z" / > < / svg >
< / label >
< button type = "reset" class = "md-search__icon md-icon" aria-label = "Clear" tabindex = "-1" >
< svg xmlns = "http://www.w3.org/2000/svg" viewBox = "0 0 24 24" > < path d = "M19 6.41L17.59 5 12 10.59 6.41 5 5 6.41 10.59 12 5 17.59 6.41 19 12 13.41 17.59 19 19 17.59 13.41 12 19 6.41z" / > < / svg >
< / button >
< / form >
< div class = "md-search__output" >
< div class = "md-search__scrollwrap" data-md-scrollfix >
< div class = "md-search-result" data-md-component = "search-result" >
< div class = "md-search-result__meta" >
Initializing search
< / div >
< ol class = "md-search-result__list" > < / ol >
< / div >
< / div >
< / div >
< / div >
< / div >
< div class = "md-header__source" >
< a href = "https://github.com/rwightman/pytorch-image-models/" title = "Go to repository" class = "md-source" data-md-component = "source" >
< div class = "md-source__icon md-icon" >
< svg xmlns = "http://www.w3.org/2000/svg" viewBox = "0 0 448 512" > < path d = "M439.55 236.05L244 40.45a28.87 28.87 0 0 0-40.81 0l-40.66 40.63 51.52 51.52c27.06-9.14 52.68 16.77 43.39 43.68l49.66 49.66c34.23-11.8 61.18 31 35.47 56.69-26.49 26.49-70.21-2.87-56-37.34L240.22 199v121.85c25.3 12.54 22.26 41.85 9.08 55a34.34 34.34 0 0 1-48.55 0c-17.57-17.6-11.07-46.91 11.25-56v-123c-20.8-8.51-24.6-30.74-18.64-45L142.57 101 8.45 235.14a28.86 28.86 0 0 0 0 40.81l195.61 195.6a28.86 28.86 0 0 0 40.8 0l194.69-194.69a28.86 28.86 0 0 0 0-40.81z" / > < / svg >
< / div >
< div class = "md-source__repository" >
rwightman/pytorch-image-models
< / div >
< / a >
< / div >
< / nav >
< / header >
< div class = "md-container" data-md-component = "container" >
< main class = "md-main" data-md-component = "main" >
< div class = "md-main__inner md-grid" >
< div class = "md-sidebar md-sidebar--primary" data-md-component = "sidebar" data-md-type = "navigation" >
< div class = "md-sidebar__scrollwrap" >
< div class = "md-sidebar__inner" >
< nav class = "md-nav md-nav--primary" aria-label = "Navigation" data-md-level = "0" >
< label class = "md-nav__title" for = "__drawer" >
< a href = ".." title = "Pytorch Image Models" class = "md-nav__button md-logo" aria-label = "Pytorch Image Models" data-md-component = "logo" >
< svg xmlns = "http://www.w3.org/2000/svg" viewBox = "0 0 24 24" > < path d = "M12 8a3 3 0 0 0 3-3 3 3 0 0 0-3-3 3 3 0 0 0-3 3 3 3 0 0 0 3 3m0 3.54C9.64 9.35 6.5 8 3 8v11c3.5 0 6.64 1.35 9 3.54 2.36-2.19 5.5-3.54 9-3.54V8c-3.5 0-6.64 1.35-9 3.54z" / > < / svg >
< / a >
Pytorch Image Models
< / label >
< div class = "md-nav__source" >
< a href = "https://github.com/rwightman/pytorch-image-models/" title = "Go to repository" class = "md-source" data-md-component = "source" >
< div class = "md-source__icon md-icon" >
< svg xmlns = "http://www.w3.org/2000/svg" viewBox = "0 0 448 512" > < path d = "M439.55 236.05L244 40.45a28.87 28.87 0 0 0-40.81 0l-40.66 40.63 51.52 51.52c27.06-9.14 52.68 16.77 43.39 43.68l49.66 49.66c34.23-11.8 61.18 31 35.47 56.69-26.49 26.49-70.21-2.87-56-37.34L240.22 199v121.85c25.3 12.54 22.26 41.85 9.08 55a34.34 34.34 0 0 1-48.55 0c-17.57-17.6-11.07-46.91 11.25-56v-123c-20.8-8.51-24.6-30.74-18.64-45L142.57 101 8.45 235.14a28.86 28.86 0 0 0 0 40.81l195.61 195.6a28.86 28.86 0 0 0 40.8 0l194.69-194.69a28.86 28.86 0 0 0 0-40.81z" / > < / svg >
< / div >
< div class = "md-source__repository" >
rwightman/pytorch-image-models
< / div >
< / a >
< / div >
< ul class = "md-nav__list" data-md-scrollfix >
< li class = "md-nav__item" >
< a href = ".." class = "md-nav__link" >
Getting Started
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/" class = "md-nav__link" >
Model Architectures
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../results/" class = "md-nav__link" >
Results
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../scripts/" class = "md-nav__link" >
Scripts
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../training_hparam_examples/" class = "md-nav__link" >
Training Examples
< / a >
< / li >
< li class = "md-nav__item md-nav__item--active" >
< input class = "md-nav__toggle md-toggle" data-md-toggle = "toc" type = "checkbox" id = "__toc" >
< label class = "md-nav__link md-nav__link--active" for = "__toc" >
Feature Extraction
< span class = "md-nav__icon md-icon" > < / span >
< / label >
< a href = "./" class = "md-nav__link md-nav__link--active" >
Feature Extraction
< / a >
< nav class = "md-nav md-nav--secondary" aria-label = "Table of contents" >
< label class = "md-nav__title" for = "__toc" >
< span class = "md-nav__icon md-icon" > < / span >
Table of contents
< / label >
< ul class = "md-nav__list" data-md-component = "toc" data-md-scrollfix >
< li class = "md-nav__item" >
< a href = "#penultimate-layer-features-pre-classifier-features" class = "md-nav__link" >
Penultimate Layer Features (Pre-Classifier Features)
< / a >
< nav class = "md-nav" aria-label = "Penultimate Layer Features (Pre-Classifier Features)" >
< ul class = "md-nav__list" >
< li class = "md-nav__item" >
< a href = "#unpooled" class = "md-nav__link" >
Unpooled
< / a >
< nav class = "md-nav" aria-label = "Unpooled" >
< ul class = "md-nav__list" >
< li class = "md-nav__item" >
< a href = "#forward_features" class = "md-nav__link" >
forward_features()
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "#create-with-no-classifier-and-pooling" class = "md-nav__link" >
Create with no classifier and pooling
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "#remove-it-later" class = "md-nav__link" >
Remove it later
< / a >
< / li >
< / ul >
< / nav >
< / li >
< li class = "md-nav__item" >
< a href = "#pooled" class = "md-nav__link" >
Pooled
< / a >
< nav class = "md-nav" aria-label = "Pooled" >
< ul class = "md-nav__list" >
< li class = "md-nav__item" >
< a href = "#create-with-no-classifier" class = "md-nav__link" >
Create with no classifier
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "#remove-it-later_1" class = "md-nav__link" >
Remove it later
< / a >
< / li >
< / ul >
< / nav >
< / li >
< / ul >
< / nav >
< / li >
< li class = "md-nav__item" >
< a href = "#multi-scale-feature-maps-feature-pyramid" class = "md-nav__link" >
Multi-scale Feature Maps (Feature Pyramid)
< / a >
< nav class = "md-nav" aria-label = "Multi-scale Feature Maps (Feature Pyramid)" >
< ul class = "md-nav__list" >
< li class = "md-nav__item" >
< a href = "#create-a-feature-map-extraction-model" class = "md-nav__link" >
Create a feature map extraction model
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "#query-the-feature-information" class = "md-nav__link" >
Query the feature information
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "#select-specific-feature-levels-or-limit-the-stride" class = "md-nav__link" >
Select specific feature levels or limit the stride
< / a >
< / li >
< / ul >
< / nav >
< / li >
< / ul >
< / nav >
< / li >
< li class = "md-nav__item" >
< a href = "../changes/" class = "md-nav__link" >
Recent Changes
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../archived_changes/" class = "md-nav__link" >
Archived Changes
< / a >
< / li >
< li class = "md-nav__item md-nav__item--nested" >
< input class = "md-nav__toggle md-toggle" data-md-toggle = "__nav_9" type = "checkbox" id = "__nav_9" >
< label class = "md-nav__link" for = "__nav_9" >
Models
< span class = "md-nav__icon md-icon" > < / span >
< / label >
< nav class = "md-nav" aria-label = "Models" data-md-level = "1" >
< label class = "md-nav__title" for = "__nav_9" >
< span class = "md-nav__icon md-icon" > < / span >
Models
< / label >
< ul class = "md-nav__list" data-md-scrollfix >
< li class = "md-nav__item" >
< a href = "../models/adversarial-inception-v3/" class = "md-nav__link" >
Adversarial Inception v3
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/advprop/" class = "md-nav__link" >
AdvProp (EfficientNet)
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/big-transfer/" class = "md-nav__link" >
Big Transfer (BiT)
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/csp-darknet/" class = "md-nav__link" >
CSP-DarkNet
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/csp-resnet/" class = "md-nav__link" >
CSP-ResNet
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/csp-resnext/" class = "md-nav__link" >
CSP-ResNeXt
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/densenet/" class = "md-nav__link" >
DenseNet
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/dla/" class = "md-nav__link" >
Deep Layer Aggregation
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/dpn/" class = "md-nav__link" >
Dual Path Network (DPN)
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/ecaresnet/" class = "md-nav__link" >
ECA-ResNet
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/efficientnet-pruned/" class = "md-nav__link" >
EfficientNet (Knapsack Pruned)
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/efficientnet/" class = "md-nav__link" >
EfficientNet
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/ensemble-adversarial/" class = "md-nav__link" >
Ensemble Adversarial Inception ResNet v2
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/ese-vovnet/" class = "md-nav__link" >
ESE-VoVNet
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/fbnet/" class = "md-nav__link" >
FBNet
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/gloun-inception-v3/" class = "md-nav__link" >
(Gluon) Inception v3
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/gloun-resnet/" class = "md-nav__link" >
(Gluon) ResNet
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/gloun-resnext/" class = "md-nav__link" >
(Gluon) ResNeXt
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/gloun-senet/" class = "md-nav__link" >
(Gluon) SENet
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/gloun-seresnext/" class = "md-nav__link" >
(Gluon) SE-ResNeXt
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/gloun-xception/" class = "md-nav__link" >
(Gluon) Xception
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/hrnet/" class = "md-nav__link" >
HRNet
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/ig-resnext/" class = "md-nav__link" >
Instagram ResNeXt WSL
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/inception-resnet-v2/" class = "md-nav__link" >
Inception ResNet v2
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/inception-v3/" class = "md-nav__link" >
Inception v3
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/inception-v4/" class = "md-nav__link" >
Inception v4
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/legacy-se-resnet/" class = "md-nav__link" >
(Legacy) SE-ResNet
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/legacy-se-resnext/" class = "md-nav__link" >
(Legacy) SE-ResNeXt
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/legacy-senet/" class = "md-nav__link" >
(Legacy) SENet
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/mixnet/" class = "md-nav__link" >
MixNet
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/mnasnet/" class = "md-nav__link" >
MnasNet
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/mobilenet-v2/" class = "md-nav__link" >
MobileNet v2
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/mobilenet-v3/" class = "md-nav__link" >
MobileNet v3
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/nasnet/" class = "md-nav__link" >
NASNet
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/noisy-student/" class = "md-nav__link" >
Noisy Student (EfficientNet)
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/pnasnet/" class = "md-nav__link" >
PNASNet
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/regnetx/" class = "md-nav__link" >
RegNetX
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/regnety/" class = "md-nav__link" >
RegNetY
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/res2net/" class = "md-nav__link" >
Res2Net
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/res2next/" class = "md-nav__link" >
Res2NeXt
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/resnest/" class = "md-nav__link" >
ResNeSt
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/resnet-d/" class = "md-nav__link" >
ResNet-D
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/resnet/" class = "md-nav__link" >
ResNet
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/resnext/" class = "md-nav__link" >
ResNeXt
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/rexnet/" class = "md-nav__link" >
RexNet
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/se-resnet/" class = "md-nav__link" >
SE-ResNet
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/selecsls/" class = "md-nav__link" >
SelecSLS
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/seresnext/" class = "md-nav__link" >
SE-ResNeXt
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/skresnet/" class = "md-nav__link" >
SK-ResNet
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/skresnext/" class = "md-nav__link" >
SK-ResNeXt
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/spnasnet/" class = "md-nav__link" >
SPNASNet
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/ssl-resnet/" class = "md-nav__link" >
SSL ResNet
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/ssl-resnext/" class = "md-nav__link" >
SSL ResNeXT
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/swsl-resnet/" class = "md-nav__link" >
SWSL ResNet
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/swsl-resnext/" class = "md-nav__link" >
SWSL ResNeXt
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/tf-efficientnet-condconv/" class = "md-nav__link" >
(Tensorflow) EfficientNet CondConv
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/tf-efficientnet-lite/" class = "md-nav__link" >
(Tensorflow) EfficientNet Lite
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/tf-efficientnet/" class = "md-nav__link" >
(Tensorflow) EfficientNet
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/tf-inception-v3/" class = "md-nav__link" >
(Tensorflow) Inception v3
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/tf-mixnet/" class = "md-nav__link" >
(Tensorflow) MixNet
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/tf-mobilenet-v3/" class = "md-nav__link" >
(Tensorflow) MobileNet v3
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/tresnet/" class = "md-nav__link" >
TResNet
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/vision-transformer/" class = "md-nav__link" >
Vision Transformer (ViT)
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/wide-resnet/" class = "md-nav__link" >
Wide ResNet
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "../models/xception/" class = "md-nav__link" >
Xception
< / a >
< / li >
< / ul >
< / nav >
< / li >
< / ul >
< / nav >
< / div >
< / div >
< / div >
< div class = "md-sidebar md-sidebar--secondary" data-md-component = "sidebar" data-md-type = "toc" >
< div class = "md-sidebar__scrollwrap" >
< div class = "md-sidebar__inner" >
< nav class = "md-nav md-nav--secondary" aria-label = "Table of contents" >
< label class = "md-nav__title" for = "__toc" >
< span class = "md-nav__icon md-icon" > < / span >
Table of contents
< / label >
< ul class = "md-nav__list" data-md-component = "toc" data-md-scrollfix >
< li class = "md-nav__item" >
< a href = "#penultimate-layer-features-pre-classifier-features" class = "md-nav__link" >
Penultimate Layer Features (Pre-Classifier Features)
< / a >
< nav class = "md-nav" aria-label = "Penultimate Layer Features (Pre-Classifier Features)" >
< ul class = "md-nav__list" >
< li class = "md-nav__item" >
< a href = "#unpooled" class = "md-nav__link" >
Unpooled
< / a >
< nav class = "md-nav" aria-label = "Unpooled" >
< ul class = "md-nav__list" >
< li class = "md-nav__item" >
< a href = "#forward_features" class = "md-nav__link" >
forward_features()
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "#create-with-no-classifier-and-pooling" class = "md-nav__link" >
Create with no classifier and pooling
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "#remove-it-later" class = "md-nav__link" >
Remove it later
< / a >
< / li >
< / ul >
< / nav >
< / li >
< li class = "md-nav__item" >
< a href = "#pooled" class = "md-nav__link" >
Pooled
< / a >
< nav class = "md-nav" aria-label = "Pooled" >
< ul class = "md-nav__list" >
< li class = "md-nav__item" >
< a href = "#create-with-no-classifier" class = "md-nav__link" >
Create with no classifier
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "#remove-it-later_1" class = "md-nav__link" >
Remove it later
< / a >
< / li >
< / ul >
< / nav >
< / li >
< / ul >
< / nav >
< / li >
< li class = "md-nav__item" >
< a href = "#multi-scale-feature-maps-feature-pyramid" class = "md-nav__link" >
Multi-scale Feature Maps (Feature Pyramid)
< / a >
< nav class = "md-nav" aria-label = "Multi-scale Feature Maps (Feature Pyramid)" >
< ul class = "md-nav__list" >
< li class = "md-nav__item" >
< a href = "#create-a-feature-map-extraction-model" class = "md-nav__link" >
Create a feature map extraction model
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "#query-the-feature-information" class = "md-nav__link" >
Query the feature information
< / a >
< / li >
< li class = "md-nav__item" >
< a href = "#select-specific-feature-levels-or-limit-the-stride" class = "md-nav__link" >
Select specific feature levels or limit the stride
< / a >
< / li >
< / ul >
< / nav >
< / li >
< / ul >
< / nav >
< / div >
< / div >
< / div >
< div class = "md-content" data-md-component = "content" >
< article class = "md-content__inner md-typeset" >
< a href = "https://github.com/rwightman/pytorch-image-models/edit/master/docs/feature_extraction.md" title = "Edit this page" class = "md-content__button md-icon" >
< svg xmlns = "http://www.w3.org/2000/svg" viewBox = "0 0 24 24" > < path d = "M20.71 7.04c.39-.39.39-1.04 0-1.41l-2.34-2.34c-.37-.39-1.02-.39-1.41 0l-1.84 1.83 3.75 3.75M3 17.25V21h3.75L17.81 9.93l-3.75-3.75L3 17.25z" / > < / svg >
< / a >
< h1 id = "feature-extraction" > Feature Extraction< / h1 >
< p > All of the models in < code > timm< / code > have consistent mechanisms for obtaining various types of features from the model for tasks besides classification.< / p >
< h2 id = "penultimate-layer-features-pre-classifier-features" > Penultimate Layer Features (Pre-Classifier Features)< / h2 >
< p > The features from the penultimate model layer can be obtained in several ways without requiring model surgery (although feel free to do surgery). One must first decide if they want pooled or un-pooled features.< / p >
< h3 id = "unpooled" > Unpooled< / h3 >
< p > There are three ways to obtain unpooled features.< / p >
< p > Without modifying the network, one can call < code > model.forward_features(input)< / code > on any model instead of the usual < code > model(input)< / code > . This will bypass the head classifier and global pooling for networks.< / p >
< p > If one wants to explicitly modify the network to return unpooled features, they can either create the model without a classifier and pooling, or remove it later. Both paths remove the parameters associated with the classifier from the network.< / p >
< h4 id = "forward_features" > forward_features()< / h4 >
< p > < div class = "highlight" > < pre > < span > < / span > < code > < span class = "kn" > import< / span > < span class = "nn" > torch< / span >
< span class = "kn" > import< / span > < span class = "nn" > timm< / span >
< span class = "hll" > < span class = "n" > m< / span > < span class = "o" > =< / span > < span class = "n" > timm< / span > < span class = "o" > .< / span > < span class = "n" > create_model< / span > < span class = "p" > (< / span > < span class = "s1" > ' xception41' < / span > < span class = "p" > ,< / span > < span class = "n" > pretrained< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > )< / span >
< / span > < span class = "n" > o< / span > < span class = "o" > =< / span > < span class = "n" > m< / span > < span class = "p" > (< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > randn< / span > < span class = "p" > (< / span > < span class = "mi" > 2< / span > < span class = "p" > ,< / span > < span class = "mi" > 3< / span > < span class = "p" > ,< / span > < span class = "mi" > 299< / span > < span class = "p" > ,< / span > < span class = "mi" > 299< / span > < span class = "p" > ))< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "sa" > f< / span > < span class = "s1" > ' Original shape: < / span > < span class = "si" > {< / span > < span class = "n" > o< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "si" > }< / span > < span class = "s1" > ' < / span > < span class = "p" > )< / span >
< span class = "hll" > < span class = "n" > o< / span > < span class = "o" > =< / span > < span class = "n" > m< / span > < span class = "o" > .< / span > < span class = "n" > forward_features< / span > < span class = "p" > (< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > randn< / span > < span class = "p" > (< / span > < span class = "mi" > 2< / span > < span class = "p" > ,< / span > < span class = "mi" > 3< / span > < span class = "p" > ,< / span > < span class = "mi" > 299< / span > < span class = "p" > ,< / span > < span class = "mi" > 299< / span > < span class = "p" > ))< / span >
< / span > < span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "sa" > f< / span > < span class = "s1" > ' Unpooled shape: < / span > < span class = "si" > {< / span > < span class = "n" > o< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "si" > }< / span > < span class = "s1" > ' < / span > < span class = "p" > )< / span >
< / code > < / pre > < / div >
Output:
< div class = "highlight" > < pre > < span > < / span > < code > Original shape: torch.Size([2, 1000])
Unpooled shape: torch.Size([2, 2048, 10, 10])
< / code > < / pre > < / div > < / p >
< h4 id = "create-with-no-classifier-and-pooling" > Create with no classifier and pooling< / h4 >
< p > < div class = "highlight" > < pre > < span > < / span > < code > < span class = "kn" > import< / span > < span class = "nn" > torch< / span >
< span class = "kn" > import< / span > < span class = "nn" > timm< / span >
< span class = "hll" > < span class = "n" > m< / span > < span class = "o" > =< / span > < span class = "n" > timm< / span > < span class = "o" > .< / span > < span class = "n" > create_model< / span > < span class = "p" > (< / span > < span class = "s1" > ' resnet50' < / span > < span class = "p" > ,< / span > < span class = "n" > pretrained< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > ,< / span > < span class = "n" > num_classes< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > global_pool< / span > < span class = "o" > =< / span > < span class = "s1" > ' ' < / span > < span class = "p" > )< / span >
< / span > < span class = "n" > o< / span > < span class = "o" > =< / span > < span class = "n" > m< / span > < span class = "p" > (< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > randn< / span > < span class = "p" > (< / span > < span class = "mi" > 2< / span > < span class = "p" > ,< / span > < span class = "mi" > 3< / span > < span class = "p" > ,< / span > < span class = "mi" > 224< / span > < span class = "p" > ,< / span > < span class = "mi" > 224< / span > < span class = "p" > ))< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "sa" > f< / span > < span class = "s1" > ' Unpooled shape: < / span > < span class = "si" > {< / span > < span class = "n" > o< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "si" > }< / span > < span class = "s1" > ' < / span > < span class = "p" > )< / span >
< / code > < / pre > < / div >
Output:
< div class = "highlight" > < pre > < span > < / span > < code > Unpooled shape: torch.Size([2, 2048, 7, 7])
< / code > < / pre > < / div > < / p >
< h4 id = "remove-it-later" > Remove it later< / h4 >
< p > < div class = "highlight" > < pre > < span > < / span > < code > < span class = "kn" > import< / span > < span class = "nn" > torch< / span >
< span class = "kn" > import< / span > < span class = "nn" > timm< / span >
< span class = "hll" > < span class = "n" > m< / span > < span class = "o" > =< / span > < span class = "n" > timm< / span > < span class = "o" > .< / span > < span class = "n" > create_model< / span > < span class = "p" > (< / span > < span class = "s1" > ' densenet121' < / span > < span class = "p" > ,< / span > < span class = "n" > pretrained< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > )< / span >
< / span > < span class = "n" > o< / span > < span class = "o" > =< / span > < span class = "n" > m< / span > < span class = "p" > (< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > randn< / span > < span class = "p" > (< / span > < span class = "mi" > 2< / span > < span class = "p" > ,< / span > < span class = "mi" > 3< / span > < span class = "p" > ,< / span > < span class = "mi" > 224< / span > < span class = "p" > ,< / span > < span class = "mi" > 224< / span > < span class = "p" > ))< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "sa" > f< / span > < span class = "s1" > ' Original shape: < / span > < span class = "si" > {< / span > < span class = "n" > o< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "si" > }< / span > < span class = "s1" > ' < / span > < span class = "p" > )< / span >
< span class = "hll" > < span class = "n" > m< / span > < span class = "o" > .< / span > < span class = "n" > reset_classifier< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "s1" > ' ' < / span > < span class = "p" > )< / span >
< / span > < span class = "n" > o< / span > < span class = "o" > =< / span > < span class = "n" > m< / span > < span class = "p" > (< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > randn< / span > < span class = "p" > (< / span > < span class = "mi" > 2< / span > < span class = "p" > ,< / span > < span class = "mi" > 3< / span > < span class = "p" > ,< / span > < span class = "mi" > 224< / span > < span class = "p" > ,< / span > < span class = "mi" > 224< / span > < span class = "p" > ))< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "sa" > f< / span > < span class = "s1" > ' Unpooled shape: < / span > < span class = "si" > {< / span > < span class = "n" > o< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "si" > }< / span > < span class = "s1" > ' < / span > < span class = "p" > )< / span >
< / code > < / pre > < / div >
Output:
< div class = "highlight" > < pre > < span > < / span > < code > Original shape: torch.Size([2, 1000])
Unpooled shape: torch.Size([2, 1024, 7, 7])
< / code > < / pre > < / div > < / p >
< h3 id = "pooled" > Pooled< / h3 >
< p > To modify the network to return pooled features, one can use < code > forward_features()< / code > and pool/flatten the result themselves, or modify the network like above but keep pooling intact. < / p >
< h4 id = "create-with-no-classifier" > Create with no classifier< / h4 >
< p > < div class = "highlight" > < pre > < span > < / span > < code > < span class = "kn" > import< / span > < span class = "nn" > torch< / span >
< span class = "kn" > import< / span > < span class = "nn" > timm< / span >
< span class = "hll" > < span class = "n" > m< / span > < span class = "o" > =< / span > < span class = "n" > timm< / span > < span class = "o" > .< / span > < span class = "n" > create_model< / span > < span class = "p" > (< / span > < span class = "s1" > ' resnet50' < / span > < span class = "p" > ,< / span > < span class = "n" > pretrained< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > ,< / span > < span class = "n" > num_classes< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span >
< / span > < span class = "n" > o< / span > < span class = "o" > =< / span > < span class = "n" > m< / span > < span class = "p" > (< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > randn< / span > < span class = "p" > (< / span > < span class = "mi" > 2< / span > < span class = "p" > ,< / span > < span class = "mi" > 3< / span > < span class = "p" > ,< / span > < span class = "mi" > 224< / span > < span class = "p" > ,< / span > < span class = "mi" > 224< / span > < span class = "p" > ))< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "sa" > f< / span > < span class = "s1" > ' Pooled shape: < / span > < span class = "si" > {< / span > < span class = "n" > o< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "si" > }< / span > < span class = "s1" > ' < / span > < span class = "p" > )< / span >
< / code > < / pre > < / div >
Output:
< div class = "highlight" > < pre > < span > < / span > < code > Pooled shape: torch.Size([2, 2048])
< / code > < / pre > < / div > < / p >
< h4 id = "remove-it-later_1" > Remove it later< / h4 >
< p > < div class = "highlight" > < pre > < span > < / span > < code > < span class = "kn" > import< / span > < span class = "nn" > torch< / span >
< span class = "kn" > import< / span > < span class = "nn" > timm< / span >
< span class = "hll" > < span class = "n" > m< / span > < span class = "o" > =< / span > < span class = "n" > timm< / span > < span class = "o" > .< / span > < span class = "n" > create_model< / span > < span class = "p" > (< / span > < span class = "s1" > ' ese_vovnet19b_dw' < / span > < span class = "p" > ,< / span > < span class = "n" > pretrained< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > )< / span >
< / span > < span class = "n" > o< / span > < span class = "o" > =< / span > < span class = "n" > m< / span > < span class = "p" > (< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > randn< / span > < span class = "p" > (< / span > < span class = "mi" > 2< / span > < span class = "p" > ,< / span > < span class = "mi" > 3< / span > < span class = "p" > ,< / span > < span class = "mi" > 224< / span > < span class = "p" > ,< / span > < span class = "mi" > 224< / span > < span class = "p" > ))< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "sa" > f< / span > < span class = "s1" > ' Original shape: < / span > < span class = "si" > {< / span > < span class = "n" > o< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "si" > }< / span > < span class = "s1" > ' < / span > < span class = "p" > )< / span >
< span class = "hll" > < span class = "n" > m< / span > < span class = "o" > .< / span > < span class = "n" > reset_classifier< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span >
< / span > < span class = "n" > o< / span > < span class = "o" > =< / span > < span class = "n" > m< / span > < span class = "p" > (< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > randn< / span > < span class = "p" > (< / span > < span class = "mi" > 2< / span > < span class = "p" > ,< / span > < span class = "mi" > 3< / span > < span class = "p" > ,< / span > < span class = "mi" > 224< / span > < span class = "p" > ,< / span > < span class = "mi" > 224< / span > < span class = "p" > ))< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "sa" > f< / span > < span class = "s1" > ' Pooled shape: < / span > < span class = "si" > {< / span > < span class = "n" > o< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "si" > }< / span > < span class = "s1" > ' < / span > < span class = "p" > )< / span >
< / code > < / pre > < / div >
Output:
< div class = "highlight" > < pre > < span > < / span > < code > Pooled shape: torch.Size([2, 1024])
< / code > < / pre > < / div > < / p >
< h2 id = "multi-scale-feature-maps-feature-pyramid" > Multi-scale Feature Maps (Feature Pyramid)< / h2 >
< p > Object detection, segmentation, keypoint, and a variety of dense pixel tasks require access to feature maps from the backbone network at multiple scales. This is often done by modifying the original classification network. Since each network varies quite a bit in structure, it's not uncommon to see only a few backbones supported in any given obj detection or segmentation library.< / p >
< p > < code > timm< / code > allows a consistent interface for creating any of the included models as feature backbones that output feature maps for selected levels. < / p >
< p > A feature backbone can be created by adding the argument < code > features_only=True< / code > to any < code > create_model< / code > call. By default 5 strides will be output from most models (not all have that many), with the first starting at 2 (some start at 1 or 4).< / p >
< h3 id = "create-a-feature-map-extraction-model" > Create a feature map extraction model< / h3 >
< p > < div class = "highlight" > < pre > < span > < / span > < code > < span class = "kn" > import< / span > < span class = "nn" > torch< / span >
< span class = "kn" > import< / span > < span class = "nn" > timm< / span >
< span class = "hll" > < span class = "n" > m< / span > < span class = "o" > =< / span > < span class = "n" > timm< / span > < span class = "o" > .< / span > < span class = "n" > create_model< / span > < span class = "p" > (< / span > < span class = "s1" > ' resnest26d' < / span > < span class = "p" > ,< / span > < span class = "n" > features_only< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > ,< / span > < span class = "n" > pretrained< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > )< / span >
< / span > < span class = "n" > o< / span > < span class = "o" > =< / span > < span class = "n" > m< / span > < span class = "p" > (< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > randn< / span > < span class = "p" > (< / span > < span class = "mi" > 2< / span > < span class = "p" > ,< / span > < span class = "mi" > 3< / span > < span class = "p" > ,< / span > < span class = "mi" > 224< / span > < span class = "p" > ,< / span > < span class = "mi" > 224< / span > < span class = "p" > ))< / span >
< span class = "k" > for< / span > < span class = "n" > x< / span > < span class = "ow" > in< / span > < span class = "n" > o< / span > < span class = "p" > :< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "p" > )< / span >
< / code > < / pre > < / div >
Output:
< div class = "highlight" > < pre > < span > < / span > < code > torch.Size([2, 64, 112, 112])
torch.Size([2, 256, 56, 56])
torch.Size([2, 512, 28, 28])
torch.Size([2, 1024, 14, 14])
torch.Size([2, 2048, 7, 7])
< / code > < / pre > < / div > < / p >
< h3 id = "query-the-feature-information" > Query the feature information< / h3 >
< p > After a feature backbone has been created, it can be queried to provide channel or resolution reduction information to the downstream heads without requiring static config or hardcoded constants. The < code > .feature_info< / code > attribute is a class encapsulating the information about the feature extraction points.< / p >
< p > < div class = "highlight" > < pre > < span > < / span > < code > < span class = "kn" > import< / span > < span class = "nn" > torch< / span >
< span class = "kn" > import< / span > < span class = "nn" > timm< / span >
< span class = "hll" > < span class = "n" > m< / span > < span class = "o" > =< / span > < span class = "n" > timm< / span > < span class = "o" > .< / span > < span class = "n" > create_model< / span > < span class = "p" > (< / span > < span class = "s1" > ' regnety_032' < / span > < span class = "p" > ,< / span > < span class = "n" > features_only< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > ,< / span > < span class = "n" > pretrained< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > )< / span >
< / span > < span class = "hll" > < span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "sa" > f< / span > < span class = "s1" > ' Feature channels: < / span > < span class = "si" > {< / span > < span class = "n" > m< / span > < span class = "o" > .< / span > < span class = "n" > feature_info< / span > < span class = "o" > .< / span > < span class = "n" > channels< / span > < span class = "p" > ()< / span > < span class = "si" > }< / span > < span class = "s1" > ' < / span > < span class = "p" > )< / span >
< / span > < span class = "n" > o< / span > < span class = "o" > =< / span > < span class = "n" > m< / span > < span class = "p" > (< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > randn< / span > < span class = "p" > (< / span > < span class = "mi" > 2< / span > < span class = "p" > ,< / span > < span class = "mi" > 3< / span > < span class = "p" > ,< / span > < span class = "mi" > 224< / span > < span class = "p" > ,< / span > < span class = "mi" > 224< / span > < span class = "p" > ))< / span >
< span class = "k" > for< / span > < span class = "n" > x< / span > < span class = "ow" > in< / span > < span class = "n" > o< / span > < span class = "p" > :< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "p" > )< / span >
< / code > < / pre > < / div >
Output:
< div class = "highlight" > < pre > < span > < / span > < code > Feature channels: [32, 72, 216, 576, 1512]
torch.Size([2, 32, 112, 112])
torch.Size([2, 72, 56, 56])
torch.Size([2, 216, 28, 28])
torch.Size([2, 576, 14, 14])
torch.Size([2, 1512, 7, 7])
< / code > < / pre > < / div > < / p >
< h3 id = "select-specific-feature-levels-or-limit-the-stride" > Select specific feature levels or limit the stride< / h3 >
< p > There are to additional creation arguments impacting the output features. < / p >
< ul >
< li > < code > out_indices< / code > selects which indices to output< / li >
< li > < code > output_stride< / code > limits the feature output stride of the network (also works in classification mode BTW)< / li >
< / ul >
< p > < code > out_indices< / code > is supported by all models, but not all models have the same index to feature stride mapping. Look at the code or check feature_info to compare. The out indices generally correspond to the < code > C(i+1)th< / code > feature level (a < code > 2^(i+1)< / code > reduction). For most models, index 0 is the stride 2 features, and index 4 is stride 32.< / p >
< p > < code > output_stride< / code > is achieved by converting layers to use dilated convolutions. Doing so is not always straightforward, some networks only support < code > output_stride=32< / code > .< / p >
< p > < div class = "highlight" > < pre > < span > < / span > < code > < span class = "kn" > import< / span > < span class = "nn" > torch< / span >
< span class = "kn" > import< / span > < span class = "nn" > timm< / span >
< span class = "hll" > < span class = "n" > m< / span > < span class = "o" > =< / span > < span class = "n" > timm< / span > < span class = "o" > .< / span > < span class = "n" > create_model< / span > < span class = "p" > (< / span > < span class = "s1" > ' ecaresnet101d' < / span > < span class = "p" > ,< / span > < span class = "n" > features_only< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > ,< / span > < span class = "n" > output_stride< / span > < span class = "o" > =< / span > < span class = "mi" > 8< / span > < span class = "p" > ,< / span > < span class = "n" > out_indices< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "mi" > 2< / span > < span class = "p" > ,< / span > < span class = "mi" > 4< / span > < span class = "p" > ),< / span > < span class = "n" > pretrained< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > )< / span >
< / span > < span class = "hll" > < span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "sa" > f< / span > < span class = "s1" > ' Feature channels: < / span > < span class = "si" > {< / span > < span class = "n" > m< / span > < span class = "o" > .< / span > < span class = "n" > feature_info< / span > < span class = "o" > .< / span > < span class = "n" > channels< / span > < span class = "p" > ()< / span > < span class = "si" > }< / span > < span class = "s1" > ' < / span > < span class = "p" > )< / span >
< / span > < span class = "hll" > < span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "sa" > f< / span > < span class = "s1" > ' Feature reduction: < / span > < span class = "si" > {< / span > < span class = "n" > m< / span > < span class = "o" > .< / span > < span class = "n" > feature_info< / span > < span class = "o" > .< / span > < span class = "n" > reduction< / span > < span class = "p" > ()< / span > < span class = "si" > }< / span > < span class = "s1" > ' < / span > < span class = "p" > )< / span >
< / span > < span class = "n" > o< / span > < span class = "o" > =< / span > < span class = "n" > m< / span > < span class = "p" > (< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > randn< / span > < span class = "p" > (< / span > < span class = "mi" > 2< / span > < span class = "p" > ,< / span > < span class = "mi" > 3< / span > < span class = "p" > ,< / span > < span class = "mi" > 320< / span > < span class = "p" > ,< / span > < span class = "mi" > 320< / span > < span class = "p" > ))< / span >
< span class = "k" > for< / span > < span class = "n" > x< / span > < span class = "ow" > in< / span > < span class = "n" > o< / span > < span class = "p" > :< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span > < span class = "p" > )< / span >
< / code > < / pre > < / div >
Output:
< div class = "highlight" > < pre > < span > < / span > < code > Feature channels: [512, 2048]
Feature reduction: [8, 8]
torch.Size([2, 512, 40, 40])
torch.Size([2, 2048, 40, 40])
< / code > < / pre > < / div > < / p >
< / article >
< / div >
< / div >
< / main >
< footer class = "md-footer" >
< nav class = "md-footer__inner md-grid" aria-label = "Footer" >
< a href = "../training_hparam_examples/" class = "md-footer__link md-footer__link--prev" rel = "prev" >
< div class = "md-footer__button md-icon" >
< svg xmlns = "http://www.w3.org/2000/svg" viewBox = "0 0 24 24" > < path d = "M20 11v2H8l5.5 5.5-1.42 1.42L4.16 12l7.92-7.92L13.5 5.5 8 11h12z" / > < / svg >
< / div >
< div class = "md-footer__title" >
< div class = "md-ellipsis" >
< span class = "md-footer__direction" >
Previous
< / span >
Training Examples
< / div >
< / div >
< / a >
< a href = "../changes/" class = "md-footer__link md-footer__link--next" rel = "next" >
< div class = "md-footer__title" >
< div class = "md-ellipsis" >
< span class = "md-footer__direction" >
Next
< / span >
Recent Changes
< / div >
< / div >
< div class = "md-footer__button md-icon" >
< svg xmlns = "http://www.w3.org/2000/svg" viewBox = "0 0 24 24" > < path d = "M4 11v2h12l-5.5 5.5 1.42 1.42L19.84 12l-7.92-7.92L10.5 5.5 16 11H4z" / > < / svg >
< / div >
< / a >
< / nav >
< div class = "md-footer-meta md-typeset" >
< div class = "md-footer-meta__inner md-grid" >
< div class = "md-footer-copyright" >
Made with
< a href = "https://squidfunk.github.io/mkdocs-material/" target = "_blank" rel = "noopener" >
Material for MkDocs
< / a >
< / div >
< / div >
< / div >
< / footer >
< / div >
< div class = "md-dialog" data-md-component = "dialog" >
< div class = "md-dialog__inner md-typeset" > < / div >
< / div >
< script id = "__config" type = "application/json" > { "base" : ".." , "features" : [ ] , "translations" : { "clipboard.copy" : "Copy to clipboard" , "clipboard.copied" : "Copied to clipboard" , "search.config.lang" : "en" , "search.config.pipeline" : "trimmer, stopWordFilter" , "search.config.separator" : "[\\s\\-]+" , "search.placeholder" : "Search" , "search.result.placeholder" : "Type to start searching" , "search.result.none" : "No matching documents" , "search.result.one" : "1 matching document" , "search.result.other" : "# matching documents" , "search.result.more.one" : "1 more on this page" , "search.result.more.other" : "# more on this page" , "search.result.term.missing" : "Missing" } , "search" : "../assets/javascripts/workers/search.fb4a9340.min.js" , "version" : null } < / script >
< script src = "../assets/javascripts/bundle.a1c7c35e.min.js" > < / script >
< script src = "https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-MML-AM_CHTML" > < / script >
< script src = "https://cdnjs.cloudflare.com/ajax/libs/tablesort/5.2.1/tablesort.min.js" > < / script >
< script src = "../javascripts/tables.js" > < / script >
< / body >
< / html >