|
|
@ -38,13 +38,24 @@ def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False
|
|
|
|
class ClassifierHead(nn.Module):
|
|
|
|
class ClassifierHead(nn.Module):
|
|
|
|
"""Classifier head w/ configurable global pooling and dropout."""
|
|
|
|
"""Classifier head w/ configurable global pooling and dropout."""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0., use_conv=False):
|
|
|
|
def __init__(self, in_features, num_classes, pool_type='avg', drop_rate=0., use_conv=False):
|
|
|
|
super(ClassifierHead, self).__init__()
|
|
|
|
super(ClassifierHead, self).__init__()
|
|
|
|
self.drop_rate = drop_rate
|
|
|
|
self.drop_rate = drop_rate
|
|
|
|
self.global_pool, num_pooled_features = _create_pool(in_chs, num_classes, pool_type, use_conv=use_conv)
|
|
|
|
self.in_features = in_features
|
|
|
|
|
|
|
|
self.use_conv = use_conv
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.global_pool, num_pooled_features = _create_pool(in_features, num_classes, pool_type, use_conv=use_conv)
|
|
|
|
self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
|
|
|
|
self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
|
|
|
|
self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity()
|
|
|
|
self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reset(self, num_classes, global_pool=None):
|
|
|
|
|
|
|
|
if global_pool is not None:
|
|
|
|
|
|
|
|
if global_pool != self.global_pool.pool_type:
|
|
|
|
|
|
|
|
self.global_pool, _ = _create_pool(self.in_features, num_classes, global_pool, use_conv=self.use_conv)
|
|
|
|
|
|
|
|
self.flatten = nn.Flatten(1) if self.use_conv and global_pool else nn.Identity()
|
|
|
|
|
|
|
|
num_pooled_features = self.in_features * self.global_pool.feat_mult()
|
|
|
|
|
|
|
|
self.fc = _create_fc(num_pooled_features, num_classes, use_conv=self.use_conv)
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x, pre_logits: bool = False):
|
|
|
|
def forward(self, x, pre_logits: bool = False):
|
|
|
|
x = self.global_pool(x)
|
|
|
|
x = self.global_pool(x)
|
|
|
|
if self.drop_rate:
|
|
|
|
if self.drop_rate:
|
|
|
|