From 6164d5bcd224404c2b3fb31d60642bea1455b2fd Mon Sep 17 00:00:00 2001 From: John Carlyle Date: Fri, 28 Apr 2017 22:53:21 -0700 Subject: [PATCH] Add support for a more granular CORSConfig object. (#311) Add support for a more granular CORSConfig object. --- CHANGELOG.rst | 3 +- README.rst | 43 +++++++++++++++++- chalice/__init__.py | 2 +- chalice/app.py | 67 +++++++++++++++++++++++++-- chalice/app.pyi | 12 +++-- chalice/deploy/swagger.py | 22 +++++---- chalice/local.py | 6 +-- tests/integration/test_features.py | 25 ++++++++++- tests/integration/testapp/app.py | 19 +++++++- tests/unit/deploy/test_swagger.py | 72 +++++++++++++++++++++++++++++- tests/unit/test_local.py | 29 +++++++++++- 11 files changed, 268 insertions(+), 32 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 9182d41b0..7a93d36a2 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -7,7 +7,8 @@ Next Release (TBD) * Alway overwrite existing API Gateway Rest API on updates (`#305 `__) - +* Added more granular support for CORS + (`#311 `__) 0.8.0 ===== diff --git a/README.rst b/README.rst index 162649234..226a6c376 100644 --- a/README.rst +++ b/README.rst @@ -730,15 +730,54 @@ The preflight request will return a response that includes: * ``Access-Control-Allow-Headers: Content-Type,X-Amz-Date,Authorization, X-Api-Key,X-Amz-Security-Token``. +If more fine grained control of the CORS headers is desired, set the ``cors`` +parameter to an instance of ``CORSConfig`` instead of ``True``. The +``CORSConfig`` object can be imported from from the ``chalice`` package it's +constructor takes the following keyword arguments that map to CORS headers: + +================= ==== ================================ +Argument Type Header +================= ==== ================================ +allow_origin str Access-Control-Allow-Origin +allow_headers list Access-Control-Allow-Headers +expose_headers list Access-Control-Expose-Headers +max_age int Access-Control-Max-Age +allow_credentials bool Access-Control-Allow-Credentials +================= ==== ================================ + +Code sample defining more CORS headers: + +.. code-block:: python + + from chalice import CORSConfig + cors_config = CORSConfig( + allow_origin='https://foo.example.com', + allow_headers=['X-Special-Header'], + max_age=600, + expose_headers=['X-Special-Header'], + allow_credentials=True + ) + @app.route('/custom_cors', methods=['GET'], cors=cors_config) + def supports_custom_cors(): + return {'cors': True} + + There's a couple of things to keep in mind when enabling cors for a view: * An ``OPTIONS`` method for preflighting is always injected. Ensure that you don't have ``OPTIONS`` in the ``methods=[...]`` list of your view function. +* Even though the ``Access-Control-Allow-Origin`` header can be set to a + string that is a space separated list of origins, this behavior does not + work on all clients that implement CORS. You should only supply a single + origin to the ``CORSConfig`` object. If you need to supply multiple origins + you will need to define a custom handler for it that accepts ``OPTIONS`` + requests and matches the ``Origin`` header against a whitelist of origins. + If the match is succssful then return just their ``Origin`` back to them + in the ``Access-Control-Allow-Origin`` header. * Every view function must explicitly enable CORS support. -* There's no support for customizing the CORS configuration. -The last two points will change in the future. See +The last point will change in the future. See `this issue `_ for more information. diff --git a/chalice/__init__.py b/chalice/__init__.py index 303693d16..db3836562 100644 --- a/chalice/__init__.py +++ b/chalice/__init__.py @@ -1,7 +1,7 @@ from chalice.app import Chalice from chalice.app import ( ChaliceViewError, BadRequestError, UnauthorizedError, ForbiddenError, - NotFoundError, ConflictError, TooManyRequestsError, Response + NotFoundError, ConflictError, TooManyRequestsError, Response, CORSConfig ) __version__ = '0.8.0' diff --git a/chalice/app.py b/chalice/app.py index 00ac993f2..3b4f6f3de 100644 --- a/chalice/app.py +++ b/chalice/app.py @@ -100,6 +100,54 @@ def __repr__(self): return 'CaseInsensitiveMapping(%s)' % repr(self._dict) +class CORSConfig(object): + """A cors configuration to attach to a route.""" + + _REQUIRED_HEADERS = ['Content-Type', 'X-Amz-Date', 'Authorization', + 'X-Api-Key', 'X-Amz-Security-Token'] + + def __init__(self, allow_origin='*', allow_headers=None, + expose_headers=None, max_age=None, allow_credentials=None): + self.allow_origin = allow_origin + + if allow_headers is None: + allow_headers = set(self._REQUIRED_HEADERS) + else: + allow_headers = set(allow_headers + self._REQUIRED_HEADERS) + self._allow_headers = allow_headers + + if expose_headers is None: + expose_headers = [] + self._expose_headers = expose_headers + + self._max_age = max_age + self._allow_credentials = allow_credentials + + @property + def allow_headers(self): + return ','.join(sorted(self._allow_headers)) + + def get_access_control_headers(self): + headers = { + 'Access-Control-Allow-Origin': self.allow_origin, + 'Access-Control-Allow-Headers': self.allow_headers + } + if self._expose_headers: + headers.update({ + 'Access-Control-Expose-Headers': ','.join(self._expose_headers) + }) + if self._max_age is not None: + headers.update({ + 'Access-Control-Max-Age': str(self._max_age) + }) + if self._allow_credentials is True: + headers.update({ + 'Access-Control-Allow-Credentials': 'true' + }) + + return headers + + class Request(object): """The current request from API gateway.""" @@ -167,6 +215,15 @@ def __init__(self, view_function, view_name, path, methods, #: e.g, '/foo/{bar}/{baz}/qux -> ['bar', 'baz'] self.view_args = self._parse_view_args() self.content_types = content_types + # cors is passed as either a boolean or a CORSConfig object. If it is a + # boolean it needs to be replaced with a real CORSConfig object to + # pass the typechecker. None in this context will not inject any cors + # headers, otherwise the CORSConfig object will determine which + # headers are injected. + if cors is True: + cors = CORSConfig() + elif cors is False: + cors = None self.cors = cors def _parse_view_args(self): @@ -309,7 +366,7 @@ def __call__(self, event, context): response = self._get_view_function_response(view_function, function_args) if self._cors_enabled_for_route(route_entry): - self._add_cors_headers(response) + self._add_cors_headers(response, route_entry.cors) return response.to_dict() def _get_view_function_response(self, view_function, function_args): @@ -354,7 +411,9 @@ def _validate_response(self, response): (header, value)) def _cors_enabled_for_route(self, route_entry): - return route_entry.cors + return route_entry.cors is not None - def _add_cors_headers(self, response): - response.headers['Access-Control-Allow-Origin'] = '*' + def _add_cors_headers(self, response, cors): + for name, value in cors.get_access_control_headers().items(): + if name not in response.headers: + response.headers[name] = value diff --git a/chalice/app.pyi b/chalice/app.pyi index 577b905f7..e3003e7e6 100644 --- a/chalice/app.pyi +++ b/chalice/app.pyi @@ -1,4 +1,4 @@ -from typing import Dict, List, Any, Callable +from typing import Dict, List, Any, Callable, Union class ChaliceError(Exception): ... class ChaliceViewError(ChaliceError): @@ -14,6 +14,12 @@ class TooManyRequestsError(ChaliceViewError): ... ALL_ERRORS = ... # type: List[ChaliceViewError] +class CORSConfig: + allow_origin = ... # type: str + allow_headers = ... # type: str + get_access_control_headers = ... # type: Callable[..., Dict[str, str]] + + class Request: query_params = ... # type: Dict[str, str] headers = ... # type: Dict[str, str] @@ -60,14 +66,14 @@ class RouteEntry(object): api_key_required = ... # type: bool content_types = ... # type: List[str] view_args = ... # type: List[str] - cors = ... # type: bool + cors = ... # type: CORSConfig def __init__(self, view_function: Callable[..., Any], view_name: str, path: str, methods: List[str], authorizer_name: str=None, api_key_required: bool=None, content_types: List[str]=None, - cors: bool=False) -> None: ... + cors: Union[bool, CORSConfig]=False) -> None: ... def _parse_view_args(self) -> List[str]: ... diff --git a/chalice/deploy/swagger.py b/chalice/deploy/swagger.py index 4b0f1a029..ccbe14ebb 100644 --- a/chalice/deploy/swagger.py +++ b/chalice/deploy/swagger.py @@ -46,7 +46,7 @@ def _add_route_paths(self, api, app): self._add_to_security_definition( current['security'], api, app.authorizers) swagger_for_path[http_method.lower()] = current - if view.cors: + if view.cors is not None: self._add_preflight_request(view, swagger_for_path) def _add_to_security_definition(self, security, api_config, authorizers): @@ -141,16 +141,18 @@ def _add_view_args(self, apig_integ, view_args): def _add_preflight_request(self, view, swagger_for_path): # type: (RouteEntry, Dict[str, Any]) -> None + cors = view.cors methods = view.methods + ['OPTIONS'] allowed_methods = ','.join(methods) + response_params = { - "method.response.header.Access-Control-Allow-Methods": ( - "'%s'" % allowed_methods), - "method.response.header.Access-Control-Allow-Headers": ( - "'Content-Type,X-Amz-Date,Authorization,X-Api-Key," - "X-Amz-Security-Token'"), - "method.response.header.Access-Control-Allow-Origin": "'*'" + 'Access-Control-Allow-Methods': '%s' % allowed_methods } + response_params.update(cors.get_access_control_headers()) + + headers = {k: {'type': 'string'} for k, _ in response_params.items()} + response_params = {'method.response.header.%s' % k: "'%s'" % v for k, v + in response_params.items()} options_request = { "consumes": ["application/json"], @@ -159,11 +161,7 @@ def _add_preflight_request(self, view, swagger_for_path): "200": { "description": "200 response", "schema": {"$ref": "#/definitions/Empty"}, - "headers": { - "Access-Control-Allow-Origin": {"type": "string"}, - "Access-Control-Allow-Methods": {"type": "string"}, - "Access-Control-Allow-Headers": {"type": "string"}, - } + "headers": headers } }, "x-amazon-apigateway-integration": { diff --git a/chalice/local.py b/chalice/local.py index 1c8451d2a..85ee47c18 100644 --- a/chalice/local.py +++ b/chalice/local.py @@ -8,9 +8,9 @@ from six.moves.BaseHTTPServer import HTTPServer from six.moves.BaseHTTPServer import BaseHTTPRequestHandler -from chalice.app import Chalice # noqa -from chalice.compat import urlparse, parse_qs +from chalice.app import Chalice, CORSConfig # noqa from typing import List, Any, Dict, Tuple, Callable # noqa +from chalice.compat import urlparse, parse_qs MatchResult = namedtuple('MatchResult', ['route', 'captured', 'query_params']) @@ -156,7 +156,7 @@ def do_OPTIONS(self): self._send_autogen_options_response() def _cors_enabled_for_route(self, lambda_event): - # type: (EventType) -> bool + # type: (EventType) -> CORSConfig route_key = lambda_event['requestContext']['resourcePath'] route_entry = self.app_object.routes[route_key] return route_entry.cors diff --git a/tests/integration/test_features.py b/tests/integration/test_features.py index adf23d14c..a3134e418 100644 --- a/tests/integration/test_features.py +++ b/tests/integration/test_features.py @@ -188,10 +188,33 @@ def test_can_support_cors(smoke_test_app): headers = response.headers assert headers['Access-Control-Allow-Origin'] == '*' assert headers['Access-Control-Allow-Headers'] == ( - 'Content-Type,X-Amz-Date,Authorization,X-Api-Key,X-Amz-Security-Token') + 'Authorization,Content-Type,X-Amz-Date,X-Amz-Security-Token,' + 'X-Api-Key') assert headers['Access-Control-Allow-Methods'] == 'GET,POST,PUT,OPTIONS' +def test_can_support_custom_cors(smoke_test_app): + response = requests.get(smoke_test_app.url + '/custom_cors') + response.raise_for_status() + expected_allow_origin = 'https://foo.example.com' + assert response.headers[ + 'Access-Control-Allow-Origin'] == expected_allow_origin + + # Should also have injected an OPTIONs request. + response = requests.options(smoke_test_app.url + '/custom_cors') + response.raise_for_status() + headers = response.headers + print(headers) + assert headers['Access-Control-Allow-Origin'] == expected_allow_origin + assert headers['Access-Control-Allow-Headers'] == ( + 'Authorization,Content-Type,X-Amz-Date,X-Amz-Security-Token,' + 'X-Api-Key,X-Special-Header') + assert headers['Access-Control-Allow-Methods'] == 'GET,POST,PUT,OPTIONS' + assert headers['Access-Control-Max-Age'] == '600' + assert headers['Access-Control-Expose-Headers'] == 'X-Special-Header' + assert headers['Access-Control-Allow-Credentials'] == 'true' + + def test_to_dict_is_also_json_serializable(smoke_test_app): assert 'headers' in smoke_test_app.get_json('/todict') diff --git a/tests/integration/testapp/app.py b/tests/integration/testapp/app.py index dc66a028c..98c512032 100644 --- a/tests/integration/testapp/app.py +++ b/tests/integration/testapp/app.py @@ -1,5 +1,10 @@ -from chalice import Chalice, BadRequestError, NotFoundError, Response -from chalice.compat import parse_qs +from chalice import Chalice, BadRequestError, NotFoundError, Response,\ + CORSConfig + +try: + from urllib.parse import parse_qs +except: + from urlparse import parse_qs # This is a test app that is used by integration tests. # This app exercises all the major features of chalice @@ -75,6 +80,16 @@ def supports_cors(): return {'cors': True} +@app.route('/custom_cors', methods=['GET', 'POST', 'PUT'], cors=CORSConfig( + allow_origin='https://foo.example.com', + allow_headers=['X-Special-Header'], + max_age=600, + expose_headers=['X-Special-Header'], + allow_credentials=True)) +def supports_custom_cors(): + return {'cors': True} + + @app.route('/todict', methods=['GET']) def todict(): return app.current_request.to_dict() diff --git a/tests/unit/deploy/test_swagger.py b/tests/unit/deploy/test_swagger.py index 6fa6f16a5..8ffb94cfc 100644 --- a/tests/unit/deploy/test_swagger.py +++ b/tests/unit/deploy/test_swagger.py @@ -1,4 +1,5 @@ from chalice.deploy.swagger import SwaggerGenerator +from chalice import CORSConfig from pytest import fixture @@ -71,6 +72,73 @@ def multiple_methods(): def test_can_add_preflight_cors(sample_app, swagger_gen): + @sample_app.route('/cors', methods=['GET', 'POST'], cors=CORSConfig( + allow_origin='http://foo.com', + allow_headers=['X-ZZ-Top', 'X-Special-Header'], + expose_headers=['X-Exposed', 'X-Special'], + max_age=600, + allow_credentials=True)) + def cors_request(): + pass + + doc = swagger_gen.generate_swagger(sample_app) + view_config = doc['paths']['/cors'] + # We should add an OPTIONS preflight request automatically. + assert 'options' in view_config, ( + 'Preflight OPTIONS method not added to CORS view') + options = view_config['options'] + expected_response_params = { + 'method.response.header.Access-Control-Allow-Methods': ( + "'GET,POST,OPTIONS'"), + 'method.response.header.Access-Control-Allow-Headers': ( + "'Authorization,Content-Type,X-Amz-Date,X-Amz-Security-Token," + "X-Api-Key,X-Special-Header,X-ZZ-Top'"), + 'method.response.header.Access-Control-Allow-Origin': ( + "'http://foo.com'"), + 'method.response.header.Access-Control-Expose-Headers': ( + "'X-Exposed,X-Special'"), + 'method.response.header.Access-Control-Max-Age': ( + "'600'"), + 'method.response.header.Access-Control-Allow-Credentials': ( + "'true'"), + + } + assert options == { + 'consumes': ['application/json'], + 'produces': ['application/json'], + 'responses': { + '200': { + 'description': '200 response', + 'schema': { + '$ref': '#/definitions/Empty' + }, + 'headers': { + 'Access-Control-Allow-Origin': {'type': 'string'}, + 'Access-Control-Allow-Methods': {'type': 'string'}, + 'Access-Control-Allow-Headers': {'type': 'string'}, + 'Access-Control-Expose-Headers': {'type': 'string'}, + 'Access-Control-Max-Age': {'type': 'string'}, + 'Access-Control-Allow-Credentials': {'type': 'string'}, + } + } + }, + 'x-amazon-apigateway-integration': { + 'responses': { + 'default': { + 'statusCode': '200', + 'responseParameters': expected_response_params, + } + }, + 'requestTemplates': { + 'application/json': '{"statusCode": 200}' + }, + 'passthroughBehavior': 'when_no_match', + 'type': 'mock', + }, + } + + +def test_can_add_preflight_custom_cors(sample_app, swagger_gen): @sample_app.route('/cors', methods=['GET', 'POST'], cors=True) def cors_request(): pass @@ -85,8 +153,8 @@ def cors_request(): 'method.response.header.Access-Control-Allow-Methods': ( "'GET,POST,OPTIONS'"), 'method.response.header.Access-Control-Allow-Headers': ( - "'Content-Type,X-Amz-Date,Authorization," - "X-Api-Key,X-Amz-Security-Token'"), + "'Authorization,Content-Type,X-Amz-Date,X-Amz-Security-Token," + "X-Api-Key'"), 'method.response.header.Access-Control-Allow-Origin': "'*'", } assert options == { diff --git a/tests/unit/test_local.py b/tests/unit/test_local.py index 462fe0957..841a23d1c 100644 --- a/tests/unit/test_local.py +++ b/tests/unit/test_local.py @@ -1,4 +1,4 @@ -from chalice import local, BadRequestError +from chalice import local, BadRequestError, CORSConfig import json import decimal import pytest @@ -42,6 +42,16 @@ def put(): def cors(): return {'cors': True} + @demo.route('/custom_cors', methods=['GET', 'PUT'], cors=CORSConfig( + allow_origin='https://foo.bar', + allow_headers=['Header-A', 'Header-B'], + expose_headers=['Header-A', 'Header-B'], + max_age=600, + allow_credentials=True + )) + def custom_cors(): + return {'cors': True} + @demo.route('/options', methods=['OPTIONS']) def options(): return {'options': True} @@ -129,6 +139,22 @@ def test_will_respond_with_cors_enabled(handler): assert b'Access-Control-Allow-Origin: *' in response_lines +def test_will_respond_with_custom_cors_enabled(handler): + headers = {'content-type': 'application/json', 'origin': 'null'} + set_current_request(handler, method='GET', path='/custom_cors', + headers=headers) + handler.do_GET() + response = handler.wfile.getvalue().splitlines() + print(response) + assert b'Access-Control-Allow-Origin: https://foo.bar' in response + assert (b'Access-Control-Allow-Headers: Authorization,Content-Type,' + b'Header-A,Header-B,X-Amz-Date,X-Amz-Security-Token,' + b'X-Api-Key') in response + assert b'Access-Control-Expose-Headers: Header-A,Header-B' in response + assert b'Access-Control-Max-Age: 600' in response + assert b'Access-Control-Allow-Credentials: true' in response + + def test_can_preflight_request(handler): headers = {'content-type': 'application/json', 'origin': 'null'} set_current_request(handler, method='OPTIONS', path='/cors', @@ -166,6 +192,7 @@ def test_can_support_patch_method(handler): handler.do_PATCH() assert _get_body_from_response_stream(handler) == {'patch': True} + def test_can_support_decimals(handler): set_current_request(handler, method='GET', path='/decimals') handler.do_PATCH()