Skip to content

Commit

Permalink
Merge pull request #17 from mikenerone/mikenerone/fix-ensure-use-of-i…
Browse files Browse the repository at this point in the history
…terators

Fix risky uses of `__anext__()`
  • Loading branch information
andersea authored Jan 19, 2024
2 parents bdb467a + 428e033 commit a47bafe
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 6 deletions.
3 changes: 2 additions & 1 deletion slurry/environments/_threading.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ def sync_input():
"""Wrapper for turning an async iterable into a blocking generator."""
if input is None:
return
input_aiter = input.__aiter__()
try:
while True:
yield trio.from_thread.run(input.__anext__)
yield trio.from_thread.run(input_aiter.__anext__)
except StopAsyncIteration:
pass

Expand Down
2 changes: 1 addition & 1 deletion slurry/sections/_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ async def refine(self, input, output):
else:
raise RuntimeError('No input provided.')

async with aclosing(source) as aiter:
async with aclosing(source.__aiter__()) as aiter:
try:
for _ in range(self.count):
await aiter.__anext__()
Expand Down
2 changes: 1 addition & 1 deletion slurry/sections/weld.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ async def pump(section, input: Optional[AsyncIterable[Any]], output: trio.Memory
await section.pump(input, output.send)
except trio.BrokenResourceError:
pass
if input:
if input and hasattr(input, "aclose") and callable(input.aclose):
await input.aclose()
await output.aclose()

Expand Down
7 changes: 7 additions & 0 deletions tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,10 @@ def fibonacci(self, i):
def refine(self, input: Iterable[Any], output: Callable[[Any], None]):
for i in range(self.i):
output(self.fibonacci(i))

class AsyncNonIteratorIterable:
def __init__(self, source_aiterable):
self.source_aiterable = source_aiterable

def __aiter__(self):
return self.source_aiterable.__aiter__()
9 changes: 8 additions & 1 deletion tests/test_filters.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from slurry import Pipeline
from slurry.sections import Merge, RateLimit, Skip, SkipWhile, Filter, Changes

from .fixtures import produce_increasing_integers, produce_mappings
from .fixtures import AsyncNonIteratorIterable, produce_increasing_integers, produce_mappings

async def test_skip(autojump_clock):
async with Pipeline.create(
Expand All @@ -10,6 +10,13 @@ async def test_skip(autojump_clock):
result = [i async for i in aiter]
assert result == [5, 6, 7, 8, 9]

async def test_skip_input_non_iterator_iterable(autojump_clock):
async with Pipeline.create(
Skip(5, AsyncNonIteratorIterable(produce_increasing_integers(1, max=10)))
) as pipeline, pipeline.tap() as aiter:
result = [i async for i in aiter]
assert result == [5, 6, 7, 8, 9]

async def test_skip_short_stream(autojump_clock):
async with Pipeline.create(
Skip(5, produce_increasing_integers(1))
Expand Down
12 changes: 10 additions & 2 deletions tests/test_threading.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from slurry import Pipeline
from slurry.sections import Map

from .fixtures import produce_increasing_integers, SyncSquares
from .fixtures import AsyncNonIteratorIterable, produce_increasing_integers, SyncSquares

async def test_thread_section(autojump_clock):
async with Pipeline.create(
Expand All @@ -12,6 +12,14 @@ async def test_thread_section(autojump_clock):
result = [i async for i in aiter]
assert result == [0, 1, 4, 9, 16]

async def test_thread_section_input_non_iterator_iterable(autojump_clock):
async with Pipeline.create(
AsyncNonIteratorIterable(produce_increasing_integers(1, max=5)),
SyncSquares()
) as pipeline, pipeline.tap() as aiter:
result = [i async for i in aiter]
assert result == [0, 1, 4, 9, 16]

async def test_thread_section_early_break(autojump_clock):
async with Pipeline.create(
produce_increasing_integers(1, max=5),
Expand Down Expand Up @@ -39,4 +47,4 @@ async def test_thread_section_section_input(autojump_clock):
SyncSquares()
) as pipeline, pipeline.tap() as aiter:
result = [i async for i in aiter]
assert result == [0, 1, 4]
assert result == [0, 1, 4]

0 comments on commit a47bafe

Please sign in to comment.