Skip to content

Commit

Permalink
clean up DelayProxy, fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
kristjanvalur committed Apr 12, 2023
1 parent 6f6b6f6 commit cf2ae67
Showing 1 changed file with 30 additions and 32 deletions.
62 changes: 30 additions & 32 deletions tests/test_asyncio/test_cwe_404.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,31 +17,12 @@ def redis_addr(request):
return host, int(port)


async def pipe(
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
proxy: "DelayProxy",
name="",
event: asyncio.Event = None,
):
while True:
data = await reader.read(1000)
if not data:
break
if event:
event.set()
await asyncio.sleep(proxy.delay)
writer.write(data)
await writer.drain()


class DelayProxy:
def __init__(self, addr, redis_addr, delay: float):
self.addr = addr
self.redis_addr = redis_addr
self.delay = delay
self.send_event = asyncio.Event()
self.redis_streams = None

async def start(self):
# test that we can connect to redis
Expand All @@ -52,31 +33,48 @@ async def start(self):
self.ROUTINE = asyncio.create_task(self.server.serve_forever())

@contextlib.contextmanager
def override(self, delay: float = 0.0):
def set_delay(self, delay: float = 0.0):
"""
Allow to override the delay for parts of tests which aren't time dependent,
to speed up execution.
"""
old = self.delay
old_delay = self.delay
self.delay = delay
try:
yield
finally:
self.delay = old
self.delay = old_delay

async def handle(self, reader, writer):
# establish connection to redis
redis_reader, redis_writer = await asyncio.open_connection(*self.redis_addr)
try:
pipe1 = asyncio.create_task(
pipe(reader, redis_writer, self, "to redis:", self.send_event)
self.pipe(reader, redis_writer, "to redis:", self.send_event)
)
pipe2 = asyncio.create_task(pipe(redis_reader, writer, self, "from redis:"))
pipe2 = asyncio.create_task(self.pipe(redis_reader, writer, "from redis:"))
await asyncio.gather(pipe1, pipe2)
finally:
redis_writer.close()
redis_reader.close()

async def pipe(
self,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
name="",
event: asyncio.Event = None,
):
while True:
data = await reader.read(1000)
if not data:
break
if event:
event.set()
await asyncio.sleep(self.delay)
writer.write(data)
await writer.drain()

async def stop(self):
# clean up enough so that we can reuse the looper
self.ROUTINE.cancel()
Expand All @@ -101,7 +99,7 @@ async def test_standalone(delay, redis_addr):
# note that we connect to proxy, rather than to Redis directly
async with Redis(host="127.0.0.1", port=5380, single_connection_client=b) as r:

with dp.override():
with dp.set_delay(0):
await r.set("foo", "foo")
await r.set("bar", "bar")

Expand All @@ -117,7 +115,7 @@ async def test_standalone(delay, redis_addr):

# make sure that our previous request, cancelled while waiting for
# a repsponse, didn't leave the connection open andin a bad state
with dp.override():
with dp.set_delay(0):
assert await r.get("bar") == b"bar"
assert await r.ping()
assert await r.get("foo") == b"foo"
Expand All @@ -132,7 +130,7 @@ async def test_standalone_pipeline(delay, redis_addr):
await dp.start()
for b in [True, False]:
async with Redis(host="127.0.0.1", port=5380, single_connection_client=b) as r:
with dp.override():
with dp.set_delay(0):
await r.set("foo", "foo")
await r.set("bar", "bar")

Expand All @@ -154,7 +152,7 @@ async def test_standalone_pipeline(delay, redis_addr):

# we have now cancelled the pieline in the middle of a request, make sure
# that the connection is still usable
with dp.override():
with dp.set_delay(0):
pipe.get("bar")
pipe.ping()
pipe.get("foo")
Expand Down Expand Up @@ -205,10 +203,10 @@ async def any_wait():
)

@contextlib.contextmanager
def all_override(delay: int = 0):
def set_delay(delay: int = 0):
with contextlib.ExitStack() as stack:
for p in proxies:
stack.enter_context(p.override(delay=delay))
stack.enter_context(p.delay_as(delay))
yield

# start proxies
Expand All @@ -222,9 +220,9 @@ def all_override(delay: int = 0):
await r.set("bar", "bar")

all_clear()
with all_override(delay=delay):
with set_delay(delay=delay):
t = asyncio.create_task(r.get("foo"))
# cannot wait on the send event, we don't know which node will be used
# One of the proxies will handle our request, wait for it to send
await any_wait()
await asyncio.sleep(delay)
t.cancel()
Expand Down

0 comments on commit cf2ae67

Please sign in to comment.