Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix (backend): Patching the SSRF vulnerability in Github/Web Search/Request related blocks #8531

Merged
merged 18 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions autogpt_platform/backend/backend/blocks/github/issues.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from urllib.parse import urlparse

import requests
from typing_extensions import TypedDict

Expand All @@ -13,6 +15,10 @@
)


def is_github_url(url: str) -> bool:
return urlparse(url).netloc == "github.com"


# --8<-- [start:GithubCommentBlockExample]
class GithubCommentBlock(Block):
class Input(BlockSchema):
Expand Down Expand Up @@ -62,6 +68,10 @@ def __init__(self):
def post_comment(
credentials: GithubCredentials, issue_url: str, body_text: str
) -> tuple[int, str]:

if is_github_url(issue_url) is False:
raise ValueError("The input URL must be a valid GitHub URL.")

majdyz marked this conversation as resolved.
Show resolved Hide resolved
if "/pull/" in issue_url:
api_url = (
issue_url.replace("github.com", "api.github.com/repos").replace(
Expand Down Expand Up @@ -156,6 +166,9 @@ def __init__(self):
def create_issue(
credentials: GithubCredentials, repo_url: str, title: str, body: str
) -> tuple[int, str]:
if is_github_url(repo_url) is False:
raise ValueError("The input URL must be a valid GitHub URL.")

api_url = repo_url.replace("github.com", "api.github.com/repos") + "/issues"
headers = {
"Authorization": credentials.bearer(),
Expand Down Expand Up @@ -232,6 +245,10 @@ def __init__(self):
def read_issue(
credentials: GithubCredentials, issue_url: str
) -> tuple[str, str, str]:

if not is_github_url(issue_url):
raise ValueError("The input URL must be a valid GitHub URL.")

api_url = issue_url.replace("github.com", "api.github.com/repos")

headers = {
Expand Down Expand Up @@ -318,6 +335,10 @@ def __init__(self):
def list_issues(
credentials: GithubCredentials, repo_url: str
) -> list[Output.IssueItem]:

if not is_github_url(repo_url):
raise ValueError("The input URL must be a valid GitHub URL.")

api_url = repo_url.replace("github.com", "api.github.com/repos") + "/issues"
headers = {
"Authorization": credentials.bearer(),
Expand Down Expand Up @@ -385,6 +406,10 @@ def __init__(self):

@staticmethod
def add_label(credentials: GithubCredentials, issue_url: str, label: str) -> str:

if is_github_url(issue_url) is False:
raise ValueError("The input URL must be a valid GitHub URL.")

# Convert the provided GitHub URL to the API URL
if "/pull/" in issue_url:
api_url = (
Expand Down Expand Up @@ -463,6 +488,9 @@ def __init__(self):

@staticmethod
def remove_label(credentials: GithubCredentials, issue_url: str, label: str) -> str:
if is_github_url(issue_url) is False:
raise ValueError("The input URL must be a valid GitHub URL.")

# Convert the provided GitHub URL to the API URL
if "/pull/" in issue_url:
api_url = (
Expand Down Expand Up @@ -550,6 +578,9 @@ def assign_issue(
issue_url: str,
assignee: str,
) -> str:
if is_github_url(issue_url) is False:
raise ValueError("The input URL must be a valid GitHub URL.")

# Extracting repo path and issue number from the issue URL
repo_path, issue_number = issue_url.replace("https://github.com/", "").split(
"/issues/"
Expand Down Expand Up @@ -629,6 +660,9 @@ def unassign_issue(
issue_url: str,
assignee: str,
) -> str:
if is_github_url(issue_url) is False:
raise ValueError("The input URL must be a valid GitHub URL.")

# Extracting repo path and issue number from the issue URL
repo_path, issue_number = issue_url.replace("https://github.com/", "").split(
"/issues/"
Expand Down
23 changes: 23 additions & 0 deletions autogpt_platform/backend/backend/blocks/github/pull_requests.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from urllib.parse import urlparse

import requests
from typing_extensions import TypedDict

Expand All @@ -13,6 +15,10 @@
)


def is_github_url(url: str) -> bool:
return urlparse(url).netloc == "github.com"


class GithubListPullRequestsBlock(Block):
class Input(BlockSchema):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
Expand Down Expand Up @@ -64,6 +70,9 @@ def __init__(self):

@staticmethod
def list_prs(credentials: GithubCredentials, repo_url: str) -> list[Output.PRItem]:
if is_github_url(repo_url) is False:
raise ValueError("The input URL must be a valid GitHub URL.")

api_url = repo_url.replace("github.com", "api.github.com/repos") + "/pulls"
headers = {
"Authorization": credentials.bearer(),
Expand Down Expand Up @@ -162,6 +171,8 @@ def create_pr(
head: str,
base: str,
) -> tuple[int, str]:
if is_github_url(repo_url) is False:
raise ValueError("The input URL must be a valid GitHub URL.")
repo_path = repo_url.replace("https://github.com/", "")
api_url = f"https://api.github.com/repos/{repo_path}/pulls"
headers = {
Expand Down Expand Up @@ -255,6 +266,9 @@ def __init__(self):

@staticmethod
def read_pr(credentials: GithubCredentials, pr_url: str) -> tuple[str, str, str]:
if is_github_url(pr_url) is False:
raise ValueError("The input URL must be a valid GitHub URL.")

api_url = pr_url.replace("github.com", "api.github.com/repos").replace(
"/pull/", "/issues/"
)
Expand Down Expand Up @@ -367,6 +381,9 @@ def __init__(self):
def assign_reviewer(
credentials: GithubCredentials, pr_url: str, reviewer: str
) -> str:
if is_github_url(pr_url) is False:
raise ValueError("The input URL must be a valid GitHub URL.")

# Convert the PR URL to the appropriate API endpoint
api_url = (
pr_url.replace("github.com", "api.github.com/repos").replace(
Expand Down Expand Up @@ -456,6 +473,9 @@ def __init__(self):
def unassign_reviewer(
credentials: GithubCredentials, pr_url: str, reviewer: str
) -> str:
if is_github_url(pr_url) is False:
raise ValueError("The input URL must be a valid GitHub URL.")

api_url = (
pr_url.replace("github.com", "api.github.com/repos").replace(
"/pull/", "/pulls/"
Expand Down Expand Up @@ -544,6 +564,9 @@ def __init__(self):
def list_reviewers(
credentials: GithubCredentials, pr_url: str
) -> list[Output.ReviewerItem]:
if is_github_url(pr_url) is False:
raise ValueError("The input URL must be a valid GitHub URL.")

api_url = (
pr_url.replace("github.com", "api.github.com/repos").replace(
"/pull/", "/pulls/"
Expand Down
23 changes: 23 additions & 0 deletions autogpt_platform/backend/backend/blocks/github/repo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import base64
from urllib.parse import urlparse

import requests
from typing_extensions import TypedDict
Expand All @@ -15,6 +16,10 @@
)


def is_github_url(url: str) -> bool:
return urlparse(url).netloc == "github.com"


class GithubListTagsBlock(Block):
class Input(BlockSchema):
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
Expand Down Expand Up @@ -68,6 +73,9 @@ def __init__(self):
def list_tags(
credentials: GithubCredentials, repo_url: str
) -> list[Output.TagItem]:
if is_github_url(repo_url) is False:
raise ValueError("The input URL must be a valid GitHub URL.")

repo_path = repo_url.replace("https://github.com/", "")
api_url = f"https://api.github.com/repos/{repo_path}/tags"
headers = {
Expand Down Expand Up @@ -157,6 +165,9 @@ def __init__(self):
def list_branches(
credentials: GithubCredentials, repo_url: str
) -> list[Output.BranchItem]:
if is_github_url(repo_url) is False:
raise ValueError("The input URL must be a valid GitHub URL.")

api_url = repo_url.replace("github.com", "api.github.com/repos") + "/branches"
headers = {
"Authorization": credentials.bearer(),
Expand Down Expand Up @@ -246,6 +257,8 @@ def __init__(self):
def list_discussions(
credentials: GithubCredentials, repo_url: str, num_discussions: int
) -> list[Output.DiscussionItem]:
if is_github_url(repo_url) is False:
raise ValueError("The input URL must be a valid GitHub URL.")
repo_path = repo_url.replace("https://github.com/", "")
owner, repo = repo_path.split("/")
query = """
Expand Down Expand Up @@ -348,6 +361,8 @@ def __init__(self):
def list_releases(
credentials: GithubCredentials, repo_url: str
) -> list[Output.ReleaseItem]:
if is_github_url(repo_url) is False:
raise ValueError("The input URL must be a valid GitHub URL.")
repo_path = repo_url.replace("https://github.com/", "")
api_url = f"https://api.github.com/repos/{repo_path}/releases"
headers = {
Expand Down Expand Up @@ -432,6 +447,8 @@ def __init__(self):
def read_file(
credentials: GithubCredentials, repo_url: str, file_path: str, branch: str
) -> tuple[str, int]:
if is_github_url(repo_url) is False:
raise ValueError("The input URL must be a valid GitHub URL.")
repo_path = repo_url.replace("https://github.com/", "")
api_url = f"https://api.github.com/repos/{repo_path}/contents/{file_path}?ref={branch}"
headers = {
Expand Down Expand Up @@ -549,6 +566,8 @@ def __init__(self):
def read_folder(
credentials: GithubCredentials, repo_url: str, folder_path: str, branch: str
) -> tuple[list[Output.FileEntry], list[Output.DirEntry]]:
if is_github_url(repo_url) is False:
raise ValueError("The input URL must be a valid GitHub URL.")
repo_path = repo_url.replace("https://github.com/", "")
api_url = f"https://api.github.com/repos/{repo_path}/contents/{folder_path}?ref={branch}"
headers = {
Expand Down Expand Up @@ -656,6 +675,8 @@ def create_branch(
new_branch: str,
source_branch: str,
) -> str:
if is_github_url(repo_url) is False:
raise ValueError("The input URL must be a valid GitHub URL.")
repo_path = repo_url.replace("https://github.com/", "")
ref_api_url = (
f"https://api.github.com/repos/{repo_path}/git/refs/heads/{source_branch}"
Expand Down Expand Up @@ -735,6 +756,8 @@ def __init__(self):
def delete_branch(
credentials: GithubCredentials, repo_url: str, branch: str
) -> str:
if is_github_url(repo_url) is False:
raise ValueError("The input URL must be a valid GitHub URL.")
repo_path = repo_url.replace("https://github.com/", "")
api_url = f"https://api.github.com/repos/{repo_path}/git/refs/heads/{branch}"
headers = {
Expand Down
38 changes: 37 additions & 1 deletion autogpt_platform/backend/backend/blocks/http.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import ipaddress
import json
import socket
from enum import Enum
from urllib.parse import urlparse

import requests

from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import SchemaField
from backend.util.settings import Config


class HttpMethod(Enum):
Expand All @@ -17,6 +21,35 @@ class HttpMethod(Enum):
HEAD = "HEAD"


def validate_url(url: str) -> str:
"""
To avoid SSRF attacks, the URL should not be a private IP address
unless it is whitelisted in TRUST_ENDPOINTS_FOR_REQUESTS config.
"""
if any(url.startswith(origin) for origin in Config().trust_endpoints_for_requests):
return url

parsed_url = urlparse(url)
hostname = parsed_url.hostname

if not hostname:
raise ValueError(f"Invalid URL: Unable to determine hostname from {url}")

try:
host = socket.gethostbyname_ex(hostname)
for ip in host[2]:
ip_addr = ipaddress.ip_address(ip)
if ip_addr.is_global:
return url
raise ValueError(
f"Access to private or untrusted IP address at {hostname} is not allowed."
)
except ValueError:
raise
except Exception as e:
raise ValueError(f"Invalid or unresolvable URL: {url}") from e


class SendWebRequestBlock(Block):
class Input(BlockSchema):
url: str = SchemaField(
Expand Down Expand Up @@ -54,11 +87,14 @@ def run(self, input_data: Input, **kwargs) -> BlockOutput:
if isinstance(input_data.body, str):
input_data.body = json.loads(input_data.body)

validated_url = validate_url(input_data.url)

response = requests.request(
input_data.method.value,
input_data.url,
validated_url,
headers=input_data.headers,
json=input_data.body,
allow_redirects=False,
)
if response.status_code // 100 == 2:
yield "response", response.json()
Expand Down
40 changes: 38 additions & 2 deletions autogpt_platform/backend/backend/blocks/search.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,57 @@
import ipaddress
import socket
from typing import Any, Literal
from urllib.parse import quote
from urllib.parse import quote, urlparse

import requests
from autogpt_libs.supabase_integration_credentials_store.types import APIKeyCredentials
from pydantic import SecretStr

from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
from backend.data.model import CredentialsField, CredentialsMetaInput, SchemaField
from backend.util.settings import Config


class GetRequest:
@classmethod
def get_request(cls, url: str, json=False) -> Any:
response = requests.get(url)
validated_url = cls().validate_url(url)

response = requests.get(validated_url, allow_redirects=False)
response.raise_for_status()
return response.json() if json else response.text

@classmethod
def validate_url(self, url: str) -> str:
"""
To avoid SSRF attacks, the URL should not be a private IP address
unless it is whitelisted in TRUST_ENDPOINTS_FOR_REQUESTS config.
"""
if any(
url.startswith(origin) for origin in Config().trust_endpoints_for_requests
):
return url

parsed_url = urlparse(url)
hostname = parsed_url.hostname

if not hostname:
raise ValueError(f"Invalid URL: Unable to determine hostname from {url}")

try:
host = socket.gethostbyname_ex(hostname)
for ip in host[2]:
ip_addr = ipaddress.ip_address(ip)
if ip_addr.is_global:
return url
raise ValueError(
f"Access to private or untrusted IP address at {hostname} is not allowed."
)
except ValueError:
raise
except Exception as e:
raise ValueError(f"Invalid or unresolvable URL: {url}") from e


class GetWikipediaSummaryBlock(Block, GetRequest):
class Input(BlockSchema):
Expand Down
Loading
Loading