-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix bug in SSRF redirect protection with infinte loops
- Loading branch information
Wout Feys
committed
Nov 20, 2024
1 parent
d3e2c7d
commit 83c070f
Showing
2 changed files
with
324 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,49 +1,66 @@ | ||
"""Exports get_redirect_origin function""" | ||
"""Exports get_redirect_origin""" | ||
|
||
import copy | ||
from aikido_zen.helpers.get_port_from_url import get_port_from_url | ||
from aikido_zen.helpers.urls.normalize_url import normalize_url | ||
|
||
|
||
def compare_urls(dst, src): | ||
"""Compares normalized urls""" | ||
if len(src) == 2: | ||
# Source is a hostname, port tuple. Check if it matches : | ||
port_matches = get_port_from_url(dst, parsed=True) == src[1] | ||
return dst.hostname == src[0] and port_matches | ||
|
||
normalized_dst = normalize_url(dst.geturl()) | ||
normalized_src = normalize_url(src.geturl()) | ||
return normalized_dst == normalized_src | ||
|
||
|
||
def get_redirect_origin(redirects, hostname, port): | ||
""" | ||
This function checks if the given URL is part of a redirect chain that is passed in the | ||
redirects parameter. | ||
This function checks if the given URL is part of a redirect chain that is passed in the redirects parameter. | ||
It returns the origin of a redirect chain if the URL is the result of a redirect. | ||
The origin is the first URL in the chain, so the initial URL that was requested and redirected | ||
to the given URL or in case of multiple redirects the URL that was redirected to the given URL. | ||
to the given URL or, in the case of multiple redirects, the URL that was redirected to the given URL. | ||
Example: | ||
Redirect chain: A -> B -> C: getRedirectOrigin([A -> B, B -> C], C) => A | ||
: getRedirectOrigin([A -> B, B -> C], B) => A | ||
: getRedirectOrigin([A -> B, B -> C], D) => undefined | ||
Redirect chain: A -> B -> C: get_redirect_origin([A -> B, B -> C], C) => A | ||
: get_redirect_origin([A -> B, B -> C], B) => A | ||
: get_redirect_origin([A -> B, B -> C], D) => None | ||
""" | ||
if not isinstance(redirects, list): | ||
return None | ||
current_url = copy.deepcopy((hostname, port)) | ||
|
||
# Follow the redirect chain until we reach the origin or don't find a redirect | ||
visited = set() | ||
current_urls = find_url_matching_hostname_and_port(hostname, port, redirects) | ||
for url in current_urls: | ||
origin = find_origin(redirects, url, visited) | ||
if origin and not compare_urls(origin, url): | ||
# If origin exists and it's different return | ||
return origin | ||
|
||
|
||
def find_origin(redirects, url, visited): | ||
"""Recursive function that traverses the redirects""" | ||
if url is None or url.geturl() in visited: | ||
# To avoid infinite loops in case of cyclic redirects | ||
return url | ||
|
||
visited.add(url.geturl()) | ||
|
||
# Find a redirect where the current URL is the destination | ||
redirect = next((r for r in redirects if compare_urls(r["destination"], url)), None) | ||
print("Redirects: ", redirect) | ||
if redirect: | ||
# Recursively find the origin starting from the source URL | ||
return find_origin(redirects, redirect["source"], visited) | ||
|
||
# If no redirect leads to this URL, return the URL itself | ||
return url | ||
|
||
|
||
def compare_urls(dst, src): | ||
"""Compares normalized URLs.""" | ||
normalized_dst = normalize_url(dst.geturl()) | ||
normalized_src = normalize_url(src.geturl()) | ||
return normalized_dst == normalized_src | ||
|
||
while True: | ||
redirect = None | ||
for r in redirects: | ||
if compare_urls(r["destination"], current_url): | ||
redirect = r | ||
if not redirect: | ||
break | ||
current_url = redirect["source"] | ||
|
||
current_url_changed = current_url != (hostname, port) | ||
return current_url if current_url_changed else None | ||
def find_url_matching_hostname_and_port(hostname, port, redirects): | ||
"""Finds the initial url to start with (the one matching hostname and port)""" | ||
urls = [] | ||
for redirect in redirects: | ||
r_port = redirect["destination"].port | ||
if r_port and r_port != port: | ||
continue | ||
if redirect["destination"].hostname != hostname: | ||
continue | ||
urls.append(redirect["destination"]) | ||
return urls |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters