Skip to content

Commit

Permalink
Conditionally define the protocol instead
Browse files Browse the repository at this point in the history
  • Loading branch information
layday committed Aug 19, 2022
1 parent c227b2f commit ce67df1
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 9 deletions.
7 changes: 5 additions & 2 deletions src/attr/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ from typing import (
List,
Mapping,
Optional,
Protocol,
Sequence,
Tuple,
Type,
Expand All @@ -26,8 +27,6 @@ from ._cmp import cmp_using as cmp_using
from ._version_info import VersionInfo
from ._typing_compat import AttrsInstance_

AttrsInstance = AttrsInstance_

if sys.version_info >= (3, 10):
from typing import TypeGuard
else:
Expand Down Expand Up @@ -65,6 +64,10 @@ _FieldTransformer = Callable[
# _ValidatorType from working when passed in a list or tuple.
_ValidatorArgType = Union[_ValidatorType[_T], Sequence[_ValidatorType[_T]]]

# We subclass this here to keep the protocol's qualified name clean.
class AttrsInstance(AttrsInstance_, Protocol):
pass

# _make --

NOTHING: object
Expand Down
12 changes: 6 additions & 6 deletions src/attr/_typing_compat.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ from typing import Any, ClassVar, Protocol

MYPY = False

# A protocol to be able to statically accept an attrs class.
class AttrsInstance(Protocol):
__attrs_attrs__: ClassVar[Any]

if MYPY:
AttrsInstance_ = AttrsInstance
# A protocol to be able to statically accept an attrs class.
class AttrsInstance_(Protocol):
__attrs_attrs__: ClassVar[Any]

else:
AttrsInstance_ = Any
class AttrsInstance_(Protocol):
pass
3 changes: 2 additions & 1 deletion tests/test_pyright.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def test_pyright_attrsinstance_is_any(tmp_path):
"""\
import attrs
foo: attrs.AttrsInstance = object() # We can assign any old object to `AttrsInstance`.
reveal_type(attrs.AttrsInstance)
"""
)
Expand All @@ -99,7 +100,7 @@ def test_pyright_attrsinstance_is_any(tmp_path):
expected_diagnostics = {
PyrightDiagnostic(
severity="information",
message='Type of "attrs.AttrsInstance" is "Any"',
message='Type of "attrs.AttrsInstance" is "Type[AttrsInstance]"',
),
}
assert diagnostics == expected_diagnostics

0 comments on commit ce67df1

Please sign in to comment.