diff --git a/timm/models/byoanet.py b/timm/models/byoanet.py index d296d4ba..c7a5c53e 100644 --- a/timm/models/byoanet.py +++ b/timm/models/byoanet.py @@ -68,7 +68,8 @@ default_cfgs = { fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.94), 'haloregnetz_b': _cfg( - url='', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/haloregnetz_c_raa_256-c8ad7616.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), first_conv='stem.conv', input_size=(3, 224, 224), pool_size=(7, 7), min_input_size=(3, 224, 224), crop_pct=0.94), 'trionet50ts_256': _cfg( url='', diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index 4363709f..93898209 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -139,13 +139,13 @@ default_cfgs = { 'regnetz_b': _cfgr( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_b_raa-677d9606.pth', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - input_size=(3, 224, 224), pool_size=(7, 7), test_input_size=(3, 288, 288), first_conv='stem.conv', crop_pct=0.95), + input_size=(3, 224, 224), pool_size=(7, 7), test_input_size=(3, 288, 288), first_conv='stem.conv', crop_pct=0.94), 'regnetz_c': _cfgr( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_c_rab_256-6bdb3c01.pth', - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), first_conv='stem.conv', crop_pct=0.95), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_c_rab2_256-a54bf36a.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), first_conv='stem.conv', crop_pct=0.94), 'regnetz_d': _cfgr( - url='', - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_d_rab_256-b8073a89.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), crop_pct=0.95), } diff --git a/timm/models/layers/halo_attn.py b/timm/models/layers/halo_attn.py index 846c12ff..4149e812 100644 --- a/timm/models/layers/halo_attn.py +++ b/timm/models/layers/halo_attn.py @@ -183,7 +183,9 @@ class HaloAttn(nn.Module): # B * num_heads, num_blocks, block_size ** 2, dim_head kv = self.kv(x) - # generate overlapping windows for kv + # Generate overlapping windows for kv. This approach is good for GPU and CPU. However, unfold() is not + # lowered for PyTorch XLA so it will be very slow. See code at bottom of file for XLA friendly approach. + # FIXME figure out how to switch impl between this and conv2d if XLA being used. kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]) kv = kv.unfold(2, self.win_size, self.block_size).unfold(3, self.win_size, self.block_size).reshape( B * self.num_heads, self.dim_head_qk + self.dim_head_v, num_blocks, -1).permute(0, 2, 3, 1) @@ -207,17 +209,24 @@ class HaloAttn(nn.Module): return out -""" Two alternatives for overlapping windows. +""" Three alternatives for overlapping windows. `.unfold().unfold()` is same speed as stride tricks with similar clarity as F.unfold() - if self.stride_tricks: + if is_xla: + # This code achieves haloing on PyTorch XLA with reasonable runtime trade-off, it is + # EXTREMELY slow for backward on a GPU though so I need a way of selecting based on environment. + WW = self.win_size ** 2 + pw = torch.eye(WW, dtype=x.dtype, device=x.device).reshape(WW, 1, self.win_size, self.win_size) + kv = F.conv2d(kv.reshape(-1, 1, H, W), pw, stride=self.block_size, padding=self.halo_size) + elif self.stride_tricks: kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]).contiguous() kv = kv.as_strided(( B, self.dim_out_qk + self.dim_out_v, self.win_size, self.win_size, num_h_blocks, num_w_blocks), stride=(kv.stride(0), kv.stride(1), kv.shape[-1], 1, self.block_size * kv.shape[-1], self.block_size)) else: kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size) + kv = kv.reshape( B * self.num_heads, self.dim_head_qk + self.dim_head_v, -1, num_blocks).transpose(1, 3) """