Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Default parameter overrides. #33

Merged
merged 20 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions cyclopts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
"DocstringError",
"InvalidCommandError",
"MissingArgumentError",
"MultipleParameterAnnotationError",
"Parameter",
"UnusedCliTokensError",
"ValidationError",
Expand All @@ -31,7 +30,6 @@
DocstringError,
InvalidCommandError,
MissingArgumentError,
MultipleParameterAnnotationError,
UnusedCliTokensError,
ValidationError,
)
Expand Down
82 changes: 45 additions & 37 deletions cyclopts/bind.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import shlex
import sys
from functools import lru_cache
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union, get_origin
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union, get_origin

from cyclopts.coercion import resolve, token_count
from cyclopts.exceptions import (
Expand All @@ -13,7 +13,7 @@
RepeatArgumentError,
ValidationError,
)
from cyclopts.parameter import get_hint_parameter, get_names, validate_command
from cyclopts.parameter import Parameter, get_hint_parameter, get_names, validate_command


def normalize_tokens(tokens: Union[None, str, Iterable[str]]) -> List[str]:
Expand All @@ -27,7 +27,9 @@ def normalize_tokens(tokens: Union[None, str, Iterable[str]]) -> List[str]:


@lru_cache(maxsize=16)
def cli2parameter(f: Callable) -> Dict[str, Tuple[inspect.Parameter, Any]]:
def cli2parameter(
f: Callable, default_parameter: Optional[Parameter] = None
) -> Dict[str, Tuple[inspect.Parameter, Any]]:
"""Creates a dictionary mapping CLI keywords to python keywords.

Typically the mapping is something like::
Expand All @@ -44,14 +46,14 @@ def cli2parameter(f: Callable) -> Dict[str, Tuple[inspect.Parameter, Any]]:
signature = inspect.signature(f)
for iparam in signature.parameters.values():
annotation = str if iparam.annotation is iparam.empty else iparam.annotation
_, cparam = get_hint_parameter(annotation)
_, cparam = get_hint_parameter(annotation, default_parameter=default_parameter)

if not cparam.parse:
if cparam.parse is False:
continue

if iparam.kind in (iparam.POSITIONAL_OR_KEYWORD, iparam.KEYWORD_ONLY):
hint = resolve(annotation)
keys = get_names(iparam)
keys = get_names(iparam, default_parameter=default_parameter)

for key in keys:
mapping[key] = (iparam, True if hint is bool else None)
Expand All @@ -63,8 +65,8 @@ def cli2parameter(f: Callable) -> Dict[str, Tuple[inspect.Parameter, Any]]:


@lru_cache(maxsize=16)
def parameter2cli(f: Callable) -> Dict[inspect.Parameter, List[str]]:
c2p = cli2parameter(f)
def parameter2cli(f: Callable, default_parameter: Optional[Parameter] = None) -> Dict[inspect.Parameter, List[str]]:
c2p = cli2parameter(f, default_parameter=default_parameter)
p2c = {}

for cli, tup in c2p.items():
Expand All @@ -75,14 +77,14 @@ def parameter2cli(f: Callable) -> Dict[inspect.Parameter, List[str]]:
signature = inspect.signature(f)
for iparam in signature.parameters.values():
annotation = str if iparam.annotation is iparam.empty else iparam.annotation
_, cparam = get_hint_parameter(annotation)
_, cparam = get_hint_parameter(annotation, default_parameter=default_parameter)

if not cparam.parse:
if cparam.parse is False:
continue

# POSITIONAL_OR_KEYWORD and KEYWORD_ONLY already handled in cli2parameter
if iparam.kind in (iparam.POSITIONAL_ONLY, iparam.VAR_KEYWORD, iparam.VAR_POSITIONAL):
p2c[iparam] = get_names(iparam)
p2c[iparam] = get_names(iparam, default_parameter=default_parameter)

return p2c

Expand All @@ -95,8 +97,8 @@ def _cli_kw_to_f_kw(cli_key: str):
return cli_key


def _parse_kw_and_flags(f, tokens, mapping):
cli2kw = cli2parameter(f)
def _parse_kw_and_flags(f, tokens, mapping, default_parameter: Optional[Parameter] = None):
cli2kw = cli2parameter(f, default_parameter=default_parameter)
kwargs_parameter = next((p for p in inspect.signature(f).parameters.values() if p.kind == p.VAR_KEYWORD), None)

if kwargs_parameter:
Expand Down Expand Up @@ -140,7 +142,7 @@ def _parse_kw_and_flags(f, tokens, mapping):
if implicit_value is not None:
cli_values.append(implicit_value)
else:
consume_count += max(1, token_count(parameter.annotation)[0])
consume_count += max(1, token_count(parameter.annotation, default_parameter=default_parameter)[0])

try:
for j in range(consume_count):
Expand All @@ -150,7 +152,7 @@ def _parse_kw_and_flags(f, tokens, mapping):

skip_next_iterations = consume_count

_, repeatable = token_count(parameter.annotation)
_, repeatable = token_count(parameter.annotation, default_parameter=default_parameter)
if parameter is kwargs_parameter:
assert kwargs_key is not None
if kwargs_key in mapping[parameter] and not repeatable:
Expand All @@ -167,16 +169,18 @@ def _parse_kw_and_flags(f, tokens, mapping):
return unused_tokens


def _parse_pos(f: Callable, tokens: Iterable[str], mapping: Dict) -> List[str]:
def _parse_pos(
f: Callable, tokens: Iterable[str], mapping: Dict, default_parameter: Optional[Parameter] = None
) -> List[str]:
tokens = list(tokens)
signature = inspect.signature(f)

def remaining_parameters():
for parameter in signature.parameters.values():
_, cparam = get_hint_parameter(parameter.annotation)
if not cparam.parse:
_, cparam = get_hint_parameter(parameter.annotation, default_parameter=default_parameter)
if cparam.parse is False:
continue
_, consume_all = token_count(parameter.annotation)
_, consume_all = token_count(parameter.annotation, default_parameter=default_parameter)
if parameter in mapping and not consume_all:
continue
if parameter.kind is parameter.KEYWORD_ONLY: # pragma: no cover
Expand All @@ -193,7 +197,7 @@ def remaining_parameters():
tokens = []
break

consume_count, consume_all = token_count(iparam.annotation)
consume_count, consume_all = token_count(iparam.annotation, default_parameter=default_parameter)
if consume_all:
mapping.setdefault(iparam, [])
mapping[iparam] = tokens + mapping[iparam]
Expand All @@ -211,7 +215,7 @@ def remaining_parameters():
return tokens


def _parse_env(f, mapping):
def _parse_env(f, mapping, default_parameter: Optional[Parameter] = None):
"""Populate argument defaults from environment variables.

In cyclopts, arguments are parsed with the following priority:
Expand All @@ -225,7 +229,7 @@ def _parse_env(f, mapping):
# Don't check environment variables for already-parsed parameters.
continue

_, cparam = get_hint_parameter(iparam.annotation)
_, cparam = get_hint_parameter(iparam.annotation, default_parameter=default_parameter)
for env_var_name in cparam.env_var:
try:
env_var_value = os.environ[env_var_name]
Expand All @@ -241,7 +245,7 @@ def _is_required(parameter: inspect.Parameter) -> bool:
return parameter.default is parameter.empty


def _bind(f: Callable, mapping: Dict[inspect.Parameter, Any]):
def _bind(f: Callable, mapping: Dict[inspect.Parameter, Any], default_parameter: Optional[Parameter] = None):
"""Bind the mapping to the function signature.

Better than directly using ``signature.bind`` because this can handle
Expand All @@ -268,8 +272,8 @@ def f_pos_append(p):
# * Parameters before a ``*args`` may have type ``POSITIONAL_OR_KEYWORD``.
# * Only args before a ``/`` are ``POSITIONAL_ONLY``.
for iparam in signature.parameters.values():
_, cparam = get_hint_parameter(iparam.annotation)
if not cparam.parse:
_, cparam = get_hint_parameter(iparam.annotation, default_parameter=default_parameter)
if cparam.parse is False:
has_unparsed_parameters |= _is_required(iparam)
continue

Expand All @@ -292,12 +296,12 @@ def f_pos_append(p):
return bound


def _convert(mapping: Dict[inspect.Parameter, List[str]]) -> dict:
def _convert(mapping: Dict[inspect.Parameter, List[str]], default_parameter: Optional[Parameter] = None) -> dict:
coerced = {}
for iparam, parameter_tokens in mapping.items():
type_, cparam = get_hint_parameter(iparam.annotation)
type_, cparam = get_hint_parameter(iparam.annotation, default_parameter=default_parameter)

if not cparam.parse:
if cparam.parse is False:
continue

# Checking if parameter_token is a string is a little jank,
Expand Down Expand Up @@ -334,7 +338,9 @@ def _convert(mapping: Dict[inspect.Parameter, List[str]]) -> dict:
return coerced


def create_bound_arguments(f: Callable, tokens: List[str]) -> Tuple[inspect.BoundArguments, List[str]]:
def create_bound_arguments(
f: Callable, tokens: List[str], default_parameter: Optional[Parameter] = None
) -> Tuple[inspect.BoundArguments, List[str]]:
"""Parse and coerce CLI tokens to match a function's signature.

Parameters
Expand All @@ -343,6 +349,8 @@ def create_bound_arguments(f: Callable, tokens: List[str]) -> Tuple[inspect.Boun
A function with (possibly) annotated parameters.
tokens: List[str]
CLI tokens to parse and coerce to match ``f``'s signature.
default_parameter: Optional[Parameter]
Default Parameter configuration.

Returns
-------
Expand All @@ -355,19 +363,19 @@ def create_bound_arguments(f: Callable, tokens: List[str]) -> Tuple[inspect.Boun
# Note: mapping is updated inplace
mapping: Dict[inspect.Parameter, List[str]] = {}

validate_command(f)
validate_command(f, default_parameter=default_parameter)

c2p, p2c = None, None
unused_tokens = []

try:
c2p = cli2parameter(f)
p2c = parameter2cli(f)
unused_tokens = _parse_kw_and_flags(f, tokens, mapping)
unused_tokens = _parse_pos(f, unused_tokens, mapping)
_parse_env(f, mapping)
coerced = _convert(mapping)
bound = _bind(f, coerced)
c2p = cli2parameter(f, default_parameter=default_parameter)
p2c = parameter2cli(f, default_parameter=default_parameter)
unused_tokens = _parse_kw_and_flags(f, tokens, mapping, default_parameter=default_parameter)
unused_tokens = _parse_pos(f, unused_tokens, mapping, default_parameter=default_parameter)
_parse_env(f, mapping, default_parameter=default_parameter)
coerced = _convert(mapping, default_parameter=default_parameter)
bound = _bind(f, coerced, default_parameter=default_parameter)
except CycloptsError as e:
e.target = f
e.root_input_tokens = tokens
Expand Down
Loading