diff --git a/aikido_firewall/helpers/get_port_from_url.py b/aikido_firewall/helpers/get_port_from_url.py index a154bef4..2cee6a03 100644 --- a/aikido_firewall/helpers/get_port_from_url.py +++ b/aikido_firewall/helpers/get_port_from_url.py @@ -5,11 +5,14 @@ from urllib.parse import urlparse -def get_port_from_url(url): +def get_port_from_url(url, parsed=False): """ Tries to retrieve a port number from the given url """ - parsed_url = urlparse(url) + if not parsed: + parsed_url = urlparse(url) + else: + parsed_url = url # Check if the port is specified and is a valid integer if parsed_url.port is not None: diff --git a/aikido_firewall/helpers/get_port_from_url_test.py b/aikido_firewall/helpers/get_port_from_url_test.py index ec65efe0..18dd45df 100644 --- a/aikido_firewall/helpers/get_port_from_url_test.py +++ b/aikido_firewall/helpers/get_port_from_url_test.py @@ -1,5 +1,6 @@ import pytest from .get_port_from_url import get_port_from_url +from urllib.parse import urlparse def test_get_port_from_url(): @@ -8,3 +9,13 @@ def test_get_port_from_url(): assert get_port_from_url("https://test.com:8080/test?abc=123") == 8080 assert get_port_from_url("https://localhost") == 443 assert get_port_from_url("ftp://localhost") is None + + +def test_get_port_from_parsed_url(): + assert get_port_from_url(urlparse("http://localhost:4000"), True) == 4000 + assert get_port_from_url(urlparse("http://localhost"), True) == 80 + assert ( + get_port_from_url(urlparse("https://test.com:8080/test?abc=123"), True) == 8080 + ) + assert get_port_from_url(urlparse("https://localhost"), True) == 443 + assert get_port_from_url(urlparse("ftp://localhost"), True) is None diff --git a/aikido_firewall/helpers/urls/__init__.py b/aikido_firewall/helpers/urls/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/aikido_firewall/helpers/urls/normalize_url.py b/aikido_firewall/helpers/urls/normalize_url.py new file mode 100644 index 00000000..35f01c22 --- /dev/null +++ b/aikido_firewall/helpers/urls/normalize_url.py @@ -0,0 +1,26 @@ +"""Helper function file, exports normalize_url""" + +from urllib.parse import urlparse, urlunparse + + +def normalize_url(url): + """Normalizes the url""" + # Parse the URL + parsed_url = urlparse(url) + + # Normalize components + scheme = parsed_url.scheme.lower() # Lowercase scheme + netloc = parsed_url.netloc.lower() # Lowercase netloc + path = parsed_url.path.rstrip("/") # Remove trailing slash + query = parsed_url.query # Keep query as is + fragment = parsed_url.fragment # Keep fragment as is + + # Remove default ports (80 for http, 443 for https) + if scheme == "http" and parsed_url.port == 80: + netloc = netloc.replace(":80", "") + elif scheme == "https" and parsed_url.port == 443: + netloc = netloc.replace(":443", "") + + # Reconstruct the normalized URL + normalized_url = urlunparse((scheme, netloc, path, "", query, fragment)) + return normalized_url diff --git a/aikido_firewall/helpers/urls/normalize_url_test.py b/aikido_firewall/helpers/urls/normalize_url_test.py new file mode 100644 index 00000000..e70a07d7 --- /dev/null +++ b/aikido_firewall/helpers/urls/normalize_url_test.py @@ -0,0 +1,58 @@ +import pytest +from .normalize_url import normalize_url + + +def test_normalize_url(): + # Test with standard URLs + assert normalize_url("http://example.com") == "http://example.com" + assert normalize_url("https://example.com") == "https://example.com" + assert normalize_url("http://example.com/") == "http://example.com" + assert normalize_url("http://example.com/path/") == "http://example.com/path" + assert normalize_url("http://example.com/path") == "http://example.com/path" + + # Test with lowercase and uppercase schemes + assert normalize_url("HTTP://EXAMPLE.COM") == "http://example.com" + assert normalize_url("Https://EXAMPLE.COM") == "https://example.com" + + # Test with default ports + assert normalize_url("http://example.com:80/path") == "http://example.com/path" + assert normalize_url("https://example.com:443/path") == "https://example.com/path" + + # Test with non-default ports + assert ( + normalize_url("http://example.com:8080/path") == "http://example.com:8080/path" + ) + assert ( + normalize_url("https://example.com:8443/path") + == "https://example.com:8443/path" + ) + + # Test with query parameters + assert ( + normalize_url("http://example.com/path?query=1") + == "http://example.com/path?query=1" + ) + assert ( + normalize_url("http://example.com/path/?query=1") + == "http://example.com/path?query=1" + ) + + # Test with fragments + assert ( + normalize_url("http://example.com/path#fragment") + == "http://example.com/path#fragment" + ) + assert ( + normalize_url("http://example.com/path/?query=1#fragment") + == "http://example.com/path?query=1#fragment" + ) + + # Test with URLs that have trailing slashes and mixed cases + assert normalize_url("http://Example.com/Path/") == "http://example.com/Path" + assert ( + normalize_url("http://example.com/path/another/") + == "http://example.com/path/another" + ) + + # Test with empty URL + assert normalize_url("") == "" diff --git a/aikido_firewall/sinks/http_client.py b/aikido_firewall/sinks/http_client.py index 88f8ee61..da3ac815 100644 --- a/aikido_firewall/sinks/http_client.py +++ b/aikido_firewall/sinks/http_client.py @@ -26,22 +26,9 @@ def on_http_import(http): former_getresponse = copy.deepcopy(http.HTTPConnection.getresponse) def aik_new_putrequest(_self, method, path, *args, **kwargs): - # Aikido putrequest, gets called before the request went through - try: - # Set path for aik_new_getresponse : - _self.aikido_attr_path = path - - # Create a URL Object : - assembled_url = f"http://{_self.host}:{_self.port}{path}" - url_object = try_parse_url(assembled_url) - - run_vulnerability_scan( - kind="ssrf", op="http.client.putrequest", args=(url_object, _self.port) - ) - except AikidoException as e: - raise e - except Exception as e: - logger.debug("Exception occured in custom putrequest function : %s", e) + # Aikido putrequest, gets called before the request goes through + # Set path for aik_new_getresponse : + _self.aikido_attr_path = path return former_putrequest(_self, method, path, *args, **kwargs) def aik_new_getresponse(_self): diff --git a/aikido_firewall/sinks/socket.py b/aikido_firewall/sinks/socket.py index cf374016..2a4aaa5a 100644 --- a/aikido_firewall/sinks/socket.py +++ b/aikido_firewall/sinks/socket.py @@ -5,9 +5,7 @@ import copy import importhook from aikido_firewall.helpers.logging import logger -from aikido_firewall.vulnerabilities.ssrf.inspect_getaddrinfo_result import ( - inspect_getaddrinfo_result, -) +from aikido_firewall.vulnerabilities import run_vulnerability_scan SOCKET_OPERATIONS = [ "gethostbyname", @@ -25,8 +23,9 @@ def generate_aikido_function(former_func, op): def aik_new_func(*args, **kwargs): res = former_func(*args, **kwargs) if op == "getaddrinfo": - inspect_getaddrinfo_result(dns_results=res, hostname=args[0], port=args[1]) - logger.debug("Res %s", res) + run_vulnerability_scan( + kind="ssrf", op="socket.getaddrinfo", args=(res, args[0], args[1]) + ) return res return aik_new_func @@ -36,9 +35,7 @@ def aik_new_func(*args, **kwargs): def on_socket_import(socket): """ Hook 'n wrap on `socket` - Our goal is to wrap the following socket functions that take a hostname : - - gethostbyname() -- map a hostname to its IP number - - gethostbyaddr() -- map an IP number or hostname to DNS info + Our goal is to wrap the getaddrinfo socket function https://github.com/python/cpython/blob/8f19be47b6a50059924e1d7b64277ad3cef4dac7/Lib/socket.py#L10 Returns : Modified socket object """ diff --git a/aikido_firewall/vulnerabilities/__init__.py b/aikido_firewall/vulnerabilities/__init__.py index ae20b3c0..80086277 100644 --- a/aikido_firewall/vulnerabilities/__init__.py +++ b/aikido_firewall/vulnerabilities/__init__.py @@ -19,7 +19,7 @@ from aikido_firewall.background_process.ipc_lifecycle_cache import get_cache from .sql_injection.context_contains_sql_injection import context_contains_sql_injection from .nosql_injection.check_context import check_context_for_nosql_injection -from .ssrf import scan_for_ssrf_in_request +from .ssrf.inspect_getaddrinfo_result import inspect_getaddrinfo_result from .shell_injection.check_context_for_shell_injection import ( check_context_for_shell_injection, ) @@ -36,9 +36,14 @@ def run_vulnerability_scan(kind, op, args): context = get_current_context() comms = get_comms() lifecycle_cache = get_cache() - if not context or not lifecycle_cache: - logger.debug("Not running a vulnerability scan due to incomplete data.") - logger.debug("%s : %s", kind, op) + if not context and kind != "ssrf": + # Make a special exception for SSRF, which checks itself if context is set. + # This is because some scans/tests for SSRF do not require a context to be set. + logger.debug("Not running scans due to incomplete data %s : %s", kind, op) + return + + if not lifecycle_cache: + logger.debug("Not running scans due to incomplete data %s : %s", kind, op) return if lifecycle_cache.protection_forced_off(): @@ -74,9 +79,10 @@ def run_vulnerability_scan(kind, op, args): ) error_type = AikidoPathTraversal elif kind == "ssrf": - # args[0] : URL object, args[1] : Port - # Report hostname and port to background process : - injection_results = scan_for_ssrf_in_request(args[0], args[1], op, context) + # args[0] : DNS Results, args[1] : Hostname, args[2] : Port + injection_results = inspect_getaddrinfo_result( + dns_results=args[0], hostname=args[1], port=args[2] + ) error_type = AikidoSSRF blocked_request = is_blocking_enabled() and injection_results if not blocked_request: diff --git a/aikido_firewall/vulnerabilities/ssrf/__init__.py b/aikido_firewall/vulnerabilities/ssrf/__init__.py index 29aa3ee4..e69de29b 100644 --- a/aikido_firewall/vulnerabilities/ssrf/__init__.py +++ b/aikido_firewall/vulnerabilities/ssrf/__init__.py @@ -1,29 +0,0 @@ -"""Exports scan_for_ssrf_in_request function""" - -from aikido_firewall.helpers.logging import logger -from .check_context_for_ssrf import check_context_for_ssrf -from .is_redirect_to_private_ip import is_redirect_to_private_ip - - -def scan_for_ssrf_in_request(url, port, operation, context): - """Scans for SSRF attacks""" - - # Check if the request is a SSRF : - context_contains_ssrf_results = check_context_for_ssrf( - url.hostname, port, operation, context - ) - if context_contains_ssrf_results: - return context_contains_ssrf_results - - # Check if the request is a SSRF with redirects : - logger.debug("Redirects : %s", context.outgoing_req_redirects) - redirected_ssrf_results = is_redirect_to_private_ip(url, context) - if redirected_ssrf_results: - return { - "operation": operation, - "kind": "ssrf", - "source": redirected_ssrf_results["source"], - "pathToPayload": redirected_ssrf_results["pathToPayload"], - "metadata": {}, - "payload": redirected_ssrf_results["payload"], - } diff --git a/aikido_firewall/vulnerabilities/ssrf/check_context_for_ssrf.py b/aikido_firewall/vulnerabilities/ssrf/check_context_for_ssrf.py deleted file mode 100644 index 2f5bb446..00000000 --- a/aikido_firewall/vulnerabilities/ssrf/check_context_for_ssrf.py +++ /dev/null @@ -1,35 +0,0 @@ -"""Exports check_context_for_ssrf""" - -from aikido_firewall.helpers.extract_strings_from_user_input import ( - extract_strings_from_user_input_cached, -) -from aikido_firewall.helpers.logging import logger -from aikido_firewall.context import UINPUT_SOURCES as SOURCES -from .find_hostname_in_userinput import find_hostname_in_userinput -from .contains_private_ip_address import contains_private_ip_address - - -def check_context_for_ssrf(hostname, port, operation, context): - """ - This will check the context for SSRF - """ - if not isinstance(hostname, str) or not isinstance(port, int): - # Validate hostname and port input - return {} - for source in SOURCES: - if hasattr(context, source): - user_inputs = extract_strings_from_user_input_cached( - getattr(context, source), source - ) - for user_input, path in user_inputs.items(): - found = find_hostname_in_userinput(user_input, hostname, port) - if found and contains_private_ip_address(hostname): - return { - "operation": operation, - "kind": "ssrf", - "source": source, - "pathToPayload": path, - "metadata": {"hostname": hostname}, - "payload": user_input, - } - return {} diff --git a/aikido_firewall/vulnerabilities/ssrf/check_context_for_ssrf_test.py b/aikido_firewall/vulnerabilities/ssrf/check_context_for_ssrf_test.py deleted file mode 100644 index 000b6ca3..00000000 --- a/aikido_firewall/vulnerabilities/ssrf/check_context_for_ssrf_test.py +++ /dev/null @@ -1,77 +0,0 @@ -import pytest -from aikido_firewall.context import Context -from .check_context_for_ssrf import check_context_for_ssrf - - -class Context2(Context): - def __init__(self): - self.cookies = {} - self.headers = {} - self.remote_address = "ip" - self.method = "POST" - self.url = "url" - self.body = {} - self.query = { - "domain": "www.example`whoami`.com", - } - self.source = "express" - self.route = "/" - self.parsed_userinput = {} - - -@pytest.mark.parametrize( - "invalid_input", - [ - None, - 123456789, # Integer - 45.67, # Float - [], # Empty list - [1, 2, 3], # List of integers - {}, # Empty dictionary - {"key": "value"}, # Dictionary - set(), # Empty set - {1, 2, 3}, # Set of integers - object(), # Instance of a generic object - lambda x: x, # Lambda function - (1, 2), # Tuple - b"bytes", # Bytes - ], -) -def test_doesnt_crash_with_invalid_hostname(invalid_input): - context = Context2() - result = check_context_for_ssrf( - hostname=invalid_input, - port=8080, - operation="http.putrequest", - context=context, - ) - assert result == {} - - -@pytest.mark.parametrize( - "invalid_input", - [ - None, - "test", # String - 45.67, # Float - [], # Empty list - [1, 2, 3], # List of integers - {}, # Empty dictionary - {"key": "value"}, # Dictionary - set(), # Empty set - {1, 2, 3}, # Set of integers - object(), # Instance of a generic object - lambda x: x, # Lambda function - (1, 2), # Tuple - b"bytes", # Bytes - ], -) -def test_doesnt_crash_with_invalid_port(invalid_input): - context = Context2() - result = check_context_for_ssrf( - hostname="example.com", - port=invalid_input, - operation="http.putrequest", - context=context, - ) - assert result == {} diff --git a/aikido_firewall/vulnerabilities/ssrf/contains_private_ip_address.py b/aikido_firewall/vulnerabilities/ssrf/contains_private_ip_address.py deleted file mode 100644 index a62caf01..00000000 --- a/aikido_firewall/vulnerabilities/ssrf/contains_private_ip_address.py +++ /dev/null @@ -1,25 +0,0 @@ -"""exports contains_private_ip_address""" - -from aikido_firewall.helpers.try_parse_url import try_parse_url -from .is_private_ip import is_private_ip - - -def contains_private_ip_address(hostname): - """ - Checks if the hostname contains an IP that's private - """ - if hostname == "localhost": - return True - - # Attempt to parse the URL - url = try_parse_url(f"http://{hostname}") - if url is None: - return False - - # Check for IPv6 addresses enclosed in square brackets - if url.hostname.startswith("[") and url.hostname.endswith("]"): - ipv6 = url.hostname[1:-1] # Extract the IPv6 address - if is_private_ip(ipv6): - return True - - return is_private_ip(url.hostname) diff --git a/aikido_firewall/vulnerabilities/ssrf/contains_private_ip_address_test.py b/aikido_firewall/vulnerabilities/ssrf/contains_private_ip_address_test.py deleted file mode 100644 index 3d9ef0f2..00000000 --- a/aikido_firewall/vulnerabilities/ssrf/contains_private_ip_address_test.py +++ /dev/null @@ -1,183 +0,0 @@ -import pytest -from .contains_private_ip_address import contains_private_ip_address - - -public_ips = [ - "44.37.112.180", - "46.192.247.73", - "71.12.102.112", - "101.0.26.90", - "111.211.73.40", - "156.238.194.84", - "164.101.185.82", - "223.231.138.242", - "::1fff:0.0.0.0", - "::1fff:10.0.0.0", - "::1fff:0:0.0.0.0", - "::1fff:0:10.0.0.0", - "2001:2:ffff:ffff:ffff:ffff:ffff:ffff", - "64:ff9a::0.0.0.0", - "64:ff9a::255.255.255.255", - "99::", - "99::ffff:ffff:ffff:ffff", - "101::", - "101::ffff:ffff:ffff:ffff", - "2000::", - "2000::ffff:ffff:ffff:ffff:ffff:ffff", - "2001:10::", - "2001:1f:ffff:ffff:ffff:ffff:ffff:ffff", - "2001:db7::", - "2001:db7:ffff:ffff:ffff:ffff:ffff:ffff", - "2001:db9::", - "fb00::", - "fbff:ffff:ffff:ffff:ffff:ffff:ffff:ffff", - "fec0::", -] - -private_ips = [ - "0.0.0.0", - "0000.0000.0000.0000", - "0.0.0.1", - "0.0.0.7", - "0.0.0.255", - "0.0.255.255", - "0.1.255.255", - "0.15.255.255", - "0.63.255.255", - "0.255.255.254", - "0.255.255.255", - "10.0.0.0", - "10.0.0.1", - "10.0.0.01", - "10.0.0.001", - "10.255.255.254", - "10.255.255.255", - "100.64.0.0", - "100.64.0.1", - "100.127.255.254", - "100.127.255.255", - "127.0.0.0", - "127.0.0.1", - "127.0.0.01", - "127.000.000.1", - "127.255.255.254", - "127.255.255.255", - "169.254.0.0", - "169.254.0.1", - "169.254.255.254", - "169.254.255.255", - "172.16.0.0", - "172.16.0.1", - "172.16.0.001", - "172.31.255.254", - "172.31.255.255", - "192.0.0.0", - "192.0.0.1", - "192.0.0.6", - "192.0.0.7", - "192.0.0.8", - "192.0.0.9", - "192.0.0.10", - "192.0.0.11", - "192.0.0.170", - "192.0.0.171", - "192.0.0.254", - "192.0.0.255", - "192.0.2.0", - "192.0.2.1", - "192.0.2.254", - "192.0.2.255", - "192.31.196.0", - "192.31.196.1", - "192.31.196.254", - "192.31.196.255", - "192.52.193.0", - "192.52.193.1", - "192.52.193.254", - "192.52.193.255", - "192.88.99.0", - "192.88.99.1", - "192.88.99.254", - "192.88.99.255", - "192.168.0.0", - "192.168.0.1", - "192.168.255.254", - "192.168.255.255", - "192.175.48.0", - "192.175.48.1", - "192.175.48.254", - "192.175.48.255", - "198.18.0.0", - "198.18.0.1", - "198.19.255.254", - "198.19.255.255", - "198.51.100.0", - "198.51.100.1", - "198.51.100.254", - "198.51.100.255", - "203.0.113.0", - "203.0.113.1", - "203.0.113.254", - "203.0.113.255", - "240.0.0.0", - "240.0.0.1", - "224.0.0.0", - "224.0.0.1", - "255.0.0.0", - "255.192.0.0", - "255.240.0.0", - "255.254.0.0", - "255.255.0.0", - "255.255.255.0", - "255.255.255.248", - "255.255.255.254", - "255.255.255.255", - "0000:0000:0000:0000:0000:0000:0000:0000", - "::", - "::1", - "::ffff:0.0.0.0", - "::ffff:127.0.0.1", - "fe80::", - "fe80::1", - "fe80::abc:1", - "febf:ffff:ffff:ffff:ffff:ffff:ffff:ffff", - "fc00::", - "fc00::1", - "fc00::abc:1", - "fdff:ffff:ffff:ffff:ffff:ffff:ffff:ffff", - "fd00:ec2::254", - "169.254.169.254", - "localhost", -] - -invalid_ips = [ - "100::ffff::", - "::ffff:0.0.255.255.255", - "::ffff:0.255.255.255.255", - "0000.0000", - "127.1", - "127.0.1", - "2130706433", - "0x7f000001", -] - - -def test_public_ips(): - for ip in public_ips: - if ":" in ip: - ip = f"[{ip}]" # IPv6 are enclosed in brackets - assert not contains_private_ip_address(ip), f"Expected {ip} to be public" - - -def test_private_ips(): - for ip in private_ips: - if ":" in ip: - ip = f"[{ip}]" # IPv6 are enclosed in brackets - assert contains_private_ip_address(ip), f"Expected {ip} to be private" - - -def test_invalid_ips(): - for ip in invalid_ips: - if ":" in ip: - ip = f"[{ip}]" # IPv6 are enclosed in brackets - assert not contains_private_ip_address(ip), f"Expected {ip} to be invalid" diff --git a/aikido_firewall/vulnerabilities/ssrf/find_hostname_in_context.py b/aikido_firewall/vulnerabilities/ssrf/find_hostname_in_context.py index ebf5134b..1254eda4 100644 --- a/aikido_firewall/vulnerabilities/ssrf/find_hostname_in_context.py +++ b/aikido_firewall/vulnerabilities/ssrf/find_hostname_in_context.py @@ -11,6 +11,9 @@ def find_hostname_in_context(hostname, context, port): """Tries to locate the given hostname from context""" + if not isinstance(hostname, str) or not isinstance(port, int): + # Validate hostname and port input + return None # Punycode detected in hostname, while user input may not be in Punycode # We need to convert it to ensure we compare the right values diff --git a/aikido_firewall/vulnerabilities/ssrf/find_hostname_in_context_test.py b/aikido_firewall/vulnerabilities/ssrf/find_hostname_in_context_test.py index b9f77dea..16899766 100644 --- a/aikido_firewall/vulnerabilities/ssrf/find_hostname_in_context_test.py +++ b/aikido_firewall/vulnerabilities/ssrf/find_hostname_in_context_test.py @@ -101,3 +101,53 @@ def test_find_hostname_in_context_no_sources(monkeypatch): # To run the tests, use the command: pytest .py + + +@pytest.mark.parametrize( + "invalid_input", + [ + None, + 123456789, # Integer + 45.67, # Float + [], # Empty list + [1, 2, 3], # List of integers + {}, # Empty dictionary + {"key": "value"}, # Dictionary + set(), # Empty set + {1, 2, 3}, # Set of integers + object(), # Instance of a generic object + lambda x: x, # Lambda function + (1, 2), # Tuple + b"bytes", # Bytes + ], +) +def test_doesnt_crash_with_invalid_hostname(invalid_input, monkeypatch): + context = MagicMock() # No attributes + monkeypatch.setattr("aikido_firewall.context.get_current_context", lambda: None) + result = find_hostname_in_context(invalid_input, context, 8080) + assert result == None + + +@pytest.mark.parametrize( + "invalid_input", + [ + None, + "test", # String + 45.67, # Float + [], # Empty list + [1, 2, 3], # List of integers + {}, # Empty dictionary + {"key": "value"}, # Dictionary + set(), # Empty set + {1, 2, 3}, # Set of integers + object(), # Instance of a generic object + lambda x: x, # Lambda function + (1, 2), # Tuple + b"bytes", # Bytes + ], +) +def test_doesnt_crash_with_invalid_port(invalid_input, monkeypatch): + context = MagicMock() # No attributes + monkeypatch.setattr("aikido_firewall.context.get_current_context", lambda: None) + result = find_hostname_in_context("https://example.com", context, invalid_input) + assert result == None diff --git a/aikido_firewall/vulnerabilities/ssrf/get_redirect_origin.py b/aikido_firewall/vulnerabilities/ssrf/get_redirect_origin.py index 93acad86..6f8e8d19 100644 --- a/aikido_firewall/vulnerabilities/ssrf/get_redirect_origin.py +++ b/aikido_firewall/vulnerabilities/ssrf/get_redirect_origin.py @@ -1,43 +1,23 @@ """Exports get_redirect_origin function""" import copy -from urllib.parse import urlparse, urlunparse +from aikido_firewall.helpers.get_port_from_url import get_port_from_url +from aikido_firewall.helpers.urls.normalize_url import normalize_url -def normalize_url(url): - """Normalizes the url""" - # Parse the URL - parsed_url = urlparse(url) - - # Normalize components - scheme = parsed_url.scheme.lower() # Lowercase scheme - netloc = parsed_url.netloc.lower() # Lowercase netloc - path = parsed_url.path.rstrip("/") # Remove trailing slash - query = parsed_url.query # Keep query as is - fragment = parsed_url.fragment # Keep fragment as is - - # Remove default ports (80 for http, 443 for https) - if scheme == "http" and parsed_url.port == 80: - netloc = netloc.replace(":80", "") - elif scheme == "https" and parsed_url.port == 443: - netloc = netloc.replace(":443", "") - - # We do not care about the scheme (Isn't extracted) : - scheme = "http" - - # Reconstruct the normalized URL - normalized_url = urlunparse((scheme, netloc, path, "", query, fragment)) - return normalized_url - - -def compare_urls(url1, url2): +def compare_urls(dst, src): """Compares normalized urls""" - normalized_url1 = normalize_url(url1.geturl()) - normalized_url2 = normalize_url(url2.geturl()) - return normalized_url1 == normalized_url2 + if len(src) == 2: + # Source is a hostname, port tuple. Check if it matches : + port_matches = get_port_from_url(dst, parsed=True) == src[1] + return dst.hostname == src[0] and port_matches + normalized_dst = normalize_url(dst.geturl()) + normalized_src = normalize_url(src.geturl()) + return normalized_dst == normalized_src -def get_redirect_origin(redirects, url): + +def get_redirect_origin(redirects, hostname, port): """ This function checks if the given URL is part of a redirect chain that is passed in the redirects parameter. @@ -52,8 +32,7 @@ def get_redirect_origin(redirects, url): """ if not isinstance(redirects, list): return None - - current_url = copy.deepcopy(url) + current_url = copy.deepcopy((hostname, port)) # Follow the redirect chain until we reach the origin or don't find a redirect @@ -66,4 +45,5 @@ def get_redirect_origin(redirects, url): break current_url = redirect["source"] - return current_url if not compare_urls(current_url, url) else None + current_url_changed = current_url != (hostname, port) + return current_url if current_url_changed else None diff --git a/aikido_firewall/vulnerabilities/ssrf/get_redirect_origin_test.py b/aikido_firewall/vulnerabilities/ssrf/get_redirect_origin_test.py index 9e974f1a..412dc5df 100644 --- a/aikido_firewall/vulnerabilities/ssrf/get_redirect_origin_test.py +++ b/aikido_firewall/vulnerabilities/ssrf/get_redirect_origin_test.py @@ -17,7 +17,8 @@ def test_get_redirect_origin(): "destination": create_url("https://hackers.com"), }, ], - create_url("https://hackers.com"), + "hackers.com", + 443, ) == create_url("https://example.com") assert get_redirect_origin( @@ -31,13 +32,28 @@ def test_get_redirect_origin(): "destination": create_url("https://hackers.com/test"), }, ], - create_url("https://hackers.com/test"), + "hackers.com", + 443, ) == create_url("https://example.com") def test_get_redirect_origin_no_redirects(): - assert get_redirect_origin([], create_url("https://hackers.com")) is None - assert get_redirect_origin(None, create_url("https://hackers.com")) is None + assert ( + get_redirect_origin( + [], + "hackers.com", + 443, + ) + is None + ) + assert ( + get_redirect_origin( + None, + "hackers.com", + 443, + ) + is None + ) def test_get_redirect_origin_not_a_destination(): @@ -49,7 +65,8 @@ def test_get_redirect_origin_not_a_destination(): "destination": create_url("https://hackers.com"), }, ], - create_url("https://example.com"), + "example.com", + 443, ) is None ) @@ -64,7 +81,8 @@ def test_get_redirect_origin_not_in_redirects(): "destination": create_url("https://hackers.com"), }, ], - create_url("https://example.com"), + "example.com", + 443, ) is None ) @@ -86,7 +104,8 @@ def test_get_redirect_origin_multiple_redirects(): "destination": create_url("https://another.com"), }, ], - create_url("https://hackers.com/test"), + "hackers.com", + 443, ) == create_url("https://example.com") diff --git a/aikido_firewall/vulnerabilities/ssrf/inspect_getaddrinfo_result.py b/aikido_firewall/vulnerabilities/ssrf/inspect_getaddrinfo_result.py index 604f54e2..06698288 100644 --- a/aikido_firewall/vulnerabilities/ssrf/inspect_getaddrinfo_result.py +++ b/aikido_firewall/vulnerabilities/ssrf/inspect_getaddrinfo_result.py @@ -2,18 +2,16 @@ Mainly exports inspect_getaddrinfo_result function """ -import traceback from aikido_firewall.helpers.try_parse_url import try_parse_url from aikido_firewall.context import get_current_context from aikido_firewall.helpers.logging import logger -from aikido_firewall.background_process import get_comms from aikido_firewall.errors import AikidoSSRF from aikido_firewall.helpers.blocking_enabled import is_blocking_enabled -from aikido_firewall.helpers.get_clean_stacktrace import get_clean_stacktrace from .imds import resolves_to_imds_ip from .is_private_ip import is_private_ip from .find_hostname_in_context import find_hostname_in_context from .extract_ip_array_from_results import extract_ip_array_from_results +from .is_redirect_to_private_ip import is_redirect_to_private_ip # gets called when the result of the DNS resolution has come in @@ -25,46 +23,43 @@ def inspect_getaddrinfo_result(dns_results, hostname, port): return context = get_current_context() + if not inspect_dns_results(dns_results, hostname): + return + + if not context: + return + # attack_findings is an object containing source, pathToPayload and payload. + attack_findings = find_hostname_in_context(hostname, context, port) + if not attack_findings: + # Hostname/port not found in context, checking for redirects + logger.debug("Redirects : %s", context.outgoing_req_redirects) + attack_findings = is_redirect_to_private_ip(hostname, context, port) + + if attack_findings: + return { + "module": "socket", + "operation": "socket.getaddrinfo", + "kind": "ssrf", + "source": attack_findings["source"], + "path": attack_findings["pathToPayload"], + "metadata": {"hostname": hostname}, + "payload": attack_findings["payload"], + } + + +def inspect_dns_results(dns_results, hostname): + """ + Blocks stored SSRF attack that target IMDS IP addresses and returns True + if a private_ip is present. + This function gets called by inspect_getaddrinfo_result after parsing the hostname. + """ ip_addresses = extract_ip_array_from_results(dns_results) if resolves_to_imds_ip(ip_addresses, hostname): - # Block stored SSRF attack that target IMDS IP addresses # An attacker could have stored a hostname in a database that points to an IMDS IP address # We don't check if the user input contains the hostname because there's no context if is_blocking_enabled(): raise AikidoSSRF() - if not context: - return - private_ip = next((ip for ip in ip_addresses if is_private_ip(ip)), None) - if not private_ip: - return - - found = find_hostname_in_context(hostname, context, port) - if not found: - return - - should_block = is_blocking_enabled() - stack = " ".join(traceback.format_stack()) - attack = { - "module": "socket", - "operation": "socket.getaddrinfo", - "kind": "ssrf", - "source": found["source"], - "blocked": should_block, - "stack": stack, - "path": found["pathToPayload"], - "metadata": {"hostname": hostname}, - "payload": found["payload"], - } - logger.debug("Attack results : %s", attack) - - logger.debug("Sending data to bg process :") - stack = get_clean_stacktrace() - get_comms().send_data_to_bg_process( - "ATTACK", (attack, context, should_block, stack) - ) - - if should_block: - raise AikidoSSRF() + return private_ip diff --git a/aikido_firewall/vulnerabilities/ssrf/is_redirect_to_private_ip.py b/aikido_firewall/vulnerabilities/ssrf/is_redirect_to_private_ip.py index 4252bee7..3124f47e 100644 --- a/aikido_firewall/vulnerabilities/ssrf/is_redirect_to_private_ip.py +++ b/aikido_firewall/vulnerabilities/ssrf/is_redirect_to_private_ip.py @@ -1,12 +1,11 @@ """Exports is_redirect_to_private_ip""" from aikido_firewall.helpers.get_port_from_url import get_port_from_url -from .contains_private_ip_address import contains_private_ip_address from .get_redirect_origin import get_redirect_origin from .find_hostname_in_context import find_hostname_in_context -def is_redirect_to_private_ip(url, context): +def is_redirect_to_private_ip(hostname, context, port): """ This function is called before an outgoing request is made. It's used to prevent requests to private IP addresses after a redirect with @@ -18,11 +17,13 @@ def is_redirect_to_private_ip(url, context): - The redirect origin, so the user-supplied hostname and port that caused the first redirect, is found in the context of the incoming request """ - if context.outgoing_req_redirects and contains_private_ip_address(url.hostname): - redirect_origin = get_redirect_origin(context.outgoing_req_redirects, url) + if context.outgoing_req_redirects: + redirect_origin = get_redirect_origin( + context.outgoing_req_redirects, hostname, port + ) if redirect_origin: - hostname = getattr(redirect_origin, "hostname") - port = get_port_from_url(redirect_origin.geturl()) - return find_hostname_in_context(hostname, context, port) + origin_hostname = getattr(redirect_origin, "hostname") + origin_port = get_port_from_url(redirect_origin, parsed=True) + return find_hostname_in_context(origin_hostname, context, origin_port) return None diff --git a/aikido_firewall/vulnerabilities/ssrf/is_redirect_to_private_ip_test.py b/aikido_firewall/vulnerabilities/ssrf/is_redirect_to_private_ip_test.py index cd310078..936b00e0 100644 --- a/aikido_firewall/vulnerabilities/ssrf/is_redirect_to_private_ip_test.py +++ b/aikido_firewall/vulnerabilities/ssrf/is_redirect_to_private_ip_test.py @@ -10,7 +10,6 @@ def create_url(href): def test_is_redirect_to_private_ip_success(): - url = create_url("http://192.168.0.1/") # Private IP context = MagicMock() context.outgoing_req_redirects = [ { @@ -21,7 +20,7 @@ def test_is_redirect_to_private_ip_success(): context.parsed_userinput = {} context.body = {"field": ["http://example.com"]} with patch("aikido_firewall.context.get_current_context", return_value=context): - result = is_redirect_to_private_ip(url, context) + result = is_redirect_to_private_ip("192.168.0.1", context, 80) assert result == { "pathToPayload": ".field.[0]", "payload": "http://example.com", @@ -30,16 +29,14 @@ def test_is_redirect_to_private_ip_success(): def test_is_redirect_to_private_ip_no_redirects(): - url = create_url("http://192.168.0.1/") # Private IP context = MagicMock() context.outgoing_req_redirects = [] - result = is_redirect_to_private_ip(url, context) + result = is_redirect_to_private_ip("192.168.0.1", context, 80) assert result is None def test_is_redirect_to_private_ip_not_private_ip(): - url = create_url("https://example.com/") # Not a private IP context = MagicMock() context.outgoing_req_redirects = [ { @@ -48,20 +45,11 @@ def test_is_redirect_to_private_ip_not_private_ip(): }, ] - with MagicMock() as mock_contains_private_ip_address: - mock_contains_private_ip_address.return_value = False - with pytest.MonkeyPatch.context() as mp: - mp.setattr( - "aikido_firewall.vulnerabilities.ssrf.contains_private_ip_address", - mock_contains_private_ip_address, - ) - - result = is_redirect_to_private_ip(url, context) - assert result is None + result = is_redirect_to_private_ip("example.com", context, 443) + assert result is None def test_is_redirect_to_private_ip_redirect_origin_not_found(): - url = create_url("http://192.168.0.1/") # Private IP context = MagicMock() context.outgoing_req_redirects = [ { @@ -70,27 +58,21 @@ def test_is_redirect_to_private_ip_redirect_origin_not_found(): }, ] - with MagicMock() as mock_contains_private_ip_address, MagicMock() as mock_get_redirect_origin: + with MagicMock() as mock_get_redirect_origin: - mock_contains_private_ip_address.return_value = True mock_get_redirect_origin.return_value = None with pytest.MonkeyPatch.context() as mp: - mp.setattr( - "aikido_firewall.vulnerabilities.ssrf.contains_private_ip_address", - mock_contains_private_ip_address, - ) mp.setattr( "aikido_firewall.vulnerabilities.ssrf.get_redirect_origin", mock_get_redirect_origin, ) - result = is_redirect_to_private_ip(url, context) + result = is_redirect_to_private_ip("192.168.0.1", context, 80) assert result is None def test_is_redirect_to_private_ip_hostname_not_found_in_context(): - url = create_url("http://192.168.0.1/") # Private IP context = MagicMock() context.outgoing_req_redirects = [ { @@ -99,19 +81,14 @@ def test_is_redirect_to_private_ip_hostname_not_found_in_context(): }, ] - with MagicMock() as mock_contains_private_ip_address, MagicMock() as mock_get_redirect_origin, MagicMock() as mock_find_hostname_in_context: + with MagicMock() as mock_get_redirect_origin, MagicMock() as mock_find_hostname_in_context: - mock_contains_private_ip_address.return_value = True mock_get_redirect_origin.return_value = MagicMock( hostname="example.com", port=80 ) mock_find_hostname_in_context.return_value = False with pytest.MonkeyPatch.context() as mp: - mp.setattr( - "aikido_firewall.vulnerabilities.ssrf.contains_private_ip_address", - mock_contains_private_ip_address, - ) mp.setattr( "aikido_firewall.vulnerabilities.ssrf.get_redirect_origin", mock_get_redirect_origin, @@ -121,5 +98,5 @@ def test_is_redirect_to_private_ip_hostname_not_found_in_context(): mock_find_hostname_in_context, ) - result = is_redirect_to_private_ip(url, context) + result = is_redirect_to_private_ip("192.168.0.1", context, 80) assert result is None