Skip to content

Commit

Permalink
Legibility refactor (more higher order functions instead of dataclasses)
Browse files Browse the repository at this point in the history
  • Loading branch information
alukach committed Dec 13, 2024
1 parent a34c370 commit a1db2ba
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 85 deletions.
8 changes: 4 additions & 4 deletions src/stac_auth_proxy/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from .auth import OpenIdConnectAuth
from .config import Settings
from .handlers import OpenApiSpecHandler, ReverseProxyHandler
from .handlers import ReverseProxyHandler, build_openapi_spec_handler
from .middleware import AddProcessTimeHeaderMiddleware

# from .utils import apply_filter
Expand Down Expand Up @@ -55,7 +55,7 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI:
collections_filter=collections_filter,
items_filter=items_filter,
)
openapi_handler = OpenApiSpecHandler(
openapi_handler = build_openapi_spec_handler(
proxy=proxy_handler,
oidc_config_url=str(settings.oidc_discovery_url),
)
Expand All @@ -67,7 +67,7 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI:
(
proxy_handler.stream
if path != settings.openapi_spec_endpoint
else openapi_handler.dispatch
else openapi_handler
),
methods=methods,
dependencies=[Security(auth_scheme.validated_user)],
Expand All @@ -80,7 +80,7 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI:
(
proxy_handler.stream
if path != settings.openapi_spec_endpoint
else openapi_handler.dispatch
else openapi_handler
),
methods=methods,
dependencies=[Security(auth_scheme.maybe_validated_user)],
Expand Down
132 changes: 70 additions & 62 deletions src/stac_auth_proxy/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import urllib.request
from dataclasses import dataclass, field
from typing import Annotated, Any, Callable, Optional, Sequence
from typing import Annotated, Optional, Sequence

import jwt
from fastapi import HTTPException, Security, security, status
Expand All @@ -25,8 +25,6 @@ class OpenIdConnectAuth:
# Generated attributes
auth_scheme: SecurityBase = field(init=False)
jwks_client: jwt.PyJWKClient = field(init=False)
validated_user: Callable[..., Any] = field(init=False)
maybe_validated_user: Callable[..., Any] = field(init=False)

def __post_init__(self):
"""Initialize the OIDC authentication class."""
Expand All @@ -50,70 +48,80 @@ def __post_init__(self):
openIdConnectUrl=str(self.openid_configuration_url),
auto_error=False,
)
self.validated_user = self._build(auto_error=True)
self.maybe_validated_user = self._build(auto_error=False)

def _build(self, auto_error: bool = True):
"""Build a dependency for validating an OIDC token."""

def valid_token_dependency(
auth_header: Annotated[str, Security(self.auth_scheme)],
required_scopes: security.SecurityScopes,
):
"""Dependency to validate an OIDC token."""
if not auth_header:

# Update annotations to support FastAPI's dependency injection
for endpoint in [self.validated_user, self.maybe_validated_user]:
endpoint.__annotations__["auth_header"] = Annotated[
str,
Security(self.auth_scheme),
]

def maybe_validated_user(
self,
auth_header: Annotated[str, Security(...)],
required_scopes: security.SecurityScopes,
):
"""Dependency to validate an OIDC token."""
return self.validated_user(auth_header, required_scopes, auto_error=False)

def validated_user(
self,
auth_header: Annotated[str, Security(...)],
required_scopes: security.SecurityScopes,
auto_error: bool = True,
):
"""Dependency to validate an OIDC token."""
if not auth_header:
if auto_error:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not authenticated",
)
return None

# Extract token from header
token_parts = auth_header.split(" ")
if len(token_parts) != 2 or token_parts[0].lower() != "bearer":
logger.error(f"Invalid token: {auth_header}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
[_, token] = token_parts

# Parse & validate token
try:
key = self.jwks_client.get_signing_key_from_jwt(token).key
payload = jwt.decode(
token,
key,
algorithms=["RS256"],
# NOTE: Audience validation MUST match audience claim if set in token (https://pyjwt.readthedocs.io/en/stable/changelog.html?highlight=audience#id40)
audience=self.allowed_jwt_audiences,
)
except (jwt.exceptions.InvalidTokenError, jwt.exceptions.DecodeError) as e:
logger.exception(f"InvalidTokenError: {e=}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
) from e

# Validate scopes (if required)
for scope in required_scopes.scopes:
if scope not in payload["scope"]:
if auto_error:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Not authenticated",
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Not enough permissions",
headers={
"WWW-Authenticate": f'Bearer scope="{required_scopes.scope_str}"'
},
)
return None

# Extract token from header
token_parts = auth_header.split(" ")
if len(token_parts) != 2 or token_parts[0].lower() != "bearer":
logger.error(f"Invalid token: {auth_header}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
[_, token] = token_parts

# Parse & validate token
try:
key = self.jwks_client.get_signing_key_from_jwt(token).key
payload = jwt.decode(
token,
key,
algorithms=["RS256"],
# NOTE: Audience validation MUST match audience claim if set in token (https://pyjwt.readthedocs.io/en/stable/changelog.html?highlight=audience#id40)
audience=self.allowed_jwt_audiences,
)
except (jwt.exceptions.InvalidTokenError, jwt.exceptions.DecodeError) as e:
logger.exception(f"InvalidTokenError: {e=}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
) from e

# Validate scopes (if required)
for scope in required_scopes.scopes:
if scope not in payload["scope"]:
if auto_error:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Not enough permissions",
headers={
"WWW-Authenticate": f'Bearer scope="{required_scopes.scope_str}"'
},
)
return None

return payload

return valid_token_dependency
return payload


class OidcFetchError(Exception):
Expand Down
2 changes: 1 addition & 1 deletion src/stac_auth_proxy/filters/template.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Generate CQL2 filter expressions via Jinja2 templating."""

from typing import Any, Annotated, Callable
from typing import Annotated, Any, Callable

from cql2 import Expr
from fastapi import Request, Security
Expand Down
4 changes: 2 additions & 2 deletions src/stac_auth_proxy/handlers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Handlers to process requests."""

from .open_api_spec import OpenApiSpecHandler
from .open_api_spec import build_openapi_spec_handler
from .reverse_proxy import ReverseProxyHandler

__all__ = ["OpenApiSpecHandler", "ReverseProxyHandler"]
__all__ = ["build_openapi_spec_handler", "ReverseProxyHandler"]
28 changes: 12 additions & 16 deletions src/stac_auth_proxy/handlers/open_api_spec.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Custom request handlers."""

import logging
from dataclasses import dataclass

from fastapi import Request, Response
from fastapi.routing import APIRoute
Expand All @@ -12,17 +11,14 @@
logger = logging.getLogger(__name__)


@dataclass
class OpenApiSpecHandler:
"""Handler for OpenAPI spec requests."""

proxy: ReverseProxyHandler
oidc_config_url: str
auth_scheme_name: str = "oidcAuth"

async def dispatch(self, req: Request, res: Response):
def build_openapi_spec_handler(
proxy: ReverseProxyHandler,
oidc_config_url: str,
auth_scheme_name: str = "oidcAuth",
):
async def dispatch(req: Request, res: Response):
"""Proxy the OpenAPI spec from the upstream STAC API, updating it with OIDC security requirements."""
oidc_spec_response = await self.proxy.proxy_request(req)
oidc_spec_response = await proxy.proxy_request(req)
openapi_spec = oidc_spec_response.json()

# Pass along the response headers
Expand All @@ -45,10 +41,10 @@ async def dispatch(self, req: Request, res: Response):

# Add the OIDC security scheme to the components
openapi_spec.setdefault("components", {}).setdefault("securitySchemes", {})[
self.auth_scheme_name
auth_scheme_name
] = {
"type": "openIdConnect",
"openIdConnectUrl": self.oidc_config_url,
"openIdConnectUrl": oidc_config_url,
}

# Update the paths with the specified security requirements
Expand All @@ -61,9 +57,9 @@ async def dispatch(self, req: Request, res: Response):
if match.name != "FULL":
continue
# Add the OIDC security requirement
config.setdefault("security", []).append(
{self.auth_scheme_name: []}
)
config.setdefault("security", []).append({auth_scheme_name: []})
break

return openapi_spec

return dispatch

0 comments on commit a1db2ba

Please sign in to comment.