From 014b8e6c39d0052d9bb80ad85bae9a390d1aad09 Mon Sep 17 00:00:00 2001 From: Greesb Date: Wed, 21 Sep 2022 14:20:00 +0200 Subject: [PATCH] feat: Add retry_if_exception_cause_type (#362) Add a new retry_base class called `retry_if_exception_cause_type` that checks that the cause of the raised exception is of a certain type. Co-authored-by: Guillaume RISBOURG --- ...exception_cause_type-d16b918ace4ae0ad.yaml | 5 ++ tenacity/__init__.py | 1 + tenacity/retry.py | 27 ++++++++ tests/test_tenacity.py | 64 +++++++++++++++++++ 4 files changed, 97 insertions(+) create mode 100644 releasenotes/notes/add_retry_if_exception_cause_type-d16b918ace4ae0ad.yaml diff --git a/releasenotes/notes/add_retry_if_exception_cause_type-d16b918ace4ae0ad.yaml b/releasenotes/notes/add_retry_if_exception_cause_type-d16b918ace4ae0ad.yaml new file mode 100644 index 00000000..8b5a420f --- /dev/null +++ b/releasenotes/notes/add_retry_if_exception_cause_type-d16b918ace4ae0ad.yaml @@ -0,0 +1,5 @@ +--- +features: + - | + Add a new `retry_base` class called `retry_if_exception_cause_type` that + checks, recursively, if any of the causes of the raised exception is of a certain type. diff --git a/tenacity/__init__.py b/tenacity/__init__.py index fd403761..008049a8 100644 --- a/tenacity/__init__.py +++ b/tenacity/__init__.py @@ -33,6 +33,7 @@ from .retry import retry_any # noqa from .retry import retry_if_exception # noqa from .retry import retry_if_exception_type # noqa +from .retry import retry_if_exception_cause_type # noqa from .retry import retry_if_not_exception_type # noqa from .retry import retry_if_not_result # noqa from .retry import retry_if_result # noqa diff --git a/tenacity/retry.py b/tenacity/retry.py index dd271175..1305d3f0 100644 --- a/tenacity/retry.py +++ b/tenacity/retry.py @@ -117,6 +117,33 @@ def __call__(self, retry_state: "RetryCallState") -> bool: return self.predicate(retry_state.outcome.exception()) +class retry_if_exception_cause_type(retry_base): + """Retries if any of the causes of the raised exception is of one or more types. + + The check on the type of the cause of the exception is done recursively (until finding + an exception in the chain that has no `__cause__`) + """ + + def __init__( + self, + exception_types: typing.Union[ + typing.Type[BaseException], + typing.Tuple[typing.Type[BaseException], ...], + ] = Exception, + ) -> None: + self.exception_cause_types = exception_types + + def __call__(self, retry_state: "RetryCallState") -> bool: + if retry_state.outcome.failed: + exc = retry_state.outcome.exception() + while exc is not None: + if isinstance(exc.__cause__, self.exception_cause_types): + return True + exc = exc.__cause__ + + return False + + class retry_if_result(retry_base): """Retries if the result verifies a predicate.""" diff --git a/tests/test_tenacity.py b/tests/test_tenacity.py index b6f6bbb0..2e5febd8 100644 --- a/tests/test_tenacity.py +++ b/tests/test_tenacity.py @@ -676,6 +676,56 @@ def go(self): return True +class NoNameErrorCauseAfterCount: + """Holds counter state for invoking a method several times in a row.""" + + def __init__(self, count): + self.counter = 0 + self.count = count + + def go2(self): + raise NameError("Hi there, I'm a NameError") + + def go(self): + """Raise an IOError with a NameError as cause until after count threshold has been crossed. + + Then return True. + """ + if self.counter < self.count: + self.counter += 1 + try: + self.go2() + except NameError as e: + raise IOError() from e + + return True + + +class NoIOErrorCauseAfterCount: + """Holds counter state for invoking a method several times in a row.""" + + def __init__(self, count): + self.counter = 0 + self.count = count + + def go2(self): + raise IOError("Hi there, I'm an IOError") + + def go(self): + """Raise a NameError with an IOError as cause until after count threshold has been crossed. + + Then return True. + """ + if self.counter < self.count: + self.counter += 1 + try: + self.go2() + except IOError as e: + raise NameError() from e + + return True + + class NameErrorUntilCount: """Holds counter state for invoking a method several times in a row.""" @@ -783,6 +833,11 @@ def _retryable_test_with_stop(thing): return thing.go() +@retry(retry=tenacity.retry_if_exception_cause_type(NameError)) +def _retryable_test_with_exception_cause_type(thing): + return thing.go() + + @retry(retry=tenacity.retry_if_exception_type(IOError)) def _retryable_test_with_exception_type_io(thing): return thing.go() @@ -987,6 +1042,15 @@ def test_retry_if_not_exception_message_match(self): s = _retryable_test_if_not_exception_message_message.retry.statistics self.assertTrue(s["attempt_number"] == 1) + def test_retry_if_exception_cause_type(self): + self.assertTrue(_retryable_test_with_exception_cause_type(NoNameErrorCauseAfterCount(5))) + + try: + _retryable_test_with_exception_cause_type(NoIOErrorCauseAfterCount(5)) + self.fail("Expected exception without NameError as cause") + except NameError: + pass + def test_defaults(self): self.assertTrue(_retryable_default(NoNameErrorAfterCount(5))) self.assertTrue(_retryable_default_f(NoNameErrorAfterCount(5)))