Skip to content

Commit

Permalink
modbus_server: call execute in a way that those can be either corouti…
Browse files Browse the repository at this point in the history
…nes or normal methods
  • Loading branch information
ilkka-ollakka committed Mar 31, 2024
1 parent 9e9e50e commit 170fa30
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion pymodbus/server/async_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def __init__(self, owner):
self.receive_queue: asyncio.Queue = asyncio.Queue()
self.handler_task = None # coroutine to be run on asyncio loop
self.framer: ModbusFramer
self.request_tasks: set[asyncio.Task] = set()
self.loop = asyncio.get_running_loop()

def _log_exception(self):
"""Show log exception."""
Expand Down Expand Up @@ -173,17 +175,28 @@ def execute(self, request, *addr):
if self.server.request_tracer:
self.server.request_tracer(request, *addr)

task = self.loop.create_task(self._async_execute(request, *addr))
self.request_tasks.add(task)
task.add_done_callback(self.request_tasks.discard)

async def _async_execute(self, request, *addr):
broadcast = False
response = None
try:
if self.server.broadcast_enable and not request.slave_id:
broadcast = True
# if broadcasting then execute on all slave contexts,
# note response will be ignored
for slave_id in self.server.context.slaves():
response = request.execute(self.server.context[slave_id])
if asyncio.iscoroutine(response):
response = await response
else:
context = self.server.context[request.slave_id]
response = request.execute(context)
if asyncio.iscoroutine(response):
response = await response

except NoSuchSlaveException:
Log.error("requested slave does not exist: {}", request.slave_id)
if self.server.ignore_missing_slaves:
Expand All @@ -196,8 +209,9 @@ def execute(self, request, *addr):
traceback.format_exc(),
)
response = request.doException(merror.SlaveFailure)

# no response when broadcasting
if not broadcast:
if not broadcast and response is not None:
response.transaction_id = request.transaction_id
response.slave_id = request.slave_id
skip_encoding = False
Expand Down Expand Up @@ -305,6 +319,7 @@ def callback_data(self, data: bytes, addr: tuple | None = None) -> int:
Log.debug("callback_data called: {} addr={}", data, ":hex", addr)
return 0


class ModbusTcpServer(ModbusBaseServer):
"""A modbus threaded tcp socket server.
Expand Down

0 comments on commit 170fa30

Please sign in to comment.