Add initial AttentionPool2d that's being trialed. Fix comment and still trying to improve reliability of sgd test.

pull/821/head
Ross Wightman 3 years ago
parent 76881d207b
commit 5f12de4875

@ -317,10 +317,10 @@ def test_sgd(optimizer):
# lambda opt: ReduceLROnPlateau(opt)] # lambda opt: ReduceLROnPlateau(opt)]
# ) # )
_test_basic_cases( _test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=3e-3, momentum=1) lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, momentum=1)
) )
_test_basic_cases( _test_basic_cases(
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=3e-3, momentum=1, weight_decay=.1) lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, momentum=1, weight_decay=.1)
) )
_test_rosenbrock( _test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3) lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)

@ -246,7 +246,7 @@ def halonet26t(pretrained=False, **kwargs):
@register_model @register_model
def sehalonet33ts(pretrained=False, **kwargs): def sehalonet33ts(pretrained=False, **kwargs):
""" HaloNet w/ a ResNet26-t backbone. Halo attention in final two stages """ HaloNet w/ a ResNet33-t backbone, SE attn for non Halo blocks, SiLU, 1-2 Halo in stage 2,3,4.
""" """
return _create_byoanet('sehalonet33ts', pretrained=pretrained, **kwargs) return _create_byoanet('sehalonet33ts', pretrained=pretrained, **kwargs)

@ -0,0 +1,182 @@
""" Attention Pool 2D
Implementations of 2D spatial feature pooling using multi-head attention instead of average pool.
Based on idea in CLIP by OpenAI, licensed Apache 2.0
https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py
Hacked together by / Copyright 2021 Ross Wightman
"""
import math
from typing import List, Union, Tuple
import torch
import torch.nn as nn
from .helpers import to_2tuple
from .weight_init import trunc_normal_
def rot(x):
return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape)
def apply_rot_embed(x: torch.Tensor, sin_emb, cos_emb):
return x * cos_emb + rot(x) * sin_emb
def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb):
if isinstance(x, torch.Tensor):
x = [x]
return [t * cos_emb + rot(t) * sin_emb for t in x]
class RotaryEmbedding(nn.Module):
""" Rotary position embedding
NOTE: This is my initial attempt at impl rotary embedding for spatial use, it has not
been well tested, and will likely change. It will be moved to its own file.
The following impl/resources were referenced for this impl:
* https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py
* https://blog.eleuther.ai/rotary-embeddings/
"""
def __init__(self, dim, max_freq=4):
super().__init__()
self.dim = dim
self.register_buffer('bands', 2 ** torch.linspace(0., max_freq - 1, self.dim // 4), persistent=False)
def get_embed(self, shape: torch.Size, device: torch.device = None, dtype: torch.dtype = None):
"""
NOTE: shape arg should include spatial dim only
"""
device = device or self.bands.device
dtype = dtype or self.bands.dtype
if not isinstance(shape, torch.Size):
shape = torch.Size(shape)
N = shape.numel()
grid = torch.stack(torch.meshgrid(
[torch.linspace(-1., 1., steps=s, device=device, dtype=dtype) for s in shape]), dim=-1).unsqueeze(-1)
emb = grid * math.pi * self.bands
sin = emb.sin().reshape(N, -1).repeat_interleave(2, -1)
cos = emb.cos().reshape(N, -1).repeat_interleave(2, -1)
return sin, cos
def forward(self, x):
# assuming channel-first tensor where spatial dim are >= 2
sin_emb, cos_emb = self.get_embed(x.shape[2:])
return apply_rot_embed(x, sin_emb, cos_emb)
class RotAttentionPool2d(nn.Module):
""" Attention based 2D feature pooling w/ rotary (relative) pos embedding.
This is a multi-head attention based replacement for (spatial) average pooling in NN architectures.
Adapted from the AttentionPool2d in CLIP w/ rotary embedding instead of learned embed.
https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py
NOTE: While this impl does not require a fixed feature size, performance at differeing resolutions from
train varies widely and falls off dramatically. I'm not sure if there is a way around this... -RW
"""
def __init__(
self,
in_features: int,
out_features: int = None,
embed_dim: int = None,
num_heads: int = 4,
qkv_bias: bool = True,
):
super().__init__()
embed_dim = embed_dim or in_features
out_features = out_features or in_features
self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)
self.proj = nn.Linear(embed_dim, out_features)
self.num_heads = num_heads
assert embed_dim % num_heads == 0
self.head_dim = embed_dim // num_heads
self.scale = self.head_dim ** -0.5
self.pos_embed = RotaryEmbedding(self.head_dim)
trunc_normal_(self.qkv.weight, std=in_features ** -0.5)
nn.init.zeros_(self.qkv.bias)
def forward(self, x):
B, _, H, W = x.shape
N = H * W
sin_emb, cos_emb = self.pos_embed.get_embed(x.shape[2:])
x = x.reshape(B, -1, N).permute(0, 2, 1)
x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = x[0], x[1], x[2]
qc, q = q[:, :, :1], q[:, :, 1:]
q = apply_rot_embed(q, sin_emb, cos_emb)
q = torch.cat([qc, q], dim=2)
kc, k = k[:, :, :1], k[:, :, 1:]
k = apply_rot_embed(k, sin_emb, cos_emb)
k = torch.cat([kc, k], dim=2)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1)
x = self.proj(x)
return x[:, 0]
class AttentionPool2d(nn.Module):
""" Attention based 2D feature pooling w/ learned (absolute) pos embedding.
This is a multi-head attention based replacement for (spatial) average pooling in NN architectures.
It was based on impl in CLIP by OpenAI
https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py
NOTE: This requires feature size upon construction and well prevent adaptive sizing of the network.
"""
def __init__(
self,
in_features: int,
feat_size: Union[int, Tuple[int, int]],
out_features: int = None,
embed_dim: int = None,
num_heads: int = 4,
qkv_bias: bool = True,
):
super().__init__()
embed_dim = embed_dim or in_features
out_features = out_features or in_features
assert embed_dim % num_heads == 0
self.feat_size = to_2tuple(feat_size)
self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)
self.proj = nn.Linear(embed_dim, out_features)
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.scale = self.head_dim ** -0.5
spatial_dim = self.feat_size[0] * self.feat_size[1]
self.pos_embed = nn.Parameter(torch.zeros(spatial_dim + 1, in_features))
trunc_normal_(self.pos_embed, std=in_features ** -0.5)
trunc_normal_(self.qkv.weight, std=in_features ** -0.5)
nn.init.zeros_(self.qkv.bias)
def forward(self, x):
B, _, H, W = x.shape
N = H * W
assert self.feat_size[0] == H
assert self.feat_size[1] == W
x = x.reshape(B, -1, N).permute(0, 2, 1)
x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
x = x + self.pos_embed.unsqueeze(0).to(x.dtype)
x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = x[0], x[1], x[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1)
x = self.proj(x)
return x[:, 0]
Loading…
Cancel
Save