You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
57 lines
2.3 KiB
57 lines
2.3 KiB
4 years ago
|
""" Classifier head and layer factory
|
||
|
|
||
|
Hacked together by / Copyright 2020 Ross Wightman
|
||
|
"""
|
||
4 years ago
|
from torch import nn as nn
|
||
|
from torch.nn import functional as F
|
||
|
|
||
|
from .adaptive_avgmax_pool import SelectAdaptivePool2d
|
||
|
|
||
|
|
||
4 years ago
|
def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False):
|
||
|
flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling
|
||
4 years ago
|
if not pool_type:
|
||
|
assert num_classes == 0 or use_conv,\
|
||
|
'Pooling can only be disabled if classifier is also removed or conv classifier is used'
|
||
4 years ago
|
flatten_in_pool = False # disable flattening if pooling is pass-through (no pooling)
|
||
|
global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten_in_pool)
|
||
4 years ago
|
num_pooled_features = num_features * global_pool.feat_mult()
|
||
4 years ago
|
return global_pool, num_pooled_features
|
||
|
|
||
|
|
||
4 years ago
|
def _create_fc(num_features, num_classes, use_conv=False):
|
||
4 years ago
|
if num_classes <= 0:
|
||
|
fc = nn.Identity() # pass-through (no classifier)
|
||
|
elif use_conv:
|
||
4 years ago
|
fc = nn.Conv2d(num_features, num_classes, 1, bias=True)
|
||
4 years ago
|
else:
|
||
3 years ago
|
fc = nn.Linear(num_features, num_classes, bias=True)
|
||
4 years ago
|
return fc
|
||
|
|
||
|
|
||
|
def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False):
|
||
|
global_pool, num_pooled_features = _create_pool(num_features, num_classes, pool_type, use_conv=use_conv)
|
||
|
fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
|
||
4 years ago
|
return global_pool, fc
|
||
|
|
||
|
|
||
4 years ago
|
class ClassifierHead(nn.Module):
|
||
4 years ago
|
"""Classifier head w/ configurable global pooling and dropout."""
|
||
4 years ago
|
|
||
4 years ago
|
def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0., use_conv=False):
|
||
4 years ago
|
super(ClassifierHead, self).__init__()
|
||
|
self.drop_rate = drop_rate
|
||
4 years ago
|
self.global_pool, num_pooled_features = _create_pool(in_chs, num_classes, pool_type, use_conv=use_conv)
|
||
|
self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
|
||
4 years ago
|
self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity()
|
||
4 years ago
|
|
||
3 years ago
|
def forward(self, x, pre_logits: bool = False):
|
||
4 years ago
|
x = self.global_pool(x)
|
||
4 years ago
|
if self.drop_rate:
|
||
|
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
|
||
3 years ago
|
if pre_logits:
|
||
|
return x.flatten(1)
|
||
|
else:
|
||
|
x = self.fc(x)
|
||
|
return self.flatten(x)
|