Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into fix-django-file-upload
Browse files Browse the repository at this point in the history
  • Loading branch information
Wout Feys committed Aug 22, 2024
2 parents 0d8c6ae + 9bcdeaf commit 3fd4bb7
Show file tree
Hide file tree
Showing 14 changed files with 318 additions and 5 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/end2end.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ jobs:
working-directory: ./sample-apps/flask-postgres
run: |
docker compose -f docker-compose.yml -f docker-compose.benchmark.yml up --build -d
- name: Start flask-postgres-xml
working-directory: ./sample-apps/flask-postgres-xml
run: |
docker compose -f docker-compose.yml -f docker-compose.benchmark.yml up --build -d
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
Expand Down
2 changes: 2 additions & 0 deletions aikido_firewall/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def protect(module="any", server=True):
# Import sources
import aikido_firewall.sources.django
import aikido_firewall.sources.flask
import aikido_firewall.sources.xml
import aikido_firewall.sources.lxml

import aikido_firewall.sources.gunicorn
import aikido_firewall.sources.uwsgi
Expand Down
4 changes: 3 additions & 1 deletion aikido_firewall/context/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .extract_wsgi_headers import extract_wsgi_headers
from .build_url_from_wsgi import build_url_from_wsgi

UINPUT_SOURCES = ["body", "cookies", "query", "headers"]
UINPUT_SOURCES = ["body", "cookies", "query", "headers", "xml"]
local = threading.local()


Expand Down Expand Up @@ -53,6 +53,7 @@ def __init__(self, context_obj=None, body=None, req=None, source=None):
self.user = None
self.remote_address = get_ip_from_request(req["REMOTE_ADDR"], self.headers)
self.parsed_userinput = {}
self.xml = {}

def __reduce__(self):
return (
Expand All @@ -70,6 +71,7 @@ def __reduce__(self):
"route": self.route,
"subdomains": self.subdomains,
"user": self.user,
"xml": self.xml,
},
None,
None,
Expand Down
2 changes: 2 additions & 0 deletions aikido_firewall/context/init_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def test_wsgi_context_1():
"user": None,
"remote_address": "198.51.100.23",
"parsed_userinput": {},
"xml": {},
}


Expand Down Expand Up @@ -78,6 +79,7 @@ def test_wsgi_context_2():
"user": None,
"remote_address": "198.51.100.23",
"parsed_userinput": {},
"xml": {},
}


Expand Down
20 changes: 20 additions & 0 deletions aikido_firewall/helpers/extract_data_from_xml_body.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""Exports extract_data_from_xml_body helper function"""

import aikido_firewall.context as ctx


def extract_data_from_xml_body(user_input, root_element):
"""Extracts all attributes from the xml and adds them to context"""
context = ctx.get_current_context()
if not isinstance(context.body, str) or user_input != context.body:
return

extracted_xml_attrs = context.xml
for el in root_element:
print(extracted_xml_attrs)
for k, v in el.items():
print("Key : %s, Value : %s", k, v)
if not extracted_xml_attrs.get(k):
extracted_xml_attrs[k] = set()
extracted_xml_attrs[k].add(v)
context.set_as_current_context()
110 changes: 110 additions & 0 deletions aikido_firewall/helpers/extract_data_from_xml_body_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import pytest
from unittest.mock import MagicMock
import aikido_firewall.context as ctx
from .extract_data_from_xml_body import (
extract_data_from_xml_body,
) # Replace 'your_module' with the actual module name


@pytest.fixture
def mock_context():
"""Fixture to mock the context."""
mock_ctx = MagicMock()
mock_ctx.body = "valid_input"
mock_ctx.xml = {} # Initialize with an empty dictionary
ctx.get_current_context = MagicMock(return_value=mock_ctx)
return mock_ctx


def test_extract_data_from_xml_body_valid_input(mock_context):
"""Test with valid user input and root_element."""
user_input = "valid_input"
root_element = [
{"attr1": "value1", "attr2": "value2"},
{"attr1": "value3", "attr3": "value4"},
]

extract_data_from_xml_body(user_input, root_element)

assert mock_context.xml == {
"attr1": {"value1", "value3"},
"attr2": {"value2"},
"attr3": {"value4"},
}


def test_extract_data_from_xml_body_invalid_user_input(mock_context):
"""Test with invalid user input."""
user_input = "invalid_input"
root_element = [{"attr1": "value1"}]

extract_data_from_xml_body(user_input, root_element)

assert mock_context.xml == {}


def test_extract_data_from_xml_body_empty_root_element(mock_context):
"""Test with an empty root_element."""
user_input = "valid_input"
root_element = []

extract_data_from_xml_body(user_input, root_element)

assert mock_context.xml == {}


def test_extract_data_from_xml_body_non_string_context_body(mock_context):
"""Test with non-string context body."""
mock_context.body = 123 # Set body to a non-string value
user_input = "valid_input"
root_element = [{"attr1": "value1"}]

extract_data_from_xml_body(user_input, root_element)

assert mock_context.xml == {}


def test_extract_data_from_xml_body_multiple_calls(mock_context):
"""Test multiple calls with the same user input."""
user_input = "valid_input"
root_element1 = [{"attr1": "value1"}]
root_element2 = [{"attr1": "value2"}]

extract_data_from_xml_body(user_input, root_element1)
extract_data_from_xml_body(user_input, root_element2)

assert mock_context.xml == {"attr1": {"value1", "value2"}}


def test_extract_data_from_xml_body_duplicate_attributes(mock_context):
"""Test with duplicate attributes in root_element."""
user_input = "valid_input"
root_element = [
{"attr1": "value1"},
{"attr1": "value1"}, # Duplicate
{"attr2": "value2"},
]

extract_data_from_xml_body(user_input, root_element)

assert mock_context.xml == {"attr1": {"value1"}, "attr2": {"value2"}}


def test_extract_data_from_xml_body_no_attributes(mock_context):
"""Test with elements that have no attributes."""
user_input = "valid_input"
root_element = [{}]

extract_data_from_xml_body(user_input, root_element)

assert mock_context.xml == {}


def test_extract_data_from_xml_body_context_set_as_current(mock_context):
"""Test if context.set_as_current_context is called."""
user_input = "valid_input"
root_element = [{"attr1": "value1"}]

extract_data_from_xml_body(user_input, root_element)

mock_context.set_as_current_context.assert_called_once()
3 changes: 2 additions & 1 deletion aikido_firewall/helpers/extract_strings_from_user_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@


def extract_strings_from_user_input_cached(obj, source):
"""Use the cache to speed up getting user input"""
context = get_current_context()
if context.parsed_userinput and context.parsed_userinput.get(source):
return context.parsed_userinput.get(source)
Expand Down Expand Up @@ -37,7 +38,7 @@ def extract_strings_from_user_input(obj, path_to_payload=None):
).items():
results[k] = v

if isinstance(obj, list):
if isinstance(obj, (set, list, tuple)):
# Add the stringified array as well to the results, there might
# be accidental concatenation if the client expects a string but gets the array
# E.g. HTTP Parameter pollution
Expand Down
7 changes: 4 additions & 3 deletions aikido_firewall/sources/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@ def aikido_view_func(*args, **kwargs):
if context:
if req.is_json:
context.body = req.get_json()
context.set_as_current_context()
else:
elif req.form:
context.body = req.form
context.set_as_current_context()
else:
context.body = req.data.decode("utf-8")
context.set_as_current_context()

pre_response = request_handler(stage="pre_response")
if pre_response:
Expand Down
46 changes: 46 additions & 0 deletions aikido_firewall/sources/lxml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""
Sink module for `xml`, python's built-in function
"""

import copy
import importhook
from aikido_firewall.helpers.extract_data_from_xml_body import (
extract_data_from_xml_body,
)
from aikido_firewall.background_process.packages import add_wrapped_package


@importhook.on_import("lxml.etree")
def on_lxml_import(eltree):
"""
Hook 'n wrap on `lxml.etree`.
- Wrap on fromstring() function
- Wrap on
Returns : Modified `lxml.etree` object
"""
modified_eltree = importhook.copy_module(eltree)

former_fromstring = copy.deepcopy(eltree.fromstring)

def aikido_fromstring(text, *args, **kwargs):
res = former_fromstring(text, *args, **kwargs)
extract_data_from_xml_body(user_input=text, root_element=res)
return res

former_fromstringlist = copy.deepcopy(eltree.fromstringlist)

def aikido_fromstringlist(strings, *args, **kwargs):
res = former_fromstringlist(strings, *args, **kwargs)
for string in strings:
extract_data_from_xml_body(user_input=string, root_element=res)
return res

# pylint: disable=no-member
setattr(eltree, "fromstring", aikido_fromstring)
setattr(modified_eltree, "fromstring", aikido_fromstring)

# pylint: disable=no-member
setattr(eltree, "fromstringlist", aikido_fromstringlist)
setattr(modified_eltree, "fromstringlist", aikido_fromstringlist)
add_wrapped_package("lxml")
return modified_eltree
52 changes: 52 additions & 0 deletions aikido_firewall/sources/xml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""
Sink module for `xml`, python's built-in function
"""

import copy
import importhook
from aikido_firewall.helpers.logging import logger
from aikido_firewall.helpers.extract_data_from_xml_body import (
extract_data_from_xml_body,
)


@importhook.on_import("xml.etree.ElementTree")
def on_xml_import(eltree):
"""
Hook 'n wrap on `xml.etree.ElementTree`, python's built-in xml lib
Our goal is to create a new and mutable aikido parser class
Returns : Modified ElementTree object
"""
modified_eltree = importhook.copy_module(eltree)
copy_xml_parser = copy.deepcopy(eltree.XMLParser)

class MutableAikidoXMLParser:
"""Aikido's mutable connection class"""

def __init__(self, *args, **kwargs):
self._former_xml_parser = copy_xml_parser(*args, **kwargs)
self._feed_func_copy = copy.deepcopy(self._former_xml_parser.feed)

def __getattr__(self, name):
if name != "feed":
return getattr(self._former_xml_parser, name)

# Return aa function dynamically
def feed(data):
former_feed_result = self._feed_func_copy(data)

# Fetch the data, this should just return an internal attribute and not close a stream
# Or something that is noticable by the end-user
parsed_xml = self.target.close()
extract_data_from_xml_body(user_input=data, root_element=parsed_xml)

return former_feed_result

return feed

# pylint: disable=no-member
setattr(eltree, "XMLParser", MutableAikidoXMLParser)
setattr(modified_eltree, "XMLParser", MutableAikidoXMLParser)

logger.debug("Wrapped `xml` module")
return modified_eltree
28 changes: 28 additions & 0 deletions end2end/flask_postgres_xml_lxml_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import pytest
import requests
# e2e tests for flask_postgres sample app
post_url_fw = "http://localhost:8092/xml_post_lxml"
post_url_nofw = "http://localhost:8093/xml_post_lxml"

def test_safe_response_with_firewall():
xml_data = '<dogs><dog dog_name="Bobby" /></dogs>'
res = requests.post(post_url_fw, data=xml_data)
assert res.status_code == 200


def test_safe_response_without_firewall():
xml_data = '<dogs><dog dog_name="Bobby" /></dogs>'
res = requests.post(post_url_nofw, data=xml_data)
assert res.status_code == 200


def test_dangerous_response_with_firewall():
xml_data = '<dogs><dog dog_name="Malicious dog\', TRUE); -- " /></dogs>'
res = requests.post(post_url_fw, data=xml_data)
assert res.status_code == 500

def test_dangerous_response_without_firewall():
xml_data = '<dogs><dog dog_name="Malicious dog\', TRUE); -- " /></dogs>'
res = requests.post(post_url_nofw, data=xml_data)
assert res.status_code == 200

28 changes: 28 additions & 0 deletions end2end/flask_postgres_xml_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import pytest
import requests
# e2e tests for flask_postgres sample app
post_url_fw = "http://localhost:8092/xml_post"
post_url_nofw = "http://localhost:8093/xml_post"

def test_safe_response_with_firewall():
xml_data = '<dogs><dog dog_name="Bobby" /></dogs>'
res = requests.post(post_url_fw, data=xml_data)
assert res.status_code == 200


def test_safe_response_without_firewall():
xml_data = '<dogs><dog dog_name="Bobby" /></dogs>'
res = requests.post(post_url_nofw, data=xml_data)
assert res.status_code == 200


def test_dangerous_response_with_firewall():
xml_data = '<dogs><dog dog_name="Malicious dog\', TRUE); -- " /></dogs>'
res = requests.post(post_url_fw, data=xml_data)
assert res.status_code == 500

def test_dangerous_response_without_firewall():
xml_data = '<dogs><dog dog_name="Malicious dog\', TRUE); -- " /></dogs>'
res = requests.post(post_url_nofw, data=xml_data)
assert res.status_code == 200

Loading

0 comments on commit 3fd4bb7

Please sign in to comment.