Skip to content

Commit

Permalink
Merge branch 'main' into publish-to-pypi
Browse files Browse the repository at this point in the history
  • Loading branch information
Wout Feys committed Aug 22, 2024
2 parents f16d706 + ef3ab26 commit 78cc80b
Show file tree
Hide file tree
Showing 19 changed files with 352 additions and 52 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/end2end.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ jobs:
working-directory: ./sample-apps/flask-postgres
run: |
docker compose -f docker-compose.yml -f docker-compose.benchmark.yml up --build -d
- name: Start flask-postgres-xml
working-directory: ./sample-apps/flask-postgres-xml
run: |
docker compose -f docker-compose.yml -f docker-compose.benchmark.yml up --build -d
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@ cov:
poetry run pytest aikido_firewall/ --cov=aikido_firewall --cov-report=xml

.PHONY: benchmark
e2e:
benchmark:
k6 run -q ./benchmarks/flask-mysql-benchmarks.js
2 changes: 2 additions & 0 deletions aikido_firewall/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def protect(module="any", server=True):
# Import sources
import aikido_firewall.sources.django
import aikido_firewall.sources.flask
import aikido_firewall.sources.xml
import aikido_firewall.sources.lxml

import aikido_firewall.sources.gunicorn
import aikido_firewall.sources.uwsgi
Expand Down
11 changes: 9 additions & 2 deletions aikido_firewall/background_process/reporter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,21 @@ class Reporter:

timeout_in_sec = 5 # Timeout of API calls to Aikido Server
heartbeat_secs = 600 # Heartbeat every 10 minutes
initial_stats_timeout = 60 # Wait 60 seconds after startup for initial stats

def __init__(self, block, api, token, serverless):
self.block = block
self.api = api
self.token = token # Should be instance of the Token class!
self.routes = Routes(200)
self.hostnames = Hostnames(200)
self.conf = ServiceConfig([], get_unixtime_ms(), [], [], True)
self.conf = ServiceConfig(
endpoints=[],
last_updated_at=get_unixtime_ms(),
blocked_uids=[],
bypassed_ips=[],
received_any_stats=True,
)
self.rate_limiter = RateLimiter(
max_items=5000, time_to_live_in_ms=120 * 60 * 1000 # 120 minutes
)
Expand All @@ -55,7 +62,7 @@ def start(self, event_scheduler):
"Token was invalid, not starting heartbeats and realtime polling."
)
return
event_scheduler.enter(60, 1, self.report_initial_stats)
event_scheduler.enter(self.initial_stats_timeout, 1, self.report_initial_stats)
send_heartbeats_every_x_secs(self, self.heartbeat_secs, event_scheduler)
start_polling_for_changes(self, event_scheduler)

Expand Down
2 changes: 1 addition & 1 deletion aikido_firewall/background_process/service_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
?
Exports ServiceConfig class
"""

from aikido_firewall.helpers.match_endpoint import match_endpoint
Expand Down
9 changes: 5 additions & 4 deletions aikido_firewall/context/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
from .parse_cookies import parse_cookies
from .extract_wsgi_headers import extract_wsgi_headers
from .build_url_from_wsgi import build_url_from_wsgi
from .parse_raw_body import parse_raw_body

UINPUT_SOURCES = ["body", "cookies", "query", "headers"]
UINPUT_SOURCES = ["body", "cookies", "query", "headers", "xml"]
local = threading.local()


Expand All @@ -32,7 +31,7 @@ class Context:
for vulnerability detection
"""

def __init__(self, context_obj=None, req=None, raw_body=None, source=None):
def __init__(self, context_obj=None, body=None, req=None, source=None):
if context_obj:
logger.debug("Creating Context instance based on dict object.")
self.__dict__.update(context_obj)
Expand All @@ -48,12 +47,13 @@ def __init__(self, context_obj=None, req=None, raw_body=None, source=None):
self.url = build_url_from_wsgi(req)
self.query = parse_qs(req["QUERY_STRING"])
content_type = req.get("CONTENT_TYPE", None)
self.body = parse_raw_body(raw_body, content_type)
self.body = body
self.route = build_route_from_url(self.url)
self.subdomains = get_subdomains_from_url(self.url)
self.user = None
self.remote_address = get_ip_from_request(req["REMOTE_ADDR"], self.headers)
self.parsed_userinput = {}
self.xml = {}

def __reduce__(self):
return (
Expand All @@ -71,6 +71,7 @@ def __reduce__(self):
"route": self.route,
"subdomains": self.subdomains,
"user": self.user,
"xml": self.xml,
},
None,
None,
Expand Down
23 changes: 10 additions & 13 deletions aikido_firewall/context/init_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ def test_wsgi_context_1():
"CONTENT_TYPE": "application/x-www-form-urlencoded",
"REMOTE_ADDR": "198.51.100.23",
}
wsgi_raw_body = "dog_name=Doggo 1&test=Test 1"
context = Context(req=wsgi_request, raw_body=wsgi_raw_body, source="django")
context = Context(req=wsgi_request, body=123, source="django")
assert context.__dict__ == {
"source": "django",
"method": "POST",
Expand All @@ -37,12 +36,13 @@ def test_wsgi_context_1():
"cookies": {"sessionId": "abc123xyz456"},
"url": "https://example.com/hello",
"query": {"user": ["JohnDoe"], "age": ["30", "35"]},
"body": {"dog_name": ["Doggo 1"], "test": ["Test 1"]},
"body": 123,
"route": "/hello",
"subdomains": [],
"user": None,
"remote_address": "198.51.100.23",
"parsed_userinput": {},
"xml": {},
}


Expand All @@ -60,8 +60,7 @@ def test_wsgi_context_2():
"CONTENT_TYPE": "application/json",
"REMOTE_ADDR": "198.51.100.23",
}
wsgi_raw_body = '{"a": 23, "b": 45, "Hello": [1, 2, 3]}'
context = Context(req=wsgi_request, raw_body=wsgi_raw_body, source="flask")
context = Context(req=wsgi_request, body={"test": True}, source="flask")
assert context.__dict__ == {
"source": "flask",
"method": "GET",
Expand All @@ -74,12 +73,13 @@ def test_wsgi_context_2():
"cookies": {"sessionId": "abc123xyz456"},
"url": "http://localhost:8080/hello",
"query": {"user": ["JohnDoe"], "age": ["30", "35"]},
"body": {"a": 23, "b": 45, "Hello": [1, 2, 3]},
"body": {"test": True},
"route": "/hello",
"subdomains": [],
"user": None,
"remote_address": "198.51.100.23",
"parsed_userinput": {},
"xml": {},
}


Expand All @@ -98,8 +98,7 @@ def test_set_as_current_context(mocker):
"CONTENT_TYPE": "application/json",
"REMOTE_ADDR": "198.51.100.23",
}
wsgi_raw_body = '{"a": 23, "b": 45, "Hello": [1, 2, 3]}'
context = Context(req=wsgi_request, raw_body=wsgi_raw_body, source="flask")
context = Context(req=wsgi_request, body=12, source="flask")
context.set_as_current_context()
assert get_current_context() == context

Expand All @@ -119,8 +118,7 @@ def test_get_current_context_with_context(mocker):
"CONTENT_TYPE": "application/json",
"REMOTE_ADDR": "198.51.100.23",
}
wsgi_raw_body = '{"a": 23, "b": 45, "Hello": [1, 2, 3]}'
context = Context(req=wsgi_request, raw_body=wsgi_raw_body, source="flask")
context = Context(req=wsgi_request, body=456, source="flask")
context.set_as_current_context()
assert get_current_context() == context

Expand All @@ -139,15 +137,14 @@ def test_context_is_picklable(mocker):
"CONTENT_TYPE": "application/json",
"REMOTE_ADDR": "198.51.100.23",
}
wsgi_raw_body = '{"a": 23, "b": 45, "Hello": [1, 2, 3]}'
context = Context(req=wsgi_request, raw_body=wsgi_raw_body, source="flask")
context = Context(req=wsgi_request, body=123, source="flask")
pickled_obj = pickle.dumps(context)
unpickled_obj = pickle.loads(pickled_obj)
assert unpickled_obj.source == "flask"
assert unpickled_obj.method == "GET"
assert unpickled_obj.remote_address == "198.51.100.23"
assert unpickled_obj.url == "http://localhost:8080/hello"
assert unpickled_obj.body == {"a": 23, "b": 45, "Hello": [1, 2, 3]}
assert unpickled_obj.body == 123
assert unpickled_obj.headers == {
"HEADER_1": "header 1 value",
"HEADER_2": "Header 2 value",
Expand Down
23 changes: 0 additions & 23 deletions aikido_firewall/context/parse_raw_body.py

This file was deleted.

20 changes: 20 additions & 0 deletions aikido_firewall/helpers/extract_data_from_xml_body.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""Exports extract_data_from_xml_body helper function"""

import aikido_firewall.context as ctx


def extract_data_from_xml_body(user_input, root_element):
"""Extracts all attributes from the xml and adds them to context"""
context = ctx.get_current_context()
if not isinstance(context.body, str) or user_input != context.body:
return

extracted_xml_attrs = context.xml
for el in root_element:
print(extracted_xml_attrs)
for k, v in el.items():
print("Key : %s, Value : %s", k, v)
if not extracted_xml_attrs.get(k):
extracted_xml_attrs[k] = set()
extracted_xml_attrs[k].add(v)
context.set_as_current_context()
110 changes: 110 additions & 0 deletions aikido_firewall/helpers/extract_data_from_xml_body_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import pytest
from unittest.mock import MagicMock
import aikido_firewall.context as ctx
from .extract_data_from_xml_body import (
extract_data_from_xml_body,
) # Replace 'your_module' with the actual module name


@pytest.fixture
def mock_context():
"""Fixture to mock the context."""
mock_ctx = MagicMock()
mock_ctx.body = "valid_input"
mock_ctx.xml = {} # Initialize with an empty dictionary
ctx.get_current_context = MagicMock(return_value=mock_ctx)
return mock_ctx


def test_extract_data_from_xml_body_valid_input(mock_context):
"""Test with valid user input and root_element."""
user_input = "valid_input"
root_element = [
{"attr1": "value1", "attr2": "value2"},
{"attr1": "value3", "attr3": "value4"},
]

extract_data_from_xml_body(user_input, root_element)

assert mock_context.xml == {
"attr1": {"value1", "value3"},
"attr2": {"value2"},
"attr3": {"value4"},
}


def test_extract_data_from_xml_body_invalid_user_input(mock_context):
"""Test with invalid user input."""
user_input = "invalid_input"
root_element = [{"attr1": "value1"}]

extract_data_from_xml_body(user_input, root_element)

assert mock_context.xml == {}


def test_extract_data_from_xml_body_empty_root_element(mock_context):
"""Test with an empty root_element."""
user_input = "valid_input"
root_element = []

extract_data_from_xml_body(user_input, root_element)

assert mock_context.xml == {}


def test_extract_data_from_xml_body_non_string_context_body(mock_context):
"""Test with non-string context body."""
mock_context.body = 123 # Set body to a non-string value
user_input = "valid_input"
root_element = [{"attr1": "value1"}]

extract_data_from_xml_body(user_input, root_element)

assert mock_context.xml == {}


def test_extract_data_from_xml_body_multiple_calls(mock_context):
"""Test multiple calls with the same user input."""
user_input = "valid_input"
root_element1 = [{"attr1": "value1"}]
root_element2 = [{"attr1": "value2"}]

extract_data_from_xml_body(user_input, root_element1)
extract_data_from_xml_body(user_input, root_element2)

assert mock_context.xml == {"attr1": {"value1", "value2"}}


def test_extract_data_from_xml_body_duplicate_attributes(mock_context):
"""Test with duplicate attributes in root_element."""
user_input = "valid_input"
root_element = [
{"attr1": "value1"},
{"attr1": "value1"}, # Duplicate
{"attr2": "value2"},
]

extract_data_from_xml_body(user_input, root_element)

assert mock_context.xml == {"attr1": {"value1"}, "attr2": {"value2"}}


def test_extract_data_from_xml_body_no_attributes(mock_context):
"""Test with elements that have no attributes."""
user_input = "valid_input"
root_element = [{}]

extract_data_from_xml_body(user_input, root_element)

assert mock_context.xml == {}


def test_extract_data_from_xml_body_context_set_as_current(mock_context):
"""Test if context.set_as_current_context is called."""
user_input = "valid_input"
root_element = [{"attr1": "value1"}]

extract_data_from_xml_body(user_input, root_element)

mock_context.set_as_current_context.assert_called_once()
3 changes: 2 additions & 1 deletion aikido_firewall/helpers/extract_strings_from_user_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@


def extract_strings_from_user_input_cached(obj, source):
"""Use the cache to speed up getting user input"""
context = get_current_context()
if context.parsed_userinput and context.parsed_userinput.get(source):
return context.parsed_userinput.get(source)
Expand Down Expand Up @@ -37,7 +38,7 @@ def extract_strings_from_user_input(obj, path_to_payload=None):
).items():
results[k] = v

if isinstance(obj, list):
if isinstance(obj, (set, list, tuple)):
# Add the stringified array as well to the results, there might
# be accidental concatenation if the client expects a string but gets the array
# E.g. HTTP Parameter pollution
Expand Down
15 changes: 12 additions & 3 deletions aikido_firewall/sources/django.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,18 @@ def gen_aikido_middleware_function(former__middleware_chain):
"""

def aikido_middleware_function(request):
context = Context(
req=request.META, raw_body=request.body.decode("utf-8"), source="django"
)
# Get a parsed body from Django :
body = request.POST.dict()
if len(body) == 0 and request.content_type == "application/json":
try:
body = json.loads(request.body)
except Exception:
pass
if len(body) == 0:
# E.g. XML Data
body = request.body

context = Context(req=request.META, body=body, source="django")
context.set_as_current_context()
request_handler(stage="init")

Expand Down
Loading

0 comments on commit 78cc80b

Please sign in to comment.