Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Default parameter overrides. #33

Merged
merged 20 commits into from
Dec 20, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
shove default_parameter EVERYWHERE.
  • Loading branch information
BrianPugh committed Dec 20, 2023
commit 09b27b679ddd5748ec54c5c5eea7936709bb23d0
68 changes: 38 additions & 30 deletions cyclopts/bind.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
import shlex
import sys
from functools import lru_cache
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union, get_origin
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union, get_origin

from cyclopts.coercion import resolve, token_count
from cyclopts.exceptions import (
@@ -13,7 +13,7 @@
RepeatArgumentError,
ValidationError,
)
from cyclopts.parameter import get_hint_parameter, get_names, validate_command
from cyclopts.parameter import Parameter, get_hint_parameter, get_names, validate_command


def normalize_tokens(tokens: Union[None, str, Iterable[str]]) -> List[str]:
@@ -27,7 +27,9 @@ def normalize_tokens(tokens: Union[None, str, Iterable[str]]) -> List[str]:


@lru_cache(maxsize=16)
def cli2parameter(f: Callable) -> Dict[str, Tuple[inspect.Parameter, Any]]:
def cli2parameter(
f: Callable, default_parameter: Optional[Parameter] = None
) -> Dict[str, Tuple[inspect.Parameter, Any]]:
"""Creates a dictionary mapping CLI keywords to python keywords.

Typically the mapping is something like::
@@ -44,14 +46,14 @@ def cli2parameter(f: Callable) -> Dict[str, Tuple[inspect.Parameter, Any]]:
signature = inspect.signature(f)
for iparam in signature.parameters.values():
annotation = str if iparam.annotation is iparam.empty else iparam.annotation
_, cparam = get_hint_parameter(annotation)
_, cparam = get_hint_parameter(annotation, default_parameter=default_parameter)

if cparam.parse is False:
continue

if iparam.kind in (iparam.POSITIONAL_OR_KEYWORD, iparam.KEYWORD_ONLY):
hint = resolve(annotation)
keys = get_names(iparam)
keys = get_names(iparam, default_parameter=default_parameter)

for key in keys:
mapping[key] = (iparam, True if hint is bool else None)
@@ -63,8 +65,8 @@ def cli2parameter(f: Callable) -> Dict[str, Tuple[inspect.Parameter, Any]]:


@lru_cache(maxsize=16)
def parameter2cli(f: Callable) -> Dict[inspect.Parameter, List[str]]:
c2p = cli2parameter(f)
def parameter2cli(f: Callable, default_parameter: Optional[Parameter] = None) -> Dict[inspect.Parameter, List[str]]:
c2p = cli2parameter(f, default_parameter=default_parameter)
p2c = {}

for cli, tup in c2p.items():
@@ -95,8 +97,8 @@ def _cli_kw_to_f_kw(cli_key: str):
return cli_key


def _parse_kw_and_flags(f, tokens, mapping):
cli2kw = cli2parameter(f)
def _parse_kw_and_flags(f, tokens, mapping, default_parameter: Optional[Parameter] = None):
cli2kw = cli2parameter(f, default_parameter=default_parameter)
kwargs_parameter = next((p for p in inspect.signature(f).parameters.values() if p.kind == p.VAR_KEYWORD), None)

if kwargs_parameter:
@@ -140,7 +142,7 @@ def _parse_kw_and_flags(f, tokens, mapping):
if implicit_value is not None:
cli_values.append(implicit_value)
else:
consume_count += max(1, token_count(parameter.annotation)[0])
consume_count += max(1, token_count(parameter.annotation, default_parameter=default_parameter)[0])

try:
for j in range(consume_count):
@@ -150,7 +152,7 @@ def _parse_kw_and_flags(f, tokens, mapping):

skip_next_iterations = consume_count

_, repeatable = token_count(parameter.annotation)
_, repeatable = token_count(parameter.annotation, default_parameter=default_parameter)
if parameter is kwargs_parameter:
assert kwargs_key is not None
if kwargs_key in mapping[parameter] and not repeatable:
@@ -167,16 +169,18 @@ def _parse_kw_and_flags(f, tokens, mapping):
return unused_tokens


def _parse_pos(f: Callable, tokens: Iterable[str], mapping: Dict) -> List[str]:
def _parse_pos(
f: Callable, tokens: Iterable[str], mapping: Dict, default_parameter: Optional[Parameter] = None
) -> List[str]:
tokens = list(tokens)
signature = inspect.signature(f)

def remaining_parameters():
for parameter in signature.parameters.values():
_, cparam = get_hint_parameter(parameter.annotation)
_, cparam = get_hint_parameter(parameter.annotation, default_parameter=default_parameter)
if cparam.parse is False:
continue
_, consume_all = token_count(parameter.annotation)
_, consume_all = token_count(parameter.annotation, default_parameter=default_parameter)
if parameter in mapping and not consume_all:
continue
if parameter.kind is parameter.KEYWORD_ONLY: # pragma: no cover
@@ -193,7 +197,7 @@ def remaining_parameters():
tokens = []
break

consume_count, consume_all = token_count(iparam.annotation)
consume_count, consume_all = token_count(iparam.annotation, default_parameter=default_parameter)
if consume_all:
mapping.setdefault(iparam, [])
mapping[iparam] = tokens + mapping[iparam]
@@ -211,7 +215,7 @@ def remaining_parameters():
return tokens


def _parse_env(f, mapping):
def _parse_env(f, mapping, default_parameter: Optional[Parameter] = None):
"""Populate argument defaults from environment variables.

In cyclopts, arguments are parsed with the following priority:
@@ -225,7 +229,7 @@ def _parse_env(f, mapping):
# Don't check environment variables for already-parsed parameters.
continue

_, cparam = get_hint_parameter(iparam.annotation)
_, cparam = get_hint_parameter(iparam.annotation, default_parameter=default_parameter)
for env_var_name in cparam.env_var:
try:
env_var_value = os.environ[env_var_name]
@@ -241,7 +245,7 @@ def _is_required(parameter: inspect.Parameter) -> bool:
return parameter.default is parameter.empty


def _bind(f: Callable, mapping: Dict[inspect.Parameter, Any]):
def _bind(f: Callable, mapping: Dict[inspect.Parameter, Any], default_parameter: Optional[Parameter] = None):
"""Bind the mapping to the function signature.

Better than directly using ``signature.bind`` because this can handle
@@ -268,7 +272,7 @@ def f_pos_append(p):
# * Parameters before a ``*args`` may have type ``POSITIONAL_OR_KEYWORD``.
# * Only args before a ``/`` are ``POSITIONAL_ONLY``.
for iparam in signature.parameters.values():
_, cparam = get_hint_parameter(iparam.annotation)
_, cparam = get_hint_parameter(iparam.annotation, default_parameter=default_parameter)
if cparam.parse is False:
has_unparsed_parameters |= _is_required(iparam)
continue
@@ -292,10 +296,10 @@ def f_pos_append(p):
return bound


def _convert(mapping: Dict[inspect.Parameter, List[str]]) -> dict:
def _convert(mapping: Dict[inspect.Parameter, List[str]], default_parameter: Optional[Parameter] = None) -> dict:
coerced = {}
for iparam, parameter_tokens in mapping.items():
type_, cparam = get_hint_parameter(iparam.annotation)
type_, cparam = get_hint_parameter(iparam.annotation, default_parameter=default_parameter)

if cparam.parse is False:
continue
@@ -334,7 +338,9 @@ def _convert(mapping: Dict[inspect.Parameter, List[str]]) -> dict:
return coerced


def create_bound_arguments(f: Callable, tokens: List[str]) -> Tuple[inspect.BoundArguments, List[str]]:
def create_bound_arguments(
f: Callable, tokens: List[str], default_parameter: Optional[Parameter] = None
) -> Tuple[inspect.BoundArguments, List[str]]:
"""Parse and coerce CLI tokens to match a function's signature.

Parameters
@@ -343,6 +349,8 @@ def create_bound_arguments(f: Callable, tokens: List[str]) -> Tuple[inspect.Boun
A function with (possibly) annotated parameters.
tokens: List[str]
CLI tokens to parse and coerce to match ``f``'s signature.
default_parameter: Optional[Parameter]
Default Parameter configuration.

Returns
-------
@@ -355,19 +363,19 @@ def create_bound_arguments(f: Callable, tokens: List[str]) -> Tuple[inspect.Boun
# Note: mapping is updated inplace
mapping: Dict[inspect.Parameter, List[str]] = {}

validate_command(f)
validate_command(f, default_parameter=default_parameter)

c2p, p2c = None, None
unused_tokens = []

try:
c2p = cli2parameter(f)
p2c = parameter2cli(f)
unused_tokens = _parse_kw_and_flags(f, tokens, mapping)
unused_tokens = _parse_pos(f, unused_tokens, mapping)
_parse_env(f, mapping)
coerced = _convert(mapping)
bound = _bind(f, coerced)
c2p = cli2parameter(f, default_parameter=default_parameter)
p2c = parameter2cli(f, default_parameter=default_parameter)
unused_tokens = _parse_kw_and_flags(f, tokens, mapping, default_parameter=default_parameter)
unused_tokens = _parse_pos(f, unused_tokens, mapping, default_parameter=default_parameter)
_parse_env(f, mapping, default_parameter=default_parameter)
coerced = _convert(mapping, default_parameter=default_parameter)
bound = _bind(f, coerced, default_parameter=default_parameter)
except CycloptsError as e:
e.target = f
e.root_input_tokens = tokens
47 changes: 30 additions & 17 deletions cyclopts/coercion.py
Original file line number Diff line number Diff line change
@@ -2,12 +2,15 @@
import inspect
from enum import Enum
from inspect import isclass
from typing import Any, Iterable, List, Literal, Optional, Set, Tuple, Type, Union, get_args, get_origin
from typing import TYPE_CHECKING, Any, Iterable, List, Literal, Optional, Set, Tuple, Type, Union, get_args, get_origin

from typing_extensions import Annotated

from cyclopts.exceptions import CoercionError

if TYPE_CHECKING:
from cyclopts.parameter import Parameter

# from types import NoneType is available >=3.10
NoneType = type(None)
AnnotatedType = type(Annotated[int, 0])
@@ -57,30 +60,34 @@ def _bytearray(s: str) -> bytearray:
}


def _convert(type_, element):
def _convert(type_, element, default_parameter=None):
origin_type = get_origin(type_)
inner_types = [resolve(x) for x in get_args(type_)]

if type_ in _implicit_iterable_type_mapping:
return _convert(_implicit_iterable_type_mapping[type_], element)
return _convert(_implicit_iterable_type_mapping[type_], element, default_parameter=default_parameter)
elif origin_type is collections.abc.Iterable:
assert len(inner_types) == 1
return _convert(List[inner_types[0]], element) # pyright: ignore[reportGeneralTypeIssues]
return _convert(
List[inner_types[0]], # pyright: ignore[reportGeneralTypeIssues]
element,
default_parameter=default_parameter,
)

elif origin_type is Union:
for t in inner_types:
if t is NoneType:
continue
try:
return _convert(t, element)
return _convert(t, element, default_parameter=default_parameter)
except Exception:
pass
else:
raise CoercionError(input_value=element, target_type=type_)
elif origin_type is Literal:
for choice in get_args(type_):
try:
res = _convert(type(choice), (element))
res = _convert(type(choice), (element), default_parameter=default_parameter)
except Exception:
continue
if res == choice:
@@ -94,14 +101,16 @@ def _convert(type_, element):
return member
raise CoercionError(input_value=element, target_type=type_)
elif origin_type in _iterable_types:
count, _ = token_count(inner_types[0])
count, _ = token_count(inner_types[0], default_parameter=default_parameter)
if count > 1:
gen = zip(*[iter(element)] * count)
else:
gen = element
return origin_type(_convert(inner_types[0], e) for e in gen) # pyright: ignore[reportOptionalCall]
return origin_type(
_convert(inner_types[0], e, default_parameter=default_parameter) for e in gen
) # pyright: ignore[reportOptionalCall]
elif origin_type is tuple:
return tuple(_convert(t, e) for t, e in zip(inner_types, element))
return tuple(_convert(t, e, default_parameter=default_parameter) for t, e in zip(inner_types, element))
else:
try:
return _converters.get(type_, type_)(element)
@@ -154,7 +163,7 @@ def resolve_annotated(type_: Type) -> Type:
return type_


def coerce(type_: Type, *args: str):
def coerce(type_: Type, *args: str, default_parameter=None):
"""Coerce variables into a specified type.

Internally used to coercing string CLI tokens into python builtin types.
@@ -195,16 +204,18 @@ def coerce(type_: Type, *args: str):
raise ValueError(
f"Number of arguments does not match the tuple structure: expected {len(inner_types)} but got {len(args)}"
)
return tuple(_convert(inner_type, arg) for inner_type, arg in zip(inner_types, args))
return tuple(
_convert(inner_type, arg, default_parameter=default_parameter) for inner_type, arg in zip(inner_types, args)
)
elif (origin_type or type_) in _iterable_types or origin_type is collections.abc.Iterable:
return _convert(type_, args)
return _convert(type_, args, default_parameter=default_parameter)
elif len(args) == 1:
return _convert(type_, args[0])
return _convert(type_, args[0], default_parameter=default_parameter)
else:
return [_convert(type_, item) for item in args]
return [_convert(type_, item, default_parameter=default_parameter) for item in args]


def token_count(type_: Type) -> Tuple[int, bool]:
def token_count(type_: Type, default_parameter: Optional["Parameter"] = None) -> Tuple[int, bool]:
"""The number of tokens after a keyword the parameter should consume.

Returns
@@ -213,13 +224,15 @@ def token_count(type_: Type) -> Tuple[int, bool]:
Number of tokens that constitute a single element.
bool
If this is ``True`` and positional, consume all remaining tokens.
default_parameter: Optional[Parameter]
Default Parameter configuration.
"""
from cyclopts.parameter import get_hint_parameter

if type_ is inspect.Parameter.empty:
return 1, False

type_, param = get_hint_parameter(type_)
type_, param = get_hint_parameter(type_, default_parameter=default_parameter)
if param.token_count is not None:
return abs(param.token_count), param.token_count < 0

@@ -233,7 +246,7 @@ def token_count(type_: Type) -> Tuple[int, bool]:
elif type_ in _iterable_types or (origin_type in _iterable_types and len(get_args(type_)) == 0):
return 1, True
elif (origin_type in _iterable_types or origin_type is collections.abc.Iterable) and len(get_args(type_)):
return token_count(get_args(type_)[0])[0], True
return token_count(get_args(type_)[0], default_parameter=default_parameter)[0], True
else:
return 1, False

Loading