Skip to content

Commit

Permalink
Merge pull request #97 from AikidoSec/bugfix-multiple-reqs-on-start
Browse files Browse the repository at this point in the history
Bugfix: remove multiple reqs on start
  • Loading branch information
willem-delbare authored Aug 21, 2024
2 parents 95b3d93 + f313d6d commit 4c5a293
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 58 deletions.
22 changes: 7 additions & 15 deletions aikido_firewall/background_process/heartbeats.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

from aikido_firewall.helpers.logging import logger
from aikido_firewall.helpers.create_interval import create_interval


def send_heartbeats_every_x_secs(reporter, interval_in_secs, event_scheduler):
Expand All @@ -18,19 +19,10 @@ def send_heartbeats_every_x_secs(reporter, interval_in_secs, event_scheduler):

logger.debug("Starting heartbeats")

# Start the interval by booting the first settimeout
send_heartbeat_wrapper(reporter, interval_in_secs, event_scheduler)


def send_heartbeat_wrapper(rep, interval_in_secs, event_scheduler):
"""
Wrapper function for send_heartbeat so we get an interval
"""
event_scheduler.enter(
interval_in_secs,
1,
send_heartbeat_wrapper,
(rep, interval_in_secs, event_scheduler),
# Create an interval for "interval_in_secs" seconds :
create_interval(
event_scheduler=event_scheduler,
interval_in_secs=interval_in_secs,
function=lambda reporter: reporter.send_heartbeat(),
args=(reporter,),
)
logger.debug("Heartbeat...")
rep.send_heartbeat()
18 changes: 1 addition & 17 deletions aikido_firewall/background_process/heartbeats_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import pytest
from unittest.mock import Mock, patch
from aikido_firewall.background_process.heartbeats import (
send_heartbeats_every_x_secs,
send_heartbeat_wrapper,
)
from aikido_firewall.background_process.heartbeats import send_heartbeats_every_x_secs


def test_send_heartbeats_serverless():
Expand Down Expand Up @@ -42,16 +39,3 @@ def test_send_heartbeats_success():

with patch("aikido_firewall.helpers.logging.logger.debug") as mock_debug:
send_heartbeats_every_x_secs(reporter, 5, event_scheduler)


def test_send_heartbeat_wrapper():
reporter = Mock()
reporter.send_heartbeat = Mock()
event_scheduler = Mock()

send_heartbeat_wrapper(reporter, 5, event_scheduler)

reporter.send_heartbeat.assert_called_once()
event_scheduler.enter.assert_called_once_with(
5, 1, send_heartbeat_wrapper, (reporter, 5, event_scheduler)
)
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,28 @@
POLL_FOR_CONFIG_CHANGES_INTERVAL = 60 # Poll for config changes every 60 seconds


def start_polling_for_changes(on_config_update, serverless, token, event_scheduler):
def start_polling_for_changes(reporter, event_scheduler):
"""
Arguments :
- on_config_update : A function that will run with the new config if changed
- serverless, token : Attributes from the reporter
- last_updated_at : The last time the config was updated (unixtime in ms)
This function will check if the config was updated or not
"""
if not isinstance(token, Token):
if not isinstance(reporter.token, Token):
logger.info("No token provided, not polling for config updates")
return
if serverless:
if reporter.serverless:
logger.info("Running in serverless environment, not polling for config updates")
return

# Start the interval by booting the first settimeout
poll_for_changes(on_config_update, token, 0, event_scheduler)
poll_for_changes(
on_config_update=reporter.update_service_config,
token=reporter.token,
former_last_updated=reporter.conf.last_updated_at,
event_scheduler=event_scheduler,
)


def poll_for_changes(on_config_update, token, former_last_updated, event_scheduler):
Expand All @@ -34,14 +39,15 @@ def poll_for_changes(on_config_update, token, former_last_updated, event_schedul
"""
# If something went wrong, or we don't know when the config was
# last updated, set to prev value
config_last_updated_at = former_last_updated
last_updated = former_last_updated
try:
config_last_updated_at = realtime.get_config_last_updated_at(token)
if (
isinstance(former_last_updated, int)
and config_last_updated_at > former_last_updated
):
last_updated = realtime.get_config_last_updated_at(token)
config_changed = (
isinstance(former_last_updated, int) and last_updated > former_last_updated
)
if config_changed:
# The config changed
logger.debug("According to realtime: Config changed")
config = realtime.get_config(token)
on_config_update({**config, "success": True})
except Exception as e:
Expand All @@ -53,5 +59,5 @@ def poll_for_changes(on_config_update, token, former_last_updated, event_schedul
POLL_FOR_CONFIG_CHANGES_INTERVAL,
1,
poll_for_changes,
(on_config_update, token, config_last_updated_at, event_scheduler),
(on_config_update, token, last_updated, event_scheduler),
)
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,15 @@ def enter(self, delay, priority, action, argument):
self.events.append((delay, priority, action, argument))

def run(self):
self.events[0][2](*self.events[0][3])
if self.events:
self.events[0][2](*self.events[0][3])

return EventScheduler()


def test_no_token(event_scheduler, caplog):
start_polling_for_changes(
on_config_update=lambda config: pytest.fail("Should not be called"),
serverless=None,
token=None,
reporter=MagicMock(token=None, serverless=None),
event_scheduler=event_scheduler,
)

Expand All @@ -32,9 +31,7 @@ def test_no_token(event_scheduler, caplog):

def test_serverless_environment(event_scheduler, caplog):
start_polling_for_changes(
on_config_update=lambda config: pytest.fail("Should not be called"),
serverless=True,
token=Token("123"),
reporter=MagicMock(token=Token("123"), serverless=True),
event_scheduler=event_scheduler,
)

Expand Down Expand Up @@ -64,10 +61,15 @@ def mock_get_config(token):

config_updates = []

start_polling_for_changes(
on_config_update=lambda config: config_updates.append(config),
serverless=None,
reporter = MagicMock(
update_service_config=lambda config: config_updates.append(config),
token=Token("123"),
conf=MagicMock(last_updated_at=config_updated_at),
serverless=None,
)

start_polling_for_changes(
reporter=reporter,
event_scheduler=event_scheduler,
)

Expand All @@ -78,6 +80,7 @@ def mock_get_config(token):

# Simulate a config update
config_updated_at = 1
reporter.conf.last_updated_at = config_updated_at
event_scheduler.run()

assert config_updates == [
Expand All @@ -99,10 +102,15 @@ def mock_get_config_last_updated_at(token):
):
config_updates = []

start_polling_for_changes(
on_config_update=lambda config: config_updates.append(config),
reporter = MagicMock(
update_service_config=lambda config: config_updates.append(config),
token=Token("123"),
conf=MagicMock(last_updated_at=0),
serverless=None,
)

start_polling_for_changes(
reporter=reporter,
event_scheduler=event_scheduler,
)

Expand Down
4 changes: 1 addition & 3 deletions aikido_firewall/background_process/reporter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,7 @@ def start(self, event_scheduler):
)
return
send_heartbeats_every_x_secs(self, self.heartbeat_secs, event_scheduler)
start_polling_for_changes(
self.update_service_config, self.serverless, self.token, event_scheduler
)
start_polling_for_changes(self, event_scheduler)

def on_detected_attack(self, attack, context):
"""This will send something to the API when an attack is detected"""
Expand Down
31 changes: 31 additions & 0 deletions aikido_firewall/helpers/create_interval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Exports create_interval"""


def create_interval(event_scheduler, interval_in_secs, function, args):
"""
This function creates an interval which first runs "function"
after waiting "interval_in_secs" seconds and after that keeps
executing the function every "interval_in_secs" seconds.
"""
# Sleep interval_in_secs seconds before starting the loop :
event_scheduler.enter(
interval_in_secs,
1,
interval_loop,
(event_scheduler, interval_in_secs, function, args),
)


def interval_loop(event_scheduler, interval_in_secs, function, args):
"""
This is the actual interval loop which executes and schedules the function
"""
# Execute function :
function(*args)
# Schedule the execution of the function in interval_in_secs seconds :
event_scheduler.enter(
interval_in_secs,
1,
interval_loop,
(event_scheduler, interval_in_secs, function, args),
)
75 changes: 75 additions & 0 deletions aikido_firewall/helpers/create_interval_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import pytest
from .create_interval import create_interval, interval_loop
from unittest.mock import Mock, call


def test_create_interval_schedules_function():
# Arrange
event_scheduler = Mock()
function = Mock()
args = (1, 2)
interval_in_secs = 5

# Act
create_interval(event_scheduler, interval_in_secs, function, args)

# Assert
event_scheduler.enter.assert_called_once_with(
interval_in_secs,
1,
interval_loop,
(event_scheduler, interval_in_secs, function, args),
)


def test_interval_loop_executes_function_and_reschedules():
# Arrange
event_scheduler = Mock()
function = Mock()
args = (1, 2)
interval_in_secs = 5

# Act
interval_loop(event_scheduler, interval_in_secs, function, args)

# Assert
function.assert_called_once_with(*args)
event_scheduler.enter.assert_called_once_with(
interval_in_secs,
1,
interval_loop,
(event_scheduler, interval_in_secs, function, args),
)


def test_multiple_calls_to_create_interval():
# Arrange
event_scheduler = Mock()
function = Mock()
args = (1, 2)
interval_in_secs = 5

# Act
create_interval(event_scheduler, interval_in_secs, function, args)
create_interval(event_scheduler, interval_in_secs, function, args)

# Assert
assert event_scheduler.enter.call_count == 2
assert event_scheduler.enter.call_args_list == [
call(
interval_in_secs,
1,
interval_loop,
(event_scheduler, interval_in_secs, function, args),
),
call(
interval_in_secs,
1,
interval_loop,
(event_scheduler, interval_in_secs, function, args),
),
]


if __name__ == "__main__":
pytest.main()

0 comments on commit 4c5a293

Please sign in to comment.