|
|
|
""" Image to Patch Embedding using Conv2d
|
|
|
|
|
|
|
|
A convolution based approach to patchifying a 2D image w/ embedding projection.
|
|
|
|
|
|
|
|
Based on code in:
|
|
|
|
* https://github.com/google-research/vision_transformer
|
|
|
|
* https://github.com/google-research/big_vision/tree/main/big_vision
|
|
|
|
|
|
|
|
Hacked together by / Copyright 2020 Ross Wightman
|
|
|
|
"""
|
|
|
|
import logging
|
|
|
|
from typing import List
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from torch import nn as nn
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
from .helpers import to_2tuple
|
|
|
|
from .trace_utils import _assert
|
|
|
|
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
class PatchEmbed(nn.Module):
|
|
|
|
""" 2D Image to Patch Embedding
|
|
|
|
"""
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
img_size=224,
|
|
|
|
patch_size=16,
|
|
|
|
in_chans=3,
|
|
|
|
embed_dim=768,
|
|
|
|
norm_layer=None,
|
|
|
|
flatten=True,
|
|
|
|
bias=True,
|
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
img_size = to_2tuple(img_size)
|
|
|
|
patch_size = to_2tuple(patch_size)
|
|
|
|
self.img_size = img_size
|
|
|
|
self.patch_size = patch_size
|
|
|
|
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
|
|
|
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
|
|
|
self.flatten = flatten
|
|
|
|
|
|
|
|
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
|
|
|
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
B, C, H, W = x.shape
|
|
|
|
_assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
|
|
|
|
_assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
|
|
|
|
x = self.proj(x)
|
|
|
|
if self.flatten:
|
|
|
|
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
|
|
|
x = self.norm(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
def resample_patch_embed(
|
|
|
|
patch_embed,
|
|
|
|
new_size: List[int],
|
|
|
|
interpolation: str = 'bicubic',
|
|
|
|
antialias: bool = True,
|
|
|
|
verbose: bool = False,
|
|
|
|
):
|
|
|
|
"""Resample the weights of the patch embedding kernel to target resolution.
|
|
|
|
We resample the patch embedding kernel by approximately inverting the effect
|
|
|
|
of patch resizing.
|
|
|
|
|
|
|
|
Code based on:
|
|
|
|
https://github.com/google-research/big_vision/blob/b00544b81f8694488d5f36295aeb7972f3755ffe/big_vision/models/proj/flexi/vit.py
|
|
|
|
|
|
|
|
With this resizing, we can for example load a B/8 filter into a B/16 model
|
|
|
|
and, on 2x larger input image, the result will match.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
patch_embed: original parameter to be resized.
|
|
|
|
new_size (tuple(int, int): target shape (height, width)-only.
|
|
|
|
interpolation (str): interpolation for resize
|
|
|
|
antialias (bool): use anti-aliasing filter in resize
|
|
|
|
verbose (bool): log operation
|
|
|
|
Returns:
|
|
|
|
Resized patch embedding kernel.
|
|
|
|
"""
|
|
|
|
import numpy as np
|
|
|
|
try:
|
|
|
|
import functorch
|
|
|
|
vmap = functorch.vmap
|
|
|
|
except ImportError:
|
|
|
|
if hasattr(torch, 'vmap'):
|
|
|
|
vmap = torch.vmap
|
|
|
|
else:
|
|
|
|
assert False, "functorch or a version of torch with vmap is required for FlexiViT resizing."
|
|
|
|
|
|
|
|
assert len(patch_embed.shape) == 4, "Four dimensions expected"
|
|
|
|
assert len(new_size) == 2, "New shape should only be hw"
|
|
|
|
old_size = patch_embed.shape[-2:]
|
|
|
|
if tuple(old_size) == tuple(new_size):
|
|
|
|
return patch_embed
|
|
|
|
|
|
|
|
if verbose:
|
|
|
|
_logger.info(f"Resize patch embedding {patch_embed.shape} to {new_size}, w/ {interpolation} interpolation.")
|
|
|
|
|
|
|
|
def resize(x_np, _new_size):
|
|
|
|
x_tf = torch.Tensor(x_np)[None, None, ...]
|
|
|
|
x_upsampled = F.interpolate(
|
|
|
|
x_tf, size=_new_size, mode=interpolation, antialias=antialias)[0, 0, ...].numpy()
|
|
|
|
return x_upsampled
|
|
|
|
|
|
|
|
def get_resize_mat(_old_size, _new_size):
|
|
|
|
mat = []
|
|
|
|
for i in range(np.prod(_old_size)):
|
|
|
|
basis_vec = np.zeros(_old_size)
|
|
|
|
basis_vec[np.unravel_index(i, _old_size)] = 1.
|
|
|
|
mat.append(resize(basis_vec, _new_size).reshape(-1))
|
|
|
|
return np.stack(mat).T
|
|
|
|
|
|
|
|
resize_mat = get_resize_mat(old_size, new_size)
|
|
|
|
resize_mat_pinv = torch.Tensor(np.linalg.pinv(resize_mat.T))
|
|
|
|
|
|
|
|
def resample_kernel(kernel):
|
|
|
|
resampled_kernel = resize_mat_pinv @ kernel.reshape(-1)
|
|
|
|
return resampled_kernel.reshape(new_size)
|
|
|
|
|
|
|
|
v_resample_kernel = vmap(vmap(resample_kernel, 0, 0), 1, 1)
|
|
|
|
return v_resample_kernel(patch_embed)
|
|
|
|
|
|
|
|
|
|
|
|
# def divs(n, m=None):
|
|
|
|
# m = m or n // 2
|
|
|
|
# if m == 1:
|
|
|
|
# return [1]
|
|
|
|
# if n % m == 0:
|
|
|
|
# return [m] + divs(n, m - 1)
|
|
|
|
# return divs(n, m - 1)
|
|
|
|
#
|
|
|
|
#
|
|
|
|
# class FlexiPatchEmbed(nn.Module):
|
|
|
|
# """ 2D Image to Patch Embedding w/ Flexible Patch sizes (FlexiViT)
|
|
|
|
# FIXME WIP
|
|
|
|
# """
|
|
|
|
# def __init__(
|
|
|
|
# self,
|
|
|
|
# img_size=240,
|
|
|
|
# patch_size=16,
|
|
|
|
# in_chans=3,
|
|
|
|
# embed_dim=768,
|
|
|
|
# base_img_size=240,
|
|
|
|
# base_patch_size=32,
|
|
|
|
# norm_layer=None,
|
|
|
|
# flatten=True,
|
|
|
|
# bias=True,
|
|
|
|
# ):
|
|
|
|
# super().__init__()
|
|
|
|
# self.img_size = to_2tuple(img_size)
|
|
|
|
# self.patch_size = to_2tuple(patch_size)
|
|
|
|
# self.num_patches = 0
|
|
|
|
#
|
|
|
|
# # full range for 240 = (5, 6, 8, 10, 12, 14, 15, 16, 20, 24, 30, 40, 48)
|
|
|
|
# self.seqhw = (6, 8, 10, 12, 14, 15, 16, 20, 24, 30)
|
|
|
|
#
|
|
|
|
# self.base_img_size = to_2tuple(base_img_size)
|
|
|
|
# self.base_patch_size = to_2tuple(base_patch_size)
|
|
|
|
# self.base_grid_size = tuple([i // p for i, p in zip(self.base_img_size, self.base_patch_size)])
|
|
|
|
# self.base_num_patches = self.base_grid_size[0] * self.base_grid_size[1]
|
|
|
|
#
|
|
|
|
# self.flatten = flatten
|
|
|
|
# self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=bias)
|
|
|
|
# self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
|
|
|
#
|
|
|
|
# def forward(self, x):
|
|
|
|
# B, C, H, W = x.shape
|
|
|
|
#
|
|
|
|
# if self.patch_size == self.base_patch_size:
|
|
|
|
# weight = self.proj.weight
|
|
|
|
# else:
|
|
|
|
# weight = resample_patch_embed(self.proj.weight, self.patch_size)
|
|
|
|
# patch_size = self.patch_size
|
|
|
|
# x = F.conv2d(x, weight, bias=self.proj.bias, stride=patch_size)
|
|
|
|
# if self.flatten:
|
|
|
|
# x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
|
|
|
# x = self.norm(x)
|
|
|
|
# return x
|