You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
32 lines
958 B
32 lines
958 B
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def versiontuple(v):
|
|
return tuple(map(int, (v.split("."))))[:3]
|
|
|
|
|
|
if versiontuple(torch.__version__) >= versiontuple('1.2.0'):
|
|
Flatten = nn.Flatten
|
|
else:
|
|
class Flatten(nn.Module):
|
|
r"""
|
|
Flattens a contiguous range of dims into a tensor. For use with :class:`~nn.Sequential`.
|
|
Args:
|
|
start_dim: first dim to flatten (default = 1).
|
|
end_dim: last dim to flatten (default = -1).
|
|
Shape:
|
|
- Input: :math:`(N, *dims)`
|
|
- Output: :math:`(N, \prod *dims)` (for the default case).
|
|
"""
|
|
__constants__ = ['start_dim', 'end_dim']
|
|
|
|
def __init__(self, start_dim=1, end_dim=-1):
|
|
super(Flatten, self).__init__()
|
|
self.start_dim = start_dim
|
|
self.end_dim = end_dim
|
|
|
|
def forward(self, input):
|
|
return input.flatten(self.start_dim, self.end_dim)
|