Skip to content

Commit

Permalink
Create a new set_body function and use it everywhere
Browse files Browse the repository at this point in the history
  • Loading branch information
Wout Feys committed Nov 14, 2024
1 parent b13ccc0 commit 32f1bd8
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 66 deletions.
11 changes: 10 additions & 1 deletion aikido_zen/context/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import contextvars
import json
from urllib.parse import parse_qs

from aikido_zen.helpers.build_route_from_url import build_route_from_url
Expand Down Expand Up @@ -44,7 +45,7 @@ def __init__(self, context_obj=None, body=None, req=None, source=None):
self.parsed_userinput = {}
self.xml = {}
self.outgoing_req_redirects = []
self.body = body
self.set_body(body)

# Parse WSGI/ASGI/... request :
self.cookies = self.method = self.remote_address = self.query = self.headers = (
Expand Down Expand Up @@ -91,6 +92,14 @@ def set_as_current_context(self):
"""
current_context.set(self)

def set_body(self, body):
"""Sets the body, and verifies the body is okay"""
try:
json.dumps(body)
self.body = body
except (TypeError, OverflowError):
self.body = None

def get_route_metadata(self):
"""Returns a route_metadata object"""
return {
Expand Down
110 changes: 54 additions & 56 deletions aikido_zen/context/init_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,21 @@
from aikido_zen.context import Context, get_current_context, current_context


basic_wsgi_req = {
"REQUEST_METHOD": "GET",
"HTTP_HEADER_1": "header 1 value",
"HTTP_HEADER_2": "Header 2 value",
"RANDOM_VALUE": "Random value",
"HTTP_COOKIE": "sessionId=abc123xyz456;",
"wsgi.url_scheme": "http",
"HTTP_HOST": "localhost:8080",
"PATH_INFO": "/hello",
"QUERY_STRING": "user=JohnDoe&age=30&age=35",
"CONTENT_TYPE": "application/json",
"REMOTE_ADDR": "198.51.100.23",
}


@pytest.fixture(autouse=True)
def run_around_tests():
yield
Expand Down Expand Up @@ -58,20 +73,7 @@ def test_wsgi_context_1():


def test_wsgi_context_2():
wsgi_request = {
"REQUEST_METHOD": "GET",
"HTTP_HEADER_1": "header 1 value",
"HTTP_HEADER_2": "Header 2 value",
"RANDOM_VALUE": "Random value",
"HTTP_COOKIE": "sessionId=abc123xyz456;",
"wsgi.url_scheme": "http",
"HTTP_HOST": "localhost:8080",
"PATH_INFO": "/hello",
"QUERY_STRING": "user=JohnDoe&age=30&age=35",
"CONTENT_TYPE": "application/json",
"REMOTE_ADDR": "198.51.100.23",
}
context = Context(req=wsgi_request, body={"test": True}, source="flask")
context = Context(req=basic_wsgi_req, body={"test": True}, source="flask")
assert context.__dict__ == {
"source": "flask",
"method": "GET",
Expand Down Expand Up @@ -99,59 +101,20 @@ def test_wsgi_context_2():

def test_set_as_current_context(mocker):
# Test set_as_current_context() method
wsgi_request = {
"REQUEST_METHOD": "GET",
"HTTP_HEADER_1": "header 1 value",
"HTTP_HEADER_2": "Header 2 value",
"RANDOM_VALUE": "Random value",
"HTTP_COOKIE": "sessionId=abc123xyz456;",
"wsgi.url_scheme": "http",
"HTTP_HOST": "localhost:8080",
"PATH_INFO": "/hello",
"QUERY_STRING": "user=JohnDoe&age=30&age=35",
"CONTENT_TYPE": "application/json",
"REMOTE_ADDR": "198.51.100.23",
}
context = Context(req=wsgi_request, body=12, source="flask")
context = Context(req=basic_wsgi_req, body=12, source="flask")
context.set_as_current_context()
assert get_current_context() == context


def test_get_current_context_with_context(mocker):
# Test get_current_context() when a context is set
wsgi_request = {
"REQUEST_METHOD": "GET",
"HTTP_HEADER_1": "header 1 value",
"HTTP_HEADER_2": "Header 2 value",
"RANDOM_VALUE": "Random value",
"HTTP_COOKIE": "sessionId=abc123xyz456;",
"wsgi.url_scheme": "http",
"HTTP_HOST": "localhost:8080",
"PATH_INFO": "/hello",
"QUERY_STRING": "user=JohnDoe&age=30&age=35",
"CONTENT_TYPE": "application/json",
"REMOTE_ADDR": "198.51.100.23",
}
context = Context(req=wsgi_request, body=456, source="flask")
context = Context(req=basic_wsgi_req, body=456, source="flask")
context.set_as_current_context()
assert get_current_context() == context


def test_context_is_picklable(mocker):
wsgi_request = {
"REQUEST_METHOD": "GET",
"HTTP_HEADER_1": "header 1 value",
"HTTP_HEADER_2": "Header 2 value",
"RANDOM_VALUE": "Random value",
"HTTP_COOKIE": "sessionId=abc123xyz456;",
"wsgi.url_scheme": "http",
"HTTP_HOST": "localhost:8080",
"PATH_INFO": "/hello",
"QUERY_STRING": "user=JohnDoe&age=30&age=35",
"CONTENT_TYPE": "application/json",
"REMOTE_ADDR": "198.51.100.23",
}
context = Context(req=wsgi_request, body=123, source="flask")
context = Context(req=basic_wsgi_req, body=123, source="flask")
pickled_obj = pickle.dumps(context)
unpickled_obj = pickle.loads(pickled_obj)
assert unpickled_obj.source == "flask"
Expand All @@ -168,3 +131,38 @@ def test_context_is_picklable(mocker):
}
assert unpickled_obj.query == {"user": ["JohnDoe"], "age": ["30", "35"]}
assert unpickled_obj.cookies == {"sessionId": "abc123xyz456"}


def test_set_valid_dict():
valid_body = {"key": "value"}
context = Context(req=basic_wsgi_req, body=valid_body, source="flask")
assert context.body == valid_body


def test_set_valid_list():
valid_body = [1, 2, 3]
context = Context(req=basic_wsgi_req, body=valid_body, source="flask")
assert context.body == valid_body


def test_set_valid_string():
valid_body = "This is a valid string"
context = Context(req=basic_wsgi_req, body=valid_body, source="flask")
assert context.body == valid_body


def test_set_invalid_body():
invalid_body = set([1, 2, 3]) # Sets are not JSON serializable
context = Context(req=basic_wsgi_req, body=invalid_body, source="flask")
assert context.body is None


def test_set_bytestring():
invalid_body = b"Hello World" # Byte strings are not JSON Serializablec
context = Context(req=basic_wsgi_req, body=invalid_body, source="flask")
assert context.body is None


def test_set_none():
context = Context(req=basic_wsgi_req, body=None, source="flask")
assert context.body is None
6 changes: 3 additions & 3 deletions aikido_zen/sources/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ def extract_and_save_data_from_flask_request(req):
context = get_current_context()
if context:
if req.is_json:
context.body = req.get_json()
context.set_body(req.get_json())

Check warning on line 48 in aikido_zen/sources/flask.py

View check run for this annotation

Codecov / codecov/patch

aikido_zen/sources/flask.py#L48

Added line #L48 was not covered by tests
elif req.form:
context.body = req.form
context.set_body(req.form)

Check warning on line 50 in aikido_zen/sources/flask.py

View check run for this annotation

Codecov / codecov/patch

aikido_zen/sources/flask.py#L50

Added line #L50 was not covered by tests
else:
context.body = req.data.decode("utf-8")
context.set_body(req.data.decode("utf-8"))

Check warning on line 52 in aikido_zen/sources/flask.py

View check run for this annotation

Codecov / codecov/patch

aikido_zen/sources/flask.py#L52

Added line #L52 was not covered by tests

if getattr(req, "view_args"):
context.route_params = dict(req.view_args)
Expand Down
6 changes: 3 additions & 3 deletions aikido_zen/sources/quart.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ async def handle_request_wrapper(former_handle_request, quart_app, req):
if context:
form = await req.form
if req.is_json:
context.body = await req.get_json()
context.set_body(await req.get_json())

Check warning on line 41 in aikido_zen/sources/quart.py

View check run for this annotation

Codecov / codecov/patch

aikido_zen/sources/quart.py#L41

Added line #L41 was not covered by tests
elif form:
context.body = form
context.set_body(form)

Check warning on line 43 in aikido_zen/sources/quart.py

View check run for this annotation

Codecov / codecov/patch

aikido_zen/sources/quart.py#L43

Added line #L43 was not covered by tests
else:
data = await req.data
context.body = data.decode("utf-8")
context.set_body(data.decode("utf-8"))

Check warning on line 46 in aikido_zen/sources/quart.py

View check run for this annotation

Codecov / codecov/patch

aikido_zen/sources/quart.py#L46

Added line #L46 was not covered by tests
context.cookies = req.cookies.to_dict()
context.set_as_current_context()
except Exception as e:
Expand Down
6 changes: 3 additions & 3 deletions aikido_zen/sources/starlette/extract_data_from_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@ async def extract_data_from_request(request):

# Parse data
try:
context.body = await request.json()
context.set_body(await request.json())

Check warning on line 14 in aikido_zen/sources/starlette/extract_data_from_request.py

View check run for this annotation

Codecov / codecov/patch

aikido_zen/sources/starlette/extract_data_from_request.py#L14

Added line #L14 was not covered by tests
except ValueError:
# Throws error if the body is not json
pass
if not context.body:
form_data = await request.form()
if form_data:
# Convert to dict object :
context.body = {key: value for key, value in form_data.items()}
context.set_body({key: value for key, value in form_data.items()})

Check warning on line 22 in aikido_zen/sources/starlette/extract_data_from_request.py

View check run for this annotation

Codecov / codecov/patch

aikido_zen/sources/starlette/extract_data_from_request.py#L22

Added line #L22 was not covered by tests
if not context.body:
context.body = await request.body()
context.set_body(await request.body())

Check warning on line 24 in aikido_zen/sources/starlette/extract_data_from_request.py

View check run for this annotation

Codecov / codecov/patch

aikido_zen/sources/starlette/extract_data_from_request.py#L24

Added line #L24 was not covered by tests

context.set_as_current_context()

0 comments on commit 32f1bd8

Please sign in to comment.