|
|
@ -3,13 +3,13 @@
|
|
|
|
Hacked together by / Copyright 2020 Ross Wightman
|
|
|
|
Hacked together by / Copyright 2020 Ross Wightman
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
from itertools import repeat
|
|
|
|
from itertools import repeat
|
|
|
|
from torch._six import container_abcs
|
|
|
|
import collections.abc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# From PyTorch internals
|
|
|
|
# From PyTorch internals
|
|
|
|
def _ntuple(n):
|
|
|
|
def _ntuple(n):
|
|
|
|
def parse(x):
|
|
|
|
def parse(x):
|
|
|
|
if isinstance(x, container_abcs.Iterable):
|
|
|
|
if isinstance(x, collections.abc.Iterable):
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
return tuple(repeat(x, n))
|
|
|
|
return tuple(repeat(x, n))
|
|
|
|
return parse
|
|
|
|
return parse
|
|
|
|