|
|
|
@ -4,6 +4,8 @@ import platform
|
|
|
|
|
import os
|
|
|
|
|
import fnmatch
|
|
|
|
|
|
|
|
|
|
_IS_MAC = platform.system() == 'Darwin'
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names, NodePathTracer
|
|
|
|
|
has_fx_feature_extraction = True
|
|
|
|
@ -322,6 +324,9 @@ def test_model_forward_features(model_name, batch_size):
|
|
|
|
|
assert not torch.isnan(o).any()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not _IS_MAC:
|
|
|
|
|
# MACOS test runners are really slow, only running tests below this point if not on a Darwin runner...
|
|
|
|
|
|
|
|
|
|
def _create_fx_model(model, train=False):
|
|
|
|
|
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
|
|
|
|
|
# So we trace once and look at the graph, and get the indices of the nodes that lead into the original fx output
|
|
|
|
|