diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 69dca3ba..7271d560 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -34,6 +34,6 @@ jobs: flake8 . --count --max-complexity=10 --max-line-length=127 --statistics - name: Test with pytest run: | - pytest --cov=flux_led --cov-report term-missing --cov-report xml -- tests.py + pytest --cov=flux_led --cov-report term-missing --cov-report xml -- tests.py tests_aio.py - name: Upload codecov uses: codecov/codecov-action@v2 diff --git a/.vscode/tasks.json b/.vscode/tasks.json index ce172958..90c032f9 100644 --- a/.vscode/tasks.json +++ b/.vscode/tasks.json @@ -4,7 +4,7 @@ { "label": "Pytest", "type": "shell", - "command": "pytest tests.py", + "command": "pytest tests.py tests_aio.py", "dependsOn": ["Install all Test Requirements"], "group": { "kind": "test", @@ -49,7 +49,7 @@ "label": "Code Coverage", "detail": "Generate code coverage report", "type": "shell", - "command": "pytest tests.py/ --cov=flux_led --cov-report term-missing", + "command": "pytest tests.py tests_aio.py --cov=flux_led --cov-report term-missing", "group": { "kind": "test", "isDefault": true diff --git a/flux_led/aiodevice.py b/flux_led/aiodevice.py index 9335b887..fe0d1cc9 100644 --- a/flux_led/aiodevice.py +++ b/flux_led/aiodevice.py @@ -72,18 +72,19 @@ async def _async_turn_off_on(self) -> None: async def async_turn_on(self) -> None: """Turn on the device.""" - await self._async_turn_on_with_retry() - self._set_power_state_ignore_next_push(self._protocol.on_byte) + if await self._async_turn_on_with_retry(): + self._set_power_state_ignore_next_push(self._protocol.on_byte) - async def _async_turn_on_with_retry(self) -> None: + async def _async_turn_on_with_retry(self) -> bool: calls = (self._async_turn_on, self._async_turn_off_on, self._async_turn_on) for idx, call in enumerate(calls): if ( await self._async_execute_and_wait_for(self._on_futures, call) or self.is_on ): - return + return True _LOGGER.debug("Failed to turn on (%s/%s)", 1 + idx, len(calls)) + return False async def _async_turn_off(self) -> None: await self._async_send_msg(self._protocol.construct_state_change(False)) @@ -94,18 +95,19 @@ async def _async_turn_on_off(self) -> None: async def async_turn_off(self) -> None: """Turn off the device.""" - await self._async_turn_off_with_retry() - self._set_power_state_ignore_next_push(self._protocol.off_byte) + if await self._async_turn_off_with_retry(): + self._set_power_state_ignore_next_push(self._protocol.off_byte) - async def _async_turn_off_with_retry(self) -> None: + async def _async_turn_off_with_retry(self) -> bool: calls = (self._async_turn_off, self._async_turn_on_off, self._async_turn_off) for idx, call in enumerate(calls): if ( await self._async_execute_and_wait_for(self._off_futures, call) or not self.is_on ): - return + return True _LOGGER.debug("Failed to turn off (%s/%s)", 1 + idx, len(calls)) + return False async def async_set_white_temp(self, temperature, brightness, persist=True) -> None: """Set the white tempature.""" @@ -189,43 +191,24 @@ def _async_data_recieved(self, data): """New data on the socket.""" start_empty_buffer = not self._buffer self._buffer += data - buffer = self._buffer self._updates_without_response = 0 - msg_length = len(buffer) - protocol = self._protocol - # Some of the older bulbs respond to a state request in - # multiple packets so we have to reassemble. - if protocol.is_start_of_state_response(buffer): - if not protocol.is_valid_state_response( - buffer - ) and not protocol.is_longer_than_state_response(buffer): - return - msg_length = protocol.state_response_length - if protocol.is_start_of_power_state_response(buffer): - if not protocol.is_valid_power_state_response( - buffer - ) and not protocol.is_longer_than_power_state_response(buffer): + msg_length = len(self._buffer) + while msg_length: + expected_length = self._protocol.expected_response_length(self._buffer) + if msg_length < expected_length: + # need more bytes return - msg_length = protocol.state_response_length - elif self.addressable: - # The addressable bulbs can send a state response inside an addressable response - if protocol.is_start_of_addressable_response(buffer): - if not protocol.is_valid_addressable_response( - buffer - ) and not protocol.is_longer_than_addressable_response(buffer): - return - msg_length = protocol.addressable_response_length - - msg = buffer[:msg_length] - self._buffer = buffer[msg_length:] - if not start_empty_buffer: - _LOGGER.debug( - "%s <= Reassembled (%s) (%d)", - self._aio_protocol.peername, - " ".join(f"0x{x:02X}" for x in msg), - len(msg), - ) - self._async_process_message(msg) + msg = self._buffer[:expected_length] + self._buffer = self._buffer[expected_length:] + msg_length = len(self._buffer) + if not start_empty_buffer: + _LOGGER.debug( + "%s <= Reassembled (%s) (%d)", + self._aio_protocol.peername, + " ".join(f"0x{x:02X}" for x in msg), + len(msg), + ) + self._async_process_message(msg) def _async_process_message(self, msg): """Process a full message (maybe reassembled).""" diff --git a/flux_led/protocol.py b/flux_led/protocol.py index 13204412..cf27e329 100755 --- a/flux_led/protocol.py +++ b/flux_led/protocol.py @@ -42,6 +42,33 @@ LEDENET_ORIGINAL_STATE_RESPONSE_LEN = 11 LEDENET_STATE_RESPONSE_LEN = 14 +LEDENET_POWER_RESPONSE_LEN = 4 +LEDENET_ADDRESSABLE_STATE_RESPONSE_LEN = 25 + +MSG_ORIGINAL_POWER_STATE = "original_power_state" +MSG_ORIGINAL_STATE = "original_state" + +MSG_POWER_STATE = "power_state" +MSG_STATE = "state" + +MSG_ADDRESSABLE_STATE = "addressable_state" + +MSG_FIRST_BYTE = { + 0xF0: MSG_POWER_STATE, + 0x00: MSG_POWER_STATE, + 0x0F: MSG_POWER_STATE, + 0x78: MSG_ORIGINAL_POWER_STATE, + 0x66: MSG_ORIGINAL_STATE, + 0x81: MSG_STATE, + 0xB0: MSG_ADDRESSABLE_STATE, +} +MSG_LENGTHS = { + MSG_POWER_STATE: LEDENET_POWER_RESPONSE_LEN, + MSG_ORIGINAL_POWER_STATE: LEDENET_POWER_RESPONSE_LEN, + MSG_ORIGINAL_STATE: LEDENET_ORIGINAL_STATE_RESPONSE_LEN, + MSG_STATE: LEDENET_STATE_RESPONSE_LEN, + MSG_ADDRESSABLE_STATE: LEDENET_ADDRESSABLE_STATE_RESPONSE_LEN, +} LEDENET_BASE_STATE = [ STATE_HEAD, @@ -116,7 +143,15 @@ class ProtocolBase: """The base protocol.""" - power_state_response_length = 4 + power_state_response_length = MSG_LENGTHS[MSG_POWER_STATE] + + def expected_response_length(self, data): + """Return the number of bytes expected in the response. + + If the response is unknown, we assume the response is + a complete message since we have no way of knowing otherwise. + """ + return MSG_LENGTHS.get(MSG_FIRST_BYTE.get(data[0]), len(data)) @abstractmethod def construct_state_query(self): @@ -130,10 +165,6 @@ def is_valid_state_response(self, raw_state): def is_start_of_state_response(self, data): """Check if a message is the start of a state response.""" - def is_longer_than_state_response(self, data): - """Check if a message is longer than a valid state response.""" - return len(data) > self.state_response_length - def is_checksum_correct(self, msg): """Check a checksum of a message.""" expected_sum = sum(msg[0:-1]) & 0xFF @@ -152,10 +183,6 @@ def is_valid_power_state_response(self, msg): def is_start_of_power_state_response(self, data): """Check if a message is the start of a power response.""" - def is_longer_than_power_state_response(self, data): - """Check if a message is longer than a valid power response.""" - return len(data) > self.state_response_length - @property def on_byte(self): """The on byte.""" @@ -250,25 +277,23 @@ def state_response_length(self): def is_valid_power_state_response(self, msg): """Check if a power state response is valid.""" - # We do not have dumps of the original ledenet - # protocol (these devices are no longer made). - # If we get them in the future, we can - # implement push updates for these devices by - # matching how is_valid_power_state_response works - # for the newer protocol - return False + return len(msg) == self.power_state_response_length and msg[0] == 0x78 def is_valid_state_response(self, raw_state): """Check if a state response is valid.""" - return len(raw_state) == self.state_response_length and raw_state[1] == 0x01 + return ( + len(raw_state) == self.state_response_length + and raw_state[0] == 0x66 + and raw_state[1] == 0x01 + ) def is_start_of_state_response(self, data): """Check if a message is the start of a state response.""" - return False + return data[0] == 0x66 def is_start_of_power_state_response(self, data): """Check if a message is the start of a state response.""" - return False + return data[0] == 0x78 def construct_state_query(self): """The bytes to send for a query request.""" @@ -330,7 +355,7 @@ def is_valid_power_state_response(self, msg): def is_start_of_power_state_response(self, data): """Check if a message is the start of a state response.""" - return len(data) >= 1 and data[0] in (0xF0, 0x00, 0x0F) + return len(data) >= 1 and MSG_FIRST_BYTE[data[0]] == MSG_POWER_STATE def is_start_of_state_response(self, data): """Check if a message is the start of a state response.""" @@ -465,7 +490,7 @@ def construct_preset_pattern(self, pattern, speed): class ProtocolLEDENETAddressable(ProtocolLEDENET9Byte): ADDRESSABLE_HEADER = [0xB0, 0xB1, 0xB2, 0xB3, 0x00, 0x01, 0x01] - addressable_response_length = 25 + addressable_response_length = MSG_LENGTHS[MSG_ADDRESSABLE_STATE] def __init__(self): self._counter = 0 @@ -483,10 +508,6 @@ def is_valid_addressable_response(self, data): return False return self.is_checksum_correct(data) - def is_longer_than_addressable_response(self, data): - """Check if a message is longer than a valid addressable response.""" - return len(data) > self.addressable_response_length - @property def name(self): """The name of the protocol.""" diff --git a/requirements_test.txt b/requirements_test.txt index c009f548..697f0184 100644 --- a/requirements_test.txt +++ b/requirements_test.txt @@ -1,3 +1,4 @@ pylint==2.11.1 +pytest-asyncio==0.16.0 pytest-cov==3.0.0 -flake8==4.0.1 \ No newline at end of file +flake8==4.0.1 diff --git a/tests.py b/tests.py index fd6b25e0..6d2e67c8 100644 --- a/tests.py +++ b/tests.py @@ -554,7 +554,7 @@ def read_data(expected): return bytearray(b"") if calls == 2: self.assertEqual(expected, 2) - return bytearray(b"\f\x01") + return bytearray(b"f\x01") if calls == 3: self.assertEqual(expected, 9) return bytearray(b"#A!\x08\xff\x80*\x01\x99") diff --git a/tests_aio.py b/tests_aio.py new file mode 100644 index 00000000..64710f78 --- /dev/null +++ b/tests_aio.py @@ -0,0 +1,188 @@ +import asyncio +from unittest.mock import MagicMock, patch +import logging + +import pytest +import contextlib + +from flux_led import aiodevice +from flux_led.aio import AIOWifiLedBulb +from flux_led.aioprotocol import AIOLEDENETProtocol +from flux_led.const import COLOR_MODE_CCT, COLOR_MODE_RGBWW +from flux_led.protocol import PROTOCOL_LEDENET_9BYTE + + +@pytest.fixture +async def mock_aio_protocol(): + """Fixture to mock an asyncio connection.""" + loop = asyncio.get_running_loop() + future = asyncio.Future() + + async def _wait_for_connection(): + await future + await asyncio.sleep(0) + await asyncio.sleep(0) + + async def _mock_create_connection(func, ip, port): + protocol: AIOLEDENETProtocol = func() + transport = MagicMock() + protocol.connection_made(transport) + with contextlib.suppress(asyncio.InvalidStateError): + future.set_result(True) + return transport, protocol + + with patch.object(loop, "create_connection", _mock_create_connection): + yield _wait_for_connection + + +@pytest.mark.asyncio +async def test_reassemble(mock_aio_protocol): + """Test we can reassemble.""" + light = AIOWifiLedBulb("192.168.1.166") + + def _updated_callback(*args, **kwargs): + pass + + task = asyncio.create_task(light.async_setup(_updated_callback)) + await mock_aio_protocol() + light._aio_protocol.data_received( + b"\x81\x25\x23\x61\x05\x10\xb6\x00\x98\x19\x04\x25\x0f\xde" + ) + await task + assert light.color_modes == {COLOR_MODE_RGBWW, COLOR_MODE_CCT} + assert light.protocol == PROTOCOL_LEDENET_9BYTE + assert light.model_num == 0x25 + assert light.model == "WiFi RGBCW Controller (0x25)" + assert light.is_on is True + assert len(light.effect_list) == 20 + + light._aio_protocol.data_received( + b"\x81\x25\x23\x61\x05\x10\xb6\x00\x98\x19\x04\x25\x0f\xde" + b"\x81\x25\x24\x61\x05\x10\xb6\x00\x98\x19\x04\x25\x0f\xdf" + ) + await asyncio.sleep(0) + assert light.is_on is False + + light._aio_protocol.data_received(b"\x81") + light._aio_protocol.data_received( + b"\x25\x23\x61\x05\x10\xb6\x00\x98\x19\x04\x25\x0f" + ) + light._aio_protocol.data_received(b"\xde") + await asyncio.sleep(0) + assert light.is_on is True + + +@pytest.mark.asyncio +async def test_turn_on_off(mock_aio_protocol, caplog: pytest.LogCaptureFixture): + """Test we can turn on and off.""" + light = AIOWifiLedBulb("192.168.1.166") + + def _updated_callback(*args, **kwargs): + pass + + task = asyncio.create_task(light.async_setup(_updated_callback)) + await mock_aio_protocol() + light._aio_protocol.data_received( + b"\x81\x25\x23\x61\x05\x10\xb6\x00\x98\x19\x04\x25\x0f\xde" + ) + await task + + task = asyncio.create_task(light.async_turn_off()) + # Wait for the future to get added + await asyncio.sleep(0) + light._aio_protocol.data_received( + b"\x81\x25\x24\x61\x05\x10\xb6\x00\x98\x19\x04\x25\x0f\xdf" + ) + await asyncio.sleep(0) + assert light.is_on is False + await task + + task = asyncio.create_task(light.async_turn_on()) + await asyncio.sleep(0) + light._aio_protocol.data_received( + b"\x81\x25\x23\x61\x05\x10\xb6\x00\x98\x19\x04\x25\x0f\xde" + ) + await asyncio.sleep(0) + assert light.is_on is True + await task + + await asyncio.sleep(0) + caplog.clear() + caplog.set_level(logging.DEBUG) + # Handle the failure case + with patch.object(aiodevice, "POWER_STATE_TIMEOUT", 0.05): + await asyncio.create_task(light.async_turn_off()) + assert light.is_on is True + assert "Failed to turn off (1/3)" in caplog.text + assert "Failed to turn off (2/3)" in caplog.text + assert "Failed to turn off (3/3)" in caplog.text + + with patch.object(aiodevice, "POWER_STATE_TIMEOUT", 0.05): + task = asyncio.create_task(light.async_turn_off()) + # Do NOT wait for the future to get added, we know the retry logic works + light._aio_protocol.data_received( + b"\x81\x25\x24\x61\x05\x10\xb6\x00\x98\x19\x04\x25\x0f\xdf" + ) + await asyncio.sleep(0) + assert light.is_on is False + await task + + await asyncio.sleep(0) + caplog.clear() + caplog.set_level(logging.DEBUG) + # Handle the failure case + with patch.object(aiodevice, "POWER_STATE_TIMEOUT", 0.05): + await asyncio.create_task(light.async_turn_on()) + assert light.is_on is False + assert "Failed to turn on (1/3)" in caplog.text + assert "Failed to turn on (2/3)" in caplog.text + assert "Failed to turn on (3/3)" in caplog.text + + +@pytest.mark.asyncio +async def test_shutdown(mock_aio_protocol): + """Test we can shutdown.""" + light = AIOWifiLedBulb("192.168.1.166") + + def _updated_callback(*args, **kwargs): + pass + + task = asyncio.create_task(light.async_setup(_updated_callback)) + await mock_aio_protocol() + light._aio_protocol.data_received( + b"\x81\x25\x23\x61\x05\x10\xb6\x00\x98\x19\x04\x25\x0f\xde" + ) + await task + + await light.async_stop() + await asyncio.sleep(0) # make sure nothing throws + + +@pytest.mark.asyncio +async def test_handling_connection_lost(mock_aio_protocol): + """Test we can reconnect.""" + light = AIOWifiLedBulb("192.168.1.166") + + def _updated_callback(*args, **kwargs): + pass + + task = asyncio.create_task(light.async_setup(_updated_callback)) + await mock_aio_protocol() + light._aio_protocol.data_received( + b"\x81\x25\x23\x61\x05\x10\xb6\x00\x98\x19\x04\x25\x0f\xde" + ) + await task + + light._aio_protocol.connection_lost(None) + await asyncio.sleep(0) # make sure nothing throws + + # Test we reconnect and can turn off + task = asyncio.create_task(light.async_turn_off()) + # Wait for the future to get added + await asyncio.sleep(0.1) # wait for reconnect + light._aio_protocol.data_received( + b"\x81\x25\x24\x61\x05\x10\xb6\x00\x98\x19\x04\x25\x0f\xdf" + ) + await asyncio.sleep(0) + assert light.is_on is False + await task