Replace all None by nn.Identity() in HRNet modules

pull/141/head
Vyacheslav Shults 5 years ago
parent 3b72ebff51
commit f0eb021620

@ -13,20 +13,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import logging
import functools
import torch
import torch.nn as nn
import torch._utils
import torch.nn.functional as F
from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE
from .registry import register_model
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import load_pretrained
from .layers import SelectAdaptivePool2d
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .registry import register_model
from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE
_BN_MOMENTUM = 0.1
logger = logging.getLogger(__name__)
@ -101,7 +97,7 @@ cfg_cls = dict(
),
),
hrnet_w18_small_v2 = dict(
hrnet_w18_small_v2=dict(
STEM_WIDTH=64,
STAGE1=dict(
NUM_MODULES=1,
@ -137,7 +133,7 @@ cfg_cls = dict(
),
),
hrnet_w18 = dict(
hrnet_w18=dict(
STEM_WIDTH=64,
STAGE1=dict(
NUM_MODULES=1,
@ -173,7 +169,7 @@ cfg_cls = dict(
),
),
hrnet_w30 = dict(
hrnet_w30=dict(
STEM_WIDTH=64,
STAGE1=dict(
NUM_MODULES=1,
@ -209,7 +205,7 @@ cfg_cls = dict(
),
),
hrnet_w32 = dict(
hrnet_w32=dict(
STEM_WIDTH=64,
STAGE1=dict(
NUM_MODULES=1,
@ -245,7 +241,7 @@ cfg_cls = dict(
),
),
hrnet_w40 = dict(
hrnet_w40=dict(
STEM_WIDTH=64,
STAGE1=dict(
NUM_MODULES=1,
@ -281,7 +277,7 @@ cfg_cls = dict(
),
),
hrnet_w44 = dict(
hrnet_w44=dict(
STEM_WIDTH=64,
STAGE1=dict(
NUM_MODULES=1,
@ -317,7 +313,7 @@ cfg_cls = dict(
),
),
hrnet_w48 = dict(
hrnet_w48=dict(
STEM_WIDTH=64,
STAGE1=dict(
NUM_MODULES=1,
@ -353,7 +349,7 @@ cfg_cls = dict(
),
),
hrnet_w64 = dict(
hrnet_w64=dict(
STEM_WIDTH=64,
STAGE1=dict(
NUM_MODULES=1,
@ -456,7 +452,7 @@ class HighResolutionModule(nn.Module):
def _make_fuse_layers(self):
if self.num_branches == 1:
return None
return nn.Identity()
num_branches = self.num_branches
num_inchannels = self.num_inchannels
@ -470,7 +466,7 @@ class HighResolutionModule(nn.Module):
nn.BatchNorm2d(num_inchannels[i], momentum=_BN_MOMENTUM),
nn.Upsample(scale_factor=2 ** (j - i), mode='nearest')))
elif j == i:
fuse_layer.append(None)
fuse_layer.append(nn.Identity())
else:
conv3x3s = []
for k in range(i - j):
@ -619,7 +615,7 @@ class HighResolutionNet(nn.Module):
nn.BatchNorm2d(num_channels_cur_layer[i], momentum=_BN_MOMENTUM),
nn.ReLU(inplace=True)))
else:
transition_layers.append(None)
transition_layers.append(nn.Identity())
else:
conv3x3s = []
for j in range(i + 1 - num_branches_pre):
@ -686,8 +682,11 @@ class HighResolutionNet(nn.Module):
def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.classifier = nn.Linear(
self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None
num_features = self.num_features * self.global_pool.feat_mult()
if num_classes:
self.classifier = nn.Linear(num_features, num_classes)
else:
self.classifier = nn.Identity()
def forward_features(self, x):
x = self.conv1(x)
@ -699,24 +698,21 @@ class HighResolutionNet(nn.Module):
x = self.layer1(x)
x_list = []
for i in range(self.stage2_cfg['NUM_BRANCHES']):
if self.transition1[i] is not None:
x_list.append(self.transition1[i](x))
else:
x_list.append(x)
for i in range(len(self.transition1)):
x_list.append(self.transition1[i](x))
y_list = self.stage2(x_list)
x_list = []
for i in range(self.stage3_cfg['NUM_BRANCHES']):
if self.transition2[i] is not None:
for i in range(len(self.transition2)):
if not isinstance(self.transition2[i], nn.Identity):
x_list.append(self.transition2[i](y_list[-1]))
else:
x_list.append(y_list[i])
y_list = self.stage3(x_list)
x_list = []
for i in range(self.stage4_cfg['NUM_BRANCHES']):
if self.transition3[i] is not None:
for i in range(len(self.transition3)):
if not isinstance(self.transition3[i], nn.Identity):
x_list.append(self.transition3[i](y_list[-1]))
else:
x_list.append(y_list[i])

Loading…
Cancel
Save