Skip to content

Commit

Permalink
Add support for a more granular CORSConfig object.
Browse files Browse the repository at this point in the history
Allows specification of specific access control headers. Docs still
need to be written.
  • Loading branch information
jcarlyl committed Apr 27, 2017
1 parent 5bb2239 commit 1133702
Show file tree
Hide file tree
Showing 9 changed files with 156 additions and 15 deletions.
2 changes: 1 addition & 1 deletion chalice/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from chalice.app import Chalice
from chalice.app import (
ChaliceViewError, BadRequestError, UnauthorizedError, ForbiddenError,
NotFoundError, ConflictError, TooManyRequestsError, Response
NotFoundError, ConflictError, TooManyRequestsError, Response, CORSConfig
)

__version__ = '0.8.0'
60 changes: 56 additions & 4 deletions chalice/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,54 @@ def __repr__(self):
return 'CaseInsensitiveMapping(%s)' % repr(self._dict)


class CORSConfig(object):
"""A cors configuration to attach to a route."""

_REQUIRED_HEADERS = ['Content-Type', 'X-Amz-Date', 'Authorization',
'X-Api-Key', 'X-Amz-Security-Token']

def __init__(self, allow_origin='*', allow_headers=None,
expose_headers=None, max_age=None, allow_credentials=None):
self.allow_origin = allow_origin

if allow_headers is None:
allow_headers = self._REQUIRED_HEADERS.copy()
else:
allow_headers.extend(self._REQUIRED_HEADERS)
self._allow_headers = allow_headers

if expose_headers is None:
expose_headers = []
self._expose_headers = expose_headers

self._max_age = max_age
self._allow_credentials = allow_credentials

@property
def allow_headers(self):
return ','.join(self._allow_headers)

def get_access_control_headers(self):
headers = {
'Access-Control-Allow-Origin': self.allow_origin,
'Access-Control-Allow-Headers': ','.join(self._allow_headers),
}
if self._expose_headers:
headers.update({
'Access-Control-Expose-Headers': ','.join(self._expose_headers)
})
if self._max_age is not None:
headers.update({
'Access-Control-Max-Age': str(self._max_age)
})
if self._allow_credentials is not None:
headers.update({
'Access-Control-Allow-Credentials': self._allow_credentials
})

return headers


class Request(object):
"""The current request from API gateway."""

Expand Down Expand Up @@ -167,6 +215,10 @@ def __init__(self, view_function, view_name, path, methods,
#: e.g, '/foo/{bar}/{baz}/qux -> ['bar', 'baz']
self.view_args = self._parse_view_args()
self.content_types = content_types
if cors is True:
cors = CORSConfig()
elif cors is False:
cors = None
self.cors = cors

def _parse_view_args(self):
Expand Down Expand Up @@ -309,7 +361,7 @@ def __call__(self, event, context):
response = self._get_view_function_response(view_function,
function_args)
if self._cors_enabled_for_route(route_entry):
self._add_cors_headers(response)
self._add_cors_headers(response, route_entry.cors)
return response.to_dict()

def _get_view_function_response(self, view_function, function_args):
Expand Down Expand Up @@ -354,7 +406,7 @@ def _validate_response(self, response):
(header, value))

def _cors_enabled_for_route(self, route_entry):
return route_entry.cors
return route_entry.cors is not None

def _add_cors_headers(self, response):
response.headers['Access-Control-Allow-Origin'] = '*'
def _add_cors_headers(self, response, cors):
response.headers.update(cors.get_access_control_headers())
10 changes: 7 additions & 3 deletions chalice/app.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Any, Callable
from typing import Dict, List, Any, Callable, Union

class ChaliceError(Exception): ...
class ChaliceViewError(ChaliceError):
Expand All @@ -14,6 +14,10 @@ class TooManyRequestsError(ChaliceViewError): ...

ALL_ERRORS = ... # type: List[ChaliceViewError]

class CORSConfig:
allow_origin = ... # type: str
allow_headers = ... # type: str

class Request:
query_params = ... # type: Dict[str, str]
headers = ... # type: Dict[str, str]
Expand Down Expand Up @@ -60,14 +64,14 @@ class RouteEntry(object):
api_key_required = ... # type: bool
content_types = ... # type: List[str]
view_args = ... # type: List[str]
cors = ... # type: bool
cors = ... # type: CORSConfig

def __init__(self, view_function: Callable[..., Any],
view_name: str, path: str, methods: List[str],
authorizer_name: str=None,
api_key_required: bool=None,
content_types: List[str]=None,
cors: bool=False) -> None: ...
cors: Union[bool, CORSConfig]=False) -> None: ...

def _parse_view_args(self) -> List[str]: ...

Expand Down
9 changes: 5 additions & 4 deletions chalice/deploy/swagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _add_route_paths(self, api, app):
self._add_to_security_definition(
current['security'], api, app.authorizers)
swagger_for_path[http_method.lower()] = current
if view.cors:
if view.cors is not None:
self._add_preflight_request(view, swagger_for_path)

def _add_to_security_definition(self, security, api_config, authorizers):
Expand Down Expand Up @@ -141,15 +141,16 @@ def _add_view_args(self, apig_integ, view_args):

def _add_preflight_request(self, view, swagger_for_path):
# type: (RouteEntry, Dict[str, Any]) -> None
cors = view.cors
methods = view.methods + ['OPTIONS']
allowed_methods = ','.join(methods)
response_params = {
"method.response.header.Access-Control-Allow-Methods": (
"'%s'" % allowed_methods),
"method.response.header.Access-Control-Allow-Origin": (
"'%s'" % cors.allow_origin),
"method.response.header.Access-Control-Allow-Headers": (
"'Content-Type,X-Amz-Date,Authorization,X-Api-Key,"
"X-Amz-Security-Token'"),
"method.response.header.Access-Control-Allow-Origin": "'*'"
"'%s'" % cors.allow_headers)
}

options_request = {
Expand Down
6 changes: 3 additions & 3 deletions chalice/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from six.moves.BaseHTTPServer import BaseHTTPRequestHandler


from chalice.app import Chalice # noqa
from typing import List, Any, Dict, Tuple, Callable # noqa
from chalice.app import Chalice, CORSConfig # noqa
from typing import List, Any, Dict, Tuple, Callable, Union # noqa

try:
from urllib.parse import urlparse, parse_qs
Expand Down Expand Up @@ -161,7 +161,7 @@ def do_OPTIONS(self):
self._send_autogen_options_response()

def _cors_enabled_for_route(self, lambda_event):
# type: (EventType) -> bool
# type: (EventType) -> CORSConfig
route_key = lambda_event['requestContext']['resourcePath']
route_entry = self.app_object.routes[route_key]
return route_entry.cors
Expand Down
18 changes: 18 additions & 0 deletions tests/integration/test_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,24 @@ def test_can_support_cors(smoke_test_app):
assert headers['Access-Control-Allow-Methods'] == 'GET,POST,PUT,OPTIONS'


def test_can_support_custom_cors(smoke_test_app):
response = requests.get(smoke_test_app.url + '/custom_cors')
response.raise_for_status()
expected_allow_origin = 'https://foo.example.com'
assert response.headers[
'Access-Control-Allow-Origin'] == expected_allow_origin

# Should also have injected an OPTIONs request.
response = requests.options(smoke_test_app.url + '/custom_cors')
response.raise_for_status()
headers = response.headers
assert headers['Access-Control-Allow-Origin'] == expected_allow_origin
assert headers['Access-Control-Allow-Headers'] == (
'X-Special-Header,Content-Type,X-Amz-Date,Authorization,X-Api-Key,'
'X-Amz-Security-Token')
assert headers['Access-Control-Allow-Methods'] == 'GET,POST,PUT,OPTIONS'


def test_to_dict_is_also_json_serializable(smoke_test_app):
assert 'headers' in smoke_test_app.get_json('/todict')

Expand Down
10 changes: 10 additions & 0 deletions tests/integration/testapp/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,16 @@ def supports_cors():
return {'cors': True}


@app.route('/custom_cors', methods=['GET', 'POST', 'PUT'], cors=CORSConfig(
allow_origin='https://foo.example.com',
allow_headers=['X-Special-Header'],
max_age=600,
expose_headers=['X-Special-Header'],
allow_credentials=False))
def supports_custom_cors():
return {'cors': True}


@app.route('/todict', methods=['GET'])
def todict():
return app.current_request.to_dict()
Expand Down
55 changes: 55 additions & 0 deletions tests/unit/deploy/test_swagger.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from chalice.deploy.swagger import SwaggerGenerator
from chalice import CORSConfig

from pytest import fixture

Expand Down Expand Up @@ -71,6 +72,60 @@ def multiple_methods():


def test_can_add_preflight_cors(sample_app, swagger_gen):
@sample_app.route('/cors', methods=['GET', 'POST'], cors=CORSConfig(
allow_origin='http://foo.com',
allow_headers=['X-Special-Header']))
def cors_request():
pass

doc = swagger_gen.generate_swagger(sample_app)
view_config = doc['paths']['/cors']
# We should add an OPTIONS preflight request automatically.
assert 'options' in view_config, (
'Preflight OPTIONS method not added to CORS view')
options = view_config['options']
expected_response_params = {
'method.response.header.Access-Control-Allow-Methods': (
"'GET,POST,OPTIONS'"),
'method.response.header.Access-Control-Allow-Headers': (
"'X-Special-Header,Content-Type,X-Amz-Date,Authorization,"
"X-Api-Key,X-Amz-Security-Token'"),
'method.response.header.Access-Control-Allow-Origin': (
"'http://foo.com'"),
}
assert options == {
'consumes': ['application/json'],
'produces': ['application/json'],
'responses': {
'200': {
'description': '200 response',
'schema': {
'$ref': '#/definitions/Empty'
},
'headers': {
'Access-Control-Allow-Origin': {'type': 'string'},
'Access-Control-Allow-Methods': {'type': 'string'},
'Access-Control-Allow-Headers': {'type': 'string'},
}
}
},
'x-amazon-apigateway-integration': {
'responses': {
'default': {
'statusCode': '200',
'responseParameters': expected_response_params,
}
},
'requestTemplates': {
'application/json': '{"statusCode": 200}'
},
'passthroughBehavior': 'when_no_match',
'type': 'mock',
},
}


def test_can_add_preflight_custom_cors(sample_app, swagger_gen):
@sample_app.route('/cors', methods=['GET', 'POST'], cors=True)
def cors_request():
pass
Expand Down
1 change: 1 addition & 0 deletions tests/unit/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def test_can_support_patch_method(handler):
handler.do_PATCH()
assert _get_body_from_response_stream(handler) == {'patch': True}


def test_can_support_decimals(handler):
set_current_request(handler, method='GET', path='/decimals')
handler.do_PATCH()
Expand Down

0 comments on commit 1133702

Please sign in to comment.