Skip to content

Commit

Permalink
Fix bug in SSRF redirect protection with infinte loops
Browse files Browse the repository at this point in the history
  • Loading branch information
Wout Feys committed Nov 20, 2024
1 parent d3e2c7d commit 83c070f
Show file tree
Hide file tree
Showing 2 changed files with 324 additions and 34 deletions.
83 changes: 50 additions & 33 deletions aikido_zen/vulnerabilities/ssrf/get_redirect_origin.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,66 @@
"""Exports get_redirect_origin function"""
"""Exports get_redirect_origin"""

import copy
from aikido_zen.helpers.get_port_from_url import get_port_from_url
from aikido_zen.helpers.urls.normalize_url import normalize_url


def compare_urls(dst, src):
"""Compares normalized urls"""
if len(src) == 2:
# Source is a hostname, port tuple. Check if it matches :
port_matches = get_port_from_url(dst, parsed=True) == src[1]
return dst.hostname == src[0] and port_matches

normalized_dst = normalize_url(dst.geturl())
normalized_src = normalize_url(src.geturl())
return normalized_dst == normalized_src


def get_redirect_origin(redirects, hostname, port):
"""
This function checks if the given URL is part of a redirect chain that is passed in the
redirects parameter.
This function checks if the given URL is part of a redirect chain that is passed in the redirects parameter.
It returns the origin of a redirect chain if the URL is the result of a redirect.
The origin is the first URL in the chain, so the initial URL that was requested and redirected
to the given URL or in case of multiple redirects the URL that was redirected to the given URL.
to the given URL or, in the case of multiple redirects, the URL that was redirected to the given URL.
Example:
Redirect chain: A -> B -> C: getRedirectOrigin([A -> B, B -> C], C) => A
: getRedirectOrigin([A -> B, B -> C], B) => A
: getRedirectOrigin([A -> B, B -> C], D) => undefined
Redirect chain: A -> B -> C: get_redirect_origin([A -> B, B -> C], C) => A
: get_redirect_origin([A -> B, B -> C], B) => A
: get_redirect_origin([A -> B, B -> C], D) => None
"""
if not isinstance(redirects, list):
return None
current_url = copy.deepcopy((hostname, port))

# Follow the redirect chain until we reach the origin or don't find a redirect
visited = set()
current_urls = find_url_matching_hostname_and_port(hostname, port, redirects)
for url in current_urls:
origin = find_origin(redirects, url, visited)
if origin and not compare_urls(origin, url):
# If origin exists and it's different return
return origin


def find_origin(redirects, url, visited):
"""Recursive function that traverses the redirects"""
if url is None or url.geturl() in visited:
# To avoid infinite loops in case of cyclic redirects
return url

visited.add(url.geturl())

# Find a redirect where the current URL is the destination
redirect = next((r for r in redirects if compare_urls(r["destination"], url)), None)
print("Redirects: ", redirect)
if redirect:
# Recursively find the origin starting from the source URL
return find_origin(redirects, redirect["source"], visited)

# If no redirect leads to this URL, return the URL itself
return url


def compare_urls(dst, src):
"""Compares normalized URLs."""
normalized_dst = normalize_url(dst.geturl())
normalized_src = normalize_url(src.geturl())
return normalized_dst == normalized_src

while True:
redirect = None
for r in redirects:
if compare_urls(r["destination"], current_url):
redirect = r
if not redirect:
break
current_url = redirect["source"]

current_url_changed = current_url != (hostname, port)
return current_url if current_url_changed else None
def find_url_matching_hostname_and_port(hostname, port, redirects):
"""Finds the initial url to start with (the one matching hostname and port)"""
urls = []
for redirect in redirects:
r_port = redirect["destination"].port
if r_port and r_port != port:
continue

Check warning on line 62 in aikido_zen/vulnerabilities/ssrf/get_redirect_origin.py

View check run for this annotation

Codecov / codecov/patch

aikido_zen/vulnerabilities/ssrf/get_redirect_origin.py#L62

Added line #L62 was not covered by tests
if redirect["destination"].hostname != hostname:
continue
urls.append(redirect["destination"])
return urls
275 changes: 274 additions & 1 deletion aikido_zen/vulnerabilities/ssrf/get_redirect_origin_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,277 @@ def test_get_redirect_origin_multiple_redirects():
) == create_url("https://example.com")


# To run the tests, use the command: pytest <filename>.py
def test_avoids_infinite_loops_with_unrelated_cyclic_redirects():
result = get_redirect_origin(
[
# Unrelated cyclic redirects
{
"source": create_url("https://cycle.com/a"),
"destination": create_url("https://cycle.com/b"),
},
{
"source": create_url("https://cycle.com/b"),
"destination": create_url("https://cycle.com/c"),
},
{
"source": create_url("https://cycle.com/c"),
"destination": create_url("https://cycle.com/a"),
},
# Relevant redirects
{
"source": create_url("https://start.com"),
"destination": create_url("https://middle.com"),
},
{
"source": create_url("https://middle.com"),
"destination": create_url("https://end.com"),
},
],
"end.com",
443,
)
assert result == create_url("https://start.com")


def test_handles_multiple_requests_with_overlapping_redirects():
result = get_redirect_origin(
[
# Overlapping redirects
{
"source": create_url("https://site1.com"),
"destination": create_url("https://site2.com"),
},
{
"source": create_url("https://site2.com"),
"destination": create_url("https://site3.com"),
},
{
"source": create_url("https://site3.com"),
"destination": create_url("https://site1.com"), # Cycle
},
# Relevant redirects
{
"source": create_url("https://origin.com"),
"destination": create_url("https://destination.com"),
},
],
"destination.com",
443,
)
assert result == create_url("https://origin.com")


def test_avoids_infinite_loops_when_cycles_are_part_of_the_redirect_chain():
result = get_redirect_origin(
[
{
"source": create_url("https://start.com"),
"destination": create_url("https://loop.com/a"),
},
{
"source": create_url("https://loop.com/a"),
"destination": create_url("https://loop.com/b"),
},
{
"source": create_url("https://loop.com/b"),
"destination": create_url("https://loop.com/c"),
},
{
"source": create_url("https://loop.com/c"),
"destination": create_url("https://loop.com/a"), # Cycle here
},
],
"loop.com",
443,
)
assert result == create_url("https://start.com")


def test_redirects_with_query_parameters():
result = get_redirect_origin(
[
{
"source": create_url("https://example.com"),
"destination": create_url("https://example.com?param=value"),
},
],
"example.com",
443,
)
assert result == create_url("https://example.com")


def test_redirects_with_fragment_identifiers():
result = get_redirect_origin(
[
{
"source": create_url("https://example.com"),
"destination": create_url("https://example.com#section"),
},
],
"example.com",
443,
)
assert result == create_url("https://example.com")


def test_redirects_with_different_protocols():
result = get_redirect_origin(
[
{
"source": create_url("http://example.com"),
"destination": create_url("https://example.com"),
},
],
"example.com",
443,
)
assert result == create_url("http://example.com")


def test_redirects_with_different_ports():
result = get_redirect_origin(
[
{
"source": create_url("https://example.com"),
"destination": create_url("https://example.com:8080"),
},
],
"example.com",
8080,
)
assert result == create_url("https://example.com")


def test_redirects_with_paths():
result = get_redirect_origin(
[
{
"source": create_url("https://example.com"),
"destination": create_url("https://example.com/home"),
},
{
"source": create_url("https://example.com/home"),
"destination": create_url("https://example.com/home/welcome"),
},
],
"example.com",
443,
)
assert result == create_url("https://example.com")


def test_multiple_redirects_to_same_destination():
result = get_redirect_origin(
[
{
"source": create_url("https://a.com"),
"destination": create_url("https://d.com"),
},
{
"source": create_url("https://b.com"),
"destination": create_url("https://d.com"),
},
{
"source": create_url("https://c.com"),
"destination": create_url("https://d.com"),
},
],
"d.com",
443,
)
assert result == create_url("https://a.com")


def test_multiple_redirect_paths_to_same_url():
result = get_redirect_origin(
[
{
"source": create_url("https://x.com"),
"destination": create_url("https://y.com"),
},
{
"source": create_url("https://y.com"),
"destination": create_url("https://z.com"),
},
{
"source": create_url("https://a.com"),
"destination": create_url("https://b.com"),
},
{
"source": create_url("https://b.com"),
"destination": create_url("https://z.com"),
},
],
"z.com",
443,
)
assert result == create_url("https://x.com")


def test_returns_undefined_when_source_and_destination_are_same_url():
result = get_redirect_origin(
[
{
"source": create_url("https://example.com"),
"destination": create_url("https://example.com"),
},
],
"example.com",
443,
)
assert result is None


def test_handles_very_long_redirect_chains():
redirects = []
for i in range(100):
redirects.append(
{
"source": create_url(f"https://example.com/{i}"),
"destination": create_url(f"https://example.com/{i + 1}"),
}
)

result = get_redirect_origin(redirects, "example.com", 443)
assert result == create_url("https://example.com/0")


def test_handles_redirects_with_cycles_longer_than_one_redirect():
result = get_redirect_origin(
[
{
"source": create_url("https://a.com"),
"destination": create_url("https://b.com"),
},
{
"source": create_url("https://b.com"),
"destination": create_url("https://c.com"),
},
{
"source": create_url("https://c.com"),
"destination": create_url("https://a.com"),
},
],
"a.com",
443,
)
assert result is None


def test_handles_redirects_with_different_query_parameters():
result = get_redirect_origin(
[
{
"source": create_url("https://example.com"),
"destination": create_url("https://example.com?param=1"),
},
{
"source": create_url("https://example.com?param=1"),
"destination": create_url("https://example.com?param=2"),
},
],
"example.com",
443,
)
assert result == create_url("https://example.com")

0 comments on commit 83c070f

Please sign in to comment.