Skip to content

Commit

Permalink
feat: trait typing (#818)
Browse files Browse the repository at this point in the history
Co-authored-by: Vidar Tonaas Fauske <[email protected]>
Co-authored-by: Steven Silvester <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored Sep 11, 2023
1 parent b1cdea4 commit aa1081b
Show file tree
Hide file tree
Showing 10 changed files with 864 additions and 130 deletions.
19 changes: 5 additions & 14 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,13 @@
HERE = osp.abspath(osp.dirname(__file__))
ROOT = osp.dirname(osp.dirname(HERE))

from traitlets import version_info

# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
# sys.path.insert(0, os.path.abspath('.'))

# We load the ipython release info into a dict by explicit execution
_release = {} # type:ignore
exec( # noqa
compile(
open(osp.join(ROOT, "traitlets/_version.py")).read(),
"../../traitlets/_version.py",
"exec",
),
_release,
)

# -- General configuration ------------------------------------------------

# If your documentation needs a minimal Sphinx version, state it here.
Expand All @@ -64,7 +55,7 @@
source_suffix = ".rst"

# Add dev disclaimer.
if _release["version_info"][-1] == "dev":
if version_info[-1] == "dev":
rst_prolog = """
.. note::
Expand All @@ -89,9 +80,9 @@
# built documents.
#
# The short X.Y version.
version = ".".join(map(str, _release["version_info"][:2]))
version = ".".join(map(str, version_info[:2]))
# The full version, including alpha/beta/rc tags.
release = _release["__version__"]
release = "__version__"

# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
Expand Down
2 changes: 1 addition & 1 deletion examples/docs/from_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class App(Application):
)

def start(self):
print(f"key={self.key}")
print(f"key={self.key.decode('utf8')}")


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions examples/myapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def start(self):
print("app.config:")
print(self.config)
print("try running with --help-all to see all available flags")
assert self.log is not None
self.log.debug("Debug Message")
self.log.info("Info Message")
self.log.warning("Warning Message")
Expand Down
10 changes: 7 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ classifiers = [
urls = {Homepage = "https://github.com/ipython/traitlets"}
requires-python = ">=3.7"
dynamic = ["version"]
dependencies = ["typing_extensions>=4.0.1"]

[project.optional-dependencies]
test = ["pytest", "pytest-mock", "pre-commit", "argcomplete>=2.0"]
test = ["pytest>=7.0,<7.2", "pytest-mock", "pre-commit", "argcomplete>=2.0", "pytest-mypy-testing", "mypy @ git+https://github.com/python/mypy.git@cb1d1a0baba37f35268cb605b7345726f257f960#egg=mypy"]
docs = [
"myst-parser",
"pydata-sphinx-theme",
Expand All @@ -32,6 +33,9 @@ docs = [
[tool.hatch.version]
path = "traitlets/_version.py"

[tool.hatch.metadata]
allow-direct-references = true

[tool.hatch.envs.docs]
features = ["docs"]
[tool.hatch.envs.docs.scripts]
Expand All @@ -52,7 +56,7 @@ nowarn = "test -W default {args}"

[tool.hatch.envs.typing]
features = ["test"]
dependencies = ["mypy>=0.990"]
dependencies = ["mypy @ git+https://github.com/python/mypy.git@cb1d1a0baba37f35268cb605b7345726f257f960#egg=mypy"]
[tool.hatch.envs.typing.scripts]
test = "mypy --install-types --non-interactive {args:.}"

Expand Down Expand Up @@ -89,7 +93,7 @@ warn_unused_configs = true
warn_redundant_casts = true
warn_return_any = true
warn_unused_ignores = true
exclude = ["examples/docs/configs"]
exclude = ["examples/docs/configs", "traitlets/tests/test_typing.py"]

[tool.pytest.ini_options]
addopts = "--durations=10 -ra --showlocals --doctest-modules --color yes --ignore examples/docs/configs"
Expand Down
36 changes: 22 additions & 14 deletions traitlets/config/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,21 +149,29 @@ class Application(SingletonConfigurable):

# The name of the application, will usually match the name of the command
# line application
name: t.Union[str, Unicode] = Unicode("application")
name: t.Union[str, Unicode[str, t.Union[str, bytes]]] = Unicode("application")

# The description of the application that is printed at the beginning
# of the help.
description: t.Union[str, Unicode] = Unicode("This is an application.")
description: t.Union[str, Unicode[str, t.Union[str, bytes]]] = Unicode(
"This is an application."
)
# default section descriptions
option_description: t.Union[str, Unicode] = Unicode(option_description)
keyvalue_description: t.Union[str, Unicode] = Unicode(keyvalue_description)
subcommand_description: t.Union[str, Unicode] = Unicode(subcommand_description)
option_description: t.Union[str, Unicode[str, t.Union[str, bytes]]] = Unicode(
option_description
)
keyvalue_description: t.Union[str, Unicode[str, t.Union[str, bytes]]] = Unicode(
keyvalue_description
)
subcommand_description: t.Union[str, Unicode[str, t.Union[str, bytes]]] = Unicode(
subcommand_description
)

python_config_loader_class = PyFileConfigLoader
json_config_loader_class = JSONFileConfigLoader

# The usage and example string that goes at the end of the help string.
examples: t.Union[str, Unicode] = Unicode()
examples: t.Union[str, Unicode[str, t.Union[str, bytes]]] = Unicode()

# A sequence of Configurable subclasses whose config=True attributes will
# be exposed at the command line.
Expand All @@ -190,30 +198,30 @@ def _classes_inc_parents(self, classes=None):
yield parent

# The version string of this application.
version: t.Union[str, Unicode] = Unicode("0.0")
version: t.Union[str, Unicode[str, t.Union[str, bytes]]] = Unicode("0.0")

# the argv used to initialize the application
argv: t.Union[t.List[str], List] = List()

# Whether failing to load config files should prevent startup
raise_config_file_errors: t.Union[bool, Bool] = Bool(
raise_config_file_errors: t.Union[bool, Bool[bool, t.Union[bool, int]]] = Bool(
TRAITLETS_APPLICATION_RAISE_CONFIG_FILE_ERROR
)

# The log level for the application
log_level: t.Union[str, int, Enum] = Enum(
log_level: t.Union[str, int, Enum[t.Any, t.Any]] = Enum(
(0, 10, 20, 30, 40, 50, "DEBUG", "INFO", "WARN", "ERROR", "CRITICAL"),
default_value=logging.WARN,
help="Set the log level by value or name.",
).tag(config=True)

_log_formatter_cls = LevelFormatter

log_datefmt: t.Union[str, Unicode] = Unicode(
log_datefmt: t.Union[str, Unicode[str, t.Union[str, bytes]]] = Unicode(
"%Y-%m-%d %H:%M:%S", help="The date format used by logging formatters for %(asctime)s"
).tag(config=True)

log_format: t.Union[str, Unicode] = Unicode(
log_format: t.Union[str, Unicode[str, t.Union[str, bytes]]] = Unicode(
"[%(name)s]%(highlevel)s %(message)s",
help="The Logging format template",
).tag(config=True)
Expand Down Expand Up @@ -420,11 +428,11 @@ def _log_default(self):

_loaded_config_files = List()

show_config: t.Union[bool, Bool] = Bool(
show_config: t.Union[bool, Bool[bool, t.Union[bool, int]]] = Bool(
help="Instead of starting the Application, dump configuration to stdout"
).tag(config=True)

show_config_json: t.Union[bool, Bool] = Bool(
show_config_json: t.Union[bool, Bool[bool, t.Union[bool, int]]] = Bool(
help="Instead of starting the Application, dump configuration to stdout (as JSON)"
).tag(config=True)

Expand All @@ -436,7 +444,7 @@ def _show_config_json_changed(self, change):
def _show_config_changed(self, change):
if change.new:
self._save_start = self.start
self.start = self.start_show_config # type:ignore[method-assign]
self.start = self.start_show_config # type:ignore[assignment]

def __init__(self, **kwargs):
SingletonConfigurable.__init__(self, **kwargs)
Expand Down
2 changes: 2 additions & 0 deletions traitlets/config/configurable.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def _load_config(self, cfg, section_names=None, traits=None):
from difflib import get_close_matches

if isinstance(self, LoggingConfigurable):
assert self.log is not None
warn = self.log.warning
else:

Expand Down Expand Up @@ -462,6 +463,7 @@ def _validate_log(self, proposal):
@default("log")
def _log_default(self):
if isinstance(self.parent, LoggingConfigurable):
assert self.parent is not None
return self.parent.log
from traitlets import log

Expand Down
6 changes: 3 additions & 3 deletions traitlets/config/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ class LazyConfigValue(HasTraits):
_value = None

# list methods
_extend = List()
_prepend = List()
_inserts = List()
_extend: List = List()
_prepend: List = List()
_inserts: List = List()

def append(self, obj):
"""Append an item to a List"""
Expand Down
36 changes: 18 additions & 18 deletions traitlets/tests/test_traitlets.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class A(HasTraitsStub):
self.assertEqual(a._notify_new, 10)

def test_validate(self):
class MyTT(TraitType):
class MyTT(TraitType[int, int]):
def validate(self, inst, value):
return -1

Expand All @@ -127,7 +127,7 @@ class A(HasTraitsStub):
self.assertEqual(a.tt, -1)

def test_default_validate(self):
class MyIntTT(TraitType):
class MyIntTT(TraitType[int, int]):
def validate(self, obj, value):
if isinstance(value, int):
return value
Expand All @@ -154,7 +154,7 @@ class A(HasTraits):

def test_error(self):
class A(HasTraits):
tt = TraitType()
tt = TraitType[int, int]()

a = A()
self.assertRaises(TraitError, A.tt.error, a, 10)
Expand Down Expand Up @@ -270,14 +270,14 @@ def _default_x(self):
self.assertEqual(a._trait_values, {"x": 11})

def test_tag_metadata(self):
class MyIntTT(TraitType):
class MyIntTT(TraitType[int, int]):
metadata = {"a": 1, "b": 2}

a = MyIntTT(10).tag(b=3, c=4)
self.assertEqual(a.metadata, {"a": 1, "b": 3, "c": 4})

def test_metadata_localized_instance(self):
class MyIntTT(TraitType):
class MyIntTT(TraitType[int, int]):
metadata = {"a": 1, "b": 2}

a = MyIntTT(10)
Expand Down Expand Up @@ -325,7 +325,7 @@ class Foo(HasTraits):
self.assertEqual(Foo().bar, {})

def test_deprecated_metadata_access(self):
class MyIntTT(TraitType):
class MyIntTT(TraitType[int, int]):
metadata = {"a": 1, "b": 2}

a = MyIntTT(10)
Expand Down Expand Up @@ -394,12 +394,12 @@ class C(HasTraits):

def test_this_class(self):
class A(HasTraits):
t = This()
tt = This()
t = This["A"]()
tt = This["A"]()

class B(A):
tt = This()
ttt = This()
tt = This["A"]()
ttt = This["A"]()

self.assertEqual(A.t.this_class, A)
self.assertEqual(B.t.this_class, A)
Expand Down Expand Up @@ -1094,7 +1094,7 @@ class Bar(Foo):
class Bah:
pass

class FooInstance(Instance):
class FooInstance(Instance[Foo]):
klass = Foo

class A(HasTraits):
Expand Down Expand Up @@ -1170,15 +1170,15 @@ class Foo:

def inner():
class A(HasTraits):
inst = Instance(Foo())
inst = Instance(Foo()) # type:ignore

self.assertRaises(TraitError, inner)


class TestThis(TestCase):
def test_this_class(self):
class Foo(HasTraits):
this = This()
this = This["Foo"]()

f = Foo()
self.assertEqual(f.this, None)
Expand All @@ -1189,15 +1189,15 @@ class Foo(HasTraits):

def test_this_inst(self):
class Foo(HasTraits):
this = This()
this = This["Foo"]()

f = Foo()
f.this = Foo()
self.assertTrue(isinstance(f.this, Foo))

def test_subclass(self):
class Foo(HasTraits):
t = This()
t = This["Foo"]()

class Bar(Foo):
pass
Expand All @@ -1211,7 +1211,7 @@ class Bar(Foo):

def test_subclass_override(self):
class Foo(HasTraits):
t = This()
t = This["Foo"]()

class Bar(Foo):
t = This()
Expand Down Expand Up @@ -2423,11 +2423,11 @@ def test_notification_order():
# Traits for Forward Declaration Tests
###
class ForwardDeclaredInstanceTrait(HasTraits):
value = ForwardDeclaredInstance("ForwardDeclaredBar", allow_none=True)
value = ForwardDeclaredInstance["ForwardDeclaredBar"]("ForwardDeclaredBar", allow_none=True)


class ForwardDeclaredTypeTrait(HasTraits):
value = ForwardDeclaredType("ForwardDeclaredBar", allow_none=True)
value = ForwardDeclaredType[t.Any, t.Any]("ForwardDeclaredBar", allow_none=True)


class ForwardDeclaredInstanceListTrait(HasTraits):
Expand Down
Loading

0 comments on commit aa1081b

Please sign in to comment.