Skip to content

Commit

Permalink
Merge pull request #19 from mikenerone/mikenerone/safe_aclosing
Browse files Browse the repository at this point in the history
Full support for all async iterables/iterators as inputs and sources (even if they don't have an `aclose()`)
  • Loading branch information
andersea authored Apr 16, 2024
2 parents 9d862bb + ca05d55 commit c4fef46
Show file tree
Hide file tree
Showing 20 changed files with 193 additions and 137 deletions.
1 change: 0 additions & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
Sphinx >= 6.2.1
sphinx-rtd-theme >= 2.0.0
trio >= 0.23.0
async-generator >= 1.10
13 changes: 1 addition & 12 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ classifiers = [


[tool.poetry.dependencies]
async-generator = "^1.10"
python = "^3.8"
trio = ">=0.23.0"

Expand Down
9 changes: 5 additions & 4 deletions slurry/_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
"""Contains the main Slurry ``Pipeline`` class."""

import math
from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator

import trio
from async_generator import aclosing, asynccontextmanager

from .sections.abc import PipelineSection
from .sections.weld import weld
from ._tap import Tap
from .sections.abc import PipelineSection
from ._utils import safe_aclose, safe_aclosing

class Pipeline:
"""The main Slurry ``Pipeline`` class.
Expand Down Expand Up @@ -53,7 +54,7 @@ async def _pump(self):
output = weld(nursery, *self.sections)

# Output to taps
async with aclosing(output) as aiter:
async with safe_aclosing(output) as aiter:
async for item in aiter:
self._taps = set(filter(lambda tap: not tap.closed, self._taps))
if not self._taps:
Expand All @@ -64,7 +65,7 @@ async def _pump(self):

# There is no more output to send. Close the taps.
for tap in self._taps:
await tap.send_channel.aclose()
await safe_aclose(tap.send_channel)

def tap(self, *,
max_buffer_size: int = 0,
Expand Down
34 changes: 34 additions & 0 deletions slurry/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import (
AsyncGenerator,
AsyncIterable,
AsyncIterator,
Awaitable,
Protocol,
TypeVar,
Union,
runtime_checkable,
)

from contextlib import asynccontextmanager

_T_co = TypeVar("_T_co", covariant=True)

@asynccontextmanager
async def safe_aclosing(
obj: Union[AsyncIterable[_T_co], AsyncIterator[_T_co]]
) -> AsyncGenerator[AsyncIterator[_T_co], None]:
if not isinstance(obj, AsyncIterator):
obj = obj.__aiter__()
try:
yield obj
finally:
await safe_aclose(obj)

async def safe_aclose(obj: AsyncIterator[_T_co]) -> None:
if isinstance(obj, _SupportsAclose):
await obj.aclose()

@runtime_checkable
class _SupportsAclose(Protocol):
def aclose(self) -> Awaitable[object]:
...
8 changes: 4 additions & 4 deletions slurry/sections/_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from typing import Any, AsyncIterable, Callable, Optional, Sequence

import trio
from async_generator import aclosing

from ..environments import TrioSection
from .._utils import safe_aclosing

class Window(TrioSection):
"""Window buffer with size and age limits.
Expand Down Expand Up @@ -51,7 +51,7 @@ async def refine(self, input, output):

buf = deque()

async with aclosing(source) as aiter:
async with safe_aclosing(source) as aiter:
async for item in aiter:
now = trio.current_time()
buf.append((item, now))
Expand Down Expand Up @@ -111,7 +111,7 @@ async def refine(self, input, output):

send_channel, receive_channel = trio.open_memory_channel(0)
async def pull_task():
async with send_channel, aclosing(source) as aiter:
async with send_channel, safe_aclosing(source) as aiter:
async for item in aiter:
await send_channel.send(item)
nursery.start_soon(pull_task)
Expand Down Expand Up @@ -169,7 +169,7 @@ async def refine(self, input, output):
buffer_input_channel, buffer_output_channel = trio.open_memory_channel(math.inf)

async def pull_task():
async with buffer_input_channel, aclosing(source) as aiter:
async with buffer_input_channel, safe_aclosing(source) as aiter:
async for item in aiter:
await buffer_input_channel.send((item, trio.current_time() + self.interval))

Expand Down
10 changes: 5 additions & 5 deletions slurry/sections/_combiners.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import itertools

import trio
from async_generator import aclosing

from ..environments import TrioSection
from .abc import PipelineSection
from .weld import weld
from .._utils import safe_aclose, safe_aclosing

class Chain(TrioSection):
"""Chains input from one or more sources. Any valid ``PipelineSection`` is an allowed source.
Expand Down Expand Up @@ -41,7 +41,7 @@ async def refine(self, input, output):
sources = self.sources
async with trio.open_nursery() as nursery:
for source in sources:
async with aclosing(weld(nursery, source)) as agen:
async with safe_aclosing(weld(nursery, source)) as agen:
async for item in agen:
await output(item)

Expand All @@ -67,7 +67,7 @@ async def refine(self, input, output):
async with trio.open_nursery() as nursery:

async def pull_task(source):
async with aclosing(weld(nursery, source)) as aiter:
async with safe_aclosing(weld(nursery, source)) as aiter:
async for item in aiter:
await output(item)

Expand Down Expand Up @@ -126,7 +126,7 @@ async def pull_task(source, index, results: list):
await output(tuple(result for i, result in sorted(results, key=lambda packet: packet[0])))

for source in sources:
await source.aclose()
await safe_aclose(source)

class ZipLatest(TrioSection):
"""Zips input from multiple sources and outputs a result on every received item. Any valid
Expand Down Expand Up @@ -204,7 +204,7 @@ async def refine(self, input, output):
async with trio.open_nursery() as nursery:

async def pull_task(index, source, monitor=False):
async with aclosing(weld(nursery, source)) as aiter:
async with safe_aclosing(weld(nursery, source)) as aiter:
async for item in aiter:
results[index] = item
ready[index] = True
Expand Down
12 changes: 6 additions & 6 deletions slurry/sections/_filters.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Pipeline sections that filters the incoming items."""
from typing import Any, AsyncIterable, Callable, Hashable, Optional, Union

from async_generator import aclosing
import trio

from ..environments import TrioSection
from .._utils import safe_aclosing

class Skip(TrioSection):
"""Skips the first ``count`` items in an asynchronous sequence.
Expand All @@ -29,7 +29,7 @@ async def refine(self, input, output):
else:
raise RuntimeError('No input provided.')

async with aclosing(source.__aiter__()) as aiter:
async with safe_aclosing(source) as aiter:
try:
for _ in range(self.count):
await aiter.__anext__()
Expand Down Expand Up @@ -64,7 +64,7 @@ async def refine(self, input, output):
else:
raise RuntimeError('No input provided.')

async with aclosing(source) as aiter:
async with safe_aclosing(source) as aiter:
async for item in aiter:
if not self.pred(item):
await output(item)
Expand Down Expand Up @@ -98,7 +98,7 @@ async def refine(self, input, output):
else:
raise RuntimeError('No input provided.')

async with aclosing(source) as aiter:
async with safe_aclosing(source) as aiter:
async for item in aiter:
if self.func(item):
await output(item)
Expand Down Expand Up @@ -133,7 +133,7 @@ async def refine(self, input, output):

token = object()
last = token
async with aclosing(source) as aiter:
async with safe_aclosing(source) as aiter:
async for item in aiter:
if last is token or item != last:
last = item
Expand Down Expand Up @@ -185,7 +185,7 @@ async def refine(self, input, output):
get_subject = lambda item: item[self.subject]

timestamps = {}
async with aclosing(source) as aiter:
async with safe_aclosing(source) as aiter:
async for item in aiter:
now = trio.current_time()
subject = get_subject(item)
Expand Down
4 changes: 2 additions & 2 deletions slurry/sections/_producers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from typing import Any

import trio
from async_generator import aclosing

from ..environments import TrioSection
from .._utils import safe_aclosing

class Repeat(TrioSection):
"""Yields a single item repeatedly at regular intervals.
Expand Down Expand Up @@ -54,7 +54,7 @@ async def repeater(item, *, task_status=trio.TASK_STATUS_IGNORED):
running_repeater = await nursery.start(repeater, self.default)

if input:
async with aclosing(input) as aiter:
async with safe_aclosing(input) as aiter:
async for item in aiter:
if running_repeater:
running_repeater.cancel()
Expand Down
5 changes: 2 additions & 3 deletions slurry/sections/_refiners.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""Sections for transforming an input into a different output."""
from typing import Any, AsyncIterable, Optional

from async_generator import aclosing

from ..environments import TrioSection
from .._utils import safe_aclosing

class Map(TrioSection):
"""Maps over an asynchronous sequence.
Expand All @@ -27,6 +26,6 @@ async def refine(self, input, output):
else:
raise RuntimeError('No input provided.')

async with aclosing(source) as aiter:
async with safe_aclosing(source) as aiter:
async for item in aiter:
await output(self.func(item))
7 changes: 4 additions & 3 deletions slurry/sections/weld.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import trio

from .abc import PipelineSection, Section
from .._utils import safe_aclose

def weld(nursery, *sections: PipelineSection) -> AsyncIterable[Any]:
"""
Expand All @@ -21,9 +22,9 @@ async def pump(section, input: Optional[AsyncIterable[Any]], output: trio.Memory
await section.pump(input, output.send)
except trio.BrokenResourceError:
pass
if input and hasattr(input, "aclose") and callable(input.aclose):
await input.aclose()
await output.aclose()
if input:
await safe_aclose(input)
await safe_aclose(output)

section_input = None
output = None
Expand Down
75 changes: 75 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import string
from functools import wraps

import pytest
import trio

from .fixtures import AsyncIteratorWithoutAclose

def fixture_gen_with_and_without_aclose(async_gen):

def fixture_func(_with_aclose):
if _with_aclose:
new_async_gen = async_gen
else:
@wraps(async_gen)
def new_async_gen(*args, **kwargs):
return AsyncIteratorWithoutAclose(async_gen(*args, **kwargs))
return new_async_gen

fixture_func.__name__ = async_gen.__name__
fixture_func.__qualname__ = async_gen.__name__

return pytest.fixture(fixture_func)

@pytest.fixture(params=[True, False], ids=["with_aclose", "without_aclose"])
def _with_aclose(request):
return request.param

@fixture_gen_with_and_without_aclose
async def produce_increasing_integers(interval, *, max=3, delay=0):
await trio.sleep(delay)
for i in range(max):
yield i
if i == max-1:
break
await trio.sleep(interval)

@fixture_gen_with_and_without_aclose
async def produce_alphabet(interval, *, max=3, delay=0):
await trio.sleep(delay)
for i, c in enumerate(string.ascii_lowercase):
yield c
if i == max - 1:
break
await trio.sleep(interval)

@pytest.fixture()
def spam_wait_spam_integers(produce_increasing_integers):
async def spam_wait_spam_integers(interval):
async for i in produce_increasing_integers(.1, max=5, delay=.1):
yield i
await trio.sleep(interval)
async for i in produce_increasing_integers(.1, max=5, delay=.1):
yield i

return spam_wait_spam_integers

@fixture_gen_with_and_without_aclose
async def produce_mappings(interval):
vehicles = [
{'vehicle': 'motorcycle'},
{'vehicle': 'car'},
{'vehicle': 'motorcycle'},
{'vehicle': 'autocamper'},
{'vehicle': 'car'},
{'vehicle': 'car'},
{'vehicle': 'truck'},
{'vehicle': 'car'},
{'vehicle': 'motorcycle'},
]

for i, vehicle in enumerate(vehicles):
vehicle['number'] = i
yield vehicle
await trio.sleep(interval)
Loading

0 comments on commit c4fef46

Please sign in to comment.