From 32dda8087c6fd2cc1ede70c56c023d698b5a3ba1 Mon Sep 17 00:00:00 2001 From: Mike Nerone Date: Wed, 3 Jan 2024 16:53:14 -0600 Subject: [PATCH 1/4] Support with or without `aclose()` on all inputs/sources Note: also applying to outputs so async_generator dep no longer needed. --- slurry/_pipeline.py | 9 ++-- slurry/_utils.py | 34 +++++++++++++ slurry/sections/_buffers.py | 8 ++-- slurry/sections/_combiners.py | 10 ++-- slurry/sections/_filters.py | 12 ++--- slurry/sections/_producers.py | 4 +- slurry/sections/_refiners.py | 5 +- slurry/sections/weld.py | 7 +-- tests/conftest.py | 89 +++++++++++++++++++++++++++++++++++ tests/fixtures.py | 59 ++++++----------------- tests/test_buffers.py | 8 ++-- tests/test_combiners.py | 22 ++++----- tests/test_filters.py | 20 ++++---- tests/test_pipeline.py | 13 +++-- tests/test_producers.py | 10 ++-- tests/test_refiners.py | 7 +-- tests/test_threading.py | 12 ++--- 17 files changed, 207 insertions(+), 122 deletions(-) create mode 100644 slurry/_utils.py create mode 100644 tests/conftest.py diff --git a/slurry/_pipeline.py b/slurry/_pipeline.py index a30922f..6452033 100644 --- a/slurry/_pipeline.py +++ b/slurry/_pipeline.py @@ -1,14 +1,15 @@ """Contains the main Slurry ``Pipeline`` class.""" import math +from contextlib import asynccontextmanager from typing import Any, AsyncGenerator import trio -from async_generator import aclosing, asynccontextmanager +from .sections.abc import PipelineSection from .sections.weld import weld from ._tap import Tap -from .sections.abc import PipelineSection +from ._utils import safe_aclose, safe_aclosing class Pipeline: """The main Slurry ``Pipeline`` class. @@ -53,7 +54,7 @@ async def _pump(self): output = weld(nursery, *self.sections) # Output to taps - async with aclosing(output) as aiter: + async with safe_aclosing(output) as aiter: async for item in aiter: self._taps = set(filter(lambda tap: not tap.closed, self._taps)) if not self._taps: @@ -64,7 +65,7 @@ async def _pump(self): # There is no more output to send. Close the taps. for tap in self._taps: - await tap.send_channel.aclose() + await safe_aclose(tap.send_channel) def tap(self, *, max_buffer_size: int = 0, diff --git a/slurry/_utils.py b/slurry/_utils.py new file mode 100644 index 0000000..99d1a77 --- /dev/null +++ b/slurry/_utils.py @@ -0,0 +1,34 @@ +from typing import ( + AsyncGenerator, + AsyncIterable, + AsyncIterator, + Awaitable, + Protocol, + TypeVar, + Union, + runtime_checkable, +) + +from contextlib import asynccontextmanager + +_T_co = TypeVar("_T_co", covariant=True) + +@asynccontextmanager +async def safe_aclosing( + obj: Union[AsyncIterable[_T_co], AsyncIterator[_T_co]] +) -> AsyncGenerator[AsyncIterator[_T_co], None]: + if not isinstance(obj, AsyncIterator): + obj = obj.__aiter__() + try: + yield obj + finally: + await safe_aclose(obj) + +async def safe_aclose(obj: AsyncIterator[_T_co]) -> None: + if isinstance(obj, _SupportsAclose): + await obj.aclose() + +@runtime_checkable +class _SupportsAclose(Protocol): + def aclose(self) -> Awaitable[object]: + ... diff --git a/slurry/sections/_buffers.py b/slurry/sections/_buffers.py index 6d51567..48e1140 100644 --- a/slurry/sections/_buffers.py +++ b/slurry/sections/_buffers.py @@ -4,9 +4,9 @@ from typing import Any, AsyncIterable, Callable, Optional, Sequence import trio -from async_generator import aclosing from ..environments import TrioSection +from .._utils import safe_aclosing class Window(TrioSection): """Window buffer with size and age limits. @@ -51,7 +51,7 @@ async def refine(self, input, output): buf = deque() - async with aclosing(source) as aiter: + async with safe_aclosing(source) as aiter: async for item in aiter: now = trio.current_time() buf.append((item, now)) @@ -111,7 +111,7 @@ async def refine(self, input, output): send_channel, receive_channel = trio.open_memory_channel(0) async def pull_task(): - async with send_channel, aclosing(source) as aiter: + async with send_channel, safe_aclosing(source) as aiter: async for item in aiter: await send_channel.send(item) nursery.start_soon(pull_task) @@ -169,7 +169,7 @@ async def refine(self, input, output): buffer_input_channel, buffer_output_channel = trio.open_memory_channel(math.inf) async def pull_task(): - async with buffer_input_channel, aclosing(source) as aiter: + async with buffer_input_channel, safe_aclosing(source) as aiter: async for item in aiter: await buffer_input_channel.send((item, trio.current_time() + self.interval)) diff --git a/slurry/sections/_combiners.py b/slurry/sections/_combiners.py index 7c79383..b11ff2e 100644 --- a/slurry/sections/_combiners.py +++ b/slurry/sections/_combiners.py @@ -3,11 +3,11 @@ import itertools import trio -from async_generator import aclosing from ..environments import TrioSection from .abc import PipelineSection from .weld import weld +from .._utils import safe_aclose, safe_aclosing class Chain(TrioSection): """Chains input from one or more sources. Any valid ``PipelineSection`` is an allowed source. @@ -41,7 +41,7 @@ async def refine(self, input, output): sources = self.sources async with trio.open_nursery() as nursery: for source in sources: - async with aclosing(weld(nursery, source)) as agen: + async with safe_aclosing(weld(nursery, source)) as agen: async for item in agen: await output(item) @@ -67,7 +67,7 @@ async def refine(self, input, output): async with trio.open_nursery() as nursery: async def pull_task(source): - async with aclosing(weld(nursery, source)) as aiter: + async with safe_aclosing(weld(nursery, source)) as aiter: async for item in aiter: await output(item) @@ -126,7 +126,7 @@ async def pull_task(source, index, results: list): await output(tuple(result for i, result in sorted(results, key=lambda packet: packet[0]))) for source in sources: - await source.aclose() + await safe_aclose(source) class ZipLatest(TrioSection): """Zips input from multiple sources and outputs a result on every received item. Any valid @@ -204,7 +204,7 @@ async def refine(self, input, output): async with trio.open_nursery() as nursery: async def pull_task(index, source, monitor=False): - async with aclosing(weld(nursery, source)) as aiter: + async with safe_aclosing(weld(nursery, source)) as aiter: async for item in aiter: results[index] = item ready[index] = True diff --git a/slurry/sections/_filters.py b/slurry/sections/_filters.py index ad20ea0..bf95bb2 100644 --- a/slurry/sections/_filters.py +++ b/slurry/sections/_filters.py @@ -1,10 +1,10 @@ """Pipeline sections that filters the incoming items.""" from typing import Any, AsyncIterable, Callable, Hashable, Optional, Union -from async_generator import aclosing import trio from ..environments import TrioSection +from .._utils import safe_aclosing class Skip(TrioSection): """Skips the first ``count`` items in an asynchronous sequence. @@ -29,7 +29,7 @@ async def refine(self, input, output): else: raise RuntimeError('No input provided.') - async with aclosing(source.__aiter__()) as aiter: + async with safe_aclosing(source) as aiter: try: for _ in range(self.count): await aiter.__anext__() @@ -64,7 +64,7 @@ async def refine(self, input, output): else: raise RuntimeError('No input provided.') - async with aclosing(source) as aiter: + async with safe_aclosing(source) as aiter: async for item in aiter: if not self.pred(item): await output(item) @@ -98,7 +98,7 @@ async def refine(self, input, output): else: raise RuntimeError('No input provided.') - async with aclosing(source) as aiter: + async with safe_aclosing(source) as aiter: async for item in aiter: if self.func(item): await output(item) @@ -133,7 +133,7 @@ async def refine(self, input, output): token = object() last = token - async with aclosing(source) as aiter: + async with safe_aclosing(source) as aiter: async for item in aiter: if last is token or item != last: last = item @@ -185,7 +185,7 @@ async def refine(self, input, output): get_subject = lambda item: item[self.subject] timestamps = {} - async with aclosing(source) as aiter: + async with safe_aclosing(source) as aiter: async for item in aiter: now = trio.current_time() subject = get_subject(item) diff --git a/slurry/sections/_producers.py b/slurry/sections/_producers.py index 1f1272c..4e3f8f5 100644 --- a/slurry/sections/_producers.py +++ b/slurry/sections/_producers.py @@ -3,9 +3,9 @@ from typing import Any import trio -from async_generator import aclosing from ..environments import TrioSection +from .._utils import safe_aclosing class Repeat(TrioSection): """Yields a single item repeatedly at regular intervals. @@ -54,7 +54,7 @@ async def repeater(item, *, task_status=trio.TASK_STATUS_IGNORED): running_repeater = await nursery.start(repeater, self.default) if input: - async with aclosing(input) as aiter: + async with safe_aclosing(input) as aiter: async for item in aiter: if running_repeater: running_repeater.cancel() diff --git a/slurry/sections/_refiners.py b/slurry/sections/_refiners.py index f5da718..ed1cc73 100644 --- a/slurry/sections/_refiners.py +++ b/slurry/sections/_refiners.py @@ -1,9 +1,8 @@ """Sections for transforming an input into a different output.""" from typing import Any, AsyncIterable, Optional -from async_generator import aclosing - from ..environments import TrioSection +from .._utils import safe_aclosing class Map(TrioSection): """Maps over an asynchronous sequence. @@ -27,6 +26,6 @@ async def refine(self, input, output): else: raise RuntimeError('No input provided.') - async with aclosing(source) as aiter: + async with safe_aclosing(source) as aiter: async for item in aiter: await output(self.func(item)) diff --git a/slurry/sections/weld.py b/slurry/sections/weld.py index c8934e2..f7e9144 100644 --- a/slurry/sections/weld.py +++ b/slurry/sections/weld.py @@ -5,6 +5,7 @@ import trio from .abc import PipelineSection, Section +from .._utils import safe_aclose def weld(nursery, *sections: PipelineSection) -> AsyncIterable[Any]: """ @@ -21,9 +22,9 @@ async def pump(section, input: Optional[AsyncIterable[Any]], output: trio.Memory await section.pump(input, output.send) except trio.BrokenResourceError: pass - if input and hasattr(input, "aclose") and callable(input.aclose): - await input.aclose() - await output.aclose() + if input: + await safe_aclose(input) + await safe_aclose(output) section_input = None output = None diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..fd9f80e --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,89 @@ +import string +from functools import wraps + +import pytest +import trio + +from slurry._utils import safe_aclose + +def fixture_gen_with_and_without_aclose(async_gen): + + def fixture_func(_with_aclose): + if _with_aclose: + new_async_gen = async_gen + else: + @wraps(async_gen) + def new_async_gen(*args, **kwargs): + return AsyncIteratorWithoutAclose(async_gen(*args, **kwargs)) + return new_async_gen + + fixture_func.__name__ = async_gen.__name__ + fixture_func.__qualname__ = async_gen.__name__ + + return pytest.fixture(fixture_func) + +@pytest.fixture(params=[True, False], ids=["with_aclose", "without_aclose"]) +def _with_aclose(request): + return request.param + +@fixture_gen_with_and_without_aclose +async def produce_increasing_integers(interval, *, max=3, delay=0): + await trio.sleep(delay) + for i in range(max): + yield i + if i == max-1: + break + await trio.sleep(interval) + +@fixture_gen_with_and_without_aclose +async def produce_alphabet(interval, *, max=3, delay=0): + await trio.sleep(delay) + for i, c in enumerate(string.ascii_lowercase): + yield c + if i == max - 1: + break + await trio.sleep(interval) + +@pytest.fixture() +def spam_wait_spam_integers(produce_increasing_integers): + async def spam_wait_spam_integers(interval): + async for i in produce_increasing_integers(.1, max=5, delay=.1): + yield i + await trio.sleep(interval) + async for i in produce_increasing_integers(.1, max=5, delay=.1): + yield i + + return spam_wait_spam_integers + +@fixture_gen_with_and_without_aclose +async def produce_mappings(interval): + vehicles = [ + {'vehicle': 'motorcycle'}, + {'vehicle': 'car'}, + {'vehicle': 'motorcycle'}, + {'vehicle': 'autocamper'}, + {'vehicle': 'car'}, + {'vehicle': 'car'}, + {'vehicle': 'truck'}, + {'vehicle': 'car'}, + {'vehicle': 'motorcycle'}, + ] + + for i, vehicle in enumerate(vehicles): + vehicle['number'] = i + yield vehicle + await trio.sleep(interval) + +class AsyncIteratorWithoutAclose: + def __init__(self, source_aiterable): + self.source_aiter = source_aiterable.__aiter__() + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return await self.source_aiter.__anext__() + except StopAsyncIteration: + await safe_aclose(self.source_aiter) + raise diff --git a/tests/fixtures.py b/tests/fixtures.py index 9d81da5..258aa29 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -1,54 +1,10 @@ """Asynchronous generators for testing sections.""" import math -import string from typing import Any, Callable, Iterable -import trio - from slurry.environments import ThreadSection, ProcessSection -async def produce_increasing_integers(interval, *, max=3, delay=0): - await trio.sleep(delay) - for i in range(max): - yield i - if i == max-1: - break - await trio.sleep(interval) - -async def produce_alphabet(interval, *, max=3, delay=0): - await trio.sleep(delay) - for i, c in enumerate(string.ascii_lowercase): - yield c - if i == max - 1: - break - await trio.sleep(interval) - -async def spam_wait_spam_integers(interval): - async for i in produce_increasing_integers(.1, max=5, delay=.1): - yield i - await trio.sleep(interval) - async for i in produce_increasing_integers(.1, max=5, delay=.1): - yield i - -async def produce_mappings(interval): - vehicles = [ - {'vehicle': 'motorcycle'}, - {'vehicle': 'car'}, - {'vehicle': 'motorcycle'}, - {'vehicle': 'autocamper'}, - {'vehicle': 'car'}, - {'vehicle': 'car'}, - {'vehicle': 'truck'}, - {'vehicle': 'car'}, - {'vehicle': 'motorcycle'}, - ] - - for i, vehicle in enumerate(vehicles): - vehicle['number'] = i - yield vehicle - await trio.sleep(interval) - class SyncSquares(ThreadSection): def __init__(self, raise_after=math.inf) -> None: self.raise_after = raise_after @@ -86,3 +42,18 @@ def __init__(self, source_aiterable): def __aiter__(self): return self.source_aiterable.__aiter__() + +class AsyncIteratorWithoutAclose: + def __init__(self, source_aiterable): + self.source_aiter = source_aiterable.__aiter__() + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return await self.source_aiter.__anext__() + except StopAsyncIteration: + if hasattr(self.source_aiter, "aclose"): + await self.source_aiter.aclose() + raise diff --git a/tests/test_buffers.py b/tests/test_buffers.py index 9c5ef62..d4f4c3e 100644 --- a/tests/test_buffers.py +++ b/tests/test_buffers.py @@ -3,23 +3,21 @@ from slurry import Pipeline from slurry.sections import Window, Group, Delay -from .fixtures import produce_increasing_integers, spam_wait_spam_integers - -async def test_window(autojump_clock): +async def test_window(produce_increasing_integers, autojump_clock): async with Pipeline.create( Window(3, produce_increasing_integers(1, max=5)) ) as pipeline, pipeline.tap() as aiter: result = [item async for item in aiter] assert result == [(0,), (0, 1), (0, 1, 2), (1, 2, 3), (2, 3, 4)] -async def test_group_max_size(autojump_clock): +async def test_group_max_size(produce_increasing_integers, autojump_clock): async with Pipeline.create( Group(2.5, produce_increasing_integers(1, max=5), max_size=3) ) as pipeline, pipeline.tap() as aiter: result = [item async for item in aiter] assert result == [(0, 1, 2), (3, 4)] -async def test_group_timeout(autojump_clock): +async def test_group_timeout(spam_wait_spam_integers, autojump_clock): async with Pipeline.create( Group(2.5, spam_wait_spam_integers(5)) ) as pipeline, pipeline.tap() as aiter: diff --git a/tests/test_combiners.py b/tests/test_combiners.py index 4a731fc..458fbb0 100644 --- a/tests/test_combiners.py +++ b/tests/test_combiners.py @@ -1,23 +1,21 @@ from slurry import Pipeline from slurry.sections import Chain, Merge, Zip, ZipLatest, Repeat, Map, Skip -from .fixtures import produce_increasing_integers, produce_alphabet - -async def test_chain(autojump_clock): +async def test_chain(produce_increasing_integers, produce_alphabet, autojump_clock): async with Pipeline.create( Chain(produce_increasing_integers(1, max=3), produce_alphabet(1, max=3)) ) as pipeline, pipeline.tap() as aiter: result = [i async for i in aiter] assert result == [0, 1, 2, 'a', 'b', 'c'] -async def test_merge(autojump_clock): +async def test_merge(produce_increasing_integers, produce_alphabet, autojump_clock): async with Pipeline.create( Merge(produce_increasing_integers(1, max=3), produce_alphabet(1, max=3, delay=0.1)) ) as pipeline, pipeline.tap() as aiter: result = [i async for i in aiter] assert result == [0, 'a', 1, 'b', 2, 'c'] -async def test_merge_section(autojump_clock): +async def test_merge_section(produce_increasing_integers, autojump_clock): async with Pipeline.create( Merge(produce_increasing_integers(1, max=3, delay=0.5), Repeat(1, default='a')) ) as pipeline, pipeline.tap() as aiter: @@ -28,7 +26,7 @@ async def test_merge_section(autojump_clock): break assert result == ['a', 0, 'a', 1, 'a', 2] -async def test_merge_pipeline_section(autojump_clock): +async def test_merge_pipeline_section(produce_increasing_integers, autojump_clock): async with Pipeline.create( Merge(produce_increasing_integers(1, max=3, delay=0.5), ( @@ -43,7 +41,7 @@ async def test_merge_pipeline_section(autojump_clock): break assert result == ['ax', 0, 'ax', 1, 'ax', 2] -async def test_zip(autojump_clock): +async def test_zip(produce_increasing_integers, produce_alphabet, autojump_clock): async with Pipeline.create( Zip(produce_increasing_integers(1), produce_alphabet(0.9)) ) as pipeline: @@ -51,7 +49,7 @@ async def test_zip(autojump_clock): results = [item async for item in aiter] assert results == [(0,'a'), (1, 'b'), (2, 'c')] -async def test_zip_pipeline_section(autojump_clock): +async def test_zip_pipeline_section(produce_increasing_integers, produce_alphabet, autojump_clock): async with Pipeline.create( Zip( ( @@ -63,10 +61,10 @@ async def test_zip_pipeline_section(autojump_clock): Map(lambda item: item + 'x') )) ) as pipeline, pipeline.tap() as aiter: - results = [item async for item in aiter] - assert results == [(2,'ax'), (3, 'bx'), (4, 'cx')] + results = [item async for item in aiter] + assert results == [(2,'ax'), (3, 'bx'), (4, 'cx')] -async def test_zip_latest(autojump_clock): +async def test_zip_latest(produce_increasing_integers, produce_alphabet, autojump_clock): async with Pipeline.create( ZipLatest( produce_increasing_integers(1, max=3), @@ -75,7 +73,7 @@ async def test_zip_latest(autojump_clock): result = [item async for item in aiter] assert result == [(0, None), (0, 'a'), (1, 'a'), (1, 'b'), (2, 'b')] -async def test_zip_latest_pipeline_section(autojump_clock): +async def test_zip_latest_pipeline_section(produce_increasing_integers, produce_alphabet, autojump_clock): async with Pipeline.create( ZipLatest( ( diff --git a/tests/test_filters.py b/tests/test_filters.py index 72ac907..e7e7469 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -1,44 +1,44 @@ from slurry import Pipeline from slurry.sections import Merge, RateLimit, Skip, SkipWhile, Filter, Changes -from .fixtures import AsyncNonIteratorIterable, produce_increasing_integers, produce_mappings +from .fixtures import AsyncNonIteratorIterable -async def test_skip(autojump_clock): +async def test_skip(produce_increasing_integers, autojump_clock): async with Pipeline.create( Skip(5, produce_increasing_integers(1, max=10)) ) as pipeline, pipeline.tap() as aiter: result = [i async for i in aiter] assert result == [5, 6, 7, 8, 9] -async def test_skip_input_non_iterator_iterable(autojump_clock): +async def test_skip_input_non_iterator_iterable(produce_increasing_integers, autojump_clock): async with Pipeline.create( Skip(5, AsyncNonIteratorIterable(produce_increasing_integers(1, max=10))) ) as pipeline, pipeline.tap() as aiter: result = [i async for i in aiter] assert result == [5, 6, 7, 8, 9] -async def test_skip_short_stream(autojump_clock): +async def test_skip_short_stream(produce_increasing_integers, autojump_clock): async with Pipeline.create( Skip(5, produce_increasing_integers(1)) ) as pipeline, pipeline.tap() as aiter: result = [i async for i in aiter] assert result == [] -async def test_skipwhile(autojump_clock): +async def test_skipwhile(produce_increasing_integers, autojump_clock): async with Pipeline.create( SkipWhile(lambda x: x < 3, produce_increasing_integers(1, max=5)) ) as pipeline, pipeline.tap() as aiter: result = [i async for i in aiter] assert result == [3, 4] -async def test_filter(autojump_clock): +async def test_filter(produce_increasing_integers, autojump_clock): async with Pipeline.create( Filter(lambda x: x%2, produce_increasing_integers(1, max=10)) ) as pipeline, pipeline.tap() as aiter: result = [i async for i in aiter] assert result == [1, 3, 5, 7, 9] -async def test_changes(autojump_clock): +async def test_changes(produce_increasing_integers, autojump_clock): async with Pipeline.create( Merge( produce_increasing_integers(1, max=5), @@ -49,21 +49,21 @@ async def test_changes(autojump_clock): result = [i async for i in aiter] assert result == [0, 1, 2, 3, 4] -async def test_ratelimit(autojump_clock): +async def test_ratelimit(produce_mappings, autojump_clock): async with Pipeline.create( RateLimit(1, produce_mappings(0.5)) ) as pipeline, pipeline.tap() as aiter: result = [item['number'] async for item in aiter] assert result == [0, 3, 6] -async def test_ratelimit_str_subject(autojump_clock): +async def test_ratelimit_str_subject(produce_mappings, autojump_clock): async with Pipeline.create( RateLimit(1, produce_mappings(0.5), subject='vehicle') ) as pipeline, pipeline.tap() as aiter: result = [item['number'] async for item in aiter] assert result == [0,1,3,4,6,7,8] -async def test_ratelimit_callable_subject(autojump_clock): +async def test_ratelimit_callable_subject(produce_mappings, autojump_clock): async with Pipeline.create( RateLimit(1, produce_mappings(0.5), subject=lambda item: item['vehicle'][2]) ) as pipeline, pipeline.tap() as aiter: diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index b294b9f..599c1e9 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -5,13 +5,13 @@ from slurry.sections import Map from slurry.environments import TrioSection -from .fixtures import produce_increasing_integers +from .fixtures import AsyncIteratorWithoutAclose async def test_pipeline_create(autojump_clock): async with Pipeline.create(None): await trio.sleep(1) -async def test_pipeline_passthrough(autojump_clock): +async def test_pipeline_passthrough(produce_increasing_integers, autojump_clock): async with Pipeline.create(produce_increasing_integers(1)) as pipeline: async with pipeline.tap() as aiter: result = [i async for i in aiter] @@ -47,20 +47,19 @@ async def refine(self, input, output): assert isinstance(i, int) break -async def test_welding(autojump_clock): +async def test_welding(produce_increasing_integers, autojump_clock): async with Pipeline.create( produce_increasing_integers(1), - (Map(lambda i: i+1),) - ) as pipeline: + (Map(lambda i: i + 1),) + ) as pipeline: async with pipeline.tap() as aiter: result = [i async for i in aiter] assert result == [1, 2, 3] -async def test_welding_two_generator_functions_not_allowed(autojump_clock): +async def test_welding_two_generator_functions_not_allowed(produce_increasing_integers, autojump_clock): with pytest.raises(ValueError): async with Pipeline.create( produce_increasing_integers(1), produce_increasing_integers(1), ) as pipeline, pipeline.tap() as aiter: result = [i async for i in aiter] - diff --git a/tests/test_producers.py b/tests/test_producers.py index 095691b..3b94222 100644 --- a/tests/test_producers.py +++ b/tests/test_producers.py @@ -4,8 +4,6 @@ from slurry import Pipeline from slurry.sections import Repeat, Metronome, InsertValue, _producers -from .fixtures import produce_alphabet - async def test_repeat_valid_args(): with pytest.raises(RuntimeError): async with Pipeline.create( @@ -38,7 +36,7 @@ async def test_repeat_kwargs(autojump_clock): break assert results == [('a', 0), ('a', 1), ('a', 2), ('a', 3), ('a', 4)] -async def test_repeat_input(autojump_clock): +async def test_repeat_input(produce_alphabet, autojump_clock): results = [] async with Pipeline.create( produce_alphabet(1.5, max=3, delay=1), @@ -51,7 +49,7 @@ async def test_repeat_input(autojump_clock): break assert results == [('a', 1), ('a', 2), ('b', 2.5), ('b', 3.5), ('c', 4)] -async def test_metronome(autojump_clock, monkeypatch): +async def test_metronome(produce_alphabet, autojump_clock, monkeypatch): monkeypatch.setattr(_producers, "time", trio.current_time) async with Pipeline.create( produce_alphabet(5, max=6, delay=1), @@ -62,7 +60,7 @@ async def test_metronome(autojump_clock, monkeypatch): results.append((item, trio.current_time())) assert results == [(letter, 5.0 * (i + 1)) for i, letter in enumerate("abcde")] -async def test_metronome_no_input(autojump_clock, monkeypatch): +async def test_metronome_no_input(produce_alphabet, autojump_clock, monkeypatch): monkeypatch.setattr(_producers, "time", trio.current_time) async with Pipeline.create( Metronome(5, "a") @@ -73,7 +71,7 @@ async def test_metronome_no_input(autojump_clock, monkeypatch): results.append((item, trio.current_time())) assert results == [("a", 5.0 * (i + 1)) for i in range(5)] -async def test_insert_value(autojump_clock): +async def test_insert_value(produce_alphabet, autojump_clock): async with Pipeline.create( produce_alphabet(1, max=3, delay=1), InsertValue('n') diff --git a/tests/test_refiners.py b/tests/test_refiners.py index ccfa42f..130e7de 100644 --- a/tests/test_refiners.py +++ b/tests/test_refiners.py @@ -1,10 +1,7 @@ from slurry import Pipeline -from slurry.sections import Map, Delay -import trio +from slurry.sections import Map -from .fixtures import produce_increasing_integers - -async def test_map(autojump_clock): +async def test_map(produce_increasing_integers, autojump_clock): async with Pipeline.create( Map(lambda x: x*x, produce_increasing_integers(1, max=5)) ) as pipeline, pipeline.tap() as aiter: diff --git a/tests/test_threading.py b/tests/test_threading.py index add6744..51b9ad5 100644 --- a/tests/test_threading.py +++ b/tests/test_threading.py @@ -2,9 +2,9 @@ from slurry import Pipeline from slurry.sections import Map -from .fixtures import AsyncNonIteratorIterable, produce_increasing_integers, SyncSquares +from .fixtures import AsyncNonIteratorIterable, SyncSquares -async def test_thread_section(autojump_clock): +async def test_thread_section(produce_increasing_integers, autojump_clock): async with Pipeline.create( produce_increasing_integers(1, max=5), SyncSquares() @@ -12,7 +12,7 @@ async def test_thread_section(autojump_clock): result = [i async for i in aiter] assert result == [0, 1, 4, 9, 16] -async def test_thread_section_input_non_iterator_iterable(autojump_clock): +async def test_thread_section_input_non_iterator_iterable(produce_increasing_integers, autojump_clock): async with Pipeline.create( AsyncNonIteratorIterable(produce_increasing_integers(1, max=5)), SyncSquares() @@ -20,7 +20,7 @@ async def test_thread_section_input_non_iterator_iterable(autojump_clock): result = [i async for i in aiter] assert result == [0, 1, 4, 9, 16] -async def test_thread_section_early_break(autojump_clock): +async def test_thread_section_early_break(produce_increasing_integers, autojump_clock): async with Pipeline.create( produce_increasing_integers(1, max=5), SyncSquares() @@ -30,7 +30,7 @@ async def test_thread_section_early_break(autojump_clock): break assert i == 4 -async def test_thread_section_exception(autojump_clock): +async def test_thread_section_exception(produce_increasing_integers, autojump_clock): with pytest.raises(RuntimeError): async with Pipeline.create( produce_increasing_integers(1, max=5), @@ -40,7 +40,7 @@ async def test_thread_section_exception(autojump_clock): pass assert i == 9 -async def test_thread_section_section_input(autojump_clock): +async def test_thread_section_section_input(produce_increasing_integers, autojump_clock): async with Pipeline.create( produce_increasing_integers(1), Map(lambda i: i), From 23b572a950417724454393b8066529e301c5645d Mon Sep 17 00:00:00 2001 From: Mike Nerone Date: Wed, 7 Feb 2024 22:09:55 -0600 Subject: [PATCH 2/4] Remove async-generator dependency --- docs/requirements.txt | 1 - poetry.lock | 13 +------------ pyproject.toml | 1 - 3 files changed, 1 insertion(+), 14 deletions(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index f90dbc1..0db6675 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,4 +1,3 @@ Sphinx >= 6.2.1 sphinx-rtd-theme >= 2.0.0 trio >= 0.23.0 -async-generator >= 1.10 \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index 7b3d501..dfcee57 100644 --- a/poetry.lock +++ b/poetry.lock @@ -25,17 +25,6 @@ files = [ [package.dependencies] typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""} -[[package]] -name = "async-generator" -version = "1.10" -description = "Async generators and context managers for Python 3.5+" -optional = false -python-versions = ">=3.5" -files = [ - {file = "async_generator-1.10-py3-none-any.whl", hash = "sha256:01c7bf666359b4967d2cda0000cc2e4af16a0ae098cbffcb8472fb9e8ad6585b"}, - {file = "async_generator-1.10.tar.gz", hash = "sha256:6ebb3d106c12920aaae42ccb6f787ef5eefdcdd166ea3d628fa8476abe712144"}, -] - [[package]] name = "attrs" version = "23.2.0" @@ -928,4 +917,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "a1fc3dc1690cb6a67e4aa96f402c5f80a4ab2dec899ef2fd3c4a4e304a2dd54e" +content-hash = "c904e42805d647e8fce60e9b667283fa912850a651f8d671311728a03597a671" diff --git a/pyproject.toml b/pyproject.toml index 752c89d..d42c76c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,6 @@ classifiers = [ [tool.poetry.dependencies] -async-generator = "^1.10" python = "^3.8" trio = ">=0.23.0" From a3683e11914b87b9db9451626ae001aa05f88e23 Mon Sep 17 00:00:00 2001 From: Mike Nerone Date: Fri, 15 Mar 2024 21:37:52 -0500 Subject: [PATCH 3/4] Unused import --- tests/test_pipeline.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 599c1e9..c22e6fb 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -5,8 +5,6 @@ from slurry.sections import Map from slurry.environments import TrioSection -from .fixtures import AsyncIteratorWithoutAclose - async def test_pipeline_create(autojump_clock): async with Pipeline.create(None): await trio.sleep(1) From ca05d557792b6c5216e968ce99d5e5072e13d4b3 Mon Sep 17 00:00:00 2001 From: Mike Nerone Date: Fri, 15 Mar 2024 21:54:30 -0500 Subject: [PATCH 4/4] Redundant copy of AsyncIteratorWithoutAclose class --- tests/conftest.py | 16 +--------------- tests/fixtures.py | 4 ++-- 2 files changed, 3 insertions(+), 17 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index fd9f80e..d73fb90 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,7 +4,7 @@ import pytest import trio -from slurry._utils import safe_aclose +from .fixtures import AsyncIteratorWithoutAclose def fixture_gen_with_and_without_aclose(async_gen): @@ -73,17 +73,3 @@ async def produce_mappings(interval): vehicle['number'] = i yield vehicle await trio.sleep(interval) - -class AsyncIteratorWithoutAclose: - def __init__(self, source_aiterable): - self.source_aiter = source_aiterable.__aiter__() - - def __aiter__(self): - return self - - async def __anext__(self): - try: - return await self.source_aiter.__anext__() - except StopAsyncIteration: - await safe_aclose(self.source_aiter) - raise diff --git a/tests/fixtures.py b/tests/fixtures.py index 258aa29..f1f9484 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -4,6 +4,7 @@ from typing import Any, Callable, Iterable from slurry.environments import ThreadSection, ProcessSection +from slurry._utils import safe_aclose class SyncSquares(ThreadSection): def __init__(self, raise_after=math.inf) -> None: @@ -54,6 +55,5 @@ async def __anext__(self): try: return await self.source_aiter.__anext__() except StopAsyncIteration: - if hasattr(self.source_aiter, "aclose"): - await self.source_aiter.aclose() + await safe_aclose(self.source_aiter) raise