""" Classifier head and layer factory Hacked together by / Copyright 2020 Ross Wightman """ from torch import nn as nn from torch.nn import functional as F from .adaptive_avgmax_pool import SelectAdaptivePool2d def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False): flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling if not pool_type: assert num_classes == 0 or use_conv,\ 'Pooling can only be disabled if classifier is also removed or conv classifier is used' flatten_in_pool = False # disable flattening if pooling is pass-through (no pooling) global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten_in_pool) num_pooled_features = num_features * global_pool.feat_mult() return global_pool, num_pooled_features def _create_fc(num_features, num_classes, use_conv=False): if num_classes <= 0: fc = nn.Identity() # pass-through (no classifier) elif use_conv: fc = nn.Conv2d(num_features, num_classes, 1, bias=True) else: fc = nn.Linear(num_features, num_classes, bias=True) return fc def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False): global_pool, num_pooled_features = _create_pool(num_features, num_classes, pool_type, use_conv=use_conv) fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) return global_pool, fc class ClassifierHead(nn.Module): """Classifier head w/ configurable global pooling and dropout.""" def __init__(self, in_features, num_classes, pool_type='avg', drop_rate=0., use_conv=False): super(ClassifierHead, self).__init__() self.drop_rate = drop_rate self.in_features = in_features self.use_conv = use_conv self.global_pool, num_pooled_features = _create_pool(in_features, num_classes, pool_type, use_conv=use_conv) self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity() def reset(self, num_classes, global_pool=None): if global_pool is not None: if global_pool != self.global_pool.pool_type: self.global_pool, _ = _create_pool(self.in_features, num_classes, global_pool, use_conv=self.use_conv) self.flatten = nn.Flatten(1) if self.use_conv and global_pool else nn.Identity() num_pooled_features = self.in_features * self.global_pool.feat_mult() self.fc = _create_fc(num_pooled_features, num_classes, use_conv=self.use_conv) def forward(self, x, pre_logits: bool = False): x = self.global_pool(x) if self.drop_rate: x = F.dropout(x, p=float(self.drop_rate), training=self.training) if pre_logits: return x.flatten(1) else: x = self.fc(x) return self.flatten(x)