You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/models/layers/se.py

67 lines
2.7 KiB

import torch
from torch import nn as nn
from .helpers import make_divisible
class SqueezeExcite(nn.Module):
""" Squeeze-and-Excitation module as used in Pytorch SENet, SE-ResNeXt implementations
Args:
channels (int): number of input and output channels
reduction (int, float): divisor for attention (squeezed) channels
act_layer (nn.Module): override the default ReLU activation
"""
def __init__(self, channels, reduction=16, act_layer=nn.ReLU, divisible_by=1):
super(SqueezeExcite, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
reduction_channels = make_divisible(channels // reduction, divisible_by)
self.fc1 = nn.Conv2d(
channels, reduction_channels, kernel_size=1, padding=0, bias=True)
self.act = act_layer(inplace=True)
self.fc2 = nn.Conv2d(
reduction_channels, channels, kernel_size=1, padding=0, bias=True)
def forward(self, x):
x_se = self.avg_pool(x)
x_se = self.fc1(x_se)
x_se = self.act(x_se)
x_se = self.fc2(x_se)
return x * x_se.sigmoid()
class SqueezeExciteV2(nn.Module):
""" Squeeze-and-Excitation module as used in EfficientNet, MobileNetV3, related models
Differs from the original SqueezeExcite impl in that:
* reduction is specified as a float multiplier instead of divisor (se_ratio)
* gate function is changeable from sigmoid to alternate (ie hard_sigmoid)
* layer names match those in weights for the EfficientNet/MobileNetV3 families
Args:
channels (int): number of input and output channels
se_ratio (float): multiplier for attention (squeezed) channels
reduced_base_chs (int): specify alternate channel count to base the reduction channels on
act_layer (nn.Module): override the default ReLU activation
gate_fn (callable): override the default gate function
"""
def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None,
act_layer=nn.ReLU, gate_fn=torch.sigmoid, divisible_by=1, **_):
super(SqueezeExciteV2, self).__init__()
self.gate_fn = gate_fn
reduced_chs = make_divisible((reduced_base_chs or in_chs) * se_ratio, divisible_by)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
self.act1 = act_layer(inplace=True)
self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)
def forward(self, x):
x_se = self.avg_pool(x)
x_se = self.conv_reduce(x_se)
x_se = self.act1(x_se)
x_se = self.conv_expand(x_se)
x = x * self.gate_fn(x_se)
return x