Skip to content

Commit

Permalink
Merge branch 'main' into send-initial-statistics
Browse files Browse the repository at this point in the history
  • Loading branch information
bitterpanda63 authored Aug 22, 2024
2 parents a3e65ab + 9bcdeaf commit cdf51f3
Show file tree
Hide file tree
Showing 57 changed files with 1,060 additions and 59 deletions.
53 changes: 53 additions & 0 deletions .github/workflows/end2end.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
name: Run End-2-End Tests

on: [pull_request]

jobs:
test:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v2

- name: Start django-mysql
working-directory: ./sample-apps/django-mysql
run: |
docker compose -f docker-compose.yml -f docker-compose.benchmark.yml up --build -d
- name: Start django-mysql-gunicorn
working-directory: ./sample-apps/django-mysql-gunicorn
run: |
docker compose -f docker-compose.yml -f docker-compose.benchmark.yml up --build -d
- name: Start flask-mongo
working-directory: ./sample-apps/flask-mongo
run: |
docker compose -f docker-compose.yml -f docker-compose.benchmark.yml up --build -d
- name: Start flask-mysql
working-directory: ./sample-apps/flask-mysql
run: |
docker compose -f docker-compose.yml -f docker-compose.benchmark.yml up --build -d
- name: Start flask-mysql-uwsgi
working-directory: ./sample-apps/flask-mysql-uwsgi
run: |
docker compose -f docker-compose.yml -f docker-compose.benchmark.yml up --build -d
- name: Start flask-postgres
working-directory: ./sample-apps/flask-postgres
run: |
docker compose -f docker-compose.yml -f docker-compose.benchmark.yml up --build -d
- name: Start flask-postgres-xml
working-directory: ./sample-apps/flask-postgres-xml
run: |
docker compose -f docker-compose.yml -f docker-compose.benchmark.yml up --build -d
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies
run: |
python -m pip install --upgrade pip
make install
- name: Run end2end tests
run: |
make end2end
8 changes: 6 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@ install:

.PHONY: test
test:
poetry run pytest
poetry run pytest aikido_firewall/

.PHONY: end2end
end2end:
poetry run pytest end2end/

.PHONY: cov
cov:
poetry run pytest --cov=aikido_firewall --cov-report=xml
poetry run pytest aikido_firewall/ --cov=aikido_firewall --cov-report=xml
2 changes: 2 additions & 0 deletions aikido_firewall/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def protect(module="any", server=True):
# Import sources
import aikido_firewall.sources.django
import aikido_firewall.sources.flask
import aikido_firewall.sources.xml
import aikido_firewall.sources.lxml

import aikido_firewall.sources.gunicorn
import aikido_firewall.sources.uwsgi
Expand Down
10 changes: 7 additions & 3 deletions aikido_firewall/background_process/aikido_background_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ def send_to_reporter(self, event_scheduler):
)
logger.debug("Checking queue")
while not self.queue.empty():
attack = self.queue.get()
logger.debug("Reporting attack : %s", attack)
self.reporter.on_detected_attack(attack[0], attack[1])
queue_attack_item = self.queue.get()
self.reporter.on_detected_attack(
attack=queue_attack_item[0],
context=queue_attack_item[1],
blocked=queue_attack_item[2],
stack=queue_attack_item[3],
)
2 changes: 1 addition & 1 deletion aikido_firewall/background_process/commands/attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
def process_attack(bg_process, data, conn):
"""
Adds ATTACK data object to queue
Expected data object : [injection_results, context, blocked_or_not]
Expected data object : [injection_results, context, blocked_or_not, stacktrace]
"""
bg_process.queue.put(data)
if bg_process.reporter.statistics:
Expand Down
4 changes: 2 additions & 2 deletions aikido_firewall/background_process/reporter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ def report_initial_stats(self):
if should_report_initial_stats:
self.send_heartbeat()

Check warning on line 78 in aikido_firewall/background_process/reporter/__init__.py

View check run for this annotation

Codecov / codecov/patch

aikido_firewall/background_process/reporter/__init__.py#L77-L78

Added lines #L77 - L78 were not covered by tests

def on_detected_attack(self, attack, context):
def on_detected_attack(self, attack, context, blocked, stack):
"""This will send something to the API when an attack is detected"""
return on_detected_attack(self, attack, context)
return on_detected_attack(self, attack, context, blocked, stack)

def on_start(self):
"""This will send out an Event signalling the start to the server"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from aikido_firewall.helpers.get_ua_from_context import get_ua_from_context


def on_detected_attack(reporter, attack, context):
def on_detected_attack(reporter, attack, context, blocked, stack):
"""
This will send something to the API when an attack is detected
"""
Expand All @@ -18,7 +18,8 @@ def on_detected_attack(reporter, attack, context):
attack["user"] = None
attack["payload"] = json.dumps(attack["payload"])[:4096]
attack["metadata"] = limit_length_metadata(attack["metadata"], 4096)
attack["blocked"] = reporter.block
attack["blocked"] = blocked
attack["stack"] = stack

payload = {
"type": "detected_attack",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class Context:
def test_on_detected_attack_no_token(mock_context):
reporter = MagicMock()
reporter.token = None
on_detected_attack(reporter, {}, mock_context)
on_detected_attack(reporter, {}, mock_context, blocked=False, stack=None)
reporter.api.report.assert_not_called()


Expand All @@ -42,7 +42,7 @@ def test_on_detected_attack_with_long_payload(mock_reporter, mock_context):
"metadata": {"test": "1"},
}

on_detected_attack(mock_reporter, attack, mock_context)
on_detected_attack(mock_reporter, attack, mock_context, blocked=False, stack=None)
assert len(attack["payload"]) == 4096 # Ensure payload is truncated
mock_reporter.api.report.assert_called_once()

Expand All @@ -54,7 +54,7 @@ def test_on_detected_attack_with_long_metadata(mock_reporter, mock_context):
"metadata": {"test": long_metadata},
}

on_detected_attack(mock_reporter, attack, mock_context)
on_detected_attack(mock_reporter, attack, mock_context, blocked=False, stack=None)

assert (
attack["metadata"]["test"] == long_metadata[:4096]
Expand All @@ -68,7 +68,7 @@ def test_on_detected_attack_success(mock_reporter, mock_context):
"metadata": {},
}

on_detected_attack(mock_reporter, attack, mock_context)
on_detected_attack(mock_reporter, attack, mock_context, blocked=False, stack=None)
assert mock_reporter.api.report.call_count == 1


Expand All @@ -81,6 +81,24 @@ def test_on_detected_attack_exception_handling(mock_reporter, mock_context, capl
# Simulate an exception during the API call
mock_reporter.api.report.side_effect = Exception("API error")

on_detected_attack(mock_reporter, attack, mock_context)
on_detected_attack(mock_reporter, attack, mock_context, blocked=False, stack=None)

assert "Failed to report attack" in caplog.text


def test_on_detected_attack_with_blocked_and_stack(mock_reporter, mock_context):
attack = {
"payload": {"key": "value"},
"metadata": {},
}
blocked = True
stack = "sample stack trace"

on_detected_attack(
mock_reporter, attack, mock_context, blocked=blocked, stack=stack
)

# Check that the attack dictionary has the blocked and stack fields set
assert attack["blocked"] is True
assert attack["stack"] == stack
assert mock_reporter.api.report.call_count == 1
4 changes: 3 additions & 1 deletion aikido_firewall/context/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .build_url_from_wsgi import build_url_from_wsgi
from .parse_raw_body import parse_raw_body

UINPUT_SOURCES = ["body", "cookies", "query", "headers"]
UINPUT_SOURCES = ["body", "cookies", "query", "headers", "xml"]
local = threading.local()


Expand Down Expand Up @@ -54,6 +54,7 @@ def __init__(self, context_obj=None, req=None, raw_body=None, source=None):
self.user = None
self.remote_address = get_ip_from_request(req["REMOTE_ADDR"], self.headers)
self.parsed_userinput = {}
self.xml = {}

def __reduce__(self):
return (
Expand All @@ -71,6 +72,7 @@ def __reduce__(self):
"route": self.route,
"subdomains": self.subdomains,
"user": self.user,
"xml": self.xml,
},
None,
None,
Expand Down
2 changes: 2 additions & 0 deletions aikido_firewall/context/init_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def test_wsgi_context_1():
"user": None,
"remote_address": "198.51.100.23",
"parsed_userinput": {},
"xml": {},
}


Expand Down Expand Up @@ -80,6 +81,7 @@ def test_wsgi_context_2():
"user": None,
"remote_address": "198.51.100.23",
"parsed_userinput": {},
"xml": {},
}


Expand Down
1 change: 1 addition & 0 deletions aikido_firewall/context/parse_raw_body.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ def parse_raw_body(raw_body, content_type):
return parsed_body
except Exception as e:
logger.debug("Exception in parse_raw_body : %s", e)
return raw_body
20 changes: 20 additions & 0 deletions aikido_firewall/helpers/extract_data_from_xml_body.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""Exports extract_data_from_xml_body helper function"""

import aikido_firewall.context as ctx


def extract_data_from_xml_body(user_input, root_element):
"""Extracts all attributes from the xml and adds them to context"""
context = ctx.get_current_context()
if not isinstance(context.body, str) or user_input != context.body:
return

extracted_xml_attrs = context.xml
for el in root_element:
print(extracted_xml_attrs)
for k, v in el.items():
print("Key : %s, Value : %s", k, v)
if not extracted_xml_attrs.get(k):
extracted_xml_attrs[k] = set()
extracted_xml_attrs[k].add(v)
context.set_as_current_context()
110 changes: 110 additions & 0 deletions aikido_firewall/helpers/extract_data_from_xml_body_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import pytest
from unittest.mock import MagicMock
import aikido_firewall.context as ctx
from .extract_data_from_xml_body import (
extract_data_from_xml_body,
) # Replace 'your_module' with the actual module name


@pytest.fixture
def mock_context():
"""Fixture to mock the context."""
mock_ctx = MagicMock()
mock_ctx.body = "valid_input"
mock_ctx.xml = {} # Initialize with an empty dictionary
ctx.get_current_context = MagicMock(return_value=mock_ctx)
return mock_ctx


def test_extract_data_from_xml_body_valid_input(mock_context):
"""Test with valid user input and root_element."""
user_input = "valid_input"
root_element = [
{"attr1": "value1", "attr2": "value2"},
{"attr1": "value3", "attr3": "value4"},
]

extract_data_from_xml_body(user_input, root_element)

assert mock_context.xml == {
"attr1": {"value1", "value3"},
"attr2": {"value2"},
"attr3": {"value4"},
}


def test_extract_data_from_xml_body_invalid_user_input(mock_context):
"""Test with invalid user input."""
user_input = "invalid_input"
root_element = [{"attr1": "value1"}]

extract_data_from_xml_body(user_input, root_element)

assert mock_context.xml == {}


def test_extract_data_from_xml_body_empty_root_element(mock_context):
"""Test with an empty root_element."""
user_input = "valid_input"
root_element = []

extract_data_from_xml_body(user_input, root_element)

assert mock_context.xml == {}


def test_extract_data_from_xml_body_non_string_context_body(mock_context):
"""Test with non-string context body."""
mock_context.body = 123 # Set body to a non-string value
user_input = "valid_input"
root_element = [{"attr1": "value1"}]

extract_data_from_xml_body(user_input, root_element)

assert mock_context.xml == {}


def test_extract_data_from_xml_body_multiple_calls(mock_context):
"""Test multiple calls with the same user input."""
user_input = "valid_input"
root_element1 = [{"attr1": "value1"}]
root_element2 = [{"attr1": "value2"}]

extract_data_from_xml_body(user_input, root_element1)
extract_data_from_xml_body(user_input, root_element2)

assert mock_context.xml == {"attr1": {"value1", "value2"}}


def test_extract_data_from_xml_body_duplicate_attributes(mock_context):
"""Test with duplicate attributes in root_element."""
user_input = "valid_input"
root_element = [
{"attr1": "value1"},
{"attr1": "value1"}, # Duplicate
{"attr2": "value2"},
]

extract_data_from_xml_body(user_input, root_element)

assert mock_context.xml == {"attr1": {"value1"}, "attr2": {"value2"}}


def test_extract_data_from_xml_body_no_attributes(mock_context):
"""Test with elements that have no attributes."""
user_input = "valid_input"
root_element = [{}]

extract_data_from_xml_body(user_input, root_element)

assert mock_context.xml == {}


def test_extract_data_from_xml_body_context_set_as_current(mock_context):
"""Test if context.set_as_current_context is called."""
user_input = "valid_input"
root_element = [{"attr1": "value1"}]

extract_data_from_xml_body(user_input, root_element)

mock_context.set_as_current_context.assert_called_once()
3 changes: 2 additions & 1 deletion aikido_firewall/helpers/extract_strings_from_user_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@


def extract_strings_from_user_input_cached(obj, source):
"""Use the cache to speed up getting user input"""
context = get_current_context()
if context.parsed_userinput and context.parsed_userinput.get(source):
return context.parsed_userinput.get(source)
Expand Down Expand Up @@ -37,7 +38,7 @@ def extract_strings_from_user_input(obj, path_to_payload=None):
).items():
results[k] = v

if isinstance(obj, list):
if isinstance(obj, (set, list, tuple)):
# Add the stringified array as well to the results, there might
# be accidental concatenation if the client expects a string but gets the array
# E.g. HTTP Parameter pollution
Expand Down
Loading

0 comments on commit cdf51f3

Please sign in to comment.