You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
25 lines
841 B
25 lines
841 B
from torch import nn as nn
|
|
from torch.nn import functional as F
|
|
|
|
from .adaptive_avgmax_pool import SelectAdaptivePool2d
|
|
|
|
|
|
class ClassifierHead(nn.Module):
|
|
"""Classifier Head w/ configurable global pooling and dropout."""
|
|
|
|
def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0.):
|
|
super(ClassifierHead, self).__init__()
|
|
self.drop_rate = drop_rate
|
|
self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
|
|
if num_classes > 0:
|
|
self.fc = nn.Linear(in_chs * self.global_pool.feat_mult(), num_classes, bias=True)
|
|
else:
|
|
self.fc = nn.Identity()
|
|
|
|
def forward(self, x):
|
|
x = self.global_pool(x).flatten(1)
|
|
if self.drop_rate:
|
|
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
|
|
x = self.fc(x)
|
|
return x
|