Skip to content

Commit

Permalink
Tiny ghostwriter fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Zac-HD committed Mar 22, 2021
1 parent 3d5761e commit 76855d7
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 16 deletions.
5 changes: 5 additions & 0 deletions hypothesis-python/RELEASE.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
RELEASE_TYPE: patch

This patch improves the :doc:`Ghostwriter's <ghostwriter>` handling
of strategies to generate various fiddly types including frozensets,
keysviews, valuesviews, regex matches and patterns, and so on.
47 changes: 36 additions & 11 deletions hypothesis-python/src/hypothesis/extra/ghostwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
import contextlib
import enum
import inspect
import os
import re
import sys
import types
Expand Down Expand Up @@ -105,6 +106,7 @@
FilteredStrategy,
MappedSearchStrategy,
OneOfStrategy,
SampledFromStrategy,
)
from hypothesis.strategies._internal.types import _global_type_lookup
from hypothesis.utils.conventions import InferType, infer
Expand All @@ -130,6 +132,7 @@ def test_{test_kind}_{func_name}({arg_names}):
"""

Except = Union[Type[Exception], Tuple[Type[Exception], ...]]
RE_TYPES = (type(re.compile(".")), type(re.match(".", "abc")))


def _check_except(except_: Except) -> Tuple[Type[Exception], ...]:
Expand Down Expand Up @@ -310,20 +313,28 @@ def _assert_eq(style, a, b):

def _imports_for_object(obj):
"""Return the imports for `obj`, which may be empty for e.g. lambdas"""
if isinstance(obj, RE_TYPES):
return {"re"}
try:
if (not callable(obj)) or obj.__name__ == "<lambda>":
return set()
name = _get_qualname(obj).split(".")[0]
return {(_get_module(obj), name)}
except Exception:
with contextlib.suppress(AttributeError):
if obj.__module__ == "typing": # only on CPython 3.6
return {("typing", getattr(obj, "__name__", obj.name))}
return set()


def _imports_for_strategy(strategy):
# If we have a lazy from_type strategy, because unwrapping it gives us an
# error or invalid syntax, import that type and we're done.
if isinstance(strategy, LazyStrategy) and strategy.function is st.from_type:
return _imports_for_object(strategy._LazyStrategy__args[0])
if isinstance(strategy, LazyStrategy):
if strategy.function is st.from_type:
return _imports_for_object(strategy._LazyStrategy__args[0])
elif _get_module(strategy.function).startswith("hypothesis.extra."):
return {(_get_module(strategy.function), strategy.function.__name__)}

imports = set()
strategy = unwrap_strategies(strategy)
Expand Down Expand Up @@ -353,6 +364,10 @@ def _imports_for_strategy(strategy):
for s in strategy.kwargs.values():
imports |= _imports_for_strategy(s)

if isinstance(strategy, SampledFromStrategy):
for obj in strategy.elements:
imports |= _imports_for_object(obj)

return imports


Expand All @@ -367,6 +382,8 @@ def _valid_syntax_repr(strategy):
seen = set()
elems = []
for s in strategy.element_strategies:
if isinstance(s, SampledFromStrategy) and s.elements == (os.environ,):
continue
if repr(s) not in seen:
elems.append(s)
seen.add(repr(s))
Expand Down Expand Up @@ -436,6 +453,16 @@ def _write_call(func: Callable, *pass_variables: str) -> str:
return f"{_get_qualname(func, include_module=True)}({args})"


def _st_strategy_names(s: str) -> str:
"""Replace strategy name() with st.name().
Uses a tricky re.sub() to avoid problems with frozensets() matching
sets() too.
"""
names = "|".join(sorted(st.__all__, key=len, reverse=True))
return re.sub(pattern=rf"\b(?:{names})\(", repl=r"st.\g<0>", string=s)


def _make_test_body(
*funcs: Callable,
ghost: str,
Expand All @@ -457,8 +484,7 @@ def _make_test_body(
reprs = [((k,) + _valid_syntax_repr(v)) for k, v in given_strategies.items()]
imports = imports.union(*(imp for _, imp, _ in reprs))
given_args = ", ".join(f"{k}={v}" for k, _, v in reprs)
for name in st.__all__:
given_args = given_args.replace(f"{name}(", f"st.{name}(")
given_args = _st_strategy_names(given_args)

if except_:
# This is reminiscent of de-duplication logic I wrote for flake8-bugbear,
Expand Down Expand Up @@ -596,10 +622,13 @@ def magic(
if hasattr(thing, "__all__"):
funcs = [getattr(thing, name, None) for name in thing.__all__] # type: ignore
else:
pkg = thing.__package__
funcs = [
v
for k, v in vars(thing).items()
if callable(v) and not k.startswith("_")
if callable(v)
and (getattr(v, "__module__", pkg) == pkg or not pkg)
and not k.startswith("_")
]
for f in funcs:
try:
Expand Down Expand Up @@ -1044,8 +1073,7 @@ def maker(
)

_, operands_repr = _valid_syntax_repr(operands)
for name in st.__all__:
operands_repr = operands_repr.replace(f"{name}(", f"st.{name}(")
operands_repr = _st_strategy_names(operands_repr)
classdef = ""
if style == "unittest":
classdef = f"class TestBinaryOperation{func.__name__}(unittest.TestCase):\n "
Expand Down Expand Up @@ -1100,7 +1128,7 @@ def _make_ufunc_body(func, *, except_, style):
type_assert=_assert_eq(style, "result.dtype.char", "expected_dtype"),
)

imports, body = _make_test_body(
return _make_test_body(
func,
test_body=dedent(body).strip(),
except_=except_,
Expand All @@ -1113,6 +1141,3 @@ def _make_ufunc_body(func, *, except_, style):
".filter(lambda sig: 'O' not in sig)",
},
)
imports.add("hypothesis.extra.numpy as npst")
body = body.replace("mutually_broadcastable", "npst.mutually_broadcastable")
return imports, body
3 changes: 3 additions & 0 deletions hypothesis-python/src/hypothesis/internal/reflection.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,9 @@ def get_pretty_function_description(f):
# their module as __self__. This might include c-extensions generally?
if not (self is None or inspect.isclass(self) or inspect.ismodule(self)):
return f"{self!r}.{name}"
elif getattr(dict, name, object()) is f:
# special case for keys/values views in from_type() / ghostwriter output
return f"dict.{name}"
return name


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def _networks(bits):
st.none() | st.integers(),
),
range: st.one_of(
st.integers(min_value=0).map(range),
st.builds(range, st.integers(min_value=0)),
st.builds(range, st.integers(), st.integers()),
st.builds(range, st.integers(), st.integers(), st.integers().filter(bool)),
),
Expand Down
4 changes: 2 additions & 2 deletions hypothesis-python/tests/ghostwriter/recorded/magic_gufunc.txt
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# This test code was written by the `hypothesis.extra.ghostwriter` module
# and is provided under the Creative Commons Zero public domain dedication.

import hypothesis.extra.numpy as npst
import numpy
from hypothesis import given, strategies as st
from hypothesis.extra.numpy import mutually_broadcastable_shapes


@given(
data=st.data(),
shapes=npst.mutually_broadcastable_shapes(signature="(n?,k),(k,m?)->(n?,m?)"),
shapes=mutually_broadcastable_shapes(signature="(n?,k),(k,m?)->(n?,m?)"),
types=st.sampled_from(numpy.matmul.types).filter(lambda sig: "O" not in sig),
)
def test_gufunc_matmul(data, shapes, types):
Expand Down
57 changes: 55 additions & 2 deletions hypothesis-python/tests/ghostwriter/test_ghostwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,19 @@
import unittest.mock
from decimal import Decimal
from types import ModuleType
from typing import Any, List, Sequence, Set, Union
from typing import (
Any,
FrozenSet,
KeysView,
List,
Match,
Pattern,
Sequence,
Set,
Sized,
Union,
ValuesView,
)

import pytest

Expand All @@ -38,7 +50,11 @@ def get_test_function(source_code):
# Note that this also tests that the module is syntatically-valid,
# AND free from undefined names, import problems, and so on.
namespace = {}
exec(source_code, namespace)
try:
exec(source_code, namespace)
except Exception:
print(f"************\n{source_code}\n************")
raise
tests = [
v
for k, v in namespace.items()
Expand Down Expand Up @@ -123,6 +139,30 @@ def test_flattens_one_of_repr():
assert ghostwriter._valid_syntax_repr(strat)[1].count("one_of(") == 1


def takes_keys(x: KeysView[int]) -> None:
pass


def takes_values(x: ValuesView[int]) -> None:
pass


def takes_match(x: Match[bytes]) -> None:
pass


def takes_pattern(x: Pattern[str]) -> None:
pass


def takes_sized(x: Sized) -> None:
pass


def takes_frozensets(a: FrozenSet[int], b: FrozenSet[int]) -> None:
pass


@varied_excepts
@pytest.mark.parametrize(
"func",
Expand All @@ -136,13 +176,26 @@ def test_flattens_one_of_repr():
annotated_any,
space_in_name,
non_resolvable_arg,
takes_keys,
takes_values,
takes_match,
takes_pattern,
takes_sized,
takes_frozensets,
],
)
def test_ghostwriter_fuzz(func, ex):
source_code = ghostwriter.fuzz(func, except_=ex)
get_test_function(source_code)


def test_binary_op_also_handles_frozensets():
# Using str.replace in a loop would convert `frozensets()` into
# `st.frozenst.sets()` instead of `st.frozensets()`; fixed with re.sub.
source_code = ghostwriter.binary_operation(takes_frozensets)
exec(source_code, {})


@varied_excepts
@pytest.mark.parametrize(
"func", [re.compile, json.loads, json.dump, timsort, ast.literal_eval]
Expand Down

0 comments on commit 76855d7

Please sign in to comment.