Skip to content

Commit

Permalink
Add support for a more granular CORSConfig object. (#311)
Browse files Browse the repository at this point in the history
Add support for a more granular CORSConfig object.
  • Loading branch information
stealthycoin authored Apr 29, 2017
1 parent b4ad6c6 commit 6164d5b
Show file tree
Hide file tree
Showing 11 changed files with 268 additions and 32 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ Next Release (TBD)

* Alway overwrite existing API Gateway Rest API on updates
(`#305 <https://github.com/awslabs/chalice/issues/305>`__)

* Added more granular support for CORS
(`#311 <https://github.com/awslabs/chalice/pull/311>`__)

0.8.0
=====
Expand Down
43 changes: 41 additions & 2 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -730,15 +730,54 @@ The preflight request will return a response that includes:
* ``Access-Control-Allow-Headers: Content-Type,X-Amz-Date,Authorization,
X-Api-Key,X-Amz-Security-Token``.

If more fine grained control of the CORS headers is desired, set the ``cors``
parameter to an instance of ``CORSConfig`` instead of ``True``. The
``CORSConfig`` object can be imported from from the ``chalice`` package it's
constructor takes the following keyword arguments that map to CORS headers:

================= ==== ================================
Argument Type Header
================= ==== ================================
allow_origin str Access-Control-Allow-Origin
allow_headers list Access-Control-Allow-Headers
expose_headers list Access-Control-Expose-Headers
max_age int Access-Control-Max-Age
allow_credentials bool Access-Control-Allow-Credentials
================= ==== ================================

Code sample defining more CORS headers:

.. code-block:: python
from chalice import CORSConfig
cors_config = CORSConfig(
allow_origin='https://foo.example.com',
allow_headers=['X-Special-Header'],
max_age=600,
expose_headers=['X-Special-Header'],
allow_credentials=True
)
@app.route('/custom_cors', methods=['GET'], cors=cors_config)
def supports_custom_cors():
return {'cors': True}
There's a couple of things to keep in mind when enabling cors for a view:

* An ``OPTIONS`` method for preflighting is always injected. Ensure that
you don't have ``OPTIONS`` in the ``methods=[...]`` list of your
view function.
* Even though the ``Access-Control-Allow-Origin`` header can be set to a
string that is a space separated list of origins, this behavior does not
work on all clients that implement CORS. You should only supply a single
origin to the ``CORSConfig`` object. If you need to supply multiple origins
you will need to define a custom handler for it that accepts ``OPTIONS``
requests and matches the ``Origin`` header against a whitelist of origins.
If the match is succssful then return just their ``Origin`` back to them
in the ``Access-Control-Allow-Origin`` header.
* Every view function must explicitly enable CORS support.
* There's no support for customizing the CORS configuration.

The last two points will change in the future. See
The last point will change in the future. See
`this issue
<https://github.com/awslabs/chalice/issues/70#issuecomment-248787037>`_
for more information.
Expand Down
2 changes: 1 addition & 1 deletion chalice/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from chalice.app import Chalice
from chalice.app import (
ChaliceViewError, BadRequestError, UnauthorizedError, ForbiddenError,
NotFoundError, ConflictError, TooManyRequestsError, Response
NotFoundError, ConflictError, TooManyRequestsError, Response, CORSConfig
)

__version__ = '0.8.0'
67 changes: 63 additions & 4 deletions chalice/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,54 @@ def __repr__(self):
return 'CaseInsensitiveMapping(%s)' % repr(self._dict)


class CORSConfig(object):
"""A cors configuration to attach to a route."""

_REQUIRED_HEADERS = ['Content-Type', 'X-Amz-Date', 'Authorization',
'X-Api-Key', 'X-Amz-Security-Token']

def __init__(self, allow_origin='*', allow_headers=None,
expose_headers=None, max_age=None, allow_credentials=None):
self.allow_origin = allow_origin

if allow_headers is None:
allow_headers = set(self._REQUIRED_HEADERS)
else:
allow_headers = set(allow_headers + self._REQUIRED_HEADERS)
self._allow_headers = allow_headers

if expose_headers is None:
expose_headers = []
self._expose_headers = expose_headers

self._max_age = max_age
self._allow_credentials = allow_credentials

@property
def allow_headers(self):
return ','.join(sorted(self._allow_headers))

def get_access_control_headers(self):
headers = {
'Access-Control-Allow-Origin': self.allow_origin,
'Access-Control-Allow-Headers': self.allow_headers
}
if self._expose_headers:
headers.update({
'Access-Control-Expose-Headers': ','.join(self._expose_headers)
})
if self._max_age is not None:
headers.update({
'Access-Control-Max-Age': str(self._max_age)
})
if self._allow_credentials is True:
headers.update({
'Access-Control-Allow-Credentials': 'true'
})

return headers


class Request(object):
"""The current request from API gateway."""

Expand Down Expand Up @@ -167,6 +215,15 @@ def __init__(self, view_function, view_name, path, methods,
#: e.g, '/foo/{bar}/{baz}/qux -> ['bar', 'baz']
self.view_args = self._parse_view_args()
self.content_types = content_types
# cors is passed as either a boolean or a CORSConfig object. If it is a
# boolean it needs to be replaced with a real CORSConfig object to
# pass the typechecker. None in this context will not inject any cors
# headers, otherwise the CORSConfig object will determine which
# headers are injected.
if cors is True:
cors = CORSConfig()
elif cors is False:
cors = None
self.cors = cors

def _parse_view_args(self):
Expand Down Expand Up @@ -309,7 +366,7 @@ def __call__(self, event, context):
response = self._get_view_function_response(view_function,
function_args)
if self._cors_enabled_for_route(route_entry):
self._add_cors_headers(response)
self._add_cors_headers(response, route_entry.cors)
return response.to_dict()

def _get_view_function_response(self, view_function, function_args):
Expand Down Expand Up @@ -354,7 +411,9 @@ def _validate_response(self, response):
(header, value))

def _cors_enabled_for_route(self, route_entry):
return route_entry.cors
return route_entry.cors is not None

def _add_cors_headers(self, response):
response.headers['Access-Control-Allow-Origin'] = '*'
def _add_cors_headers(self, response, cors):
for name, value in cors.get_access_control_headers().items():
if name not in response.headers:
response.headers[name] = value
12 changes: 9 additions & 3 deletions chalice/app.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Any, Callable
from typing import Dict, List, Any, Callable, Union

class ChaliceError(Exception): ...
class ChaliceViewError(ChaliceError):
Expand All @@ -14,6 +14,12 @@ class TooManyRequestsError(ChaliceViewError): ...

ALL_ERRORS = ... # type: List[ChaliceViewError]

class CORSConfig:
allow_origin = ... # type: str
allow_headers = ... # type: str
get_access_control_headers = ... # type: Callable[..., Dict[str, str]]


class Request:
query_params = ... # type: Dict[str, str]
headers = ... # type: Dict[str, str]
Expand Down Expand Up @@ -60,14 +66,14 @@ class RouteEntry(object):
api_key_required = ... # type: bool
content_types = ... # type: List[str]
view_args = ... # type: List[str]
cors = ... # type: bool
cors = ... # type: CORSConfig

def __init__(self, view_function: Callable[..., Any],
view_name: str, path: str, methods: List[str],
authorizer_name: str=None,
api_key_required: bool=None,
content_types: List[str]=None,
cors: bool=False) -> None: ...
cors: Union[bool, CORSConfig]=False) -> None: ...

def _parse_view_args(self) -> List[str]: ...

Expand Down
22 changes: 10 additions & 12 deletions chalice/deploy/swagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _add_route_paths(self, api, app):
self._add_to_security_definition(
current['security'], api, app.authorizers)
swagger_for_path[http_method.lower()] = current
if view.cors:
if view.cors is not None:
self._add_preflight_request(view, swagger_for_path)

def _add_to_security_definition(self, security, api_config, authorizers):
Expand Down Expand Up @@ -141,16 +141,18 @@ def _add_view_args(self, apig_integ, view_args):

def _add_preflight_request(self, view, swagger_for_path):
# type: (RouteEntry, Dict[str, Any]) -> None
cors = view.cors
methods = view.methods + ['OPTIONS']
allowed_methods = ','.join(methods)

response_params = {
"method.response.header.Access-Control-Allow-Methods": (
"'%s'" % allowed_methods),
"method.response.header.Access-Control-Allow-Headers": (
"'Content-Type,X-Amz-Date,Authorization,X-Api-Key,"
"X-Amz-Security-Token'"),
"method.response.header.Access-Control-Allow-Origin": "'*'"
'Access-Control-Allow-Methods': '%s' % allowed_methods
}
response_params.update(cors.get_access_control_headers())

headers = {k: {'type': 'string'} for k, _ in response_params.items()}
response_params = {'method.response.header.%s' % k: "'%s'" % v for k, v
in response_params.items()}

options_request = {
"consumes": ["application/json"],
Expand All @@ -159,11 +161,7 @@ def _add_preflight_request(self, view, swagger_for_path):
"200": {
"description": "200 response",
"schema": {"$ref": "#/definitions/Empty"},
"headers": {
"Access-Control-Allow-Origin": {"type": "string"},
"Access-Control-Allow-Methods": {"type": "string"},
"Access-Control-Allow-Headers": {"type": "string"},
}
"headers": headers
}
},
"x-amazon-apigateway-integration": {
Expand Down
6 changes: 3 additions & 3 deletions chalice/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from six.moves.BaseHTTPServer import HTTPServer
from six.moves.BaseHTTPServer import BaseHTTPRequestHandler

from chalice.app import Chalice # noqa
from chalice.compat import urlparse, parse_qs
from chalice.app import Chalice, CORSConfig # noqa
from typing import List, Any, Dict, Tuple, Callable # noqa
from chalice.compat import urlparse, parse_qs


MatchResult = namedtuple('MatchResult', ['route', 'captured', 'query_params'])
Expand Down Expand Up @@ -156,7 +156,7 @@ def do_OPTIONS(self):
self._send_autogen_options_response()

def _cors_enabled_for_route(self, lambda_event):
# type: (EventType) -> bool
# type: (EventType) -> CORSConfig
route_key = lambda_event['requestContext']['resourcePath']
route_entry = self.app_object.routes[route_key]
return route_entry.cors
Expand Down
25 changes: 24 additions & 1 deletion tests/integration/test_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,33 @@ def test_can_support_cors(smoke_test_app):
headers = response.headers
assert headers['Access-Control-Allow-Origin'] == '*'
assert headers['Access-Control-Allow-Headers'] == (
'Content-Type,X-Amz-Date,Authorization,X-Api-Key,X-Amz-Security-Token')
'Authorization,Content-Type,X-Amz-Date,X-Amz-Security-Token,'
'X-Api-Key')
assert headers['Access-Control-Allow-Methods'] == 'GET,POST,PUT,OPTIONS'


def test_can_support_custom_cors(smoke_test_app):
response = requests.get(smoke_test_app.url + '/custom_cors')
response.raise_for_status()
expected_allow_origin = 'https://foo.example.com'
assert response.headers[
'Access-Control-Allow-Origin'] == expected_allow_origin

# Should also have injected an OPTIONs request.
response = requests.options(smoke_test_app.url + '/custom_cors')
response.raise_for_status()
headers = response.headers
print(headers)
assert headers['Access-Control-Allow-Origin'] == expected_allow_origin
assert headers['Access-Control-Allow-Headers'] == (
'Authorization,Content-Type,X-Amz-Date,X-Amz-Security-Token,'
'X-Api-Key,X-Special-Header')
assert headers['Access-Control-Allow-Methods'] == 'GET,POST,PUT,OPTIONS'
assert headers['Access-Control-Max-Age'] == '600'
assert headers['Access-Control-Expose-Headers'] == 'X-Special-Header'
assert headers['Access-Control-Allow-Credentials'] == 'true'


def test_to_dict_is_also_json_serializable(smoke_test_app):
assert 'headers' in smoke_test_app.get_json('/todict')

Expand Down
19 changes: 17 additions & 2 deletions tests/integration/testapp/app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from chalice import Chalice, BadRequestError, NotFoundError, Response
from chalice.compat import parse_qs
from chalice import Chalice, BadRequestError, NotFoundError, Response,\
CORSConfig

try:
from urllib.parse import parse_qs
except:
from urlparse import parse_qs

# This is a test app that is used by integration tests.
# This app exercises all the major features of chalice
Expand Down Expand Up @@ -75,6 +80,16 @@ def supports_cors():
return {'cors': True}


@app.route('/custom_cors', methods=['GET', 'POST', 'PUT'], cors=CORSConfig(
allow_origin='https://foo.example.com',
allow_headers=['X-Special-Header'],
max_age=600,
expose_headers=['X-Special-Header'],
allow_credentials=True))
def supports_custom_cors():
return {'cors': True}


@app.route('/todict', methods=['GET'])
def todict():
return app.current_request.to_dict()
Expand Down
Loading

0 comments on commit 6164d5b

Please sign in to comment.