|
|
|
@ -13,20 +13,16 @@ from __future__ import absolute_import
|
|
|
|
|
from __future__ import division
|
|
|
|
|
from __future__ import print_function
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
import logging
|
|
|
|
|
import functools
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
import torch._utils
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE
|
|
|
|
|
from .registry import register_model
|
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
|
from .helpers import load_pretrained
|
|
|
|
|
from .layers import SelectAdaptivePool2d
|
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
|
from .registry import register_model
|
|
|
|
|
from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE
|
|
|
|
|
|
|
|
|
|
_BN_MOMENTUM = 0.1
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
@ -101,7 +97,7 @@ cfg_cls = dict(
|
|
|
|
|
),
|
|
|
|
|
),
|
|
|
|
|
|
|
|
|
|
hrnet_w18_small_v2 = dict(
|
|
|
|
|
hrnet_w18_small_v2=dict(
|
|
|
|
|
STEM_WIDTH=64,
|
|
|
|
|
STAGE1=dict(
|
|
|
|
|
NUM_MODULES=1,
|
|
|
|
@ -137,7 +133,7 @@ cfg_cls = dict(
|
|
|
|
|
),
|
|
|
|
|
),
|
|
|
|
|
|
|
|
|
|
hrnet_w18 = dict(
|
|
|
|
|
hrnet_w18=dict(
|
|
|
|
|
STEM_WIDTH=64,
|
|
|
|
|
STAGE1=dict(
|
|
|
|
|
NUM_MODULES=1,
|
|
|
|
@ -173,7 +169,7 @@ cfg_cls = dict(
|
|
|
|
|
),
|
|
|
|
|
),
|
|
|
|
|
|
|
|
|
|
hrnet_w30 = dict(
|
|
|
|
|
hrnet_w30=dict(
|
|
|
|
|
STEM_WIDTH=64,
|
|
|
|
|
STAGE1=dict(
|
|
|
|
|
NUM_MODULES=1,
|
|
|
|
@ -209,7 +205,7 @@ cfg_cls = dict(
|
|
|
|
|
),
|
|
|
|
|
),
|
|
|
|
|
|
|
|
|
|
hrnet_w32 = dict(
|
|
|
|
|
hrnet_w32=dict(
|
|
|
|
|
STEM_WIDTH=64,
|
|
|
|
|
STAGE1=dict(
|
|
|
|
|
NUM_MODULES=1,
|
|
|
|
@ -245,7 +241,7 @@ cfg_cls = dict(
|
|
|
|
|
),
|
|
|
|
|
),
|
|
|
|
|
|
|
|
|
|
hrnet_w40 = dict(
|
|
|
|
|
hrnet_w40=dict(
|
|
|
|
|
STEM_WIDTH=64,
|
|
|
|
|
STAGE1=dict(
|
|
|
|
|
NUM_MODULES=1,
|
|
|
|
@ -281,7 +277,7 @@ cfg_cls = dict(
|
|
|
|
|
),
|
|
|
|
|
),
|
|
|
|
|
|
|
|
|
|
hrnet_w44 = dict(
|
|
|
|
|
hrnet_w44=dict(
|
|
|
|
|
STEM_WIDTH=64,
|
|
|
|
|
STAGE1=dict(
|
|
|
|
|
NUM_MODULES=1,
|
|
|
|
@ -317,7 +313,7 @@ cfg_cls = dict(
|
|
|
|
|
),
|
|
|
|
|
),
|
|
|
|
|
|
|
|
|
|
hrnet_w48 = dict(
|
|
|
|
|
hrnet_w48=dict(
|
|
|
|
|
STEM_WIDTH=64,
|
|
|
|
|
STAGE1=dict(
|
|
|
|
|
NUM_MODULES=1,
|
|
|
|
@ -353,7 +349,7 @@ cfg_cls = dict(
|
|
|
|
|
),
|
|
|
|
|
),
|
|
|
|
|
|
|
|
|
|
hrnet_w64 = dict(
|
|
|
|
|
hrnet_w64=dict(
|
|
|
|
|
STEM_WIDTH=64,
|
|
|
|
|
STAGE1=dict(
|
|
|
|
|
NUM_MODULES=1,
|
|
|
|
@ -456,7 +452,7 @@ class HighResolutionModule(nn.Module):
|
|
|
|
|
|
|
|
|
|
def _make_fuse_layers(self):
|
|
|
|
|
if self.num_branches == 1:
|
|
|
|
|
return None
|
|
|
|
|
return nn.Identity()
|
|
|
|
|
|
|
|
|
|
num_branches = self.num_branches
|
|
|
|
|
num_inchannels = self.num_inchannels
|
|
|
|
@ -470,7 +466,7 @@ class HighResolutionModule(nn.Module):
|
|
|
|
|
nn.BatchNorm2d(num_inchannels[i], momentum=_BN_MOMENTUM),
|
|
|
|
|
nn.Upsample(scale_factor=2 ** (j - i), mode='nearest')))
|
|
|
|
|
elif j == i:
|
|
|
|
|
fuse_layer.append(None)
|
|
|
|
|
fuse_layer.append(nn.Identity())
|
|
|
|
|
else:
|
|
|
|
|
conv3x3s = []
|
|
|
|
|
for k in range(i - j):
|
|
|
|
@ -619,7 +615,7 @@ class HighResolutionNet(nn.Module):
|
|
|
|
|
nn.BatchNorm2d(num_channels_cur_layer[i], momentum=_BN_MOMENTUM),
|
|
|
|
|
nn.ReLU(inplace=True)))
|
|
|
|
|
else:
|
|
|
|
|
transition_layers.append(None)
|
|
|
|
|
transition_layers.append(nn.Identity())
|
|
|
|
|
else:
|
|
|
|
|
conv3x3s = []
|
|
|
|
|
for j in range(i + 1 - num_branches_pre):
|
|
|
|
@ -686,8 +682,11 @@ class HighResolutionNet(nn.Module):
|
|
|
|
|
def reset_classifier(self, num_classes, global_pool='avg'):
|
|
|
|
|
self.num_classes = num_classes
|
|
|
|
|
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
|
|
|
|
self.classifier = nn.Linear(
|
|
|
|
|
self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None
|
|
|
|
|
num_features = self.num_features * self.global_pool.feat_mult()
|
|
|
|
|
if num_classes:
|
|
|
|
|
self.classifier = nn.Linear(num_features, num_classes)
|
|
|
|
|
else:
|
|
|
|
|
self.classifier = nn.Identity()
|
|
|
|
|
|
|
|
|
|
def forward_features(self, x):
|
|
|
|
|
x = self.conv1(x)
|
|
|
|
@ -699,24 +698,21 @@ class HighResolutionNet(nn.Module):
|
|
|
|
|
x = self.layer1(x)
|
|
|
|
|
|
|
|
|
|
x_list = []
|
|
|
|
|
for i in range(self.stage2_cfg['NUM_BRANCHES']):
|
|
|
|
|
if self.transition1[i] is not None:
|
|
|
|
|
x_list.append(self.transition1[i](x))
|
|
|
|
|
else:
|
|
|
|
|
x_list.append(x)
|
|
|
|
|
for i in range(len(self.transition1)):
|
|
|
|
|
x_list.append(self.transition1[i](x))
|
|
|
|
|
y_list = self.stage2(x_list)
|
|
|
|
|
|
|
|
|
|
x_list = []
|
|
|
|
|
for i in range(self.stage3_cfg['NUM_BRANCHES']):
|
|
|
|
|
if self.transition2[i] is not None:
|
|
|
|
|
for i in range(len(self.transition2)):
|
|
|
|
|
if not isinstance(self.transition2[i], nn.Identity):
|
|
|
|
|
x_list.append(self.transition2[i](y_list[-1]))
|
|
|
|
|
else:
|
|
|
|
|
x_list.append(y_list[i])
|
|
|
|
|
y_list = self.stage3(x_list)
|
|
|
|
|
|
|
|
|
|
x_list = []
|
|
|
|
|
for i in range(self.stage4_cfg['NUM_BRANCHES']):
|
|
|
|
|
if self.transition3[i] is not None:
|
|
|
|
|
for i in range(len(self.transition3)):
|
|
|
|
|
if not isinstance(self.transition3[i], nn.Identity):
|
|
|
|
|
x_list.append(self.transition3[i](y_list[-1]))
|
|
|
|
|
else:
|
|
|
|
|
x_list.append(y_list[i])
|
|
|
|
|