Skip to content

Commit

Permalink
Support datetime.timedelta as a valid wait unit type (#342)
Browse files Browse the repository at this point in the history
* Support `datetime.timedelta` as a valid wait unit type

Signed-off-by: Noam Bloom <[email protected]>

* Add datetime.timedelta support tests

Signed-off-by: Noam Bloom <[email protected]>

Co-authored-by: Noam Bloom <[email protected]>
Co-authored-by: Julien Danjou <[email protected]>
  • Loading branch information
3 people authored May 30, 2022
1 parent f6465c0 commit 18d05a6
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 47 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
---
features:
- Add ``datetime.timedelta`` as accepted wait unit type.
37 changes: 22 additions & 15 deletions tenacity/wait.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,19 @@
import abc
import random
import typing
from datetime import timedelta

from tenacity import _utils

if typing.TYPE_CHECKING:
from tenacity import RetryCallState

wait_unit_type = typing.Union[int, float, timedelta]


def to_seconds(wait_unit: wait_unit_type) -> float:
return float(wait_unit.total_seconds() if isinstance(wait_unit, timedelta) else wait_unit)


class wait_base(abc.ABC):
"""Abstract base class for wait strategies."""
Expand All @@ -44,8 +51,8 @@ def __radd__(self, other: "wait_base") -> typing.Union["wait_combine", "wait_bas
class wait_fixed(wait_base):
"""Wait strategy that waits a fixed amount of time between each retry."""

def __init__(self, wait: float) -> None:
self.wait_fixed = wait
def __init__(self, wait: wait_unit_type) -> None:
self.wait_fixed = to_seconds(wait)

def __call__(self, retry_state: "RetryCallState") -> float:
return self.wait_fixed
Expand All @@ -61,9 +68,9 @@ def __init__(self) -> None:
class wait_random(wait_base):
"""Wait strategy that waits a random amount of time between min/max."""

def __init__(self, min: typing.Union[int, float] = 0, max: typing.Union[int, float] = 1) -> None: # noqa
self.wait_random_min = min
self.wait_random_max = max
def __init__(self, min: wait_unit_type = 0, max: wait_unit_type = 1) -> None: # noqa
self.wait_random_min = to_seconds(min)
self.wait_random_max = to_seconds(max)

def __call__(self, retry_state: "RetryCallState") -> float:
return self.wait_random_min + (random.random() * (self.wait_random_max - self.wait_random_min))
Expand Down Expand Up @@ -113,13 +120,13 @@ class wait_incrementing(wait_base):

def __init__(
self,
start: typing.Union[int, float] = 0,
increment: typing.Union[int, float] = 100,
max: typing.Union[int, float] = _utils.MAX_WAIT, # noqa
start: wait_unit_type = 0,
increment: wait_unit_type = 100,
max: wait_unit_type = _utils.MAX_WAIT, # noqa
) -> None:
self.start = start
self.increment = increment
self.max = max
self.start = to_seconds(start)
self.increment = to_seconds(increment)
self.max = to_seconds(max)

def __call__(self, retry_state: "RetryCallState") -> float:
result = self.start + (self.increment * (retry_state.attempt_number - 1))
Expand All @@ -142,13 +149,13 @@ class wait_exponential(wait_base):
def __init__(
self,
multiplier: typing.Union[int, float] = 1,
max: typing.Union[int, float] = _utils.MAX_WAIT, # noqa
max: wait_unit_type = _utils.MAX_WAIT, # noqa
exp_base: typing.Union[int, float] = 2,
min: typing.Union[int, float] = 0, # noqa
min: wait_unit_type = 0, # noqa
) -> None:
self.multiplier = multiplier
self.min = min
self.max = max
self.min = to_seconds(min)
self.max = to_seconds(max)
self.exp_base = exp_base

def __call__(self, retry_state: "RetryCallState") -> float:
Expand Down
72 changes: 40 additions & 32 deletions tests/test_tenacity.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import logging
import re
import sys
Expand All @@ -29,7 +30,6 @@
import tenacity
from tenacity import RetryCallState, RetryError, Retrying, retry


_unset = object()


Expand Down Expand Up @@ -180,28 +180,34 @@ def test_no_sleep(self):
self.assertEqual(0, r.wait(make_retry_state(18, 9879)))

def test_fixed_sleep(self):
r = Retrying(wait=tenacity.wait_fixed(1))
self.assertEqual(1, r.wait(make_retry_state(12, 6546)))
for wait in (1, datetime.timedelta(seconds=1)):
with self.subTest():
r = Retrying(wait=tenacity.wait_fixed(wait))
self.assertEqual(1, r.wait(make_retry_state(12, 6546)))

def test_incrementing_sleep(self):
r = Retrying(wait=tenacity.wait_incrementing(start=500, increment=100))
self.assertEqual(500, r.wait(make_retry_state(1, 6546)))
self.assertEqual(600, r.wait(make_retry_state(2, 6546)))
self.assertEqual(700, r.wait(make_retry_state(3, 6546)))
for start, increment in ((500, 100), (datetime.timedelta(seconds=500), datetime.timedelta(seconds=100))):
with self.subTest():
r = Retrying(wait=tenacity.wait_incrementing(start=start, increment=increment))
self.assertEqual(500, r.wait(make_retry_state(1, 6546)))
self.assertEqual(600, r.wait(make_retry_state(2, 6546)))
self.assertEqual(700, r.wait(make_retry_state(3, 6546)))

def test_random_sleep(self):
r = Retrying(wait=tenacity.wait_random(min=1, max=20))
times = set()
for x in range(1000):
times.add(r.wait(make_retry_state(1, 6546)))

# this is kind of non-deterministic...
self.assertTrue(len(times) > 1)
for t in times:
self.assertTrue(t >= 1)
self.assertTrue(t < 20)

def test_random_sleep_without_min(self):
for min_, max_ in ((1, 20), (datetime.timedelta(seconds=1), datetime.timedelta(seconds=20))):
with self.subTest():
r = Retrying(wait=tenacity.wait_random(min=min_, max=max_))
times = set()
for _ in range(1000):
times.add(r.wait(make_retry_state(1, 6546)))

# this is kind of non-deterministic...
self.assertTrue(len(times) > 1)
for t in times:
self.assertTrue(t >= 1)
self.assertTrue(t < 20)

def test_random_sleep_withoutmin_(self):
r = Retrying(wait=tenacity.wait_random(max=2))
times = set()
times.add(r.wait(make_retry_state(1, 6546)))
Expand Down Expand Up @@ -274,18 +280,20 @@ def test_exponential_with_min_wait_and_multiplier(self):
self.assertEqual(r.wait(make_retry_state(8, 0)), 256)
self.assertEqual(r.wait(make_retry_state(20, 0)), 1048576)

def test_exponential_with_min_wait_and_max_wait(self):
r = Retrying(wait=tenacity.wait_exponential(min=10, max=100))
self.assertEqual(r.wait(make_retry_state(1, 0)), 10)
self.assertEqual(r.wait(make_retry_state(2, 0)), 10)
self.assertEqual(r.wait(make_retry_state(3, 0)), 10)
self.assertEqual(r.wait(make_retry_state(4, 0)), 10)
self.assertEqual(r.wait(make_retry_state(5, 0)), 16)
self.assertEqual(r.wait(make_retry_state(6, 0)), 32)
self.assertEqual(r.wait(make_retry_state(7, 0)), 64)
self.assertEqual(r.wait(make_retry_state(8, 0)), 100)
self.assertEqual(r.wait(make_retry_state(9, 0)), 100)
self.assertEqual(r.wait(make_retry_state(20, 0)), 100)
def test_exponential_with_min_wait_andmax__wait(self):
for min_, max_ in ((10, 100), (datetime.timedelta(seconds=10), datetime.timedelta(seconds=100))):
with self.subTest():
r = Retrying(wait=tenacity.wait_exponential(min=min_, max=max_))
self.assertEqual(r.wait(make_retry_state(1, 0)), 10)
self.assertEqual(r.wait(make_retry_state(2, 0)), 10)
self.assertEqual(r.wait(make_retry_state(3, 0)), 10)
self.assertEqual(r.wait(make_retry_state(4, 0)), 10)
self.assertEqual(r.wait(make_retry_state(5, 0)), 16)
self.assertEqual(r.wait(make_retry_state(6, 0)), 32)
self.assertEqual(r.wait(make_retry_state(7, 0)), 64)
self.assertEqual(r.wait(make_retry_state(8, 0)), 100)
self.assertEqual(r.wait(make_retry_state(9, 0)), 100)
self.assertEqual(r.wait(make_retry_state(20, 0)), 100)

def test_legacy_explicit_wait_type(self):
Retrying(wait="exponential_sleep")
Expand Down Expand Up @@ -335,7 +343,7 @@ def test_wait_arbitrary_sum(self):
)
)
# Test it a few time since it's random
for i in range(1000):
for _ in range(1000):
w = r.wait(make_retry_state(1, 5))
self.assertLess(w, 9)
self.assertGreaterEqual(w, 6)
Expand Down

0 comments on commit 18d05a6

Please sign in to comment.