Skip to content

Commit

Permalink
Fixed asyncio.Task.cancelling issues (#790)
Browse files Browse the repository at this point in the history
* Shrink TaskState to save a little memory
* Fix uncancel() being called too early
* Refactor to avoid duplicate computation
* Test TaskInfo.has_pending_cancellation in cleanup code
* Fix TaskInfo.has_pending_cancellation in cleanup code on asyncio
* Test that uncancel() isn't called too early

Co-authored-by: Thomas Grainger <[email protected]>
Co-authored-by: Alex Grönholm <[email protected]>
  • Loading branch information
3 people authored Dec 5, 2024
1 parent 39cf394 commit 93a5746
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 60 deletions.
6 changes: 6 additions & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
- Fixed the return type annotations of ``readinto()`` and ``readinto1()`` methods in the
``anyio.AsyncFile`` class
(`#825 <https://github.com/agronholm/anyio/issues/825>`_)
- Fixed ``TaskInfo.has_pending_cancellation()`` on asyncio returning false positives in
cleanup code on Python >= 3.11
(`#832 <https://github.com/agronholm/anyio/issues/832>`_; PR by @gschaffner)
- Fixed cancelled cancel scopes on asyncio calling ``asyncio.Task.uncancel`` when
propagating a ``CancelledError`` on exit to a cancelled parent scope
(`#790 <https://github.com/agronholm/anyio/pull/790>`_; PR by @gschaffner)

**4.6.2**

Expand Down
119 changes: 59 additions & 60 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,11 +372,22 @@ def _task_started(task: asyncio.Task) -> bool:


def is_anyio_cancellation(exc: CancelledError) -> bool:
return (
bool(exc.args)
and isinstance(exc.args[0], str)
and exc.args[0].startswith("Cancelled by cancel scope ")
)
# Sometimes third party frameworks catch a CancelledError and raise a new one, so as
# a workaround we have to look at the previous ones in __context__ too for a
# matching cancel message
while True:
if (
exc.args
and isinstance(exc.args[0], str)
and exc.args[0].startswith("Cancelled by cancel scope ")
):
return True

if isinstance(exc.__context__, CancelledError):
exc = exc.__context__
continue

return False


class CancelScope(BaseCancelScope):
Expand All @@ -397,8 +408,10 @@ def __init__(self, deadline: float = math.inf, shield: bool = False):
self._cancel_handle: asyncio.Handle | None = None
self._tasks: set[asyncio.Task] = set()
self._host_task: asyncio.Task | None = None
self._cancel_calls: int = 0
self._cancelling: int | None = None
if sys.version_info >= (3, 11):
self._pending_uncancellations: int | None = 0
else:
self._pending_uncancellations = None

def __enter__(self) -> CancelScope:
if self._active:
Expand All @@ -424,8 +437,6 @@ def __enter__(self) -> CancelScope:

self._timeout()
self._active = True
if sys.version_info >= (3, 11):
self._cancelling = self._host_task.cancelling()

# Start cancelling the host task if the scope was cancelled before entering
if self._cancel_called:
Expand Down Expand Up @@ -470,30 +481,41 @@ def __exit__(

host_task_state.cancel_scope = self._parent_scope

# Undo all cancellations done by this scope
if self._cancelling is not None:
while self._cancel_calls:
self._cancel_calls -= 1
if self._host_task.uncancel() <= self._cancelling:
break
# Restart the cancellation effort in the closest visible, cancelled parent
# scope if necessary
self._restart_cancellation_in_parent()

# We only swallow the exception iff it was an AnyIO CancelledError, either
# directly as exc_val or inside an exception group and there are no cancelled
# parent cancel scopes visible to us here
not_swallowed_exceptions = 0
swallow_exception = False
if exc_val is not None:
for exc in iterate_exceptions(exc_val):
if self._cancel_called and isinstance(exc, CancelledError):
if not (swallow_exception := self._uncancel(exc)):
not_swallowed_exceptions += 1
else:
not_swallowed_exceptions += 1
if self._cancel_called and not self._parent_cancellation_is_visible_to_us:
# For each level-cancel() call made on the host task, call uncancel()
while self._pending_uncancellations:
self._host_task.uncancel()
self._pending_uncancellations -= 1

# Update cancelled_caught and check for exceptions we must not swallow
cannot_swallow_exc_val = False
if exc_val is not None:
for exc in iterate_exceptions(exc_val):
if isinstance(exc, CancelledError) and is_anyio_cancellation(
exc
):
self._cancelled_caught = True
else:
cannot_swallow_exc_val = True

# Restart the cancellation effort in the closest visible, cancelled parent
# scope if necessary
self._restart_cancellation_in_parent()
return swallow_exception and not not_swallowed_exceptions
return self._cancelled_caught and not cannot_swallow_exc_val
else:
if self._pending_uncancellations:
assert self._parent_scope is not None
assert self._parent_scope._pending_uncancellations is not None
self._parent_scope._pending_uncancellations += (
self._pending_uncancellations
)
self._pending_uncancellations = 0

return False
finally:
self._host_task = None
del exc_val
Expand All @@ -520,31 +542,6 @@ def _parent_cancellation_is_visible_to_us(self) -> bool:
and self._parent_scope._effectively_cancelled
)

def _uncancel(self, cancelled_exc: CancelledError) -> bool:
if self._host_task is None:
self._cancel_calls = 0
return True

while True:
if is_anyio_cancellation(cancelled_exc):
# Only swallow the cancellation exception if it's an AnyIO cancel
# exception and there are no other cancel scopes down the line pending
# cancellation
self._cancelled_caught = (
self._effectively_cancelled
and not self._parent_cancellation_is_visible_to_us
)
return self._cancelled_caught

# Sometimes third party frameworks catch a CancelledError and raise a new
# one, so as a workaround we have to look at the previous ones in
# __context__ too for a matching cancel message
if isinstance(cancelled_exc.__context__, CancelledError):
cancelled_exc = cancelled_exc.__context__
continue

return False

def _timeout(self) -> None:
if self._deadline != math.inf:
loop = get_running_loop()
Expand Down Expand Up @@ -576,8 +573,11 @@ def _deliver_cancellation(self, origin: CancelScope) -> bool:
waiter = task._fut_waiter # type: ignore[attr-defined]
if not isinstance(waiter, asyncio.Future) or not waiter.done():
task.cancel(f"Cancelled by cancel scope {id(origin):x}")
if task is origin._host_task:
origin._cancel_calls += 1
if (
task is origin._host_task
and origin._pending_uncancellations is not None
):
origin._pending_uncancellations += 1

# Deliver cancellation to child scopes that aren't shielded or running their own
# cancellation callbacks
Expand Down Expand Up @@ -2154,12 +2154,11 @@ def has_pending_cancellation(self) -> bool:
# If the task isn't around anymore, it won't have a pending cancellation
return False

if sys.version_info >= (3, 11):
if task.cancelling():
return True
if task._must_cancel: # type: ignore[attr-defined]
return True
elif (
isinstance(task._fut_waiter, asyncio.Future)
and task._fut_waiter.cancelled()
isinstance(task._fut_waiter, asyncio.Future) # type: ignore[attr-defined]
and task._fut_waiter.cancelled() # type: ignore[attr-defined]
):
return True

Expand Down
52 changes: 52 additions & 0 deletions tests/test_taskgroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,38 @@ async def test_cancel_shielded_scope() -> None:
await checkpoint()


async def test_shielded_cleanup_after_cancel() -> None:
"""Regression test for #832."""
with CancelScope() as outer_scope:
outer_scope.cancel()
try:
await checkpoint()
finally:
assert current_effective_deadline() == -math.inf
assert get_current_task().has_pending_cancellation()

with CancelScope(shield=True): # noqa: ASYNC100
assert current_effective_deadline() == math.inf
assert not get_current_task().has_pending_cancellation()

assert current_effective_deadline() == -math.inf
assert get_current_task().has_pending_cancellation()


@pytest.mark.parametrize("anyio_backend", ["asyncio"])
async def test_cleanup_after_native_cancel() -> None:
"""Regression test for #832."""
# See also https://github.com/python/cpython/pull/102815.
task = asyncio.current_task()
assert task
task.cancel()
with pytest.raises(asyncio.CancelledError):
try:
await checkpoint()
finally:
assert not get_current_task().has_pending_cancellation()


async def test_cancelled_not_caught() -> None:
with CancelScope() as scope: # noqa: ASYNC100
scope.cancel()
Expand Down Expand Up @@ -1488,6 +1520,26 @@ async def taskfunc() -> None:
assert str(exc_info.value.exceptions[0]) == "dummy error"
assert not cast(asyncio.Task, asyncio.current_task()).cancelling()

async def test_uncancel_cancelled_scope_based_checkpoint(self) -> None:
"""See also test_cancelled_scope_based_checkpoint."""
task = asyncio.current_task()
assert task

with CancelScope() as outer_scope:
outer_scope.cancel()

try:
# The following three lines are a way to implement a checkpoint
# function. See also https://github.com/python-trio/trio/issues/860.
with CancelScope() as inner_scope:
inner_scope.cancel()
await sleep_forever()
finally:
assert isinstance(sys.exc_info()[1], asyncio.CancelledError)
assert task.cancelling()

assert not task.cancelling()


async def test_cancel_before_entering_task_group() -> None:
with CancelScope() as scope:
Expand Down

0 comments on commit 93a5746

Please sign in to comment.