|
|
|
@ -1,3 +1,4 @@
|
|
|
|
|
import torch
|
|
|
|
|
from torchbench.image_classification import ImageNet
|
|
|
|
|
from timm import create_model
|
|
|
|
|
from timm.data import resolve_data_config, create_transform
|
|
|
|
@ -77,7 +78,7 @@ model_list = [
|
|
|
|
|
_entry('mixnet_m', 'MixNet-M', '1907.09595'),
|
|
|
|
|
_entry('mixnet_s', 'MixNet-S', '1907.09595'),
|
|
|
|
|
_entry('mnasnet_100', 'MnasNet-B1', '1807.11626'),
|
|
|
|
|
_entry('mobilenetv3_100', 'MobileNet V3(1.0)', '1905.02244',
|
|
|
|
|
_entry('mobilenetv3_100', 'MobileNet V3-Large 1.0', '1905.02244',
|
|
|
|
|
model_desc='Trained in PyTorch with RMSProp, exponential LR decay, and hyper-params matching '
|
|
|
|
|
'paper as closely as possible.'),
|
|
|
|
|
_entry('resnet18', 'ResNet-18', '1812.01187'),
|
|
|
|
@ -216,4 +217,6 @@ for m in model_list:
|
|
|
|
|
data_root=os.environ.get('IMAGENET_DIR', './imagenet')
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|