Skip to content

Commit

Permalink
Merge branch 'main' into send-initial-statistics
Browse files Browse the repository at this point in the history
  • Loading branch information
bitterpanda63 authored Aug 21, 2024
2 parents d6d1f88 + 4c5a293 commit ce381ee
Show file tree
Hide file tree
Showing 8 changed files with 160 additions and 59 deletions.
22 changes: 7 additions & 15 deletions aikido_firewall/background_process/heartbeats.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

from aikido_firewall.helpers.logging import logger
from aikido_firewall.helpers.create_interval import create_interval


def send_heartbeats_every_x_secs(reporter, interval_in_secs, event_scheduler):
Expand All @@ -18,19 +19,10 @@ def send_heartbeats_every_x_secs(reporter, interval_in_secs, event_scheduler):

logger.debug("Starting heartbeats")

# Start the interval by booting the first settimeout
send_heartbeat_wrapper(reporter, interval_in_secs, event_scheduler)


def send_heartbeat_wrapper(rep, interval_in_secs, event_scheduler):
"""
Wrapper function for send_heartbeat so we get an interval
"""
event_scheduler.enter(
interval_in_secs,
1,
send_heartbeat_wrapper,
(rep, interval_in_secs, event_scheduler),
# Create an interval for "interval_in_secs" seconds :
create_interval(
event_scheduler=event_scheduler,
interval_in_secs=interval_in_secs,
function=lambda reporter: reporter.send_heartbeat(),
args=(reporter,),
)
logger.debug("Heartbeat...")
rep.send_heartbeat()
18 changes: 1 addition & 17 deletions aikido_firewall/background_process/heartbeats_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import pytest
from unittest.mock import Mock, patch
from aikido_firewall.background_process.heartbeats import (
send_heartbeats_every_x_secs,
send_heartbeat_wrapper,
)
from aikido_firewall.background_process.heartbeats import send_heartbeats_every_x_secs


def test_send_heartbeats_serverless():
Expand Down Expand Up @@ -42,16 +39,3 @@ def test_send_heartbeats_success():

with patch("aikido_firewall.helpers.logging.logger.debug") as mock_debug:
send_heartbeats_every_x_secs(reporter, 5, event_scheduler)


def test_send_heartbeat_wrapper():
reporter = Mock()
reporter.send_heartbeat = Mock()
event_scheduler = Mock()

send_heartbeat_wrapper(reporter, 5, event_scheduler)

reporter.send_heartbeat.assert_called_once()
event_scheduler.enter.assert_called_once_with(
5, 1, send_heartbeat_wrapper, (reporter, 5, event_scheduler)
)
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,28 @@
POLL_FOR_CONFIG_CHANGES_INTERVAL = 60 # Poll for config changes every 60 seconds


def start_polling_for_changes(on_config_update, serverless, token, event_scheduler):
def start_polling_for_changes(reporter, event_scheduler):
"""
Arguments :
- on_config_update : A function that will run with the new config if changed
- serverless, token : Attributes from the reporter
- last_updated_at : The last time the config was updated (unixtime in ms)
This function will check if the config was updated or not
"""
if not isinstance(token, Token):
if not isinstance(reporter.token, Token):
logger.info("No token provided, not polling for config updates")
return
if serverless:
if reporter.serverless:
logger.info("Running in serverless environment, not polling for config updates")
return

# Start the interval by booting the first settimeout
poll_for_changes(on_config_update, token, 0, event_scheduler)
poll_for_changes(
on_config_update=reporter.update_service_config,
token=reporter.token,
former_last_updated=reporter.conf.last_updated_at,
event_scheduler=event_scheduler,
)


def poll_for_changes(on_config_update, token, former_last_updated, event_scheduler):
Expand All @@ -34,14 +39,15 @@ def poll_for_changes(on_config_update, token, former_last_updated, event_schedul
"""
# If something went wrong, or we don't know when the config was
# last updated, set to prev value
config_last_updated_at = former_last_updated
last_updated = former_last_updated
try:
config_last_updated_at = realtime.get_config_last_updated_at(token)
if (
isinstance(former_last_updated, int)
and config_last_updated_at > former_last_updated
):
last_updated = realtime.get_config_last_updated_at(token)
config_changed = (
isinstance(former_last_updated, int) and last_updated > former_last_updated
)
if config_changed:
# The config changed
logger.debug("According to realtime: Config changed")
config = realtime.get_config(token)
on_config_update({**config, "success": True})
except Exception as e:
Expand All @@ -53,5 +59,5 @@ def poll_for_changes(on_config_update, token, former_last_updated, event_schedul
POLL_FOR_CONFIG_CHANGES_INTERVAL,
1,
poll_for_changes,
(on_config_update, token, config_last_updated_at, event_scheduler),
(on_config_update, token, last_updated, event_scheduler),
)
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,15 @@ def enter(self, delay, priority, action, argument):
self.events.append((delay, priority, action, argument))

def run(self):
self.events[0][2](*self.events[0][3])
if self.events:
self.events[0][2](*self.events[0][3])

return EventScheduler()


def test_no_token(event_scheduler, caplog):
start_polling_for_changes(
on_config_update=lambda config: pytest.fail("Should not be called"),
serverless=None,
token=None,
reporter=MagicMock(token=None, serverless=None),
event_scheduler=event_scheduler,
)

Expand All @@ -32,9 +31,7 @@ def test_no_token(event_scheduler, caplog):

def test_serverless_environment(event_scheduler, caplog):
start_polling_for_changes(
on_config_update=lambda config: pytest.fail("Should not be called"),
serverless=True,
token=Token("123"),
reporter=MagicMock(token=Token("123"), serverless=True),
event_scheduler=event_scheduler,
)

Expand Down Expand Up @@ -64,10 +61,15 @@ def mock_get_config(token):

config_updates = []

start_polling_for_changes(
on_config_update=lambda config: config_updates.append(config),
serverless=None,
reporter = MagicMock(
update_service_config=lambda config: config_updates.append(config),
token=Token("123"),
conf=MagicMock(last_updated_at=config_updated_at),
serverless=None,
)

start_polling_for_changes(
reporter=reporter,
event_scheduler=event_scheduler,
)

Expand All @@ -78,6 +80,7 @@ def mock_get_config(token):

# Simulate a config update
config_updated_at = 1
reporter.conf.last_updated_at = config_updated_at
event_scheduler.run()

assert config_updates == [
Expand All @@ -99,10 +102,15 @@ def mock_get_config_last_updated_at(token):
):
config_updates = []

start_polling_for_changes(
on_config_update=lambda config: config_updates.append(config),
reporter = MagicMock(
update_service_config=lambda config: config_updates.append(config),
token=Token("123"),
conf=MagicMock(last_updated_at=0),
serverless=None,
)

start_polling_for_changes(
reporter=reporter,
event_scheduler=event_scheduler,
)

Expand Down
12 changes: 8 additions & 4 deletions aikido_firewall/background_process/reporter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from aikido_firewall.background_process.heartbeats import send_heartbeats_every_x_secs
from aikido_firewall.background_process.routes import Routes
from aikido_firewall.ratelimiting.rate_limiter import RateLimiter
from aikido_firewall.helpers.logging import logger
from ..service_config import ServiceConfig
from ..users import Users
from ..hostnames import Hostnames
Expand Down Expand Up @@ -49,12 +50,15 @@ def __init__(self, block, api, token, serverless):

def start(self, event_scheduler):
"""Send out start event and add heartbeats"""
self.on_start()
res = self.on_start()
if res.get("error", None) == "invalid_token":
logger.info(
"Token was invalid, not starting heartbeats and realtime polling."
)
return
event_scheduler.enter(self.initial_stats_timeout, 1, self.report_initial_stats)

Check warning on line 59 in aikido_firewall/background_process/reporter/__init__.py

View check run for this annotation

Codecov / codecov/patch

aikido_firewall/background_process/reporter/__init__.py#L59

Added line #L59 was not covered by tests
send_heartbeats_every_x_secs(self, self.heartbeat_secs, event_scheduler)
start_polling_for_changes(
self.update_service_config, self.serverless, self.token, event_scheduler
)
start_polling_for_changes(self, event_scheduler)

def report_initial_stats(self):
"""
Expand Down
1 change: 1 addition & 0 deletions aikido_firewall/background_process/reporter/on_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ def on_start(reporter):
else:
reporter.update_service_config(res)
logger.info("Established connection with Aikido Server")
return res
31 changes: 31 additions & 0 deletions aikido_firewall/helpers/create_interval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Exports create_interval"""


def create_interval(event_scheduler, interval_in_secs, function, args):
"""
This function creates an interval which first runs "function"
after waiting "interval_in_secs" seconds and after that keeps
executing the function every "interval_in_secs" seconds.
"""
# Sleep interval_in_secs seconds before starting the loop :
event_scheduler.enter(
interval_in_secs,
1,
interval_loop,
(event_scheduler, interval_in_secs, function, args),
)


def interval_loop(event_scheduler, interval_in_secs, function, args):
"""
This is the actual interval loop which executes and schedules the function
"""
# Execute function :
function(*args)
# Schedule the execution of the function in interval_in_secs seconds :
event_scheduler.enter(
interval_in_secs,
1,
interval_loop,
(event_scheduler, interval_in_secs, function, args),
)
75 changes: 75 additions & 0 deletions aikido_firewall/helpers/create_interval_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import pytest
from .create_interval import create_interval, interval_loop
from unittest.mock import Mock, call


def test_create_interval_schedules_function():
# Arrange
event_scheduler = Mock()
function = Mock()
args = (1, 2)
interval_in_secs = 5

# Act
create_interval(event_scheduler, interval_in_secs, function, args)

# Assert
event_scheduler.enter.assert_called_once_with(
interval_in_secs,
1,
interval_loop,
(event_scheduler, interval_in_secs, function, args),
)


def test_interval_loop_executes_function_and_reschedules():
# Arrange
event_scheduler = Mock()
function = Mock()
args = (1, 2)
interval_in_secs = 5

# Act
interval_loop(event_scheduler, interval_in_secs, function, args)

# Assert
function.assert_called_once_with(*args)
event_scheduler.enter.assert_called_once_with(
interval_in_secs,
1,
interval_loop,
(event_scheduler, interval_in_secs, function, args),
)


def test_multiple_calls_to_create_interval():
# Arrange
event_scheduler = Mock()
function = Mock()
args = (1, 2)
interval_in_secs = 5

# Act
create_interval(event_scheduler, interval_in_secs, function, args)
create_interval(event_scheduler, interval_in_secs, function, args)

# Assert
assert event_scheduler.enter.call_count == 2
assert event_scheduler.enter.call_args_list == [
call(
interval_in_secs,
1,
interval_loop,
(event_scheduler, interval_in_secs, function, args),
),
call(
interval_in_secs,
1,
interval_loop,
(event_scheduler, interval_in_secs, function, args),
),
]


if __name__ == "__main__":
pytest.main()

0 comments on commit ce381ee

Please sign in to comment.