Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP fix(core): allow requests to be queued in CONNECTING state (#374) #583

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions kazoo/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,13 +561,14 @@ def _notify_pending(self, state):
except IndexError:
break

while True:
try:
request, async_object = self._queue.popleft()
if async_object:
async_object.set_exception(exc)
except IndexError:
break
if state != KeeperState.CONNECTING:
while True:
try:
request, async_object = self._queue.popleft()
if async_object:
async_object.set_exception(exc)
except IndexError:
break

def _safe_close(self):
self.handler.stop()
Expand All @@ -584,7 +585,7 @@ def _call(self, request, async_object):
the queue if there is.

Returns False if the call short circuits due to AUTH_FAILED,
CLOSED, EXPIRED_SESSION or CONNECTING state.
CLOSED, or EXPIRED_SESSION state.

"""

Expand All @@ -595,8 +596,7 @@ def _call(self, request, async_object):
async_object.set_exception(ConnectionClosedError(
"Connection has been closed"))
return False
elif self._state in (KeeperState.EXPIRED_SESSION,
KeeperState.CONNECTING):
elif self._state == KeeperState.EXPIRED_SESSION:
async_object.set_exception(SessionExpiredError())
return False

Expand Down
17 changes: 9 additions & 8 deletions kazoo/protocol/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@

# Special testing hook objects used to force a session expired error as
# if it came from the server
_SESSION_EXPIRED = object()
_CONNECTION_DROP = object()
_SESSION_EXPIRED = (SessionExpiredError, "Session expired: Testing")
_CONNECTION_DROP = (ConnectionDropped, "Connection dropped: Testing")

STOP_CONNECTING = object()

Expand Down Expand Up @@ -454,12 +454,13 @@ def _send_request(self, read_timeout, connect_timeout):
pass
return

# Special case for testing, if this is a _SessionExpire object
# then throw a SessionExpiration error as if we were dropped
if request is _SESSION_EXPIRED:
raise SessionExpiredError("Session expired: Testing")
if request is _CONNECTION_DROP:
raise ConnectionDropped("Connection dropped: Testing")
# Special case for testing, if this is a _SESSION_EXPIRED or
# _CONNECTION_DROP object, throw the corresponding error.
if request is _SESSION_EXPIRED or request is _CONNECTION_DROP:
client._queue.popleft()
self._read_sock.recv(1)
error, arg = request
raise error(arg)

# Special case for auth packets
if request.type == Auth.type:
Expand Down
58 changes: 55 additions & 3 deletions kazoo/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,15 +592,67 @@ def test_create_on_broken_connection(self):
self.assertRaises(AuthFailedError, client.create,
'/closedpath', b'bar')

client._state = KeeperState.CONNECTING
self.assertRaises(SessionExpiredError, client.create,
'/closedpath', b'bar')
client.stop()
client.close()

self.assertRaises(ConnectionClosedError, client.create,
'/closedpath', b'bar')

def test_create_queued_while_connecting(self):
# This is a bit tricky: we are trying to show that a request
# can be queued while the connection is down. So we create a
# client with a fairly large but stable reconnect delay, and
# perform a number of attempts.
handler = self._makeOne()
sleep_func = handler.sleep_func
client = self._get_client(
handler=handler,
connection_retry=dict(
max_tries=-1,
delay=0.5,
backoff=1,
max_jitter=0.0,
sleep_func=sleep_func
)
)
client.start()

max_attempts = 5
attempts = 0
was_queued = False

while not was_queued and attempts < max_attempts:
attempts += 1

try:
client.delete('/queued')
except NoNodeError:
pass

# Shut the socket down, and wait for the client to
# transition to non-connected.
client._connection._socket.shutdown(socket.SHUT_RDWR)
while client.connected:
sleep_func(0.001)

# Issue an async request, assuming the client hasn't had
# enough time to reconnect.
result = client.create_async('/queued', None)

# Consider the test potentially satisfactory if client
# still hasn't had time to reconnect.
was_queued = not client.connected

# Wait for expected result.
self.assertEqual(result.get(), '/queued')
self.assertEqual(client.connected, True)

# Fail if no "window" was observed despite max_attempts.
self.assertEqual(was_queued, True)

client.stop()
client.close()

def test_create_null_data(self):
client = self.client
client.create("/nulldata", None)
Expand Down
112 changes: 88 additions & 24 deletions kazoo/tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
import sys

from nose import SkipTest
from nose.tools import eq_
from nose.tools import raises
from nose.tools import eq_, raises, assert_not_equal
import mock

from kazoo.exceptions import ConnectionLoss
from kazoo.exceptions import ConnectionLoss, SessionExpiredError
from kazoo.protocol.serialization import (
Connect,
int_struct,
Expand Down Expand Up @@ -100,53 +99,115 @@ def back(state):

def test_connection_write_timeout(self):
client = self.client
ev = threading.Event()
ev_suspended = threading.Event()
ev_connected = threading.Event()
path = "/" + uuid.uuid4().hex
handler = client.handler
_select = handler.select
_socket = client._connection._socket

def delayed_select(*args, **kwargs):
result = _select(*args, **kwargs)
if _socket in args[1]:
if client._connection._socket in args[1]:
# for any socket write, simulate a timeout
return [], [], []
return result

def back(state):
if state == KazooState.CONNECTED:
ev.set()
client.add_listener(back)
def listener(state):
if state == KazooState.SUSPENDED:
ev_suspended.set()
elif state == KazooState.CONNECTED:
ev_connected.set()
client.add_listener(listener)

try:
handler.select = delayed_select
self.assertRaises(ConnectionLoss, client.create, path)
result = client.create_async(path)
ev_suspended.wait(5)
eq_(ev_suspended.is_set(), True)
assert_not_equal(len(self.client._queue), 0)
finally:
handler.select = _select
# the client reconnects automatically, and the queued request
# is submitted.
ev_connected.wait(5)
eq_(ev_connected.is_set(), True)
eq_(result.get(), path)
assert_not_equal(client.exists(path), None)

def test_connection_lost_empties_queue(self):
client = self.client
ev_suspended = threading.Event()
ev_lost = threading.Event()
ev_connected = threading.Event()
path = "/" + uuid.uuid4().hex
handler = client.handler
_select = handler.select

def delayed_select(*args, **kwargs):
result = _select(*args, **kwargs)
if client._connection._socket in args[1]:
# for any socket write, simulate a timeout
return [], [], []
return result

def expiring_select(*args, **kwargs):
result = _select(*args, **kwargs)
if client._connection._socket in args[1]:
raise SessionExpiredError("Session expired: Testing")
return result

def listener(state):
if state == KazooState.SUSPENDED:
ev_suspended.set()
elif state == KazooState.LOST:
ev_lost.set()
elif state == KazooState.CONNECTED:
ev_connected.set()
client.add_listener(listener)

try:
handler.select = delayed_select
result = client.create_async(path)
ev_suspended.wait(5)
eq_(ev_suspended.is_set(), True)
assert_not_equal(len(self.client._queue), 0)

handler.select = expiring_select
# the client transitions to EXPIRED_SESSION, which is a closed
# state, causing the queue to be flushed.
ev_lost.wait(5)
eq_(ev_lost.is_set(), True)
self.assertRaises(SessionExpiredError, result.get)
eq_(len(self.client._queue), 0)
finally:
handler.select = _select

# the client reconnects automatically
ev.wait(5)
eq_(ev.is_set(), True)
ev_connected.wait(5)
eq_(ev_connected.is_set(), True)
eq_(client.exists(path), None)

def test_connection_deserialize_fail(self):
client = self.client
ev = threading.Event()
ev_suspended = threading.Event()
ev_connected = threading.Event()
path = "/" + uuid.uuid4().hex
handler = client.handler
_select = handler.select
_socket = client._connection._socket

def delayed_select(*args, **kwargs):
result = _select(*args, **kwargs)
if _socket in args[1]:
if client._connection._socket in args[1]:
# for any socket write, simulate a timeout
return [], [], []
return result

def back(state):
if state == KazooState.CONNECTED:
ev.set()
client.add_listener(back)
def listener(state):
if state == KazooState.SUSPENDED:
ev_suspended.set()
elif state == KazooState.CONNECTED:
ev_connected.set()
client.add_listener(listener)

deserialize_ev = threading.Event()

Expand All @@ -163,7 +224,9 @@ def bad_deserialize(_bytes, offset):
mock_deserialize.side_effect = bad_deserialize
try:
handler.select = delayed_select
self.assertRaises(ConnectionLoss, client.create, path)
result = client.create_async(path)
ev_suspended.wait(5)
eq_(ev_suspended.is_set(), True)
finally:
handler.select = _select
# the client reconnects automatically but the first attempt will
Expand All @@ -172,9 +235,10 @@ def bad_deserialize(_bytes, offset):
eq_(deserialize_ev.is_set(), True)

# this time should succeed
ev.wait(5)
eq_(ev.is_set(), True)
eq_(client.exists(path), None)
ev_connected.wait(5)
eq_(ev_connected.is_set(), True)
eq_(result.get(), path)
assert_not_equal(client.exists(path), None)

def test_connection_close(self):
self.assertRaises(Exception, self.client.close)
Expand Down