Skip to content

Commit

Permalink
Implement TypeIs (PEP 742) (#16898)
Browse files Browse the repository at this point in the history
Co-authored-by: Marc Mueller <[email protected]>
  • Loading branch information
JelleZijlstra and cdce8p authored Mar 1, 2024
1 parent 3c87af2 commit bcb3747
Show file tree
Hide file tree
Showing 19 changed files with 962 additions and 19 deletions.
16 changes: 16 additions & 0 deletions docs/source/error_code_list2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -555,3 +555,19 @@ Correct usage:
When this code is enabled, using ``reveal_locals`` is always an error,
because there's no way one can import it.

.. _code-narrowed-type-not-subtype:

Check that ``TypeIs`` narrows types [narrowed-type-not-subtype]
---------------------------------------------------------------

:pep:`742` requires that when ``TypeIs`` is used, the narrowed
type must be a subtype of the original type::

from typing_extensions import TypeIs

def f(x: int) -> TypeIs[str]: # Error, str is not a subtype of int
...

def g(x: object) -> TypeIs[str]: # OK
...
7 changes: 6 additions & 1 deletion mypy/applytype.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,15 @@ def apply_generic_arguments(
arg_types=[expand_type(at, id_to_type) for at in callable.arg_types]
)

# Apply arguments to TypeGuard if any.
# Apply arguments to TypeGuard and TypeIs if any.
if callable.type_guard is not None:
type_guard = expand_type(callable.type_guard, id_to_type)
else:
type_guard = None
if callable.type_is is not None:
type_is = expand_type(callable.type_is, id_to_type)
else:
type_is = None

# The callable may retain some type vars if only some were applied.
# TODO: move apply_poly() logic from checkexpr.py here when new inference
Expand All @@ -164,4 +168,5 @@ def apply_generic_arguments(
ret_type=expand_type(callable.ret_type, id_to_type),
variables=remaining_tvars,
type_guard=type_guard,
type_is=type_is,
)
42 changes: 38 additions & 4 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1203,6 +1203,22 @@ def check_func_def(
# visible from *inside* of this function/method.
ref_type: Type | None = self.scope.active_self_type()

if typ.type_is:
arg_index = 0
# For methods and classmethods, we want the second parameter
if ref_type is not None and (not defn.is_static or defn.name == "__new__"):
arg_index = 1
if arg_index < len(typ.arg_types) and not is_subtype(
typ.type_is, typ.arg_types[arg_index]
):
self.fail(
message_registry.NARROWED_TYPE_NOT_SUBTYPE.format(
format_type(typ.type_is, self.options),
format_type(typ.arg_types[arg_index], self.options),
),
item,
)

# Store argument types.
for i in range(len(typ.arg_types)):
arg_type = typ.arg_types[i]
Expand Down Expand Up @@ -2178,6 +2194,8 @@ def check_override(
elif isinstance(original, CallableType) and isinstance(override, CallableType):
if original.type_guard is not None and override.type_guard is None:
fail = True
if original.type_is is not None and override.type_is is None:
fail = True

if is_private(name):
fail = False
Expand Down Expand Up @@ -5643,7 +5661,7 @@ def combine_maps(list_maps: list[TypeMap]) -> TypeMap:
def find_isinstance_check(self, node: Expression) -> tuple[TypeMap, TypeMap]:
"""Find any isinstance checks (within a chain of ands). Includes
implicit and explicit checks for None and calls to callable.
Also includes TypeGuard functions.
Also includes TypeGuard and TypeIs functions.
Return value is a map of variables to their types if the condition
is true and a map of variables to their types if the condition is false.
Expand Down Expand Up @@ -5695,7 +5713,7 @@ def find_isinstance_check_helper(self, node: Expression) -> tuple[TypeMap, TypeM
if literal(expr) == LITERAL_TYPE and attr and len(attr) == 1:
return self.hasattr_type_maps(expr, self.lookup_type(expr), attr[0])
elif isinstance(node.callee, RefExpr):
if node.callee.type_guard is not None:
if node.callee.type_guard is not None or node.callee.type_is is not None:
# TODO: Follow *args, **kwargs
if node.arg_kinds[0] != nodes.ARG_POS:
# the first argument might be used as a kwarg
Expand All @@ -5721,15 +5739,31 @@ def find_isinstance_check_helper(self, node: Expression) -> tuple[TypeMap, TypeM
# we want the idx-th variable to be narrowed
expr = collapse_walrus(node.args[idx])
else:
self.fail(message_registry.TYPE_GUARD_POS_ARG_REQUIRED, node)
kind = (
"guard" if node.callee.type_guard is not None else "narrower"
)
self.fail(
message_registry.TYPE_GUARD_POS_ARG_REQUIRED.format(kind), node
)
return {}, {}
if literal(expr) == LITERAL_TYPE:
# Note: we wrap the target type, so that we can special case later.
# Namely, for isinstance() we use a normal meet, while TypeGuard is
# considered "always right" (i.e. even if the types are not overlapping).
# Also note that a care must be taken to unwrap this back at read places
# where we use this to narrow down declared type.
return {expr: TypeGuardedType(node.callee.type_guard)}, {}
if node.callee.type_guard is not None:
return {expr: TypeGuardedType(node.callee.type_guard)}, {}
else:
assert node.callee.type_is is not None
return conditional_types_to_typemaps(
expr,
*self.conditional_types_with_intersection(
self.lookup_type(expr),
[TypeRange(node.callee.type_is, is_upper_bound=False)],
expr,
),
)
elif isinstance(node, ComparisonExpr):
# Step 1: Obtain the types of each operand and whether or not we can
# narrow their types. (For example, we shouldn't try narrowing the
Expand Down
13 changes: 6 additions & 7 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1451,13 +1451,12 @@ def check_call_expr_with_callee_type(
object_type=object_type,
)
proper_callee = get_proper_type(callee_type)
if (
isinstance(e.callee, RefExpr)
and isinstance(proper_callee, CallableType)
and proper_callee.type_guard is not None
):
if isinstance(e.callee, RefExpr) and isinstance(proper_callee, CallableType):
# Cache it for find_isinstance_check()
e.callee.type_guard = proper_callee.type_guard
if proper_callee.type_guard is not None:
e.callee.type_guard = proper_callee.type_guard
if proper_callee.type_is is not None:
e.callee.type_is = proper_callee.type_is
return ret_type

def check_union_call_expr(self, e: CallExpr, object_type: UnionType, member: str) -> Type:
Expand Down Expand Up @@ -5283,7 +5282,7 @@ def infer_lambda_type_using_context(
# is a constructor -- but this fallback doesn't make sense for lambdas.
callable_ctx = callable_ctx.copy_modified(fallback=self.named_type("builtins.function"))

if callable_ctx.type_guard is not None:
if callable_ctx.type_guard is not None or callable_ctx.type_is is not None:
# Lambda's return type cannot be treated as a `TypeGuard`,
# because it is implicit. And `TypeGuard`s must be explicit.
# See https://github.com/python/mypy/issues/9927
Expand Down
16 changes: 14 additions & 2 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,10 +1018,22 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
param_spec = template.param_spec()

template_ret_type, cactual_ret_type = template.ret_type, cactual.ret_type
if template.type_guard is not None:
if template.type_guard is not None and cactual.type_guard is not None:
template_ret_type = template.type_guard
if cactual.type_guard is not None:
cactual_ret_type = cactual.type_guard
elif template.type_guard is not None:
template_ret_type = AnyType(TypeOfAny.special_form)
elif cactual.type_guard is not None:
cactual_ret_type = AnyType(TypeOfAny.special_form)

if template.type_is is not None and cactual.type_is is not None:
template_ret_type = template.type_is
cactual_ret_type = cactual.type_is
elif template.type_is is not None:
template_ret_type = AnyType(TypeOfAny.special_form)
elif cactual.type_is is not None:
cactual_ret_type = AnyType(TypeOfAny.special_form)

res.extend(infer_constraints(template_ret_type, cactual_ret_type, self.direction))

if param_spec is None:
Expand Down
6 changes: 6 additions & 0 deletions mypy/errorcodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,5 +281,11 @@ def __hash__(self) -> int:
sub_code_of=MISC,
)

NARROWED_TYPE_NOT_SUBTYPE: Final[ErrorCode] = ErrorCode(
"narrowed-type-not-subtype",
"Warn if a TypeIs function's narrowed type is not a subtype of the original type",
"General",
)

# This copy will not include any error codes defined later in the plugins.
mypy_error_codes = error_codes.copy()
2 changes: 2 additions & 0 deletions mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ def visit_callable_type(self, t: CallableType) -> CallableType:
arg_names=t.arg_names[:-2] + repl.arg_names,
ret_type=t.ret_type.accept(self),
type_guard=(t.type_guard.accept(self) if t.type_guard is not None else None),
type_is=(t.type_is.accept(self) if t.type_is is not None else None),
imprecise_arg_kinds=(t.imprecise_arg_kinds or repl.imprecise_arg_kinds),
variables=[*repl.variables, *t.variables],
)
Expand Down Expand Up @@ -384,6 +385,7 @@ def visit_callable_type(self, t: CallableType) -> CallableType:
arg_types=arg_types,
ret_type=t.ret_type.accept(self),
type_guard=(t.type_guard.accept(self) if t.type_guard is not None else None),
type_is=(t.type_is.accept(self) if t.type_is is not None else None),
)
if needs_normalization:
return expanded.with_normalized_var_args()
Expand Down
2 changes: 2 additions & 0 deletions mypy/fixup.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,8 @@ def visit_callable_type(self, ct: CallableType) -> None:
arg.accept(self)
if ct.type_guard is not None:
ct.type_guard.accept(self)
if ct.type_is is not None:
ct.type_is.accept(self)

def visit_overloaded(self, t: Overloaded) -> None:
for ct in t.items:
Expand Down
5 changes: 4 additions & 1 deletion mypy/message_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def with_additional_msg(self, info: str) -> ErrorMessage:

CONTIGUOUS_ITERABLE_EXPECTED: Final = ErrorMessage("Contiguous iterable with same type expected")
ITERABLE_TYPE_EXPECTED: Final = ErrorMessage("Invalid type '{}' for *expr (iterable expected)")
TYPE_GUARD_POS_ARG_REQUIRED: Final = ErrorMessage("Type guard requires positional argument")
TYPE_GUARD_POS_ARG_REQUIRED: Final = ErrorMessage("Type {} requires positional argument")

# Match Statement
MISSING_MATCH_ARGS: Final = 'Class "{}" doesn\'t define "__match_args__"'
Expand Down Expand Up @@ -324,3 +324,6 @@ def with_additional_msg(self, info: str) -> ErrorMessage:
ARG_NAME_EXPECTED_STRING_LITERAL: Final = ErrorMessage(
"Expected string literal for argument name, got {}", codes.SYNTAX
)
NARROWED_TYPE_NOT_SUBTYPE: Final = ErrorMessage(
"Narrowed type {} is not a subtype of input type {}", codes.NARROWED_TYPE_NOT_SUBTYPE
)
4 changes: 4 additions & 0 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2643,6 +2643,8 @@ def format_literal_value(typ: LiteralType) -> str:
elif isinstance(func, CallableType):
if func.type_guard is not None:
return_type = f"TypeGuard[{format(func.type_guard)}]"
elif func.type_is is not None:
return_type = f"TypeIs[{format(func.type_is)}]"
else:
return_type = format(func.ret_type)
if func.is_ellipsis_args:
Expand Down Expand Up @@ -2859,6 +2861,8 @@ def [T <: int] f(self, x: int, y: T) -> None
s += " -> "
if tp.type_guard is not None:
s += f"TypeGuard[{format_type_bare(tp.type_guard, options)}]"
elif tp.type_is is not None:
s += f"TypeIs[{format_type_bare(tp.type_is, options)}]"
else:
s += format_type_bare(tp.ret_type, options)

Expand Down
3 changes: 3 additions & 0 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1755,6 +1755,7 @@ class RefExpr(Expression):
"is_inferred_def",
"is_alias_rvalue",
"type_guard",
"type_is",
)

def __init__(self) -> None:
Expand All @@ -1776,6 +1777,8 @@ def __init__(self) -> None:
self.is_alias_rvalue = False
# Cache type guard from callable_type.type_guard
self.type_guard: mypy.types.Type | None = None
# And same for TypeIs
self.type_is: mypy.types.Type | None = None

@property
def fullname(self) -> str:
Expand Down
7 changes: 7 additions & 0 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,6 +881,13 @@ def analyze_func_def(self, defn: FuncDef) -> None:
)
# in this case, we just kind of just ... remove the type guard.
result = result.copy_modified(type_guard=None)
if result.type_is and ARG_POS not in result.arg_kinds[skip_self:]:
self.fail(
'"TypeIs" functions must have a positional argument',
result,
code=codes.VALID_TYPE,
)
result = result.copy_modified(type_is=None)

result = self.remove_unpack_kwargs(defn, result)
if has_self_type and self.type is not None:
Expand Down
13 changes: 13 additions & 0 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,10 +683,23 @@ def visit_callable_type(self, left: CallableType) -> bool:
if left.type_guard is not None and right.type_guard is not None:
if not self._is_subtype(left.type_guard, right.type_guard):
return False
elif left.type_is is not None and right.type_is is not None:
# For TypeIs we have to check both ways; it is unsafe to pass
# a TypeIs[Child] when a TypeIs[Parent] is expected, because
# if the narrower returns False, we assume that the narrowed value is
# *not* a Parent.
if not self._is_subtype(left.type_is, right.type_is) or not self._is_subtype(
right.type_is, left.type_is
):
return False
elif right.type_guard is not None and left.type_guard is None:
# This means that one function has `TypeGuard` and other does not.
# They are not compatible. See https://github.com/python/mypy/issues/11307
return False
elif right.type_is is not None and left.type_is is None:
# Similarly, if one function has `TypeIs` and the other does not,
# they are not compatible.
return False
return is_callable_compatible(
left,
right,
Expand Down
27 changes: 24 additions & 3 deletions mypy/typeanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,10 @@ def try_analyze_special_unbound_type(self, t: UnboundType, fullname: str) -> Typ
)
return AnyType(TypeOfAny.from_error)
return RequiredType(self.anal_type(t.args[0]), required=False)
elif self.anal_type_guard_arg(t, fullname) is not None:
elif (
self.anal_type_guard_arg(t, fullname) is not None
or self.anal_type_is_arg(t, fullname) is not None
):
# In most contexts, TypeGuard[...] acts as an alias for bool (ignoring its args)
return self.named_type("builtins.bool")
elif fullname in ("typing.Unpack", "typing_extensions.Unpack"):
Expand Down Expand Up @@ -986,7 +989,8 @@ def visit_callable_type(self, t: CallableType, nested: bool = True) -> Type:
variables = t.variables
else:
variables, _ = self.bind_function_type_variables(t, t)
special = self.anal_type_guard(t.ret_type)
type_guard = self.anal_type_guard(t.ret_type)
type_is = self.anal_type_is(t.ret_type)
arg_kinds = t.arg_kinds
if len(arg_kinds) >= 2 and arg_kinds[-2] == ARG_STAR and arg_kinds[-1] == ARG_STAR2:
arg_types = self.anal_array(t.arg_types[:-2], nested=nested) + [
Expand Down Expand Up @@ -1041,7 +1045,8 @@ def visit_callable_type(self, t: CallableType, nested: bool = True) -> Type:
# its type will be the falsey FakeInfo
fallback=(t.fallback if t.fallback.type else self.named_type("builtins.function")),
variables=self.anal_var_defs(variables),
type_guard=special,
type_guard=type_guard,
type_is=type_is,
unpack_kwargs=unpacked_kwargs,
)
return ret
Expand All @@ -1064,6 +1069,22 @@ def anal_type_guard_arg(self, t: UnboundType, fullname: str) -> Type | None:
return self.anal_type(t.args[0])
return None

def anal_type_is(self, t: Type) -> Type | None:
if isinstance(t, UnboundType):
sym = self.lookup_qualified(t.name, t)
if sym is not None and sym.node is not None:
return self.anal_type_is_arg(t, sym.node.fullname)
# TODO: What if it's an Instance? Then use t.type.fullname?
return None

def anal_type_is_arg(self, t: UnboundType, fullname: str) -> Type | None:
if fullname in ("typing_extensions.TypeIs", "typing.TypeIs"):
if len(t.args) != 1:
self.fail("TypeIs must have exactly one type argument", t, code=codes.VALID_TYPE)
return AnyType(TypeOfAny.from_error)
return self.anal_type(t.args[0])
return None

def anal_star_arg_type(self, t: Type, kind: ArgKind, nested: bool) -> Type:
"""Analyze signature argument type for *args and **kwargs argument."""
if isinstance(t, UnboundType) and t.name and "." in t.name and not t.args:
Expand Down
Loading

0 comments on commit bcb3747

Please sign in to comment.