Add weights for regnetz_d and haloregnetz_c, update regnetz_c weights. Add commented PyTorch XLA code for halo attention

pull/933/head
Ross Wightman 3 years ago
parent 02daf2ab94
commit c02334d9fa

@ -68,7 +68,8 @@ default_cfgs = {
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.94), fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.94),
'haloregnetz_b': _cfg( 'haloregnetz_b': _cfg(
url='', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/haloregnetz_c_raa_256-c8ad7616.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
first_conv='stem.conv', input_size=(3, 224, 224), pool_size=(7, 7), min_input_size=(3, 224, 224), crop_pct=0.94), first_conv='stem.conv', input_size=(3, 224, 224), pool_size=(7, 7), min_input_size=(3, 224, 224), crop_pct=0.94),
'trionet50ts_256': _cfg( 'trionet50ts_256': _cfg(
url='', url='',

@ -139,13 +139,13 @@ default_cfgs = {
'regnetz_b': _cfgr( 'regnetz_b': _cfgr(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_b_raa-677d9606.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_b_raa-677d9606.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 224, 224), pool_size=(7, 7), test_input_size=(3, 288, 288), first_conv='stem.conv', crop_pct=0.95), input_size=(3, 224, 224), pool_size=(7, 7), test_input_size=(3, 288, 288), first_conv='stem.conv', crop_pct=0.94),
'regnetz_c': _cfgr( 'regnetz_c': _cfgr(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_c_rab_256-6bdb3c01.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_c_rab2_256-a54bf36a.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), first_conv='stem.conv', crop_pct=0.95), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), first_conv='stem.conv', crop_pct=0.94),
'regnetz_d': _cfgr( 'regnetz_d': _cfgr(
url='', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_d_rab_256-b8073a89.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), crop_pct=0.95),
} }

@ -183,7 +183,9 @@ class HaloAttn(nn.Module):
# B * num_heads, num_blocks, block_size ** 2, dim_head # B * num_heads, num_blocks, block_size ** 2, dim_head
kv = self.kv(x) kv = self.kv(x)
# generate overlapping windows for kv # Generate overlapping windows for kv. This approach is good for GPU and CPU. However, unfold() is not
# lowered for PyTorch XLA so it will be very slow. See code at bottom of file for XLA friendly approach.
# FIXME figure out how to switch impl between this and conv2d if XLA being used.
kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]) kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size])
kv = kv.unfold(2, self.win_size, self.block_size).unfold(3, self.win_size, self.block_size).reshape( kv = kv.unfold(2, self.win_size, self.block_size).unfold(3, self.win_size, self.block_size).reshape(
B * self.num_heads, self.dim_head_qk + self.dim_head_v, num_blocks, -1).permute(0, 2, 3, 1) B * self.num_heads, self.dim_head_qk + self.dim_head_v, num_blocks, -1).permute(0, 2, 3, 1)
@ -207,17 +209,24 @@ class HaloAttn(nn.Module):
return out return out
""" Two alternatives for overlapping windows. """ Three alternatives for overlapping windows.
`.unfold().unfold()` is same speed as stride tricks with similar clarity as F.unfold() `.unfold().unfold()` is same speed as stride tricks with similar clarity as F.unfold()
if self.stride_tricks: if is_xla:
# This code achieves haloing on PyTorch XLA with reasonable runtime trade-off, it is
# EXTREMELY slow for backward on a GPU though so I need a way of selecting based on environment.
WW = self.win_size ** 2
pw = torch.eye(WW, dtype=x.dtype, device=x.device).reshape(WW, 1, self.win_size, self.win_size)
kv = F.conv2d(kv.reshape(-1, 1, H, W), pw, stride=self.block_size, padding=self.halo_size)
elif self.stride_tricks:
kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]).contiguous() kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]).contiguous()
kv = kv.as_strided(( kv = kv.as_strided((
B, self.dim_out_qk + self.dim_out_v, self.win_size, self.win_size, num_h_blocks, num_w_blocks), B, self.dim_out_qk + self.dim_out_v, self.win_size, self.win_size, num_h_blocks, num_w_blocks),
stride=(kv.stride(0), kv.stride(1), kv.shape[-1], 1, self.block_size * kv.shape[-1], self.block_size)) stride=(kv.stride(0), kv.stride(1), kv.shape[-1], 1, self.block_size * kv.shape[-1], self.block_size))
else: else:
kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size) kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size)
kv = kv.reshape( kv = kv.reshape(
B * self.num_heads, self.dim_head_qk + self.dim_head_v, -1, num_blocks).transpose(1, 3) B * self.num_heads, self.dim_head_qk + self.dim_head_v, -1, num_blocks).transpose(1, 3)
""" """

Loading…
Cancel
Save