You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
123 lines
3.9 KiB
123 lines
3.9 KiB
import copy
|
|
from collections import deque, defaultdict
|
|
from dataclasses import dataclass, field, replace, asdict
|
|
from typing import Any, Deque, Dict, Tuple, Optional, Union
|
|
|
|
|
|
@dataclass
|
|
class PretrainedCfg:
|
|
"""
|
|
"""
|
|
# weight locations
|
|
url: Optional[Union[str, Tuple[str, str]]] = None
|
|
file: Optional[str] = None
|
|
hf_hub_id: Optional[str] = None
|
|
hf_hub_filename: Optional[str] = None
|
|
|
|
source: Optional[str] = None # source of cfg / weight location used (url, file, hf-hub)
|
|
architecture: Optional[str] = None # architecture variant can be set when not implicit
|
|
custom_load: bool = False # use custom model specific model.load_pretrained() (ie for npz files)
|
|
|
|
# input / data config
|
|
input_size: Tuple[int, int, int] = (3, 224, 224)
|
|
test_input_size: Optional[Tuple[int, int, int]] = None
|
|
min_input_size: Optional[Tuple[int, int, int]] = None
|
|
fixed_input_size: bool = False
|
|
interpolation: str = 'bicubic'
|
|
crop_pct: float = 0.875
|
|
test_crop_pct: Optional[float] = None
|
|
crop_mode: str = 'center'
|
|
mean: Tuple[float, ...] = (0.485, 0.456, 0.406)
|
|
std: Tuple[float, ...] = (0.229, 0.224, 0.225)
|
|
|
|
# head config
|
|
num_classes: int = 1000
|
|
label_offset: Optional[int] = None
|
|
|
|
# model attributes that vary with above or required for pretrained adaptation
|
|
pool_size: Optional[Tuple[int, ...]] = None
|
|
test_pool_size: Optional[Tuple[int, ...]] = None
|
|
first_conv: Optional[str] = None
|
|
classifier: Optional[str] = None
|
|
|
|
license: Optional[str] = None
|
|
source_url: Optional[str] = None
|
|
paper: Optional[str] = None
|
|
notes: Optional[str] = None
|
|
|
|
@property
|
|
def has_weights(self):
|
|
return self.url or self.file or self.hf_hub_id
|
|
|
|
def to_dict(self, remove_source=False, remove_null=True):
|
|
return filter_pretrained_cfg(
|
|
asdict(self),
|
|
remove_source=remove_source,
|
|
remove_null=remove_null
|
|
)
|
|
|
|
|
|
def filter_pretrained_cfg(cfg, remove_source=False, remove_null=True):
|
|
filtered_cfg = {}
|
|
for k, v in cfg.items():
|
|
if remove_source and k in {'url', 'file', 'hf_hub_id', 'hf_hub_id', 'hf_hub_filename', 'source'}:
|
|
continue
|
|
if remove_null and v is None:
|
|
continue
|
|
filtered_cfg[k] = v
|
|
return filtered_cfg
|
|
|
|
|
|
@dataclass
|
|
class DefaultCfg:
|
|
tags: Deque[str] = field(default_factory=deque) # priority queue of tags (first is default)
|
|
cfgs: Dict[str, PretrainedCfg] = field(default_factory=dict) # pretrained cfgs by tag
|
|
is_pretrained: bool = False # at least one of the configs has a pretrained source set
|
|
|
|
@property
|
|
def default(self):
|
|
return self.cfgs[self.tags[0]]
|
|
|
|
@property
|
|
def default_with_tag(self):
|
|
tag = self.tags[0]
|
|
return tag, self.cfgs[tag]
|
|
|
|
|
|
def split_model_name_tag(model_name: str, no_tag=''):
|
|
model_name, *tag_list = model_name.split('.', 1)
|
|
tag = tag_list[0] if tag_list else no_tag
|
|
return model_name, tag
|
|
|
|
|
|
def generate_default_cfgs(cfgs: Dict[str, Union[Dict[str, Any], PretrainedCfg]]):
|
|
out = defaultdict(DefaultCfg)
|
|
default_set = set() # no tag and tags ending with * are prioritized as default
|
|
|
|
for k, v in cfgs.items():
|
|
if isinstance(v, dict):
|
|
v = PretrainedCfg(**v)
|
|
has_weights = v.has_weights
|
|
|
|
model, tag = split_model_name_tag(k)
|
|
is_default_set = model in default_set
|
|
priority = (has_weights and not tag) or (tag.endswith('*') and not is_default_set)
|
|
tag = tag.strip('*')
|
|
|
|
default_cfg = out[model]
|
|
|
|
if priority:
|
|
default_cfg.tags.appendleft(tag)
|
|
default_set.add(model)
|
|
elif has_weights and not default_cfg.is_pretrained:
|
|
default_cfg.tags.appendleft(tag)
|
|
else:
|
|
default_cfg.tags.append(tag)
|
|
|
|
if has_weights:
|
|
default_cfg.is_pretrained = True
|
|
|
|
default_cfg.cfgs[tag] = v
|
|
|
|
return out
|