You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

196 lines
7.3 KiB

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.nn.functional import conv2d
class VGG19(nn.Module):
def __init__(self, resize_input=False):
super(VGG19, self).__init__()
features = models.vgg19(pretrained=True).features
self.resize_input = resize_input
self.mean = torch.Tensor([0.485, 0.456, 0.406]).cuda()
self.std = torch.Tensor([0.229, 0.224, 0.225]).cuda()
prefix = [1, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5]
posfix = [1, 2, 1, 2, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4]
names = list(zip(prefix, posfix))
self.relus = []
for pre, pos in names:
self.relus.append('relu{}_{}'.format(pre, pos))
self.__setattr__('relu{}_{}'.format(
pre, pos), torch.nn.Sequential())
nums = [[0, 1], [2, 3], [4, 5, 6], [7, 8],
[9, 10, 11], [12, 13], [14, 15], [16, 17],
[18, 19, 20], [21, 22], [23, 24], [25, 26],
[27, 28, 29], [30, 31], [32, 33], [34, 35]]
for i, layer in enumerate(self.relus):
for num in nums[i]:
self.__getattr__(layer).add_module(str(num), features[num])
# don't need the gradients, just want the features
for param in self.parameters():
param.requires_grad = False
def forward(self, x):
# resize and normalize input for pretrained vgg19
x = (x + 1.0) / 2.0
x = (x - self.mean.view(1, 3, 1, 1)) / (self.std.view(1, 3, 1, 1))
if self.resize_input:
x = F.interpolate(
x, size=(256, 256), mode='bilinear', align_corners=True)
features = []
for layer in self.relus:
x = self.__getattr__(layer)(x)
features.append(x)
out = {key: value for (key, value) in list(zip(self.relus, features))}
return out
def gaussian(window_size, sigma):
def gauss_fcn(x):
return -(x - window_size // 2)**2 / float(2 * sigma**2)
gauss = torch.stack([torch.exp(torch.tensor(gauss_fcn(x)))
for x in range(window_size)])
return gauss / gauss.sum()
def get_gaussian_kernel(kernel_size: int, sigma: float) -> torch.Tensor:
r"""Function that returns Gaussian filter coefficients.
Args:
kernel_size (int): filter size. It should be odd and positive.
sigma (float): gaussian standard deviation.
Returns:
Tensor: 1D tensor with gaussian filter coefficients.
Shape:
- Output: :math:`(\text{kernel_size})`
Examples::
>>> kornia.image.get_gaussian_kernel(3, 2.5)
tensor([0.3243, 0.3513, 0.3243])
>>> kornia.image.get_gaussian_kernel(5, 1.5)
tensor([0.1201, 0.2339, 0.2921, 0.2339, 0.1201])
"""
if not isinstance(kernel_size, int) or kernel_size % 2 == 0 or kernel_size <= 0:
raise TypeError(
"kernel_size must be an odd positive integer. Got {}".format(kernel_size))
window_1d: torch.Tensor = gaussian(kernel_size, sigma)
return window_1d
def get_gaussian_kernel2d(kernel_size, sigma):
r"""Function that returns Gaussian filter matrix coefficients.
Args:
kernel_size (Tuple[int, int]): filter sizes in the x and y direction.
Sizes should be odd and positive.
sigma (Tuple[int, int]): gaussian standard deviation in the x and y
direction.
Returns:
Tensor: 2D tensor with gaussian filter matrix coefficients.
Shape:
- Output: :math:`(\text{kernel_size}_x, \text{kernel_size}_y)`
Examples::
>>> kornia.image.get_gaussian_kernel2d((3, 3), (1.5, 1.5))
tensor([[0.0947, 0.1183, 0.0947],
[0.1183, 0.1478, 0.1183],
[0.0947, 0.1183, 0.0947]])
>>> kornia.image.get_gaussian_kernel2d((3, 5), (1.5, 1.5))
tensor([[0.0370, 0.0720, 0.0899, 0.0720, 0.0370],
[0.0462, 0.0899, 0.1123, 0.0899, 0.0462],
[0.0370, 0.0720, 0.0899, 0.0720, 0.0370]])
"""
if not isinstance(kernel_size, tuple) or len(kernel_size) != 2:
raise TypeError(
"kernel_size must be a tuple of length two. Got {}".format(kernel_size))
if not isinstance(sigma, tuple) or len(sigma) != 2:
raise TypeError(
"sigma must be a tuple of length two. Got {}".format(sigma))
ksize_x, ksize_y = kernel_size
sigma_x, sigma_y = sigma
kernel_x: torch.Tensor = get_gaussian_kernel(ksize_x, sigma_x)
kernel_y: torch.Tensor = get_gaussian_kernel(ksize_y, sigma_y)
kernel_2d: torch.Tensor = torch.matmul(
kernel_x.unsqueeze(-1), kernel_y.unsqueeze(-1).t())
return kernel_2d
class GaussianBlur(nn.Module):
r"""Creates an operator that blurs a tensor using a Gaussian filter.
The operator smooths the given tensor with a gaussian kernel by convolving
it to each channel. It suports batched operation.
Arguments:
kernel_size (Tuple[int, int]): the size of the kernel.
sigma (Tuple[float, float]): the standard deviation of the kernel.
Returns:
Tensor: the blurred tensor.
Shape:
- Input: :math:`(B, C, H, W)`
- Output: :math:`(B, C, H, W)`
Examples::
>>> input = torch.rand(2, 4, 5, 5)
>>> gauss = kornia.filters.GaussianBlur((3, 3), (1.5, 1.5))
>>> output = gauss(input) # 2x4x5x5
"""
def __init__(self, kernel_size, sigma):
super(GaussianBlur, self).__init__()
self.kernel_size = kernel_size
self.sigma = sigma
self._padding = self.compute_zero_padding(kernel_size)
self.kernel = get_gaussian_kernel2d(kernel_size, sigma)
@staticmethod
def compute_zero_padding(kernel_size):
"""Computes zero padding tuple."""
computed = [(k - 1) // 2 for k in kernel_size]
return computed[0], computed[1]
def forward(self, x): # type: ignore
if not torch.is_tensor(x):
raise TypeError(
"Input x type is not a torch.Tensor. Got {}".format(type(x)))
if not len(x.shape) == 4:
raise ValueError(
"Invalid input shape, we expect BxCxHxW. Got: {}".format(x.shape))
# prepare kernel
b, c, h, w = x.shape
tmp_kernel: torch.Tensor = self.kernel.to(x.device).to(x.dtype)
kernel: torch.Tensor = tmp_kernel.repeat(c, 1, 1, 1)
# TODO: explore solution when using jit.trace since it raises a warning
# because the shape is converted to a tensor instead to a int.
# convolve tensor with gaussian kernel
return conv2d(x, kernel, padding=self._padding, stride=1, groups=c)
######################
# functional interface
######################
def gaussian_blur(input, kernel_size, sigma):
r"""Function that blurs a tensor using a Gaussian filter.
See :class:`~kornia.filters.GaussianBlur` for details.
"""
return GaussianBlur(kernel_size, sigma)(input)
if __name__ == '__main__':
img = Image.open('test.png').convert('L')
tensor_img = F.to_tensor(img).unsqueeze(0).float()
print('tensor_img size: ', tensor_img.size())
blurred_img = gaussian_blur(tensor_img, (61, 61), (10, 10))
print(torch.min(blurred_img), torch.max(blurred_img))
blurred_img = blurred_img*255
img = blurred_img.int().numpy().astype(np.uint8)[0][0]
print(img.shape, np.min(img), np.max(img), np.unique(img))
cv2.imwrite('gaussian.png', img)