Skip to content

Commit

Permalink
feat: Add retry_if_exception_cause_type (#362)
Browse files Browse the repository at this point in the history
Add a new retry_base class called `retry_if_exception_cause_type` that
checks that the cause of the raised exception is of a certain type.

Co-authored-by: Guillaume RISBOURG <[email protected]>
  • Loading branch information
Greesb and Greesb authored Sep 21, 2022
1 parent 18d05a6 commit 014b8e6
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
features:
- |
Add a new `retry_base` class called `retry_if_exception_cause_type` that
checks, recursively, if any of the causes of the raised exception is of a certain type.
1 change: 1 addition & 0 deletions tenacity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from .retry import retry_any # noqa
from .retry import retry_if_exception # noqa
from .retry import retry_if_exception_type # noqa
from .retry import retry_if_exception_cause_type # noqa
from .retry import retry_if_not_exception_type # noqa
from .retry import retry_if_not_result # noqa
from .retry import retry_if_result # noqa
Expand Down
27 changes: 27 additions & 0 deletions tenacity/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,33 @@ def __call__(self, retry_state: "RetryCallState") -> bool:
return self.predicate(retry_state.outcome.exception())


class retry_if_exception_cause_type(retry_base):
"""Retries if any of the causes of the raised exception is of one or more types.
The check on the type of the cause of the exception is done recursively (until finding
an exception in the chain that has no `__cause__`)
"""

def __init__(
self,
exception_types: typing.Union[
typing.Type[BaseException],
typing.Tuple[typing.Type[BaseException], ...],
] = Exception,
) -> None:
self.exception_cause_types = exception_types

def __call__(self, retry_state: "RetryCallState") -> bool:
if retry_state.outcome.failed:
exc = retry_state.outcome.exception()
while exc is not None:
if isinstance(exc.__cause__, self.exception_cause_types):
return True
exc = exc.__cause__

return False


class retry_if_result(retry_base):
"""Retries if the result verifies a predicate."""

Expand Down
64 changes: 64 additions & 0 deletions tests/test_tenacity.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,56 @@ def go(self):
return True


class NoNameErrorCauseAfterCount:
"""Holds counter state for invoking a method several times in a row."""

def __init__(self, count):
self.counter = 0
self.count = count

def go2(self):
raise NameError("Hi there, I'm a NameError")

def go(self):
"""Raise an IOError with a NameError as cause until after count threshold has been crossed.
Then return True.
"""
if self.counter < self.count:
self.counter += 1
try:
self.go2()
except NameError as e:
raise IOError() from e

return True


class NoIOErrorCauseAfterCount:
"""Holds counter state for invoking a method several times in a row."""

def __init__(self, count):
self.counter = 0
self.count = count

def go2(self):
raise IOError("Hi there, I'm an IOError")

def go(self):
"""Raise a NameError with an IOError as cause until after count threshold has been crossed.
Then return True.
"""
if self.counter < self.count:
self.counter += 1
try:
self.go2()
except IOError as e:
raise NameError() from e

return True


class NameErrorUntilCount:
"""Holds counter state for invoking a method several times in a row."""

Expand Down Expand Up @@ -783,6 +833,11 @@ def _retryable_test_with_stop(thing):
return thing.go()


@retry(retry=tenacity.retry_if_exception_cause_type(NameError))
def _retryable_test_with_exception_cause_type(thing):
return thing.go()


@retry(retry=tenacity.retry_if_exception_type(IOError))
def _retryable_test_with_exception_type_io(thing):
return thing.go()
Expand Down Expand Up @@ -987,6 +1042,15 @@ def test_retry_if_not_exception_message_match(self):
s = _retryable_test_if_not_exception_message_message.retry.statistics
self.assertTrue(s["attempt_number"] == 1)

def test_retry_if_exception_cause_type(self):
self.assertTrue(_retryable_test_with_exception_cause_type(NoNameErrorCauseAfterCount(5)))

try:
_retryable_test_with_exception_cause_type(NoIOErrorCauseAfterCount(5))
self.fail("Expected exception without NameError as cause")
except NameError:
pass

def test_defaults(self):
self.assertTrue(_retryable_default(NoNameErrorAfterCount(5)))
self.assertTrue(_retryable_default_f(NoNameErrorAfterCount(5)))
Expand Down

0 comments on commit 014b8e6

Please sign in to comment.