diff --git a/timm/models/_factory.py b/timm/models/_factory.py index aaaa7c73..0b12f9dd 100644 --- a/timm/models/_factory.py +++ b/timm/models/_factory.py @@ -56,30 +56,16 @@ def create_model( Args: - model_name (str): - Name of model to instantiate. - pretrained (`bool`, *optional*, defaults to `False`): - If set to `True`, load pretrained ImageNet-1k weights. - pretrained_cfg (Union[str, dict, PretrainedCfg], *optional*): - Pass in an external pretrained_cfg for model. - pretrained_cfg_overlay (dict, *optional*): - Replace key-values in base pretrained_cfg with these. - checkpoint_path (str, *optional*): - Path of checkpoint to load _after_ the model is initialized. - scriptable (bool, *optional*): - Set layer config so that model is jit scriptable (not working for all models yet). - exportable (bool, *optional*): - Set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet). - no_jit (bool, *optional*): - Set layer config so that model doesn't utilize jit scripted layers (so far activations only). - - **Keyword Args**: - - - **drop_rate** (float, *optional*, defaults to `0.0`): - Dropout rate for training. - - **global_pool** (str, *optional*, defaults to `'avg'`): - Global pooling type. - - All other kwargs are consumed by builder or model ``__init__()``. + model_name (`str`): Name of model to instantiate. + pretrained (`bool`): If set to `True`, load pretrained ImageNet-1k weights. + pretrained_cfg (`Union[str, dict, PretrainedCfg]`): Pass in an external pretrained_cfg for model. + pretrained_cfg_overlay (`dict`): Replace key-values in base pretrained_cfg with these. + checkpoint_path (`str`): Path of checkpoint to load _after_ the model is initialized. + scriptable (`bool`): Set layer config so that model is jit scriptable (not working for all models yet). + exportable (`bool`): Set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet). + no_jit (`bool`): Set layer config so that model doesn't utilize jit scripted layers (so far activations only). + **drop_rate (`float`): Dropout rate for training. Defaults to `0.0`. + **global_pool (`str`): Global pooling type. Defaults to `'avg'`. Example: