diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index a97f1c7..8d13930 100755 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -31,9 +31,10 @@ jobs: strategy: matrix: python: - - "3.7" - "3.8" - "3.9" + - "3.10" + - "3.11" runs-on: ubuntu-latest @@ -70,11 +71,12 @@ jobs: run: jupyter kernelgateway --help - name: Run tests - run: nosetests --process-restartworker --with-coverage --cover-package=kernel_gateway + run: pytest -vv -W default --cov kernel_gateway --cov-branch --cov-report term-missing:skip-covered env: ASYNC_TEST_TIMEOUT: 10 - name: Upload coverage to Codecov uses: codecov/codecov-action@v1 with: + token: ${{ secrets.CODECOV_TOKEN }} fail_ci_if_error: true diff --git a/Makefile b/Makefile index a6275be..392cd75 100644 --- a/Makefile +++ b/Makefile @@ -71,10 +71,10 @@ sdist: ## Make a dist/*.tar.gz source distribution test: TEST?= test: ## Make a python3 test run ifeq ($(TEST),) - $(SA) $(ENV) && nosetests + $(SA) $(ENV) && pytest -vv else -# e.g., make test TEST="test_gatewayapp.TestGatewayAppConfig" - $(SA) $(ENV) && nosetests kernel_gateway.tests.$(TEST) +# e.g., make test TEST="test_gatewayapp.py::TestGatewayAppConfig" + $(SA) $(ENV) && pytest -vv kernel_gateway/tests/$(TEST) endif release: POST_SDIST=register upload diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 0000000..36d97aa --- /dev/null +++ b/codecov.yml @@ -0,0 +1,10 @@ +coverage: + status: + project: + default: + target: auto + threshold: 5% + patch: + default: + target: 50% + range: 80..100 diff --git a/conftest.py b/conftest.py new file mode 100644 index 0000000..706c30a --- /dev/null +++ b/conftest.py @@ -0,0 +1,104 @@ +# Copyright (c) Jupyter Development Team. +# Distributed under the terms of the Modified BSD License. + +import os +import logging +import pytest +from binascii import hexlify +from traitlets.config import Config +from kernel_gateway.gatewayapp import KernelGatewayApp + +pytest_plugins = ["pytest_jupyter.jupyter_core", "pytest_jupyter.jupyter_server"] + + +@pytest.fixture(scope="function") +def jp_configurable_serverapp( + jp_nbconvert_templates, # this fixture must precede jp_environ + jp_environ, + jp_server_config, + jp_argv, + jp_http_port, + jp_base_url, + tmp_path, + jp_root_dir, + jp_logging_stream, + jp_asyncio_loop, + io_loop, +): + """Starts a Jupyter Server instance based on + the provided configuration values. + The fixture is a factory; it can be called like + a function inside a unit test. Here's a basic + example of how use this fixture: + + .. code-block:: python + + def my_test(jp_configurable_serverapp): + app = jp_configurable_serverapp(...) + ... + """ + KernelGatewayApp.clear_instance() + + def _configurable_serverapp( + config=jp_server_config, + base_url=jp_base_url, + argv=jp_argv, + http_port=jp_http_port, + **kwargs, + ): + c = Config(config) + + if "auth_token" not in c.KernelGatewayApp and not c.IdentityProvider.token: + default_token = hexlify(os.urandom(4)).decode("ascii") + c.IdentityProvider.token = default_token + + app = KernelGatewayApp.instance( + # Set the log level to debug for testing purposes + log_level="DEBUG", + port=http_port, + port_retries=0, + base_url=base_url, + config=c, + **kwargs, + ) + app.log.propagate = True + app.log.handlers = [] + # Initialize app without httpserver + if jp_asyncio_loop.is_running(): + app.initialize(argv=argv, new_httpserver=False) + else: + + async def initialize_app(): + app.initialize(argv=argv, new_httpserver=False) + + jp_asyncio_loop.run_until_complete(initialize_app()) + # Reroute all logging StreamHandlers away from stdin/stdout since pytest hijacks + # these streams and closes them at unfortunate times. + stream_handlers = [h for h in app.log.handlers if isinstance(h, logging.StreamHandler)] + for handler in stream_handlers: + handler.setStream(jp_logging_stream) + app.log.propagate = True + app.log.handlers = [] + app.start_app() + return app + + return _configurable_serverapp + + +@pytest.fixture(autouse=True) +def jp_server_cleanup(jp_asyncio_loop): + yield + app: KernelGatewayApp = KernelGatewayApp.instance() + try: + jp_asyncio_loop.run_until_complete(app.async_shutdown()) + except (RuntimeError, SystemExit) as e: + print("ignoring cleanup error", e) + if hasattr(app, "kernel_manager"): + app.kernel_manager.context.destroy() + KernelGatewayApp.clear_instance() + + +@pytest.fixture +def jp_auth_header(jp_serverapp): + """Configures an authorization header using the token from the serverapp fixture.""" + return {"Authorization": f"token {jp_serverapp.identity_provider.token}"} diff --git a/docs/source/features.md b/docs/source/features.md index 0bfae4d..cc7feab 100644 --- a/docs/source/features.md +++ b/docs/source/features.md @@ -24,4 +24,4 @@ The Jupyter Kernel Gateway has the following features: * Generation of [Swagger specs](http://swagger.io/introducing-the-open-api-initiative/) for notebook-defined API in `notebook-http` mode * A CLI for launching the kernel gateway: `jupyter kernelgateway OPTIONS` -* A Python 2.7 and 3.3+ compatible implementation +* A Python 3.8+ compatible implementation diff --git a/kernel_gateway/auth/identity.py b/kernel_gateway/auth/identity.py new file mode 100644 index 0000000..2b1a63a --- /dev/null +++ b/kernel_gateway/auth/identity.py @@ -0,0 +1,58 @@ +# Copyright (c) Jupyter Development Team. +# Distributed under the terms of the Modified BSD License. +"""Gateway Identity Provider interface + +This defines the _authentication_ layer of Jupyter Server, +to be used in combination with Authorizer for _authorization_. +""" +from traitlets import default +from tornado import web + +from jupyter_server.auth.identity import IdentityProvider, User +from jupyter_server.base.handlers import JupyterHandler + + +class GatewayIdentityProvider(IdentityProvider): + """ + Interface for providing identity management and authentication for a Gateway server. + """ + + @default("token") + def _token_default(self): + return self.parent.auth_token + + @property + def auth_enabled(self): + if not self.token: + return False + return True + + def should_check_origin(self, handler: JupyterHandler) -> bool: + """Should the Handler check for CORS origin validation? + + Origin check should be skipped for token-authenticated requests. + + Returns: + - True, if Handler must check for valid CORS origin. + - False, if Handler should skip origin check since requests are token-authenticated. + """ + # Always check the origin unless operator configured gateway to allow any + return handler.settings["kg_allow_origin"] != "*" + + def generate_anonymous_user(self, handler: web.RequestHandler) -> User: + """Generate a random anonymous user. + + For use when a single shared token is used, + but does not identify a user. + """ + name = display_name = f"Anonymous" + initials = "An" + color = None + return User(name.lower(), name, display_name, initials, None, color) + + def is_token_authenticated(self, handler: web.RequestHandler) -> bool: + """The default authentication flow of Gateway is token auth. + + The only other option is no auth + """ + return True diff --git a/kernel_gateway/base/handlers.py b/kernel_gateway/base/handlers.py index 464ed5a..984cd17 100644 --- a/kernel_gateway/base/handlers.py +++ b/kernel_gateway/base/handlers.py @@ -3,18 +3,20 @@ """Tornado handlers for the base of the API.""" from tornado import web -import notebook.base.handlers as notebook_handlers +import jupyter_server.base.handlers as server_handlers from ..mixins import TokenAuthorizationMixin, CORSMixin, JSONErrorsMixin + class APIVersionHandler(TokenAuthorizationMixin, CORSMixin, JSONErrorsMixin, - notebook_handlers.APIVersionHandler): + server_handlers.APIVersionHandler): """Extends the notebook server base API handler with token auth, CORS, and JSON errors. """ pass + class NotFoundHandler(JSONErrorsMixin, web.RequestHandler): """Catches all requests and responds with 404 JSON messages. @@ -28,6 +30,7 @@ class NotFoundHandler(JSONErrorsMixin, web.RequestHandler): def prepare(self): raise web.HTTPError(404) + default_handlers = [ (r'/api', APIVersionHandler), (r'/(.*)', NotFoundHandler) diff --git a/kernel_gateway/gatewayapp.py b/kernel_gateway/gatewayapp.py index 67006f1..508721e 100644 --- a/kernel_gateway/gatewayapp.py +++ b/kernel_gateway/gatewayapp.py @@ -2,47 +2,52 @@ # Distributed under the terms of the Modified BSD License. """Kernel Gateway Jupyter application.""" +import asyncio import errno import importlib import logging +import hashlib +import hmac import os import sys import signal +import select import socket import ssl +import threading +from base64 import encodebytes from distutils.util import strtobool import nbformat -from notebook.services.kernels.kernelmanager import MappingKernelManager +from jupyter_server.services.kernels.kernelmanager import MappingKernelManager -try: - from urlparse import urlparse -except ImportError: - from urllib.parse import urlparse - -from traitlets import Unicode, Integer, default, observe, Type, Instance, List, CBool +from urllib.parse import urlparse +from traitlets import Unicode, Integer, Bytes, default, observe, Type, Instance, List, CBool from jupyter_core.application import JupyterApp, base_aliases from jupyter_client.kernelspec import KernelSpecManager -# Install the pyzmq ioloop. This has to be done before anything else from -# tornado is imported. -from zmq.eventloop import ioloop -ioloop.install() - from tornado import httpserver -from tornado import web +from tornado import web, ioloop from tornado.log import enable_pretty_logging, LogFormatter -from notebook.notebookapp import random_ports +from jupyter_core.paths import secure_write +from jupyter_server.serverapp import random_ports from ._version import __version__ from .services.sessions.sessionmanager import SessionManager from .services.kernels.manager import SeedingMappingKernelManager +from .auth.identity import GatewayIdentityProvider + # Only present for generating help documentation from .notebook_http import NotebookHTTPPersonality from .jupyter_websocket import JupyterWebsocketPersonality +from jupyter_server.auth.authorizer import AllowAllAuthorizer, Authorizer +from jupyter_server.services.kernels.connection.base import BaseKernelWebsocketConnection +from jupyter_server.services.kernels.connection.channels import ZMQChannelsWebsocketConnection + + # Add additional command line aliases aliases = dict(base_aliases) aliases.update({ @@ -310,6 +315,65 @@ def ssl_version_default(self): ssl_from_env = os.getenv(self.ssl_version_env) return ssl_from_env if ssl_from_env is None else int(ssl_from_env) + cookie_secret_file = Unicode( + config=True, help="""The file where the cookie secret is stored.""" + ) + + @default("cookie_secret_file") + def _default_cookie_secret_file(self): + return os.path.join(self.runtime_dir, "jupyter_cookie_secret") + + cookie_secret = Bytes( + b"", + config=True, + help="""The random bytes used to secure cookies. + By default this is a new random number every time you start the server. + Set it to a value in a config file to enable logins to persist across server sessions. + + Note: Cookie secrets should be kept private, do not share config files with + cookie_secret stored in plaintext (you can read the value from a file). + """, + ) + + @default("cookie_secret") + def _default_cookie_secret(self): + if os.path.exists(self.cookie_secret_file): + with open(self.cookie_secret_file, "rb") as f: + key = f.read() + else: + key = encodebytes(os.urandom(32)) + self._write_cookie_secret_file(key) + h = hmac.new(key, digestmod=hashlib.sha256) + # h.update(self.password.encode()) # password is deprecated in 2.0 + return h.digest() + + def _write_cookie_secret_file(self, secret): + """write my secret to my secret_file""" + self.log.info("Writing Jupyter server cookie secret to %s", self.cookie_secret_file) + try: + with secure_write(self.cookie_secret_file, True) as f: + f.write(secret) + except OSError as e: + self.log.error( + "Failed to write cookie secret to %s: %s", + self.cookie_secret_file, + e, + ) + + ws_ping_interval_env = "KG_WS_PING_INTERVAL_SECS" + ws_ping_interval_default_value = 30 + ws_ping_interval = Integer( + ws_ping_interval_default_value, + config=True, + help="""Specifies the ping interval(in seconds) that should be used by zmq port + associated with spawned kernels. Set this variable to 0 to disable ping mechanism. + (KG_WS_PING_INTERVAL_SECS env var)""", + ) + + @default("ws_ping_interval") + def _ws_ping_interval_default(self) -> int: + return int(os.getenv(self.ws_ping_interval_env, self.ws_ping_interval_default_value)) + _log_formatter_cls = LogFormatter # traitlet default is LevelFormatter @default("log_format") @@ -337,6 +401,27 @@ def _default_log_format(self) -> str: help="""The kernel manager class to use.""" ) + kernel_websocket_connection_class = Type( + default_value=ZMQChannelsWebsocketConnection, + klass=BaseKernelWebsocketConnection, + config=True, + help="""The kernel websocket connection class to use.""", + ) + + authorizer_class = Type( + default_value=AllowAllAuthorizer, + klass=Authorizer, + config=True, + help="The authorizer class to use.", + ) + + identity_provider_class = Type( + default_value=GatewayIdentityProvider, + klass=GatewayIdentityProvider, + config=True, + help="The identity provider class to use.", + ) + def _load_api_module(self, module_name): """Tries to import the given module name. @@ -391,7 +476,7 @@ def _load_notebook(self, uri): return notebook - def initialize(self, argv=None): + def initialize(self, argv=None, new_httpserver=True,): """Initializes the base class, configurable manager instances, the Tornado web app, and the tornado HTTP server. @@ -399,11 +484,18 @@ def initialize(self, argv=None): ---------- argv Command line arguments + + new_httpserver + Indicates that a new HTTP server instance should be created """ - super(KernelGatewayApp, self).initialize(argv) + super().initialize(argv) + + self.init_io_loop() self.init_configurables() self.init_webapp() - self.init_http_server() + self.init_signal() + if new_httpserver: + self.init_http_server() def init_configurables(self): """Initializes all configurable objects including a kernel manager, kernel @@ -445,17 +537,22 @@ def init_configurables(self): ) self.contents_manager = None + self.identity_provider = self.identity_provider_class(parent=self, log=self.log) + + self.authorizer = self.authorizer_class( + parent=self, log=self.log, identity_provider=self.identity_provider + ) + if self.prespawn_count: if self.max_kernels and self.prespawn_count > self.max_kernels: - raise RuntimeError('cannot prespawn {}; more than max kernels {}'.format( - self.prespawn_count, self.max_kernels) - ) + msg = f"Cannot prespawn {self.prespawn_count} kernels; more than max kernels {self.max_kernels}" + raise RuntimeError(msg) api_module = self._load_api_module(self.api) func = getattr(api_module, 'create_personality') self.personality = func(parent=self, log=self.log) - self.personality.init_configurables() + self.io_loop.call_later(0.1, lambda: asyncio.create_task(self.personality.init_configurables())) def init_webapp(self): """Initializes Tornado web application with uri handlers. @@ -493,8 +590,17 @@ def init_webapp(self): allow_origin=self.allow_origin, # Set base_url for use in request handlers base_url=self.base_url, + # Authentication + cookie_secret=self.cookie_secret, # Always allow remote access (has been limited to localhost >= notebook 5.6) - allow_remote_access=True + allow_remote_access=True, + # setting ws_ping_interval value that can allow it to be modified for the purpose of toggling ping mechanism + # for zmq web-sockets or increasing/decreasing web socket ping interval/timeouts. + ws_ping_interval=self.ws_ping_interval * 1000, + # Add a pass-through authorizer for now + authorizer=self.authorizer_class(parent=self), + identity_provider=self.identity_provider, + kernel_websocket_connection_class=self.kernel_websocket_connection_class, ) # promote the current personality's "config" tagged traitlet values to webapp settings @@ -562,9 +668,78 @@ def init_http_server(self): 'no available port could be found.') self.exit(1) + def init_io_loop(self): + """init self.io_loop so that an extension can use it by io_loop.call_later() to create background tasks""" + self.io_loop = ioloop.IOLoop.current() + + def init_signal(self): + """Initialize signal handlers.""" + if not sys.platform.startswith("win") and sys.stdin and sys.stdin.isatty(): + signal.signal(signal.SIGINT, self._handle_sigint) + signal.signal(signal.SIGTERM, self._signal_stop) + signal.signal(signal.SIGQUIT, self._signal_stop) + + def _handle_sigint(self, sig, frame): + """SIGINT handler spawns confirmation dialog""" + # register more forceful signal handler for ^C^C case + signal.signal(signal.SIGINT, self._signal_stop) + # request confirmation dialog in bg thread, to avoid + # blocking the App + thread = threading.Thread(target=self._confirm_exit) + thread.daemon = True + thread.start() + + def _restore_sigint_handler(self): + """callback for restoring original SIGINT handler""" + signal.signal(signal.SIGINT, self._handle_sigint) + + def _confirm_exit(self): + """confirm shutdown on ^C + + A second ^C, or answering 'y' within 5s will cause shutdown, + otherwise original SIGINT handler will be restored. + + This doesn't work on Windows. + """ + info = self.log.info + info("interrupted") + # Check if answer_yes is set + if self.answer_yes: + self.log.critical("Shutting down...") + # schedule stop on the main thread, + # since this might be called from a signal handler + self.stop(from_signal=True) + return + yes = "y" + no = "n" + sys.stdout.write("Shutdown this Jupyter server (%s/[%s])? " % (yes, no)) + sys.stdout.flush() + r, w, x = select.select([sys.stdin], [], [], 5) + if r: + line = sys.stdin.readline() + if line.lower().startswith(yes) and no not in line.lower(): + self.log.critical("Shutdown confirmed") + # schedule stop on the main thread, + # since this might be called from a signal handler + self.stop(from_signal=True) + return + else: + info("No answer for 5s:") + info("resuming operation...") + # no answer, or answer is no: + # set it back to original SIGINT handler + # use IOLoop.add_callback because signal.signal must be called + # from main thread + self.io_loop.add_callback_from_signal(self._restore_sigint_handler) + + def _signal_stop(self, sig, frame): + """Handle a stop signal.""" + self.log.critical("received signal %s, stopping", sig) + self.stop(from_signal=True) + def start_app(self): """Starts the application (with ioloop to follow). """ - super(KernelGatewayApp, self).start() + super().start() self.log.info('Jupyter Kernel Gateway {} is available at http{}://{}:{}'.format( KernelGatewayApp.version, 's' if self.keyfile else '', self.ip, self.port )) @@ -574,8 +749,6 @@ def start(self): self.start_app() - self.io_loop = ioloop.IOLoop.current() - if sys.platform != 'win32': signal.signal(signal.SIGHUP, signal.SIG_IGN) @@ -586,23 +759,36 @@ def start(self): except KeyboardInterrupt: self.log.info("Interrupted...") finally: - self.shutdown() + self.stop() - def stop(self): - """ - Stops the HTTP server and IO loop associated with the application. - """ - def _stop(): - self.http_server.stop() + async def _stop(self): + """Cleanup resources and stop the IO Loop.""" + await self.personality.shutdown() + await self.kernel_websocket_connection_class.close_all() + if getattr(self, "io_loop", None): self.io_loop.stop() - self.io_loop.add_callback(_stop) + + def stop(self, from_signal=False): + """Cleanup resources and stop the server.""" + if hasattr(self, "http_server"): + # Stop a server if its set. + self.http_server.stop() + if getattr(self, "io_loop", None): + # use IOLoop.add_callback because signal.signal must be called + # from main thread + if from_signal: + self.io_loop.add_callback_from_signal(self._stop) + else: + self.io_loop.add_callback(self._stop) def shutdown(self): """Stop all kernels in the pool.""" - self.personality.shutdown() + self.io_loop.add_callback(self._stop) + + async def async_shutdown(self): + """Stop all kernels in the pool.""" + if hasattr(self, "personality"): + await self.personality.shutdown() - def _signal_stop(self, sig, frame): - self.log.info("Received signal to terminate.") - self.io_loop.stop() launch_instance = KernelGatewayApp.launch_instance diff --git a/kernel_gateway/jupyter_websocket/__init__.py b/kernel_gateway/jupyter_websocket/__init__.py index dba4105..ec87d5d 100644 --- a/kernel_gateway/jupyter_websocket/__init__.py +++ b/kernel_gateway/jupyter_websocket/__init__.py @@ -9,7 +9,7 @@ from ..services.kernelspecs.handlers import default_handlers as default_kernelspec_handlers from ..services.sessions.handlers import default_handlers as default_session_handlers from .handlers import default_handlers as default_api_handlers -from notebook.utils import url_path_join +from jupyter_server.utils import url_path_join from traitlets import Bool, List, default from traitlets.config.configurable import LoggingConfigurable @@ -39,8 +39,14 @@ def list_kernels_default(self): def env_whitelist_default(self): return os.getenv(self.env_whitelist_env, '').split(',') - def init_configurables(self): - self.kernel_pool = KernelPool( + kernel_pool: KernelPool + + def __init__(self, *args, **kwargs): + super().__init__(**kwargs) + self.kernel_pool = KernelPool() + + async def init_configurables(self): + await self.kernel_pool.initialize( self.parent.prespawn_count, self.parent.kernel_manager ) @@ -70,12 +76,12 @@ def create_request_handlers(self): def should_seed_cell(self, code): """Determines whether the given code cell source should be executed when seeding a new kernel.""" - # seed all code cells + # seed all code cells in websocket personality return True - def shutdown(self): + async def shutdown(self): """Stop all kernels in the pool.""" - self.kernel_pool.shutdown() + await self.kernel_pool.shutdown() def create_personality(*args, **kwargs): diff --git a/kernel_gateway/jupyter_websocket/handlers.py b/kernel_gateway/jupyter_websocket/handlers.py index cba187f..506d3c1 100644 --- a/kernel_gateway/jupyter_websocket/handlers.py +++ b/kernel_gateway/jupyter_websocket/handlers.py @@ -2,11 +2,11 @@ # Distributed under the terms of the Modified BSD License. """Tornado handlers for kernel specs.""" -from notebook.utils import maybe_future -from tornado import gen, web -from ..mixins import TokenAuthorizationMixin, CORSMixin, JSONErrorsMixin +from tornado import web +from ..mixins import CORSMixin import os + class BaseSpecHandler(CORSMixin, web.StaticFileHandler): """Exposes the ability to return specifications from static files""" @staticmethod @@ -23,30 +23,31 @@ def initialize(self): """ web.StaticFileHandler.initialize(self, path=os.path.dirname(__file__)) - @gen.coroutine - def get(self): + async def get(self): """Handler for a get on a specific handler """ resource_name, content_type = self.get_resource_metadata() self.set_header('Content-Type', content_type) - res = web.StaticFileHandler.get(self, resource_name) - yield maybe_future(res) + res = await web.StaticFileHandler.get(self, resource_name) + return res def options(self, **kwargs): """Method for properly handling CORS pre-flight""" self.finish() + class SpecJsonHandler(BaseSpecHandler): """Exposes a JSON swagger specification""" @staticmethod def get_resource_metadata(): - return ('swagger.json','application/json') + return 'swagger.json','application/json' + class APIYamlHandler(BaseSpecHandler): """Exposes a YAML swagger specification""" @staticmethod def get_resource_metadata(): - return ('swagger.yaml', 'text/x-yaml') + return 'swagger.yaml', 'text/x-yaml' default_handlers = [ diff --git a/kernel_gateway/mixins.py b/kernel_gateway/mixins.py index d06ceb1..2b02469 100644 --- a/kernel_gateway/mixins.py +++ b/kernel_gateway/mixins.py @@ -2,14 +2,11 @@ # Distributed under the terms of the Modified BSD License. """Mixins for Tornado handlers.""" +from http.client import responses import json import traceback from tornado import web -try: - # py3 - from http.client import responses -except ImportError: - from httplib import responses + class CORSMixin(object): """Mixes CORS headers into tornado.web.RequestHandlers.""" @@ -18,16 +15,25 @@ class CORSMixin(object): 'kg_allow_headers': 'Access-Control-Allow-Headers', 'kg_allow_methods': 'Access-Control-Allow-Methods', 'kg_allow_origin': 'Access-Control-Allow-Origin', - 'kg_expose_headers' : 'Access-Control-Expose-Headers', + 'kg_expose_headers': 'Access-Control-Expose-Headers', 'kg_max_age': 'Access-Control-Max-Age' } - def set_default_headers(self): + + def set_cors_headers(self): """Sets the CORS headers as the default for all responses. Disables CSP configured by the notebook package. It's not necessary for a programmatic API. + + Notes + ----- + This method name was changed from set_default_headers to set_cors_header + when adding support for JupyterServer 2.x. In that release, JS changed the + way the headers were implemented due to changes in the way a user is authenticated. + See https://github.com/jupyter-server/jupyter_server/pull/671. """ - super(CORSMixin, self).set_default_headers() + super().set_cors_headers() + # Add CORS headers after default if they have a non-blank value for settings_name, header_name in self.SETTINGS_TO_HEADERS.items(): header_value = self.settings.get(settings_name) @@ -77,7 +83,7 @@ def prepare(self): client_token = None if client_token != server_token: return self.send_error(401) - return super(TokenAuthorizationMixin, self).prepare() + return super().prepare() class JSONErrorsMixin(object): diff --git a/kernel_gateway/notebook_http/__init__.py b/kernel_gateway/notebook_http/__init__.py index 41a7c90..beedc11 100644 --- a/kernel_gateway/notebook_http/__init__.py +++ b/kernel_gateway/notebook_http/__init__.py @@ -10,7 +10,7 @@ from .cell.parser import APICellParser from .swagger.handlers import SwaggerSpecHandler from .handlers import NotebookAPIHandler, parameterize_path, NotebookDownloadHandler -from notebook.utils import url_path_join +from jupyter_server.utils import url_path_join from traitlets import Bool, Unicode, Dict, default from traitlets.config.configurable import LoggingConfigurable @@ -56,8 +56,10 @@ def allow_notebook_download_default(self): def static_path_default(self): return os.getenv(self.static_path_env) + kernel_pool: ManagedKernelPool + def __init__(self, *args, **kwargs): - super(NotebookHTTPPersonality, self).__init__(*args, **kwargs) + super().__init__(**kwargs) # Import the module to use for cell endpoint parsing cell_parser_module = importlib.import_module(self.cell_parser) # Build the parser using the comment syntax for the notebook language @@ -71,10 +73,10 @@ def __init__(self, *args, **kwargs): comment_prefix=prefix, notebook_cells=self.parent.seed_notebook.cells) self.kernel_language = kernel_language + self.kernel_pool = ManagedKernelPool() - def init_configurables(self): - """Create a managed kernel pool""" - self.kernel_pool = ManagedKernelPool( + async def init_configurables(self): + await self.kernel_pool.initialize( self.parent.prespawn_count, self.parent.kernel_manager ) @@ -147,11 +149,12 @@ def should_seed_cell(self, code): """Determines whether the given code cell source should be executed when seeding a new kernel.""" # seed cells that are uninvolved with the presented API - return (not self.api_parser.is_api_cell(code) and not self.api_parser.is_api_response_cell(code)) + return not self.api_parser.is_api_cell(code) and not self.api_parser.is_api_response_cell(code) - def shutdown(self): + async def shutdown(self): """Stop all kernels in the pool.""" - self.kernel_pool.shutdown() + await self.kernel_pool.shutdown() + def create_personality(*args, **kwargs): return NotebookHTTPPersonality(*args, **kwargs) diff --git a/kernel_gateway/notebook_http/cell/parser.py b/kernel_gateway/notebook_http/cell/parser.py index 58ebd0a..056e7fe 100644 --- a/kernel_gateway/notebook_http/cell/parser.py +++ b/kernel_gateway/notebook_http/cell/parser.py @@ -7,6 +7,7 @@ from traitlets import Unicode from traitlets.config.configurable import LoggingConfigurable + def first_path_param_index(endpoint): """Gets the index to the first path parameter for the endpoint. The returned value is not the string index, but rather the depth of where the @@ -37,6 +38,7 @@ def first_path_param_index(endpoint): index = endpoint.count('/', 0, endpoint.find(':')) - 1 return index + class APICellParser(LoggingConfigurable): """A utility class for parsing Jupyter code cells to find API annotations of the form: @@ -67,8 +69,8 @@ class APICellParser(LoggingConfigurable): api_indicator = Unicode(default_value=r'{}\s+(GET|PUT|POST|DELETE)\s+(\/.*)+') api_response_indicator = Unicode(default_value=r'{}\s+ResponseInfo\s+(GET|PUT|POST|DELETE)\s+(\/.*)+') - def __init__(self, comment_prefix, *args, **kwargs): - super(APICellParser, self).__init__(*args, **kwargs) + def __init__(self, comment_prefix, notebook_cells=None, **kwargs): + super().__init__(**kwargs) self.kernelspec_api_indicator = re.compile(self.api_indicator.format(comment_prefix)) self.kernelspec_api_response_indicator = re.compile(self.api_response_indicator.format(comment_prefix)) @@ -214,5 +216,6 @@ def get_default_api_spec(self): """ return {'swagger': '2.0', 'paths': {}, 'info': {'version': '0.0.0'}} + def create_parser(*args, **kwargs): return APICellParser(*args, **kwargs) diff --git a/kernel_gateway/notebook_http/errors.py b/kernel_gateway/notebook_http/errors.py index 3b0bb42..54519c3 100644 --- a/kernel_gateway/notebook_http/errors.py +++ b/kernel_gateway/notebook_http/errors.py @@ -2,12 +2,14 @@ # Distributed under the terms of the Modified BSD License. """Exception classes for notebook-http mode.""" + class CodeExecutionError(Exception): """Raised when a notebook's code fails to execute in response to an API request. """ pass + class UnsupportedMethodError(Exception): """Raised when a notebook-defined API does not support the requested HTTP method. diff --git a/kernel_gateway/notebook_http/handlers.py b/kernel_gateway/notebook_http/handlers.py index bbd5d0a..341034c 100644 --- a/kernel_gateway/notebook_http/handlers.py +++ b/kernel_gateway/notebook_http/handlers.py @@ -5,16 +5,15 @@ import os import json import tornado.web -from notebook.utils import maybe_future from tornado.log import access_log -from .request_utils import (parse_body, parse_args, format_request, - headers_to_dict, parameterize_path) -from tornado import gen +from typing import Optional +from .request_utils import parse_body, parse_args, format_request, headers_to_dict, parameterize_path from tornado.concurrent import Future from ..mixins import TokenAuthorizationMixin, CORSMixin, JSONErrorsMixin from functools import partial from .errors import UnsupportedMethodError, CodeExecutionError + class NotebookAPIHandler(TokenAuthorizationMixin, CORSMixin, JSONErrorsMixin, @@ -152,14 +151,13 @@ def execute_code(self, kernel_client, kernel_id, source_code): If the kernel returns any error """ future = Future() - result_accumulator = {'stream' : [], 'error' : None, 'result' : None} + result_accumulator = {'stream': [], 'error': None, 'result': None} parent_header = kernel_client.execute(source_code) on_recv_func = partial(self.on_recv, result_accumulator, future, parent_header) self.kernel_pool.on_recv(kernel_id, on_recv_func) return future - @gen.coroutine - def _handle_request(self): + async def _handle_request(self): """Turns an HTTP request into annotated notebook code to execute on a kernel. @@ -167,7 +165,7 @@ def _handle_request(self): result of the kernel execution. Then finishes the Tornado response. """ self.response_future = Future() - kernel_client, kernel_id = yield self.kernel_pool.acquire() + kernel_client, kernel_id = await self.kernel_pool.acquire() try: # Method not supported if self.request.method not in self.sources: @@ -181,18 +179,18 @@ def _handle_request(self): source_code = self.sources[self.request.method] # Build the request dictionary request = json.dumps({ - 'body' : parse_body(self.request), - 'args' : parse_args(self.request.query_arguments), - 'path' : self.path_kwargs, - 'headers' : headers_to_dict(self.request.headers) + 'body': parse_body(self.request), + 'args': parse_args(self.request.query_arguments), + 'path': self.path_kwargs, + 'headers': headers_to_dict(self.request.headers) }) # Turn the request string into a valid code string request_code = format_request(request, self.kernel_language) # Run the request and source code and yield until there's a result access_log.debug('Request code for notebook cell is: {}'.format(request_code)) - yield self.execute_code(kernel_client, kernel_id, request_code) - source_result = yield self.execute_code(kernel_client, kernel_id, source_code) + await self.execute_code(kernel_client, kernel_id, request_code) + source_result = await self.execute_code(kernel_client, kernel_id, source_code) # If a response code cell exists, execute it if self.request.method in self.response_sources: @@ -200,7 +198,7 @@ def _handle_request(self): response_future = self.execute_code(kernel_client, kernel_id, response_code) # Wait for the response and parse the json value - response_result = yield response_future + response_result = await response_future response = json.loads(response_result) # Copy all the header values into the tornado response @@ -219,49 +217,45 @@ def _handle_request(self): except CodeExecutionError as err: self.write(str(err)) self.set_status(500) - # An unspported method was called on this handler + # An unsupported method was called on this handler except UnsupportedMethodError: self.set_status(405) finally: # Always make sure we release the kernel and finish the request self.response_future.set_result(None) self.kernel_pool.release(kernel_id) - self.finish() + await self.finish() - @gen.coroutine - def get(self, **kwargs): - self._handle_request() - yield self.response_future + async def get(self, **kwargs): + await self._handle_request() + await self.response_future - @gen.coroutine - def post(self, **kwargs): - self._handle_request() - yield self.response_future + async def post(self, **kwargs): + await self._handle_request() + await self.response_future - @gen.coroutine - def put(self, **kwargs): - self._handle_request() - yield self.response_future + async def put(self, **kwargs): + await self._handle_request() + await self.response_future - @gen.coroutine - def delete(self, **kwargs): - self._handle_request() - yield self.response_future + async def delete(self, **kwargs): + await self._handle_request() + await self.response_future def options(self, **kwargs): self.finish() + class NotebookDownloadHandler(TokenAuthorizationMixin, CORSMixin, JSONErrorsMixin, tornado.web.StaticFileHandler): """Handles requests to download the annotated notebook behind the web API. """ - def initialize(self, path): + def initialize(self, path: str, default_filename: Optional[str] = None): self.dirname, self.filename = os.path.split(path) super(NotebookDownloadHandler, self).initialize(self.dirname) - @gen.coroutine - def get(self, include_body=True): - res = super(NotebookDownloadHandler, self).get(self.filename, include_body) - yield maybe_future(res) + async def get(self, include_body: bool = True): + res = await super().get(path=self.filename, include_body=include_body) + return res diff --git a/kernel_gateway/notebook_http/request_utils.py b/kernel_gateway/notebook_http/request_utils.py index 599efd6..1b0ae49 100644 --- a/kernel_gateway/notebook_http/request_utils.py +++ b/kernel_gateway/notebook_http/request_utils.py @@ -4,15 +4,17 @@ import json import re -from tornado.httputil import parse_body_arguments +from typing import List, Union +from tornado.httputil import HTTPHeaders, HTTPServerRequest -_named_param_regex = re.compile('(:([^/\s]*))') -FORM_URLENCODED = 'application/x-www-form-urlencoded' -MULTIPART_FORM_DATA = 'multipart/form-data' -APPLICATION_JSON = 'application/json' -TEXT_PLAIN = 'text/plain' +_named_param_regex = re.compile(r"(:([^/\s]*))") +FORM_URLENCODED = "application/x-www-form-urlencoded" +MULTIPART_FORM_DATA = "multipart/form-data" +APPLICATION_JSON = "application/json" +TEXT_PLAIN = "text/plain" -def format_request(bundle, kernel_language=''): + +def format_request(bundle, kernel_language: str = "") -> str: """Creates an assignment statement of bundle JSON-encoded to a variable named `REQUEST` by default or kernel_language specific. @@ -23,15 +25,16 @@ def format_request(bundle, kernel_language=''): ` = ""` """ bundle = json.dumps(bundle) - if kernel_language.lower() == 'perl': - statement = "my $REQUEST = {}".format(bundle) - elif kernel_language.lower() == 'bash': - statement = "REQUEST={}".format(bundle) + if kernel_language.lower() == "perl": + statement = f"my $REQUEST = {bundle}" + elif kernel_language.lower() == "bash": + statement = f"REQUEST={bundle}" else: - statement = "REQUEST = {}".format(bundle) + statement = f"REQUEST = {bundle}" return statement -def parameterize_path(path): + +def parameterize_path(path: str) -> str: """Creates a regex to match all named parameters in a path. Parameters @@ -47,10 +50,11 @@ def parameterize_path(path): """ matches = re.findall(_named_param_regex, path) for match in matches: - path = path.replace(match[0], '(?P<{}>[^\/]+)'.format(match[1])) + path = path.replace(match[0], r"(?P<{}>[^\/]+)".format(match[1])) return path.strip() -def parse_body(request): + +def parse_body(request: HTTPServerRequest) -> Union[str, dict]: """Parses the body of an HTTP request based on its Content-Type. If no Content-Type is found, treats the value as plain text. @@ -67,9 +71,9 @@ def parse_body(request): """ content_type = TEXT_PLAIN body = request.body - body = body.decode(encoding='UTF-8') if body else '' - if 'Content-Type' in request.headers: - content_type = request.headers['Content-Type'] + body = body.decode(encoding="UTF-8") if body else "" + if "Content-Type" in request.headers: + content_type = request.headers["Content-Type"] return_body = body if content_type == FORM_URLENCODED or content_type.startswith(MULTIPART_FORM_DATA): # If there is form data, we already have the values in body_arguments, we @@ -84,7 +88,8 @@ def parse_body(request): pass return return_body -def parse_args(args): + +def parse_args(args: List[str]) -> dict: """Decodes UTF-8 encoded argument values. Parameters @@ -101,10 +106,11 @@ def parse_args(args): for key in args: rv[key] = [] for value in args[key]: - rv[key].append(value.decode(encoding='UTF-8')) + rv[key].append(value.decode(encoding="UTF-8")) return rv -def headers_to_dict(headers): + +def headers_to_dict(headers: HTTPHeaders) -> dict: """Turns a set of tornado headers into a Python dict. Repeat headers are aggregated into lists. diff --git a/kernel_gateway/notebook_http/swagger/builders.py b/kernel_gateway/notebook_http/swagger/builders.py index 24ce78f..16c5787 100644 --- a/kernel_gateway/notebook_http/swagger/builders.py +++ b/kernel_gateway/notebook_http/swagger/builders.py @@ -4,6 +4,7 @@ import os + class SwaggerSpecBuilder(object): """Builds a Swagger specification. diff --git a/kernel_gateway/notebook_http/swagger/handlers.py b/kernel_gateway/notebook_http/swagger/handlers.py index 6623305..d51a1be 100644 --- a/kernel_gateway/notebook_http/swagger/handlers.py +++ b/kernel_gateway/notebook_http/swagger/handlers.py @@ -9,6 +9,7 @@ from .builders import SwaggerSpecBuilder from ...mixins import TokenAuthorizationMixin, CORSMixin, JSONErrorsMixin + class SwaggerSpecHandler(TokenAuthorizationMixin, CORSMixin, JSONErrorsMixin, @@ -38,8 +39,8 @@ def initialize(self, notebook_path, source_cells, cell_parser): spec_builder.set_default_title(notebook_path) SwaggerSpecHandler.output = json.dumps(spec_builder.build()) - def get(self, **kwargs): + async def get(self, **kwargs): """Responds with the spec in JSON format.""" self.set_header('Content-Type', 'application/json') self.set_status(200) - self.finish(self.output) + await self.finish(self.output) diff --git a/kernel_gateway/notebook_http/swagger/parser.py b/kernel_gateway/notebook_http/swagger/parser.py index 005d90a..5527247 100644 --- a/kernel_gateway/notebook_http/swagger/parser.py +++ b/kernel_gateway/notebook_http/swagger/parser.py @@ -8,6 +8,7 @@ from traitlets import List, Unicode from traitlets.config.configurable import LoggingConfigurable + def _swaggerlet_from_markdown(cell_source): """ Pulls apart the first block comment of a cell's source, then tries to parse it as a JSON object. If it contains a 'swagger' @@ -34,6 +35,7 @@ def _swaggerlet_from_markdown(cell_source): pass return None + class SwaggerCellParser(LoggingConfigurable): """A utility class for parsing Jupyter code cells to find API annotations of the form: @@ -66,8 +68,9 @@ class SwaggerCellParser(LoggingConfigurable): operation_response_indicator = Unicode(default_value=r'{}\s*ResponseInfo\s+operationId:\s*(.*)') notebook_cells = List() - def __init__(self, comment_prefix, *args, **kwargs): - super(SwaggerCellParser, self).__init__(*args, **kwargs) + def __init__(self, comment_prefix, notebook_cells, **kwargs): + super(SwaggerCellParser, self).__init__(**kwargs) + self.notebook_cells = notebook_cells self.kernelspec_operation_indicator = re.compile(self.operation_indicator.format(comment_prefix)) self.kernelspec_operation_response_indicator = re.compile(self.operation_response_indicator.format(comment_prefix)) self.swagger = dict() @@ -92,10 +95,10 @@ def __init__(self, comment_prefix, *args, **kwargs): operationIdsDeclared.append(operationId) for operationId in operationIdsDeclared: if operationId not in operationIdsFound: - self.log.warning('Operation {} was declared but not referenced in a cell'.format(operationId)) + self.log.warning(f'Operation {operationId} was declared but not referenced in a cell') for operationId in operationIdsFound: if operationId not in operationIdsDeclared: - self.log.warning('Operation {} was referenced in a cell but not declared'.format(operationId)) + self.log.warning(f'Operation {operationId} was referenced in a cell but not declared') else: self.log.warning('No paths documented in Swagger documentation') @@ -286,5 +289,6 @@ def get_default_api_spec(self): return self.swagger return {'swagger': '2.0', 'paths': {}, 'info': {'version': '0.0.0', 'title': 'Default Title'}} + def create_parser(*args, **kwargs): return SwaggerCellParser(*args, **kwargs) diff --git a/kernel_gateway/services/kernels/handlers.py b/kernel_gateway/services/kernels/handlers.py index 9ebc0c5..677ecdf 100644 --- a/kernel_gateway/services/kernels/handlers.py +++ b/kernel_gateway/services/kernels/handlers.py @@ -5,7 +5,7 @@ import os import tornado -import notebook.services.kernels.handlers as notebook_handlers +import jupyter_server.services.kernels.handlers as server_handlers from tornado import gen from functools import partial from ...mixins import TokenAuthorizationMixin, CORSMixin, JSONErrorsMixin @@ -14,7 +14,7 @@ class MainKernelHandler(TokenAuthorizationMixin, CORSMixin, JSONErrorsMixin, - notebook_handlers.MainKernelHandler): + server_handlers.MainKernelHandler): """Extends the notebook main kernel handler with token auth, CORS, and JSON errors. """ @@ -27,8 +27,7 @@ def env_whitelist(self): def env_process_whitelist(self): return self.settings['kg_env_process_whitelist'] - @gen.coroutine - def post(self): + async def post(self): """Overrides the super class method to honor the max number of allowed kernels configuration setting and to support custom kernel environment variables for every request. @@ -59,22 +58,22 @@ def post(self): env = {'PATH': os.getenv('PATH', '')} # Whitelist environment variables from current process environment env.update({key: value for key, value in os.environ.items() - if key in self.env_process_whitelist}) + if key in self.env_process_whitelist}) # Whitelist KERNEL_* args and those allowed by configuration from client env.update({key: value for key, value in model['env'].items() - if key.startswith('KERNEL_') or key in self.env_whitelist}) + if key.startswith('KERNEL_') or key in self.env_whitelist}) # No way to override the call to start_kernel on the kernel manager # so do a temporary partial (ugh) orig_start = self.kernel_manager.start_kernel self.kernel_manager.start_kernel = partial(self.kernel_manager.start_kernel, env=env) try: - yield super(MainKernelHandler, self).post() + await super().post() finally: self.kernel_manager.start_kernel = orig_start else: - yield super(MainKernelHandler, self).post() + await super().post() - def get(self): + async def get(self): """Overrides the super class method to honor the kernel listing configuration setting. @@ -88,7 +87,7 @@ def get(self): if not self.settings.get('kg_list_kernels'): raise tornado.web.HTTPError(403, 'Forbidden') else: - super(MainKernelHandler, self).get() + await super(MainKernelHandler, self).get() def options(self, **kwargs): """Method for properly handling CORS pre-flight""" @@ -98,7 +97,7 @@ def options(self, **kwargs): class KernelHandler(TokenAuthorizationMixin, CORSMixin, JSONErrorsMixin, - notebook_handlers.KernelHandler): + server_handlers.KernelHandler): """Extends the notebook kernel handler with token auth, CORS, and JSON errors. """ @@ -108,7 +107,7 @@ def options(self, **kwargs): default_handlers = [] -for path, cls in notebook_handlers.default_handlers: +for path, cls in server_handlers.default_handlers: if cls.__name__ in globals(): # Use the same named class from here if it exists default_handlers.append((path, globals()[cls.__name__])) diff --git a/kernel_gateway/services/kernels/manager.py b/kernel_gateway/services/kernels/manager.py index cc54e12..7a596cb 100644 --- a/kernel_gateway/services/kernels/manager.py +++ b/kernel_gateway/services/kernels/manager.py @@ -1,22 +1,32 @@ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. """Kernel manager that optionally seeds kernel memory.""" +import os +from typing import List, Optional +from traitlets import default +from jupyter_server.services.kernels.kernelmanager import AsyncMappingKernelManager +from jupyter_client.ioloop import AsyncIOLoopKernelManager -from functools import partial -from tornado import gen, ioloop -from notebook.services.kernels.kernelmanager import MappingKernelManager -from notebook.utils import maybe_future -from jupyter_client.ioloop import IOLoopKernelManager -class SeedingMappingKernelManager(MappingKernelManager): - """Extends the notebook kernel manager to optionally execute the contents +class SeedingMappingKernelManager(AsyncMappingKernelManager): + """Extends the server's kernel manager to optionally execute the contents of a notebook on a kernel when it starts. """ + + _seed_source: Optional[List] + _seed_kernelspec: Optional[str] + + @default("root_dir") + def _default_root_dir(self): + return os.getcwd() + def _kernel_manager_class_default(self): - return 'kernel_gateway.services.kernels.manager.KernelGatewayIOLoopKernelManager' + return ( + "kernel_gateway.services.kernels.manager.KernelGatewayIOLoopKernelManager" + ) @property - def seed_kernelspec(self): + def seed_kernelspec(self) -> Optional[str]: """Gets the kernel spec name for run the seed notebook. Prefers the spec name forced by configuration over the spec in the @@ -27,21 +37,23 @@ def seed_kernelspec(self): str Name of the notebook kernelspec or None if no seed notebook exists """ - if hasattr(self, '_seed_kernelspec'): + if hasattr(self, "_seed_kernelspec"): return self._seed_kernelspec if self.parent.seed_notebook: if self.parent.force_kernel_name: self._seed_kernelspec = self.parent.force_kernel_name else: - self._seed_kernelspec = self.parent.seed_notebook['metadata']['kernelspec']['name'] + self._seed_kernelspec = self.parent.seed_notebook["metadata"][ + "kernelspec" + ]["name"] else: self._seed_kernelspec = None return self._seed_kernelspec @property - def seed_source(self): + def seed_source(self) -> Optional[List]: """Gets the source of the seed notebook in cell order. Returns @@ -49,37 +61,37 @@ def seed_source(self): list Notebook code cell contents or None if no seed notebook exists """ - if hasattr(self, '_seed_source'): + if hasattr(self, "_seed_source"): return self._seed_source if self.parent.seed_notebook: self._seed_source = [ - cell['source'] for cell in self.parent.seed_notebook.cells - if cell['cell_type'] == 'code' + cell["source"] + for cell in self.parent.seed_notebook.cells + if cell["cell_type"] == "code" ] else: self._seed_source = None return self._seed_source - def start_seeded_kernel(self, *args, **kwargs): + async def start_seeded_kernel(self, *args, **kwargs): """Start a kernel using the language specified in the seed notebook. Run synchronously so that any exceptions thrown while seed rise up to the caller. """ - start = partial(self.start_kernel, kernel_name=self.seed_kernelspec, - *args, **kwargs) - return ioloop.IOLoop.current().run_sync(start) + await self.start_kernel(kernel_name=self.seed_kernelspec, *args, **kwargs) - @gen.coroutine - def start_kernel(self, *args, **kwargs): + async def start_kernel(self, *args, **kwargs): """Starts a kernel and then executes a list of code cells on it if a seed notebook exists. """ if self.parent.force_kernel_name: - kwargs['kernel_name'] = self.parent.force_kernel_name - kernel_id = yield maybe_future(super(SeedingMappingKernelManager, self).start_kernel(*args, **kwargs)) + kwargs["kernel_name"] = self.parent.force_kernel_name + kernel_id = await super(SeedingMappingKernelManager, self).start_kernel( + *args, **kwargs + ) if kernel_id and self.seed_source is not None: # Only run source if the kernel spec matches the notebook kernel spec @@ -95,34 +107,44 @@ def start_kernel(self, *args, **kwargs): ) # Only start channels and wait for ready in HTTP mode client.start_channels() - client.wait_for_ready() + await client.wait_for_ready() for code in self.seed_source: # Check with the personality whether it wants the cell # executed if self.parent.personality.should_seed_cell(code): client.execute(code) - msg = client.get_shell_msg() - if msg['content']['status'] != 'ok': - # Shutdown the channels to remove any lingering ZMQ messages - client.stop_channels() - # Shutdown the kernel - self.shutdown_kernel(kernel_id) - raise RuntimeError('Error seeding kernel memory', msg['content']) + msg_type = "kernel_info_reply" + while msg_type == "kernel_info_reply": + msg = await client.get_shell_msg() + msg_type = msg["msg_type"] + if msg["content"]["status"] != "ok": + # Shutdown the channels to remove any lingering ZMQ messages + client.stop_channels() + # Shutdown the kernel + await self.shutdown_kernel(kernel_id) + raise RuntimeError( + "Error seeding kernel memory", msg["content"] + ) # Shutdown the channels to remove any lingering ZMQ messages client.stop_channels() - raise gen.Return(kernel_id) + return kernel_id -class KernelGatewayIOLoopKernelManager(IOLoopKernelManager): + +class KernelGatewayIOLoopKernelManager(AsyncIOLoopKernelManager): """Extends the IOLoopKernelManager used by the SeedingMappingKernelManager. Sets the environment variable 'KERNEL_GATEWAY' to '1' to indicate that the kernel is executing within a Jupyter Kernel Gateway instance. Removes the - KG_AUTH_TOKEN from the environment variables passed to the kernel when it + KG_AUTH_TOKEN from the environment variables passed to the kernel when it starts. """ - def _launch_kernel(self, kernel_cmd, **kw): - env = kw['env'] - env['KERNEL_GATEWAY'] = '1' - if 'KG_AUTH_TOKEN' in env: - del env['KG_AUTH_TOKEN'] - return super(KernelGatewayIOLoopKernelManager, self)._launch_kernel(kernel_cmd, **kw) + + async def _async_launch_kernel(self, kernel_cmd, **kw): + # TODO - should probably figure out a better place to deal with this + env = kw["env"] + env["KERNEL_GATEWAY"] = "1" + if "KG_AUTH_TOKEN" in env: + del env["KG_AUTH_TOKEN"] + return await super(KernelGatewayIOLoopKernelManager, self)._async_launch_kernel( + kernel_cmd, **kw + ) diff --git a/kernel_gateway/services/kernels/pool.py b/kernel_gateway/services/kernels/pool.py index de1c988..e74591f 100644 --- a/kernel_gateway/services/kernels/pool.py +++ b/kernel_gateway/services/kernels/pool.py @@ -2,11 +2,15 @@ # Distributed under the terms of the Modified BSD License. """Kernel pools that track and delegate to kernels.""" -from jupyter_client.session import Session - +import asyncio +import tornado.gen from tornado.locks import Semaphore -from tornado import gen +from tornado.concurrent import Future from traitlets.config.configurable import LoggingConfigurable +from typing import Awaitable, List, Optional + +from jupyter_client.session import Session +from jupyter_server.services.kernels.kernelmanager import MappingKernelManager class KernelPool(LoggingConfigurable): @@ -22,19 +26,37 @@ class KernelPool(LoggingConfigurable): kernel_manager Kernel manager instance """ - def __init__(self, prespawn_count, kernel_manager): + + kernel_manager: Optional[MappingKernelManager] + pool_initialized: Future + + def __init__(self): + super().__init__() + self.kernel_manager = None + self.pool_initialized = Future() + + async def initialize(self, prespawn_count, kernel_manager, **kwargs): self.kernel_manager = kernel_manager # Make sure we've got a int if not prespawn_count: prespawn_count = 0 + + kernels_to_spawn: List[Awaitable] = [] for _ in range(prespawn_count): - self.kernel_manager.start_seeded_kernel() + kernels_to_spawn.append(self.kernel_manager.start_seeded_kernel()) + + await asyncio.gather(*kernels_to_spawn) + + # Indicate that pool initialization has completed + self.pool_initialized.set_result(True) - def shutdown(self): + async def shutdown(self): """Shuts down all running kernels.""" + await self.pool_initialized kids = self.kernel_manager.list_kernel_ids() for kid in kids: - self.kernel_manager.shutdown_kernel(kid, now=True) + await self.kernel_manager.shutdown_kernel(kid, now=True) + class ManagedKernelPool(KernelPool): """Spawns a pool of kernels that are treated as identical delegates for @@ -61,19 +83,29 @@ class ManagedKernelPool(KernelPool): kernel_semaphore : tornado.locks.Semaphore Semaphore that controls access to the kernel pool """ - def __init__(self, prespawn_count, kernel_manager): + kernel_clients: dict + on_recv_funcs: dict + kernel_pool: list + kernel_semaphore: Semaphore + managed_pool_initialized: Future + + def __init__(self): + super().__init__() + self.kernel_clients = {} + self.on_recv_funcs = {} + self.kernel_pool = [] + self.managed_pool_initialized = Future() + + async def initialize(self, prespawn_count, kernel_manager, **kwargs): # Make sure there's at least one kernel as a delegate if not prespawn_count: prespawn_count = 1 - super(ManagedKernelPool, self).__init__(prespawn_count, kernel_manager) + self.kernel_semaphore = Semaphore(prespawn_count) - self.kernel_clients = {} - self.on_recv_funcs = {} - self.kernel_pool = [] + await super(ManagedKernelPool, self).initialize(prespawn_count, kernel_manager) kernel_ids = self.kernel_manager.list_kernel_ids() - self.kernel_semaphore = Semaphore(len(kernel_ids)) # Create clients and iopub handlers for prespawned kernels for kernel_id in kernel_ids: @@ -82,8 +114,10 @@ def __init__(self, prespawn_count, kernel_manager): iopub = self.kernel_manager.connect_iopub(kernel_id) iopub.on_recv(self.create_on_reply(kernel_id)) - @gen.coroutine - def acquire(self): + # Indicate that pool initialization has completed + self.managed_pool_initialized.set_result(True) + + async def acquire(self): """Gets a kernel client and removes it from the available pool of clients. @@ -92,10 +126,11 @@ def acquire(self): tuple Kernel client instance, kernel ID """ - yield self.kernel_semaphore.acquire() + await self.managed_pool_initialized + await self.kernel_semaphore.acquire() kernel_id = self.kernel_pool[0] del self.kernel_pool[0] - raise gen.Return((self.kernel_clients[kernel_id], kernel_id)) + return self.kernel_clients[kernel_id], kernel_id def release(self, kernel_id): """Puts a kernel back into the pool of kernels available to handle @@ -164,12 +199,13 @@ def on_recv(self, kernel_id, func): """ self.on_recv_funcs[kernel_id] = func - def shutdown(self): + async def shutdown(self): """Shuts down all kernels and their clients. """ + await self.managed_pool_initialized for kid in self.kernel_clients: self.kernel_clients[kid].stop_channels() - self.kernel_manager.shutdown_kernel(kid, now=True) + await self.kernel_manager.shutdown_kernel(kid, now=True) # Any remaining kernels that were not created for our pool should be shutdown - super(ManagedKernelPool, self).shutdown() + await super(ManagedKernelPool, self).shutdown() diff --git a/kernel_gateway/services/kernelspecs/handlers.py b/kernel_gateway/services/kernelspecs/handlers.py index 01f46b1..131941d 100644 --- a/kernel_gateway/services/kernelspecs/handlers.py +++ b/kernel_gateway/services/kernelspecs/handlers.py @@ -2,20 +2,20 @@ # Distributed under the terms of the Modified BSD License. """Tornado handlers for kernel specs.""" -import notebook.services.kernelspecs.handlers as notebook_handlers -import notebook.kernelspecs.handlers as notebook_kernelspecs_resources_handlers +import jupyter_server.services.kernelspecs.handlers as server_handlers +import jupyter_server.kernelspecs.handlers as server_kernelspecs_resources_handlers from ...mixins import TokenAuthorizationMixin, CORSMixin, JSONErrorsMixin -# Extends the default handlers from the notebook package with token auth, CORS +# Extends the default handlers from the jupyter_server package with token auth, CORS # and JSON errors. default_handlers = [] -for path, cls in notebook_handlers.default_handlers: +for path, cls in server_handlers.default_handlers: # Everything should have CORS and token auth bases = (TokenAuthorizationMixin, CORSMixin, JSONErrorsMixin, cls) default_handlers.append((path, type(cls.__name__, bases, {}))) -for path, cls in notebook_kernelspecs_resources_handlers.default_handlers: +for path, cls in server_kernelspecs_resources_handlers.default_handlers: # Everything should have CORS and token auth bases = (TokenAuthorizationMixin, CORSMixin, JSONErrorsMixin, cls) default_handlers.append((path, type(cls.__name__, bases, {}))) diff --git a/kernel_gateway/services/sessions/handlers.py b/kernel_gateway/services/sessions/handlers.py index 0db38ae..1878e91 100644 --- a/kernel_gateway/services/sessions/handlers.py +++ b/kernel_gateway/services/sessions/handlers.py @@ -3,17 +3,18 @@ """Tornado handlers for session CRUD.""" import tornado -import notebook.services.sessions.handlers as notebook_handlers +import jupyter_server.services.sessions.handlers as server_handlers from ...mixins import TokenAuthorizationMixin, CORSMixin, JSONErrorsMixin + class SessionRootHandler(TokenAuthorizationMixin, CORSMixin, JSONErrorsMixin, - notebook_handlers.SessionRootHandler): + server_handlers.SessionRootHandler): """Extends the notebook root session handler with token auth, CORS, and JSON errors. """ - def get(self): + async def get(self): """Overrides the super class method to honor the kernel listing configuration setting. @@ -25,10 +26,11 @@ def get(self): if 'kg_list_kernels' not in self.settings or self.settings['kg_list_kernels'] != True: raise tornado.web.HTTPError(403, 'Forbidden') else: - super(SessionRootHandler, self).get() + await super(SessionRootHandler, self).get() + default_handlers = [] -for path, cls in notebook_handlers.default_handlers: +for path, cls in server_handlers.default_handlers: if cls.__name__ in globals(): # Use the same named class from here if it exists default_handlers.append((path, globals()[cls.__name__])) diff --git a/kernel_gateway/services/sessions/sessionmanager.py b/kernel_gateway/services/sessions/sessionmanager.py index 210bc96..fa4bbe6 100644 --- a/kernel_gateway/services/sessions/sessionmanager.py +++ b/kernel_gateway/services/sessions/sessionmanager.py @@ -3,9 +3,9 @@ """Session manager that keeps all its metadata in memory.""" import uuid -from notebook.utils import maybe_future -from tornado import web, gen +from tornado import web from traitlets.config.configurable import LoggingConfigurable +from typing import List, Optional class SessionManager(LoggingConfigurable): @@ -26,13 +26,13 @@ class SessionManager(LoggingConfigurable): _columns : list Session metadata key names """ - def __init__(self, kernel_manager, *args, **kwargs): - super(SessionManager, self).__init__(*args, **kwargs) + def __init__(self, kernel_manager, **kwargs): + super(SessionManager, self).__init__(**kwargs) self.kernel_manager = kernel_manager self._sessions = [] self._columns = ['session_id', 'path', 'kernel_id'] - def session_exists(self, path, *args, **kwargs): + def session_exists(self, path, *args, **kwargs) -> bool: """Checks to see if the session with the given path value exists. Parameters @@ -46,12 +46,11 @@ def session_exists(self, path, *args, **kwargs): """ return bool([item for item in self._sessions if item['path'] == path]) - def new_session_id(self): + def new_session_id(self) -> str: """Creates a uuid for a new session.""" return str(uuid.uuid4()) - @gen.coroutine - def create_session(self, path=None, kernel_name=None, kernel_id=None, *args, **kwargs): + async def create_session(self, path=None, kernel_name=None, kernel_id=None, *args, **kwargs) -> dict: """Creates a session and returns its model. Launches a kernel and stores the session metadata for later lookup. @@ -72,10 +71,10 @@ def create_session(self, path=None, kernel_name=None, kernel_id=None, *args, **k """ session_id = self.new_session_id() # allow nbm to specify kernels cwd - kernel_id = yield maybe_future(self.kernel_manager.start_kernel(path=path, kernel_name=kernel_name)) - raise gen.Return(self.save_session(session_id, path=path, kernel_id=kernel_id)) + kernel_id = await self.kernel_manager.start_kernel(path=path, kernel_name=kernel_name) + return self.save_session(session_id, path=path, kernel_id=kernel_id) - def save_session(self, session_id, path=None, kernel_id=None, *args, **kwargs): + def save_session(self, session_id, path=None, kernel_id=None, *args, **kwargs) -> dict: """Saves the metadata for the session with the given `session_id`. Given a `session_id` (and any other of the arguments), this method @@ -101,7 +100,7 @@ def save_session(self, session_id, path=None, kernel_id=None, *args, **kwargs): return self.get_session(session_id=session_id) - def get_session_by_key(self, key, val, *args, **kwargs): + def get_session_by_key(self, key, val, *args, **kwargs) -> Optional[dict]: """Gets the first session with the given key/value pair. Parameters @@ -187,6 +186,10 @@ def update_session(self, session_id, *args, **kwargs): if not row: raise KeyError + # if kernel_id is in kwargs, validate it prior to removing the row... + if 'kernel_id' in kwargs and kwargs['kernel_id'] not in self.kernel_manager: + raise KeyError(f"Kernel '{kwargs['kernel_id']}' does not exist.") + self._sessions.remove(row) if 'path' in kwargs: @@ -223,7 +226,7 @@ def row_to_model(self, row, *args, **kwargs): } return model - def list_sessions(self, *args, **kwargs): + def list_sessions(self, *args, **kwargs) -> List: """Returns a list of dictionaries containing all the information from the session store. @@ -232,10 +235,10 @@ def list_sessions(self, *args, **kwargs): list Dictionaries from `row_to_model` """ - l = [self.row_to_model(r) for r in self._sessions] - return l + sessions = [self.row_to_model(r) for r in self._sessions] + return sessions - def delete_session(self, session_id, *args, **kwargs): + async def delete_session(self, session_id, *args, **kwargs): """Deletes the session in the session store with given `session_id`. Raises @@ -248,5 +251,5 @@ def delete_session(self, session_id, *args, **kwargs): if not s: raise KeyError - self.kernel_manager.shutdown_kernel(s['kernel_id']) + await self.kernel_manager.shutdown_kernel(s['kernel_id']) self._sessions.remove(s) diff --git a/kernel_gateway/tests/__init__.py b/kernel_gateway/tests/__init__.py index 45de8cc..c146300 100644 --- a/kernel_gateway/tests/__init__.py +++ b/kernel_gateway/tests/__init__.py @@ -1,11 +1,2 @@ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. -from tornado import ioloop - -def teardown(): - """The test fixture appears to leak something on certain platforms that - endlessly tries an async socket connect and fails after the tests end. - As a stopgap, force a cleanup here. - """ - ioloop.IOLoop.current().stop() - ioloop.IOLoop.current().close(True) diff --git a/kernel_gateway/tests/notebook_http/cell/test_parser.py b/kernel_gateway/tests/notebook_http/cell/test_parser.py index 513e5ca..6ab63e2 100644 --- a/kernel_gateway/tests/notebook_http/cell/test_parser.py +++ b/kernel_gateway/tests/notebook_http/cell/test_parser.py @@ -2,114 +2,113 @@ # Distributed under the terms of the Modified BSD License. """Tests for notebook cell parsing.""" -import unittest import sys from kernel_gateway.notebook_http.cell.parser import APICellParser -class TestAPICellParser(unittest.TestCase): +class TestAPICellParser: """Unit tests the APICellParser class.""" def test_is_api_cell(self): """Parser should correctly identify annotated API cells.""" - parser = APICellParser(comment_prefix='#') - self.assertTrue(parser.is_api_cell('# GET /yes'), 'API cell was not detected') - self.assertFalse(parser.is_api_cell('no'), 'API cell was not detected') + parser = APICellParser(comment_prefix="#") + assert parser.is_api_cell("# GET /yes"), "API cell was not detected" + assert parser.is_api_cell("no") is False, "API cell was not detected" def test_endpoint_sort_default_strategy(self): """Parser should sort duplicate endpoint paths.""" source_cells = [ - '# POST /:foo', - '# POST /hello/:foo', - '# GET /hello/:foo', - '# PUT /hello/world' + "# POST /:foo", + "# POST /hello/:foo", + "# GET /hello/:foo", + "# PUT /hello/world" ] - parser = APICellParser(comment_prefix='#') + parser = APICellParser(comment_prefix="#") endpoints = parser.endpoints(source_cells) - expected_values = ['/hello/world', '/hello/:foo', '/:foo'] + expected_values = ["/hello/world", "/hello/:foo", "/:foo"] for index in range(0, len(expected_values)): endpoint, _ = endpoints[index] - self.assertEqual(expected_values[index], endpoint, 'Endpoint was not found in expected order') + assert expected_values[index] == endpoint, "Endpoint was not found in expected order" def test_endpoint_sort_custom_strategy(self): """Parser should sort duplicate endpoint paths using a custom sort strategy. """ source_cells = [ - '# POST /1', - '# POST /+', - '# GET /a' + "# POST /1", + "# POST /+", + "# GET /a" ] def custom_sort_fun(endpoint): index = sys.maxsize - if endpoint.find('1') >= 0: + if endpoint.find("1") >= 0: return 0 - elif endpoint.find('a') >= 0: + elif endpoint.find("a") >= 0: return 1 else: return 2 - parser = APICellParser(comment_prefix='#') + parser = APICellParser(comment_prefix="#") endpoints = parser.endpoints(source_cells, custom_sort_fun) - expected_values = ['/+', '/a', '/1'] + expected_values = ["/+", "/a", "/1"] for index in range(0, len(expected_values)): endpoint, _ = endpoints[index] - self.assertEqual(expected_values[index], endpoint, 'Endpoint was not found in expected order') + assert expected_values[index] == endpoint, "Endpoint was not found in expected order" def test_get_cell_endpoint_and_verb(self): """Parser should extract API endpoint and verb from cell annotations.""" - parser = APICellParser(comment_prefix='#') - endpoint, verb = parser.get_cell_endpoint_and_verb('# GET /foo') - self.assertEqual(endpoint, '/foo', 'Endpoint was not extracted correctly') - self.assertEqual(verb, 'GET', 'Endpoint was not extracted correctly') - endpoint, verb = parser.get_cell_endpoint_and_verb('# POST /bar/quo') - self.assertEqual(endpoint, '/bar/quo', 'Endpoint was not extracted correctly') - self.assertEqual(verb, 'POST', 'Endpoint was not extracted correctly') + parser = APICellParser(comment_prefix="#") + endpoint, verb = parser.get_cell_endpoint_and_verb("# GET /foo") + assert endpoint, "/foo" == "Endpoint was not extracted correctly" + assert verb, "GET" == "Endpoint was not extracted correctly" + endpoint, verb = parser.get_cell_endpoint_and_verb("# POST /bar/quo") + assert endpoint, "/bar/quo" == "Endpoint was not extracted correctly" + assert verb, "POST" == "Endpoint was not extracted correctly" - endpoint, verb = parser.get_cell_endpoint_and_verb('some regular code') - self.assertEqual(endpoint, None, 'Endpoint was not extracted correctly') - self.assertEqual(verb, None, 'Endpoint was not extracted correctly') + endpoint, verb = parser.get_cell_endpoint_and_verb("some regular code") + assert endpoint is None, "Endpoint was not extracted correctly" + assert verb is None, "Endpoint was not extracted correctly" def test_endpoint_concatenation(self): """Parser should concatenate multiple cells with the same verb+path.""" source_cells = [ - '# POST /foo/:bar', - '# POST /foo/:bar', - '# POST /foo', - 'ignored', - '# GET /foo/:bar' + "# POST /foo/:bar", + "# POST /foo/:bar", + "# POST /foo", + "ignored", + "# GET /foo/:bar" ] - parser = APICellParser(comment_prefix='#') + parser = APICellParser(comment_prefix="#") endpoints = parser.endpoints(source_cells) - self.assertEqual(len(endpoints), 2) + assert len(endpoints) == 2 # for ease of testing endpoints = dict(endpoints) - self.assertEqual(len(endpoints['/foo']), 1) - self.assertEqual(len(endpoints['/foo/:bar']), 2) - self.assertEqual(endpoints['/foo']['POST'], '# POST /foo\n') - self.assertEqual(endpoints['/foo/:bar']['POST'], '# POST /foo/:bar\n# POST /foo/:bar\n') - self.assertEqual(endpoints['/foo/:bar']['GET'], '# GET /foo/:bar\n') + assert len(endpoints["/foo"]) == 1 + assert len(endpoints["/foo/:bar"]) == 2 + assert endpoints["/foo"]["POST"] == "# POST /foo\n" + assert endpoints["/foo/:bar"]["POST"] == "# POST /foo/:bar\n# POST /foo/:bar\n" + assert endpoints["/foo/:bar"]["GET"] == "# GET /foo/:bar\n" def test_endpoint_response_concatenation(self): """Parser should concatenate multiple response cells with the same verb+path. """ source_cells = [ - '# ResponseInfo POST /foo/:bar', - '# ResponseInfo POST /foo/:bar', - '# ResponseInfo POST /foo', - 'ignored', - '# ResponseInfo GET /foo/:bar' + "# ResponseInfo POST /foo/:bar", + "# ResponseInfo POST /foo/:bar", + "# ResponseInfo POST /foo", + "ignored", + "# ResponseInfo GET /foo/:bar" ] - parser = APICellParser(comment_prefix='#') + parser = APICellParser(comment_prefix="#") endpoints = parser.endpoint_responses(source_cells) - self.assertEqual(len(endpoints), 2) + assert len(endpoints) == 2 # for ease of testing endpoints = dict(endpoints) - self.assertEqual(len(endpoints['/foo']), 1) - self.assertEqual(len(endpoints['/foo/:bar']), 2) - self.assertEqual(endpoints['/foo']['POST'], '# ResponseInfo POST /foo\n') - self.assertEqual(endpoints['/foo/:bar']['POST'], '# ResponseInfo POST /foo/:bar\n# ResponseInfo POST /foo/:bar\n') - self.assertEqual(endpoints['/foo/:bar']['GET'], '# ResponseInfo GET /foo/:bar\n') + assert len(endpoints["/foo"]) == 1 + assert len(endpoints["/foo/:bar"]) == 2 + assert endpoints["/foo"]["POST"] == "# ResponseInfo POST /foo\n" + assert endpoints["/foo/:bar"]["POST"] == "# ResponseInfo POST /foo/:bar\n# ResponseInfo POST /foo/:bar\n" + assert endpoints["/foo/:bar"]["GET"] == "# ResponseInfo GET /foo/:bar\n" diff --git a/kernel_gateway/tests/notebook_http/swagger/test_builders.py b/kernel_gateway/tests/notebook_http/swagger/test_builders.py index c8e604b..da6d348 100644 --- a/kernel_gateway/tests/notebook_http/swagger/test_builders.py +++ b/kernel_gateway/tests/notebook_http/swagger/test_builders.py @@ -3,35 +3,35 @@ """Tests for swagger spec generation.""" import json -import unittest -from nose.tools import assert_not_equal + from kernel_gateway.notebook_http.swagger.builders import SwaggerSpecBuilder from kernel_gateway.notebook_http.cell.parser import APICellParser from kernel_gateway.notebook_http.swagger.parser import SwaggerCellParser -class TestSwaggerBuilders(unittest.TestCase): + +class TestSwaggerBuilders: """Unit tests the swagger spec builder.""" def test_add_title_adds_title_to_spec(self): """Builder should store an API title.""" - expected = 'Some New Title' - builder = SwaggerSpecBuilder(APICellParser(comment_prefix='#')) + expected = "Some New Title" + builder = SwaggerSpecBuilder(APICellParser(comment_prefix="#")) builder.set_default_title(expected) result = builder.build() - self.assertEqual(result['info']['title'] ,expected,'Title was not set to new value') + assert result["info"]["title"] == expected, "Title was not set to new value" def test_add_cell_adds_api_cell_to_spec(self): """Builder should store an API cell annotation.""" expected = { - 'get' : { - 'responses' : { - 200 : { 'description': 'Success'} + "get": { + "responses": { + 200: {"description": "Success"} } } } - builder = SwaggerSpecBuilder(APICellParser(comment_prefix='#')) - builder.add_cell('# GET /some/resource') + builder = SwaggerSpecBuilder(APICellParser(comment_prefix="#")) + builder.add_cell("# GET /some/resource") result = builder.build() - self.assertEqual(result['paths']['/some/resource'] ,expected,'Title was not set to new value') + assert result["paths"]["/some/resource"] == expected, "Title was not set to new value" def test_all_swagger_preserved_in_spec(self): """Builder should store the swagger documented cell.""" @@ -73,25 +73,25 @@ def test_all_swagger_preserved_in_spec(self): } } ''' - builder = SwaggerSpecBuilder(SwaggerCellParser(comment_prefix='#', notebook_cells = [{"source":expected}])) + builder = SwaggerSpecBuilder(SwaggerCellParser(comment_prefix='#', notebook_cells=[{"source":expected}])) builder.add_cell(expected) result = builder.build() self.maxDiff = None - self.assertEqual(result['paths']['/some/resource']['get']['description'], json.loads(expected)['paths']['/some/resource']['get']['description'], 'description was not preserved') - self.assertTrue('info' in result, 'info was not preserved') - self.assertTrue('title' in result['info'], 'title was not present') - self.assertEqual(result['info']['title'], json.loads(expected)['info']['title'], 'title was not preserved') - self.assertEqual(json.dumps(result['paths']['/some/resource'], sort_keys=True), json.dumps(json.loads(expected)['paths']['/some/resource'], sort_keys=True), 'operations were not as expected') + assert result["paths"]["/some/resource"]["get"]["description"] == json.loads(expected)["paths"]["/some/resource"]["get"]["description"], "description was not preserved" + assert "info" in result, "info was not preserved" + assert "title" in result["info"], "title was not present" + assert result["info"]["title"] == json.loads(expected)["info"]["title"], "title was not preserved" + assert json.dumps(result["paths"]["/some/resource"], sort_keys=True) == json.dumps(json.loads(expected)["paths"]["/some/resource"], sort_keys=True), "operations were not as expected" - new_title = 'new title. same contents.' + new_title = "new title. same contents." builder.set_default_title(new_title) result = builder.build() - assert_not_equal(result['info']['title'], new_title, 'title should not have been changed') + assert result["info"]["title"] != new_title, "title should not have been changed" def test_add_undocumented_cell_does_not_add_non_api_cell_to_spec(self): """Builder should store ignore non-API cells.""" - builder = SwaggerSpecBuilder(SwaggerCellParser(comment_prefix='#')) - builder.add_cell('regular code cell') - builder.add_cell('# regular commented cell') + builder = SwaggerSpecBuilder(SwaggerCellParser(comment_prefix="#", notebook_cells=[])) + builder.add_cell("regular code cell") + builder.add_cell("# regular commented cell") result = builder.build() - self.assertEqual('paths' in result , 0, 'unexpected paths were found') + assert "paths" not in result, "unexpected paths were found" diff --git a/kernel_gateway/tests/notebook_http/swagger/test_parser.py b/kernel_gateway/tests/notebook_http/swagger/test_parser.py index 5118b9b..649e5c8 100644 --- a/kernel_gateway/tests/notebook_http/swagger/test_parser.py +++ b/kernel_gateway/tests/notebook_http/swagger/test_parser.py @@ -2,44 +2,43 @@ # Distributed under the terms of the Modified BSD License. """Tests for notebook cell parsing.""" -import unittest from kernel_gateway.notebook_http.swagger.parser import SwaggerCellParser -class TestSwaggerAPICellParser(unittest.TestCase): +class TestSwaggerAPICellParser: """Unit tests the SwaggerCellParser class.""" def test_basic_swagger_parse(self): """Parser should correctly identify Swagger cells.""" parser = SwaggerCellParser(comment_prefix='#', notebook_cells=[{"source":'```\n{"swagger":"2.0", "paths": {"": {"post": {"operationId": "foo", "parameters": [{"name": "foo"}]}}}}\n```\n'}]) - self.assertTrue('swagger' in parser.swagger, 'Swagger doc was not detected') + assert 'swagger' in parser.swagger, 'Swagger doc was not detected' def test_basic_is_api_cell(self): """Parser should correctly identify operation cells.""" parser = SwaggerCellParser(comment_prefix='#', notebook_cells=[{"source":'```\n{"swagger":"2.0", "paths": {"": {"post": {"operationId": "foo", "parameters": [{"name": "foo"}]}}}}\n```\n'}]) - self.assertTrue(parser.is_api_cell('#operationId:foo'), 'API cell was not detected with ' + str(parser.kernelspec_operation_indicator)) - self.assertTrue(parser.is_api_cell('# operationId:foo'), 'API cell was not detected with ' + str(parser.kernelspec_operation_indicator)) - self.assertTrue(parser.is_api_cell('#operationId: foo'), 'API cell was not detected with ' + str(parser.kernelspec_operation_indicator)) - self.assertFalse(parser.is_api_cell('no'), 'API cell was detected') - self.assertFalse(parser.is_api_cell('# another comment'), 'API cell was detected') + assert parser.is_api_cell('#operationId:foo'), 'API cell was not detected with ' + str(parser.kernelspec_operation_indicator) + assert parser.is_api_cell('# operationId:foo'), 'API cell was not detected with ' + str(parser.kernelspec_operation_indicator) + assert parser.is_api_cell('#operationId: foo'), 'API cell was not detected with ' + str(parser.kernelspec_operation_indicator) + assert parser.is_api_cell('no') is False, 'API cell was detected' + assert parser.is_api_cell('# another comment') is False, 'API cell was detected' def test_basic_is_api_response_cell(self): """Parser should correctly identify ResponseInfo cells.""" parser = SwaggerCellParser(comment_prefix='#', notebook_cells=[{"source":'```\n{"swagger":"2.0", "paths": {"": {"post": {"operationId": "foo", "parameters": [{"name": "foo"}]}}}}\n```\n'}]) - self.assertTrue(parser.is_api_response_cell('#ResponseInfo operationId:foo'), 'Response cell was not detected with ' + str(parser.kernelspec_operation_response_indicator)) - self.assertTrue(parser.is_api_response_cell('# ResponseInfo operationId:foo'), 'Response cell was not detected with ' + str(parser.kernelspec_operation_response_indicator)) - self.assertTrue(parser.is_api_response_cell('# ResponseInfo operationId: foo'), 'Response cell was not detected with ' + str(parser.kernelspec_operation_response_indicator)) - self.assertTrue(parser.is_api_response_cell('#ResponseInfo operationId: foo'), 'Response cell was not detected with ' + str(parser.kernelspec_operation_response_indicator)) - self.assertFalse(parser.is_api_response_cell('# operationId: foo'), 'API cell was detected as a ResponseInfo cell ' + str(parser.kernelspec_operation_response_indicator)) - self.assertFalse(parser.is_api_response_cell('no'), 'API cell was detected') + assert parser.is_api_response_cell('#ResponseInfo operationId:foo'), 'Response cell was not detected with ' + str(parser.kernelspec_operation_response_indicator) + assert parser.is_api_response_cell('# ResponseInfo operationId:foo'), 'Response cell was not detected with ' + str(parser.kernelspec_operation_response_indicator) + assert parser.is_api_response_cell('# ResponseInfo operationId: foo'), 'Response cell was not detected with ' + str(parser.kernelspec_operation_response_indicator) + assert parser.is_api_response_cell('#ResponseInfo operationId: foo'), 'Response cell was not detected with ' + str(parser.kernelspec_operation_response_indicator) + assert parser.is_api_response_cell('# operationId: foo') is False, 'API cell was detected as a ResponseInfo cell ' + str(parser.kernelspec_operation_response_indicator) + assert parser.is_api_response_cell('no') is False, 'API cell was detected' def test_endpoint_sort_default_strategy(self): """Parser should sort duplicate endpoint paths.""" source_cells = [ - {"source":'\n```\n{"swagger":"2.0","paths":{"":{"post":{"operationId":"postRoot","parameters":[{"name":"foo"}]}},"/hello":{"post":{"operationId":"postHello","parameters":[{"name":"foo"}]},"get":{"operationId":"getHello","parameters":[{"name":"foo"}]}},"/hello/world":{"put":{"operationId":"putWorld"}}}}\n```\n'}, - {"source":'# operationId:putWorld'}, - {"source":'# operationId:getHello'}, - {"source":'# operationId:postHello'}, - {"source":'# operationId:postRoot'}, + {"source": '\n```\n{"swagger":"2.0","paths":{"":{"post":{"operationId":"postRoot","parameters":[{"name":"foo"}]}},"/hello":{"post":{"operationId":"postHello","parameters":[{"name":"foo"}]},"get":{"operationId":"getHello","parameters":[{"name":"foo"}]}},"/hello/world":{"put":{"operationId":"putWorld"}}}}\n```\n'}, + {"source": '# operationId:putWorld'}, + {"source": '# operationId:getHello'}, + {"source": '# operationId:postHello'}, + {"source": '# operationId:postRoot'}, ] parser = SwaggerCellParser(comment_prefix='#', notebook_cells = source_cells) endpoints = parser.endpoints(cell['source'] for cell in source_cells) @@ -48,141 +47,128 @@ def test_endpoint_sort_default_strategy(self): try: for index in range(0, len(expected_values)): endpoint, _ = endpoints[index] - self.assertEqual(expected_values[index], endpoint, 'Endpoint was not found in expected order') + assert expected_values[index] == endpoint, 'Endpoint was not found in expected order' except IndexError: - self.fail(endpoints) + raise RuntimeError(endpoints) def test_endpoint_sort_custom_strategy(self): - """Parser should sort duplicate endpoint paths using a custom sort - strategy. - """ + """Parser should sort duplicate endpoint paths using a custom sort strategy.""" source_cells = [ - {"source":'```\n{"swagger": "2.0", "paths": {"/1": {"post": {"operationId": "post1"}},"/+": {"post": {"operationId": "postPlus"}},"/a": {"get": {"operationId": "getA"}}}}\n```\n'}, - {"source":'# operationId: post1'}, - {"source":'# operationId: postPlus'}, - {"source":'# operationId: getA'}, + {"source": '```\n{"swagger": "2.0", "paths": {"/1": {"post": {"operationId": "post1"}},"/+": {"post": {"operationId": "postPlus"}},"/a": {"get": {"operationId": "getA"}}}}\n```\n'}, + {"source": "# operationId: post1"}, + {"source": "# operationId: postPlus"}, + {"source": "# operationId: getA"}, ] def custom_sort_fun(endpoint): - if endpoint.find('1') >= 0: + if endpoint.find("1") >= 0: return 0 - elif endpoint.find('a') >= 0: + elif endpoint.find("a") >= 0: return 1 else: return 2 - parser = SwaggerCellParser(comment_prefix='#', notebook_cells=source_cells) - endpoints = parser.endpoints((cell['source'] for cell in source_cells), custom_sort_fun) + parser = SwaggerCellParser(comment_prefix="#", notebook_cells=source_cells) + endpoints = parser.endpoints((cell["source"] for cell in source_cells), custom_sort_fun) print(str(endpoints)) - expected_values = ['/+', '/a', '/1'] + expected_values = ["/+", "/a", "/1"] for index in range(0, len(expected_values)): endpoint, _ = endpoints[index] - self.assertEqual(expected_values[index], endpoint, 'Endpoint was not found in expected order') + assert expected_values[index] == endpoint, "Endpoint was not found in expected order" def test_get_cell_endpoint_and_verb(self): """Parser should extract API endpoint and verb from cell annotations.""" parser = SwaggerCellParser(comment_prefix='#', notebook_cells=[{'source':'```\n{"swagger":"2.0", "paths": {"/foo": {"get": {"operationId": "getFoo"}}, "/bar/quo": {"post": {"operationId": "post_bar_Quo"}}}}\n```\n'}]) - endpoint, verb = parser.get_cell_endpoint_and_verb('# operationId: getFoo') - self.assertEqual(endpoint, '/foo', 'Endpoint was not extracted correctly') - self.assertEqual(verb.lower(), 'get', 'Endpoint was not extracted correctly') - endpoint, verb = parser.get_cell_endpoint_and_verb('# operationId: post_bar_Quo') - self.assertEqual(endpoint, '/bar/quo', 'Endpoint was not extracted correctly') - self.assertEqual(verb.lower(), 'post', 'Endpoint was not extracted correctly') + endpoint, verb = parser.get_cell_endpoint_and_verb("# operationId: getFoo") + assert endpoint == "/foo", "Endpoint was not extracted correctly" + assert verb.lower() == "get", "Endpoint was not extracted correctly" + endpoint, verb = parser.get_cell_endpoint_and_verb("# operationId: post_bar_Quo") + assert endpoint == "/bar/quo", "Endpoint was not extracted correctly" + assert verb.lower() == "post", "Endpoint was not extracted correctly" - endpoint, verb = parser.get_cell_endpoint_and_verb('some regular code') - self.assertEqual(endpoint, None, 'Endpoint was not extracted correctly (something was actually returned)') - self.assertEqual(verb, None, 'Endpoint was not extracted correctly (something was actually returned)') + endpoint, verb = parser.get_cell_endpoint_and_verb("some regular code") + assert endpoint is None, "Endpoint was not extracted correctly (something was actually returned)" + assert verb is None, "Endpoint was not extracted correctly (something was actually returned)" def test_endpoint_concatenation(self): """Parser should concatenate multiple cells with the same verb+path.""" cells = [ - {"source":'```\n{"swagger":"2.0", "paths": {"/foo": {"put": {"operationId":"putFoo","parameters": [{"name": "bar"}]},"post":{"operationId":"postFooBody"},"get": {"operationId":"getFoo","parameters": [{"name": "bar"}]}}}}\n```\n'}, - {"source":'# operationId: postFooBody '}, - {"source":'# unrelated comment '}, - {"source":'# operationId: putFoo'}, - {"source":'# operationId: puttFoo'}, - {"source":'# operationId: getFoo'}, - {"source":'# operationId: putFoo'} + {"source": '```\n{"swagger":"2.0", "paths": {"/foo": {"put": {"operationId":"putFoo","parameters": [{"name": "bar"}]},"post":{"operationId":"postFooBody"},"get": {"operationId":"getFoo","parameters": [{"name": "bar"}]}}}}\n```\n'}, + {"source": "# operationId: postFooBody "}, + {"source": "# unrelated comment "}, + {"source": "# operationId: putFoo"}, + {"source": "# operationId: puttFoo"}, + {"source": "# operationId: getFoo"}, + {"source": "# operationId: putFoo"} ] - parser = SwaggerCellParser(comment_prefix='#', notebook_cells=cells) - endpoints = parser.endpoints(cell['source'] for cell in cells) - self.assertEqual(len(endpoints), 2, endpoints) + parser = SwaggerCellParser(comment_prefix="#", notebook_cells=cells) + endpoints = parser.endpoints(cell["source"] for cell in cells) + assert len(endpoints) == 2, endpoints # for ease of testing endpoints = dict(endpoints) - self.assertEqual(len(endpoints['/foo']), 1) - self.assertEqual(len(endpoints['/foo/:bar']), 2) - self.assertEqual(endpoints['/foo']['post'], '# operationId: postFooBody \n') - self.assertEqual(endpoints['/foo/:bar']['get'], '# operationId: getFoo\n') - self.assertEqual(endpoints['/foo/:bar']['put'], '# operationId: putFoo\n# operationId: putFoo\n') + assert len(endpoints["/foo"]) == 1 + assert len(endpoints["/foo/:bar"]) == 2 + assert endpoints["/foo"]["post"] == "# operationId: postFooBody \n" + assert endpoints["/foo/:bar"]["get"] == "# operationId: getFoo\n" + assert endpoints["/foo/:bar"]["put"] == "# operationId: putFoo\n# operationId: putFoo\n" def test_endpoint_response_concatenation(self): - """Parser should concatenate multiple response cells with the same - verb+path. - """ + """Parser should concatenate multiple response cells with the same verb+path.""" source_cells = [ - {"source":'```\n{"swagger":"2.0", "paths": {"/foo": {"put": {"operationId":"putbar","parameters": [{"name": "bar"}]},"post":{"operationId":"postbar"},"get": {"operationId":"get","parameters": [{"name": "bar"}]}}}}\n```\n'}, - {"source":'# ResponseInfo operationId: get'}, - {"source":'# ResponseInfo operationId: postbar '}, - {"source":'# ResponseInfo operationId: putbar'}, - {"source":'# ResponseInfo operationId: puttbar'}, - {"source":'ignored'}, - {"source":'# ResponseInfo operationId: putbar '} + {"source": '```\n{"swagger":"2.0", "paths": {"/foo": {"put": {"operationId":"putbar","parameters": [{"name": "bar"}]},"post":{"operationId":"postbar"},"get": {"operationId":"get","parameters": [{"name": "bar"}]}}}}\n```\n'}, + {"source": "# ResponseInfo operationId: get"}, + {"source": "# ResponseInfo operationId: postbar "}, + {"source": "# ResponseInfo operationId: putbar"}, + {"source": "# ResponseInfo operationId: puttbar"}, + {"source": "ignored"}, + {"source": "# ResponseInfo operationId: putbar "} ] - parser = SwaggerCellParser(comment_prefix='#', notebook_cells=source_cells) - endpoints = parser.endpoint_responses(cell['source'] for cell in source_cells) - self.assertEqual(len(endpoints), 2) + parser = SwaggerCellParser(comment_prefix="#", notebook_cells=source_cells) + endpoints = parser.endpoint_responses(cell["source"] for cell in source_cells) + assert len(endpoints) == 2 # for ease of testing endpoints = dict(endpoints) - self.assertEqual(len(endpoints['/foo']), 1) - self.assertEqual(len(endpoints['/foo/:bar']), 2) - self.assertEqual(endpoints['/foo']['post'], '# ResponseInfo operationId: postbar \n') - self.assertEqual(endpoints['/foo/:bar']['put'], '# ResponseInfo operationId: putbar\n# ResponseInfo operationId: putbar \n') - self.assertEqual(endpoints['/foo/:bar']['get'], '# ResponseInfo operationId: get\n') - - def test_undeclared_operations(self): - """Parser should warn about operations that aren't documented in the - swagger cell - """ + assert len(endpoints["/foo"]) == 1 + assert len(endpoints["/foo/:bar"]) == 2 + assert endpoints["/foo"]["post"] == "# ResponseInfo operationId: postbar \n" + assert endpoints["/foo/:bar"]["put"] == "# ResponseInfo operationId: putbar\n# ResponseInfo operationId: putbar \n" + assert endpoints["/foo/:bar"]["get"] == "# ResponseInfo operationId: get\n" + + def test_undeclared_operations(self, caplog): + """Parser should warn about operations that aren't documented in the swagger cell.""" source_cells = [ - {"source":'```\n{"swagger":"2.0", "paths": {"/foo": {"put": {"operationId":"putbar","parameters": [{"name": "bar"}]},"post":{"operationId":"postbar"},"get": {"operationId":"get","parameters": [{"name": "bar"}]}}}}\n```\n'}, - {"source":'# operationId: get'}, - {"source":'# operationId: postbar '}, - {"source":'# operationId: putbar'}, - {"source":'# operationId: extraOperation'}, + {"source": '```\n{"swagger":"2.0", "paths": {"/foo": {"put": {"operationId":"putbar","parameters": [{"name": "bar"}]},"post":{"operationId":"postbar"},"get": {"operationId":"get","parameters": [{"name": "bar"}]}}}}\n```\n'}, + {"source": "# operationId: get"}, + {"source": "# operationId: postbar "}, + {"source": "# operationId: putbar"}, + {"source": "# operationId: extraOperation"}, ] - with self.assertLogs(level='WARNING') as warnings: - SwaggerCellParser(comment_prefix='#', notebook_cells=source_cells) - for output in warnings.output: - self.assertRegex(output, 'extraOperation') - - def test_undeclared_operations_reversed(self): - """Parser should warn about operations that aren't documented in the - swagger cell - """ + + SwaggerCellParser(comment_prefix="#", notebook_cells=source_cells) + assert "extraOperation" in caplog.text + + def test_undeclared_operations_reversed(self, caplog): + """Parser should warn about operations that aren"t documented in the swagger cell.""" source_cells = [ - {"source":'# operationId: get'}, - {"source":'# operationId: postbar '}, - {"source":'# operationId: putbar'}, - {"source":'# operationId: extraOperation'}, - {"source":'```\n{"swagger":"2.0", "paths": {"/foo": {"put": {"operationId":"putbar","parameters": [{"name": "bar"}]},"post":{"operationId":"postbar"},"get": {"operationId":"get","parameters": [{"name": "bar"}]}}}}\n```\n'}, + {"source": "# operationId: get"}, + {"source": "# operationId: postbar "}, + {"source": "# operationId: putbar"}, + {"source": "# operationId: extraOperation"}, + {"source": '```\n{"swagger":"2.0", "paths": {"/foo": {"put": {"operationId":"putbar","parameters": [{"name": "bar"}]},"post":{"operationId":"postbar"},"get": {"operationId":"get","parameters": [{"name": "bar"}]}}}}\n```\n'}, ] - with self.assertLogs(level='WARNING') as warnings: - SwaggerCellParser(comment_prefix='#', notebook_cells=source_cells) - for output in warnings.output: - self.assertRegex(output, 'extraOperation') - - def test_unreferenced_operations(self): - """Parser should warn about documented operations that aren't referenced - in a cell - """ + + SwaggerCellParser(comment_prefix="#", notebook_cells=source_cells) + assert "extraOperation" in caplog.text + + def test_unreferenced_operations(self, caplog): + """Parser should warn about documented operations that aren"t referenced in a cell.""" source_cells = [ {"source": '```\n{"swagger":"2.0", "paths": {"/foo": {"put": {"operationId":"putbar","parameters": [{"name": "bar"}]},"post":{"operationId":"postbar"},"get": {"operationId":"get","parameters": [{"name": "bar"}]}}}}\n```\n'}, - {"source": '# operationId: get'}, - {"source": '# operationId: putbar'}, - {"source": '# operationId: putbar '} + {"source": "# operationId: get"}, + {"source": "# operationId: putbar"}, + {"source": "# operationId: putbar "} ] - with self.assertLogs(level='WARNING') as warnings: - SwaggerCellParser(comment_prefix='#', notebook_cells=source_cells) - for output in warnings.output: - self.assertRegex(output, 'postbar') + + SwaggerCellParser(comment_prefix="#", notebook_cells=source_cells) + assert "postbar" in caplog.text diff --git a/kernel_gateway/tests/notebook_http/test_request_utils.py b/kernel_gateway/tests/notebook_http/test_request_utils.py index 2d4e3b0..3db7e3b 100644 --- a/kernel_gateway/tests/notebook_http/test_request_utils.py +++ b/kernel_gateway/tests/notebook_http/test_request_utils.py @@ -7,11 +7,13 @@ from kernel_gateway.notebook_http.request_utils import (format_request, parse_body, parameterize_path, headers_to_dict, parse_args) + class MockRequest(dict): def __init__(self, *args, **kwargs): super(MockRequest, self).__init__(*args, **kwargs) self.__dict__ = self + class MockHeaders(object): def __init__(self, headers, **kwargs): self.headers = headers @@ -19,6 +21,7 @@ def __init__(self, headers, **kwargs): def get_all(self): return self.headers + class TestRequestUtils(unittest.TestCase): """Unit tests the request utility helper functions.""" def test_parse_body_text(self): diff --git a/kernel_gateway/tests/test_gatewayapp.py b/kernel_gateway/tests/test_gatewayapp.py index bf0dfbd..d250fde 100644 --- a/kernel_gateway/tests/test_gatewayapp.py +++ b/kernel_gateway/tests/test_gatewayapp.py @@ -1,159 +1,126 @@ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. """Tests for basic gateway app behavior.""" +import ssl - -import logging import nbformat import os -from io import StringIO -import unittest -from unittest.mock import patch -from kernel_gateway.gatewayapp import KernelGatewayApp, ioloop + +from kernel_gateway.gatewayapp import KernelGatewayApp from kernel_gateway import __version__ -from ..notebook_http.swagger.handlers import SwaggerSpecHandler -from tornado.testing import AsyncHTTPTestCase, ExpectLog RESOURCES = os.path.join(os.path.dirname(__file__), 'resources') -class TestGatewayAppConfig(unittest.TestCase): +class TestGatewayAppConfig: """Tests configuration of the gateway app.""" - def setUp(self): - """Saves a copy of the environment.""" - self.environ = dict(os.environ) - - def tearDown(self): - """Resets the environment.""" - os.environ = self.environ - - def test_config_env_vars(self): + def test_config_env_vars(self, monkeypatch): """Env vars should be honored for traitlets.""" # Environment vars are always strings - os.environ['KG_PORT'] = '1234' - os.environ['KG_PORT_RETRIES'] = '4321' - os.environ['KG_IP'] = '1.1.1.1' - os.environ['KG_AUTH_TOKEN'] = 'fake-token' - os.environ['KG_ALLOW_CREDENTIALS'] = 'true' - os.environ['KG_ALLOW_HEADERS'] = 'Authorization' - os.environ['KG_ALLOW_METHODS'] = 'GET' - os.environ['KG_ALLOW_ORIGIN'] = '*' - os.environ['KG_EXPOSE_HEADERS'] = 'X-Fake-Header' - os.environ['KG_MAX_AGE'] = '5' - os.environ['KG_BASE_URL'] = '/fake/path' - os.environ['KG_MAX_KERNELS'] = '1' - os.environ['KG_SEED_URI'] = 'fake-notebook.ipynb' - os.environ['KG_PRESPAWN_COUNT'] = '1' - os.environ['KG_FORCE_KERNEL_NAME'] = 'fake_kernel_forced' - os.environ['KG_DEFAULT_KERNEL_NAME'] = 'fake_kernel' - os.environ['KG_KEYFILE'] = '/test/fake.key' - os.environ['KG_CERTFILE'] = '/test/fake.crt' - os.environ['KG_CLIENT_CA'] = '/test/fake_ca.crt' - os.environ['KG_SSL_VERSION'] = '3' - os.environ['KG_TRUST_XHEADERS'] = 'false' + monkeypatch.setenv("KG_PORT", "1234") + monkeypatch.setenv("KG_PORT_RETRIES", "4321") + monkeypatch.setenv("KG_IP", "1.1.1.1") + monkeypatch.setenv("KG_AUTH_TOKEN", "fake-token") + monkeypatch.setenv("KG_ALLOW_CREDENTIALS", "true") + monkeypatch.setenv("KG_ALLOW_HEADERS", "Authorization") + monkeypatch.setenv("KG_ALLOW_METHODS", "GET") + monkeypatch.setenv("KG_ALLOW_ORIGIN", "*") + monkeypatch.setenv("KG_EXPOSE_HEADERS", "X-Fake-Header") + monkeypatch.setenv("KG_MAX_AGE", "5") + monkeypatch.setenv("KG_BASE_URL", "/fake/path") + monkeypatch.setenv("KG_MAX_KERNELS", "1") + monkeypatch.setenv("KG_SEED_URI", "fake-notebook.ipynb") + monkeypatch.setenv("KG_PRESPAWN_COUNT", "1") + monkeypatch.setenv("KG_FORCE_KERNEL_NAME", "fake_kernel_forced") + monkeypatch.setenv("KG_DEFAULT_KERNEL_NAME", "fake_kernel") + monkeypatch.setenv("KG_KEYFILE", "/test/fake.key") + monkeypatch.setenv("KG_CERTFILE", "/test/fake.crt") + monkeypatch.setenv("KG_CLIENT_CA", "/test/fake_ca.crt") + monkeypatch.setenv("KG_SSL_VERSION", "3") + monkeypatch.setenv("KG_TRUST_XHEADERS", "false") app = KernelGatewayApp() - self.assertEqual(app.port, 1234) - self.assertEqual(app.port_retries, 4321) - self.assertEqual(app.ip, '1.1.1.1') - self.assertEqual(app.auth_token, 'fake-token') - self.assertEqual(app.allow_credentials, 'true') - self.assertEqual(app.allow_headers, 'Authorization') - self.assertEqual(app.allow_methods, 'GET') - self.assertEqual(app.allow_origin, '*') - self.assertEqual(app.expose_headers, 'X-Fake-Header') - self.assertEqual(app.max_age, '5') - self.assertEqual(app.base_url, '/fake/path') - self.assertEqual(app.max_kernels, 1) - self.assertEqual(app.seed_uri, 'fake-notebook.ipynb') - self.assertEqual(app.prespawn_count, 1) - self.assertEqual(app.default_kernel_name, 'fake_kernel') - self.assertEqual(app.force_kernel_name, 'fake_kernel_forced') - self.assertEqual(app.keyfile, '/test/fake.key') - self.assertEqual(app.certfile, '/test/fake.crt') - self.assertEqual(app.client_ca, '/test/fake_ca.crt') - self.assertEqual(app.ssl_version, 3) - self.assertEqual(app.trust_xheaders, False) - - def test_trust_xheaders(self): - + assert app.port == 1234 + assert app.port_retries == 4321 + assert app.ip == "1.1.1.1" + assert app.auth_token == "fake-token" + assert app.allow_credentials == "true" + assert app.allow_headers == "Authorization" + assert app.allow_methods == "GET" + assert app.allow_origin == "*" + assert app.expose_headers == "X-Fake-Header" + assert app.max_age == "5" + assert app.base_url == "/fake/path" + assert app.max_kernels == 1 + assert app.seed_uri == "fake-notebook.ipynb" + assert app.prespawn_count == 1 + assert app.default_kernel_name == "fake_kernel" + assert app.force_kernel_name == "fake_kernel_forced" + assert app.keyfile == "/test/fake.key" + assert app.certfile == "/test/fake.crt" + assert app.client_ca == "/test/fake_ca.crt" + assert app.ssl_version == 3 + assert app.trust_xheaders is False + KernelGatewayApp.clear_instance() + + def test_trust_xheaders(self, monkeypatch): app = KernelGatewayApp() - self.assertEqual(app.trust_xheaders, False) - os.environ['KG_TRUST_XHEADERS'] = 'true' + assert app.trust_xheaders is False + monkeypatch.setenv("KG_TRUST_XHEADERS", "true") app = KernelGatewayApp() - self.assertEqual(app.trust_xheaders, True) + assert app.trust_xheaders is True + KernelGatewayApp.clear_instance() - def test_ssl_options(self): + def test_ssl_options(self, monkeypatch): app = KernelGatewayApp() ssl_options = app._build_ssl_options() - self.assertIsNone(ssl_options) + assert ssl_options is None + KernelGatewayApp.clear_instance() + + # Set all options + monkeypatch.setenv("KG_CERTFILE", "/test/fake.crt") + monkeypatch.setenv("KG_KEYFILE", "/test/fake.key") + monkeypatch.setenv("KG_CLIENT_CA", "/test/fake.ca") + monkeypatch.setenv("KG_SSL_VERSION", "42") app = KernelGatewayApp() - os.environ['KG_CERTFILE'] = '/test/fake.crt' ssl_options = app._build_ssl_options() - self.assertEqual(ssl_options['ssl_version'], 5) - - def test_load_notebook_local(self): - nb_path = os.path.join(RESOURCES, 'weirdly%20named#notebook.ipynb') - os.environ['KG_SEED_URI'] = nb_path + assert ssl_options["certfile"] == "/test/fake.crt" + assert ssl_options["keyfile"] == "/test/fake.key" + assert ssl_options["ca_certs"] == "/test/fake.ca" + assert ssl_options["cert_reqs"] == ssl.CERT_REQUIRED + assert ssl_options["ssl_version"] == 42 + KernelGatewayApp.clear_instance() + + # Set few options + monkeypatch.delenv("KG_KEYFILE") + monkeypatch.delenv("KG_CLIENT_CA") + monkeypatch.delenv("KG_SSL_VERSION") + app = KernelGatewayApp() + ssl_options = app._build_ssl_options() + assert ssl_options["certfile"] == "/test/fake.crt" + assert ssl_options["ssl_version"] == ssl.PROTOCOL_TLSv1_2 + assert "cert_reqs" not in ssl_options + KernelGatewayApp.clear_instance() + + def test_load_notebook_local(self, monkeypatch): + nb_path = os.path.join(RESOURCES, "weirdly%20named#notebook.ipynb") + monkeypatch.setenv("KG_SEED_URI", nb_path) with open(nb_path) as nb_fh: nb_contents = nbformat.read(nb_fh, 4) app = KernelGatewayApp() + app.init_io_loop() app.init_configurables() - self.assertEqual(app.seed_notebook, nb_contents) - - @patch('sys.stderr', new_callable=StringIO) - def test_start_banner(self, stderr): + assert app.seed_notebook == nb_contents + KernelGatewayApp.clear_instance() + def test_start_banner(self, capsys): app = KernelGatewayApp() + app.init_io_loop() app.init_configurables() app.start_app() - banner = stderr.getvalue() - self.assertIn(f"Jupyter Kernel Gateway {__version__}", banner) - - -class TestGatewayAppBase(AsyncHTTPTestCase, ExpectLog): - """Base class for integration style tests using HTTP/Websockets against an - instance of the gateway app. - - Attributes - ---------- - app : KernelGatewayApp - Instance of the app - """ - - def tearDown(self): - """Shuts down the app after test run.""" - if self.app: - self.app.shutdown() - # Make sure the generated Swagger output is reset for subsequent tests - SwaggerSpecHandler.output = None - super(TestGatewayAppBase, self).tearDown() - - def get_new_ioloop(self): - """Uses a global zmq ioloop for tests.""" - return ioloop.IOLoop.current() - - def get_app(self): - """Returns a tornado.web.Application for the Tornado test runner.""" - if hasattr(self, '_app'): - return self._app - self.app = KernelGatewayApp(log_level=logging.CRITICAL) - self.setup_app() - self.app.init_configurables() - self.setup_configurables() - self.app.init_webapp() - return self.app.web_app - - def setup_app(self): - """Override to configure KernelGatewayApp instance before initializing - configurables and the web app. - """ - pass - - def setup_configurables(self): - """Override to configure further settings, such as the personality. - """ - pass + log = capsys.readouterr() + assert f"Jupyter Kernel Gateway {__version__}" in log.err + KernelGatewayApp.clear_instance() diff --git a/kernel_gateway/tests/test_jupyter_websocket.py b/kernel_gateway/tests/test_jupyter_websocket.py index 5079a90..07d9414 100644 --- a/kernel_gateway/tests/test_jupyter_websocket.py +++ b/kernel_gateway/tests/test_jupyter_websocket.py @@ -2,402 +2,301 @@ # Distributed under the terms of the Modified BSD License. """Tests for jupyter-websocket mode.""" -import os -import sys import json +import os +import pytest +import uuid -from .test_gatewayapp import TestGatewayAppBase, RESOURCES +from jupyter_client.kernelspec import NoSuchKernel +from tornado.gen import sleep +from tornado.httpclient import HTTPClientError +from tornado.escape import json_encode, json_decode, url_escape +from tornado.web import HTTPError +from traitlets.config import Config from kernel_gateway.gatewayapp import KernelGatewayApp -from jupyter_client.kernelspec import NoSuchKernel +from kernel_gateway.services.kernels.manager import AsyncMappingKernelManager +from kernel_gateway.services.sessions.sessionmanager import SessionManager +from .test_gatewayapp import RESOURCES -from tornado.gen import coroutine, Return, sleep -from tornado.websocket import websocket_connect -from tornado.httpclient import HTTPRequest -from tornado.testing import gen_test -from tornado.escape import json_encode, json_decode, url_escape +@pytest.fixture +def jp_server_config(): + """Allows tests to setup their specific configuration values.""" + config = { + "KernelGatewayApp": { + "api": "kernel_gateway.jupyter_websocket", + } + } + return Config(config) -class TestJupyterWebsocket(TestGatewayAppBase): - """Base class for jupyter-websocket mode tests that spawn kernels.""" - @coroutine - def spawn_kernel(self, kernel_body='{}'): - """Spawns a kernel using the gateway API and connects a websocket - client to it. - Parameters - ---------- - kernel_body : str - POST /api/kernels body +@pytest.fixture +def spawn_kernel(jp_fetch, jp_http_port, jp_base_url, jp_ws_fetch): + """Spawns a kernel where request.param contains the request body and returns the websocket.""" - Returns - ------- - Future - Promise of a WebSocketClientConnection - """ + async def _spawn_kernel(body='{}'): # Request a kernel - response = yield self.http_client.fetch( - self.get_url('/api/kernels'), - method='POST', - body=kernel_body - ) - self.assertEqual(response.code, 201) + response = await jp_fetch("api", "kernels", method="POST", body=body) + assert response.code == 201 # Connect to the kernel via websocket kernel = json_decode(response.body) - ws_url = 'ws://localhost:{}/api/kernels/{}/channels'.format( - self.get_http_port(), - url_escape(kernel['id']) - ) - - ws = yield websocket_connect(ws_url) - raise Return(ws) - - def execute_request(self, code): - """Creates an execute_request message. - - Parameters - ---------- - code : str - Code to execute - - Returns - ------- - dict - The message - """ - return { - 'header': { - 'username': '', - 'version': '5.0', - 'session': '', - 'msg_id': 'fake-msg-id', - 'msg_type': 'execute_request' - }, - 'parent_header': {}, - 'channel': 'shell', - 'content': { - 'code': code, - 'silent': False, - 'store_history': False, - 'user_expressions' : {} - }, - 'metadata': {}, - 'buffers': {} - } + kernel_id = kernel['id'] + ws = await jp_ws_fetch("api", "kernels", kernel_id, "channels") + return ws + + return _spawn_kernel - @coroutine - def await_stream(self, ws): - """Returns stream output associated with an execute_request.""" - while 1: - msg = yield ws.read_message() - msg = json_decode(msg) - msg_type = msg['msg_type'] - parent_msg_id = msg['parent_header']['msg_id'] - if msg_type == 'stream' and parent_msg_id == 'fake-msg-id': - raise Return(msg['content']) +def get_execute_request(code: str) -> dict: + """Creates an execute_request message. -class TestDefaults(TestJupyterWebsocket): + Parameters + ---------- + code : str + Code to execute + + Returns + ------- + dict + The message + """ + return { + 'header': { + 'username': '', + 'version': '5.0', + 'session': '', + 'msg_id': 'fake-msg-id', + 'msg_type': 'execute_request' + }, + 'parent_header': {}, + 'channel': 'shell', + 'content': { + 'code': code, + 'silent': False, + 'store_history': False, + 'user_expressions': {} + }, + 'metadata': {}, + 'buffers': {} + } + + +async def await_stream(ws): + """Returns stream output associated with an execute_request.""" + while 1: + msg = await ws.read_message() + msg = json_decode(msg) + msg_type = msg["msg_type"] + parent_msg_id = msg["parent_header"]["msg_id"] + if msg_type == "stream" and parent_msg_id == "fake-msg-id": + return msg["content"] + + +class TestDefaults: """Tests gateway behavior.""" - @gen_test - def test_startup(self): + @pytest.mark.parametrize("jp_argv", (["--JupyterWebsocketPersonality.list_kernels=True"],)) + async def test_startup(self, jp_fetch, jp_argv): """Root of kernels resource should be OK.""" - self.app.web_app.settings['kg_list_kernels'] = True - response = yield self.http_client.fetch(self.get_url('/api/kernels')) - self.assertEqual(response.code, 200) + response = await jp_fetch("api", "kernels", method="GET") + assert response.code == 200 - @gen_test - def test_headless(self): + async def test_headless(self, jp_fetch): """Other notebook resources should not exist.""" - response = yield self.http_client.fetch(self.get_url('/api/contents'), - raise_error=False) - self.assertEqual(response.code, 404) - response = yield self.http_client.fetch(self.get_url('/'), - raise_error=False) - self.assertEqual(response.code, 404) - response = yield self.http_client.fetch(self.get_url('/tree'), - raise_error=False) - self.assertEqual(response.code, 404) - - @gen_test - def test_check_origin(self): + with pytest.raises(HTTPClientError) as e: + await jp_fetch("api", "contents", method="GET") + assert e.value.code == 404 + + with pytest.raises(HTTPClientError) as e: + await jp_fetch("", method="GET") + assert e.value.code == 404 + + with pytest.raises(HTTPClientError) as e: + await jp_fetch("tree", method="GET") + assert e.value.code == 404 + + async def test_check_origin(self, jp_fetch, jp_web_app): """Allow origin setting should pass through to base handlers.""" - response = yield self.http_client.fetch( - self.get_url('/api/kernelspecs'), - method='GET', - headers={'Origin': 'fake.com:8888'}, - raise_error=False - ) - self.assertEqual(response.code, 404) - - app = self.get_app() - app.settings['allow_origin'] = '*' - - response = yield self.http_client.fetch( - self.get_url('/api/kernelspecs'), - method='GET', - headers={'Origin': 'fake.com:8888'}, - raise_error=False - ) - self.assertEqual(response.code, 200) - - @gen_test - def test_config_bad_api_value(self): + with pytest.raises(HTTPClientError) as e: + await jp_fetch("api", "kernelspecs", + headers={'Origin': 'fake.com:8888'}, method="GET") + assert e.value.code == 404 + + jp_web_app.settings['allow_origin'] = '*' + + response = await jp_fetch("api", "kernelspecs", + headers={'Origin': 'fake.com:8888'}, method="GET") + assert response.code == 200 + + @pytest.mark.parametrize("jp_server_config", (Config({"KernelGatewayApp": {"api": "notebook-gopher", }}),)) + async def test_config_bad_api_value(self, jp_configurable_serverapp, jp_server_config): """Should raise an ImportError for nonexistent API personality modules.""" - def _set_api(): - self.app.api = 'notebook-gopher' - self.assertRaises(ImportError, _set_api) + with pytest.raises(ImportError): + await jp_configurable_serverapp() - @gen_test - def test_options_without_auth_token(self): + async def test_options_without_auth_token(self, jp_fetch, jp_web_app): """OPTIONS requests doesn't need to submit a token. Used for CORS preflight.""" - # Set token requirement - app = self.get_app() - app.settings['kg_auth_token'] = 'fake-token' - # Confirm that OPTIONS request doesn't require token - response = yield self.http_client.fetch( - self.get_url('/api'), - method='OPTIONS' - ) - self.assertEqual(response.code, 200) - - @gen_test - def test_auth_token(self): + response = await jp_fetch("api", method='OPTIONS') + assert response.code == 200 + + @pytest.mark.parametrize("jp_server_config", (Config({"KernelGatewayApp": {"auth_token": "fake-token", }}),)) + async def test_auth_token(self, jp_server_config, jp_fetch, jp_web_app, jp_ws_fetch): """All server endpoints should check the configured auth token.""" - # Set token requirement - app = self.get_app() - app.settings['kg_auth_token'] = 'fake-token' - - # Requst API without the token - response = yield self.http_client.fetch( - self.get_url('/api'), - method='GET', - raise_error=False - ) - self.assertEqual(response.code, 401) + + # Request API without the token + # Note that we'd prefer not to set _any_ header, but `fp_auth_header` will force it + # to be set, so setting the empty authorization header is necessary for the tests + # asserting 401. + with pytest.raises(HTTPClientError) as e: + await jp_fetch("api", method="GET", headers={'Authorization': ''}) + assert e.value.response.code == 401 # Now with it - response = yield self.http_client.fetch( - self.get_url('/api'), - method='GET', - headers={'Authorization': 'token fake-token'}, - raise_error=False - ) - self.assertEqual(response.code, 200) + response = await jp_fetch("api", method="GET", + headers={'Authorization': 'token fake-token'}) + assert response.code == 200 # Request kernelspecs without the token - response = yield self.http_client.fetch( - self.get_url('/api/kernelspecs'), - method='GET', - raise_error=False - ) - self.assertEqual(response.code, 401) + with pytest.raises(HTTPClientError) as e: + await jp_fetch("api", "kernelspecs", method="GET", headers={'Authorization': ''}) + assert e.value.response.code == 401 # Now with it - response = yield self.http_client.fetch( - self.get_url('/api/kernelspecs'), - method='GET', - headers={'Authorization': 'token fake-token'}, - raise_error=False - ) - self.assertEqual(response.code, 200) + response = await jp_fetch("api", "kernelspecs", method="GET", + headers={'Authorization': 'token fake-token'}) + assert response.code == 200 # Request a kernel without the token - response = yield self.http_client.fetch( - self.get_url('/api/kernels'), - method='POST', - body='{}', - raise_error=False - ) - self.assertEqual(response.code, 401) - - # Request with the token now - response = yield self.http_client.fetch( - self.get_url('/api/kernels'), - method='POST', - body='{}', - headers={'Authorization': 'token fake-token'}, - raise_error=False - ) - self.assertEqual(response.code, 201) + with pytest.raises(HTTPClientError) as e: + await jp_fetch("api", "kernels", method="POST", body='{}', headers={'Authorization': ''}) + assert e.value.response.code == 401 + # Now with it + response = await jp_fetch("api", "kernels", method="POST", body='{}', + headers={'Authorization': 'token fake-token'}) + assert response.code == 201 kernel = json_decode(response.body) + kernel_id = url_escape(kernel['id']) + # Request kernel info without the token - response = yield self.http_client.fetch( - self.get_url('/api/kernels/'+url_escape(kernel['id'])), - method='GET', - raise_error=False - ) - self.assertEqual(response.code, 401) + with pytest.raises(HTTPClientError) as e: + await jp_fetch("api", "kernels", kernel_id, method="GET", headers={'Authorization': ''}) + assert e.value.response.code == 401 # Now with it - response = yield self.http_client.fetch( - self.get_url('/api/kernels/'+url_escape(kernel['id'])), - method='GET', - headers={'Authorization': 'token fake-token'}, - raise_error=False - ) - self.assertEqual(response.code, 200) + response = await jp_fetch("api", "kernels", kernel_id, method="GET", + headers={'Authorization': 'token fake-token'}) + assert response.code == 200 # Request websocket connection without the token - ws_url = 'ws://localhost:{}/api/kernels/{}/channels'.format( - self.get_http_port(), - url_escape(kernel['id']) - ) + # No option to ignore errors so try/except - try: - ws = yield websocket_connect(ws_url) - except Exception as ex: - self.assertEqual(ex.code, 401) - else: - raise AssertionError('no exception raised') + with pytest.raises(HTTPClientError) as e: + await jp_ws_fetch("api", "kernels", kernel_id, "channels", headers={'Authorization': ''}) + assert e.value.response.code == 401 # Now request the websocket with the token - ws_req = HTTPRequest(ws_url, - headers={'Authorization': 'token fake-token'} - ) - ws = yield websocket_connect(ws_req) + ws = await jp_ws_fetch("api", "kernels", kernel_id, "channels", + headers={'Authorization': 'token fake-token'}) ws.close() - @gen_test - def test_cors_headers(self): + async def test_cors_headers(self, jp_fetch, jp_web_app): """All kernel endpoints should respond with configured CORS headers.""" - app = self.get_app() - app.settings['kg_allow_credentials'] = 'false' - app.settings['kg_allow_headers'] = 'Authorization,Content-Type' - app.settings['kg_allow_methods'] = 'GET,POST' - app.settings['kg_allow_origin'] = 'https://jupyter.org' - app.settings['kg_expose_headers'] = 'X-My-Fake-Header' - app.settings['kg_max_age'] = '600' - app.settings['kg_list_kernels'] = True + + jp_web_app.settings['kg_allow_credentials'] = 'false' + jp_web_app.settings['kg_allow_headers'] = 'Authorization,Content-Type' + jp_web_app.settings['kg_allow_methods'] = 'GET,POST' + jp_web_app.settings['kg_allow_origin'] = 'https://jupyter.org' + jp_web_app.settings['kg_expose_headers'] = 'X-My-Fake-Header' + jp_web_app.settings['kg_max_age'] = '600' + jp_web_app.settings['kg_list_kernels'] = True # Get kernels to check headers - response = yield self.http_client.fetch( - self.get_url('/api/kernels'), - method='GET' - ) - self.assertEqual(response.code, 200) - self.assertEqual(response.headers['Access-Control-Allow-Credentials'], 'false') - self.assertEqual(response.headers['Access-Control-Allow-Headers'], 'Authorization,Content-Type') - self.assertEqual(response.headers['Access-Control-Allow-Methods'], 'GET,POST') - self.assertEqual(response.headers['Access-Control-Allow-Origin'], 'https://jupyter.org') - self.assertEqual(response.headers['Access-Control-Expose-Headers'], 'X-My-Fake-Header') - self.assertEqual(response.headers['Access-Control-Max-Age'], '600') - self.assertEqual(response.headers.get('Content-Security-Policy'), None) - - @gen_test - def test_cors_options_headers(self): + response = await jp_fetch("api", "kernels", method="GET") + assert response.code == 200 + assert response.headers['Access-Control-Allow-Credentials'] == 'false' + assert response.headers['Access-Control-Allow-Headers'] == 'Authorization,Content-Type' + assert response.headers['Access-Control-Allow-Methods'] == 'GET,POST' + assert response.headers['Access-Control-Allow-Origin'] == 'https://jupyter.org' + assert response.headers['Access-Control-Expose-Headers'] == 'X-My-Fake-Header' + assert response.headers['Access-Control-Max-Age'] == '600' + assert response.headers.get('Content-Security-Policy') is None + + async def test_cors_options_headers(self, jp_fetch, jp_web_app): """All preflight OPTIONS requests should return configured headers.""" - app = self.get_app() - app.settings['kg_allow_headers'] = 'X-XSRFToken' - app.settings['kg_allow_methods'] = 'GET,POST,OPTIONS' - - response = yield self.http_client.fetch( - self.get_url('/api/kernelspecs'), - method='OPTIONS' - ) - self.assertEqual(response.code, 200) - self.assertEqual(response.headers['Access-Control-Allow-Methods'], 'GET,POST,OPTIONS') - self.assertEqual(response.headers['Access-Control-Allow-Headers'], 'X-XSRFToken') - - @gen_test - def test_max_kernels(self): + jp_web_app.settings['kg_allow_headers'] = 'X-XSRFToken' + jp_web_app.settings['kg_allow_methods'] = 'GET,POST,OPTIONS' + + response = await jp_fetch("api", "kernelspecs", method='OPTIONS') + assert response.code == 200 + assert response.headers['Access-Control-Allow-Methods'] == 'GET,POST,OPTIONS' + assert response.headers['Access-Control-Allow-Headers'] == 'X-XSRFToken' + + async def test_max_kernels(self, jp_fetch, jp_web_app): """Number of kernels should be limited.""" - app = self.get_app() - app.settings['kg_max_kernels'] = 1 + jp_web_app.settings['kg_max_kernels'] = 1 # Request a kernel - response = yield self.http_client.fetch( - self.get_url('/api/kernels'), - method='POST', - body='{}' - ) - self.assertEqual(response.code, 201) + response = await jp_fetch("api", "kernels", method="POST", body='{}') + assert response.code == 201 # Request another - response2 = yield self.http_client.fetch( - self.get_url('/api/kernels'), - method='POST', - body='{}', - raise_error=False - ) - self.assertEqual(response2.code, 403) + with pytest.raises(HTTPClientError) as e: + await jp_fetch("api", "kernels", method="POST", body='{}') + assert e.value.response.code == 403 # Shut down the kernel kernel = json_decode(response.body) - response = yield self.http_client.fetch( - self.get_url('/api/kernels/'+url_escape(kernel['id'])), - method='DELETE' - ) - self.assertEqual(response.code, 204) - - # Try again - response = yield self.http_client.fetch( - self.get_url('/api/kernels'), - method='POST', - body='{}' - ) - self.assertEqual(response.code, 201) - - @gen_test - def test_get_api(self): + response = await jp_fetch("api", "kernels", url_escape(kernel['id']), method="DELETE") + assert response.code == 204 + + # Try creation again + response = await jp_fetch("api", "kernels", method="POST", body='{}') + assert response.code == 201 + + async def test_get_api(self, jp_fetch): """Server should respond with the API version metadata.""" - response = yield self.http_client.fetch( - self.get_url('/api') - ) - self.assertEqual(response.code, 200) + response = await jp_fetch("api", method="GET") + assert response.code == 200 info = json_decode(response.body) - self.assertIn('version', info) + assert 'version' in info - @gen_test - def test_get_kernelspecs(self): + async def test_get_kernelspecs(self, jp_fetch): """Server should respond with kernel spec metadata.""" - response = yield self.http_client.fetch( - self.get_url('/api/kernelspecs') - ) - self.assertEqual(response.code, 200) + response = await jp_fetch("api", "kernelspecs", method="GET") + assert response.code == 200 specs = json_decode(response.body) - self.assertIn('kernelspecs', specs) - self.assertIn('default', specs) + assert 'kernelspecs' in specs + assert 'default' in specs - @gen_test - def test_get_kernels(self): + async def test_get_kernels(self, jp_fetch, jp_web_app): """Server should respond with running kernel information.""" - self.app.web_app.settings['kg_list_kernels'] = True - response = yield self.http_client.fetch( - self.get_url('/api/kernels') - ) - self.assertEqual(response.code, 200) + jp_web_app.settings['kg_list_kernels'] = True + response = await jp_fetch("api", "kernels", method="GET") + assert response.code == 200 kernels = json_decode(response.body) - self.assertEqual(len(kernels), 0) + assert len(kernels) == 0 # Launch a kernel - response = yield self.http_client.fetch( - self.get_url('/api/kernels'), - method='POST', - body='{}' - ) - self.assertEqual(response.code, 201) + response = await jp_fetch("api", "kernels", method="POST", body='{}') + assert response.code == 201 kernel = json_decode(response.body) # Check the list again - response = yield self.http_client.fetch( - self.get_url('/api/kernels') - ) - self.assertEqual(response.code, 200) + response = await jp_fetch("api", "kernels", method="GET") + assert response.code == 200 kernels = json_decode(response.body) - self.assertEqual(len(kernels), 1) - self.assertEqual(kernels[0]['id'], kernel['id']) + assert len(kernels) == 1 + assert kernels[0]['id'] == kernel['id'] - @gen_test - def test_kernel_comm(self): + async def test_kernel_comm(self, spawn_kernel): """Default kernel should launch and accept commands.""" - ws = yield self.spawn_kernel() + ws = await spawn_kernel() # Send a request for kernel info - ws.write_message(json_encode({ + await ws.write_message(json_encode({ 'header': { 'username': '', 'version': '5.0', @@ -414,412 +313,370 @@ def test_kernel_comm(self): # Assert the reply comes back. Test will timeout if this hangs. for _ in range(10): - msg = yield ws.read_message() + msg = await ws.read_message() msg = json_decode(msg) - if(msg['msg_type'] == 'kernel_info_reply'): + if msg['msg_type'] == 'kernel_info_reply': break else: raise AssertionError('never received kernel_info_reply') ws.close() - @gen_test - def test_no_discovery(self): + async def test_no_discovery(self, jp_fetch): """The list of kernels / sessions should be forbidden by default.""" - response = yield self.http_client.fetch( - self.get_url('/api/kernels'), - raise_error=False - ) - self.assertEqual(response.code, 403) - - response = yield self.http_client.fetch( - self.get_url('/api/sessions'), - raise_error=False - ) - self.assertEqual(response.code, 403) - - @gen_test - def test_crud_sessions(self): + with pytest.raises(HTTPClientError) as e: + await jp_fetch("api", "kernels", method="GET") + assert e.value.response.code == 403 + + with pytest.raises(HTTPClientError) as e: + await jp_fetch("api", "sessions", method="GET") + assert e.value.response.code == 403 + + async def test_crud_sessions(self, jp_fetch, jp_web_app): """Server should create, list, and delete sessions.""" - app = self.get_app() - app.settings['kg_list_kernels'] = True + jp_web_app.settings['kg_list_kernels'] = True # Ensure no sessions by default - response = yield self.http_client.fetch( - self.get_url('/api/sessions') - ) - self.assertEqual(response.code, 200) + response = await jp_fetch("api", "sessions", method="GET") + assert response.code == 200 sessions = json_decode(response.body) - self.assertEqual(len(sessions), 0) + assert len(sessions) == 0 # Launch a session - response = yield self.http_client.fetch( - self.get_url('/api/sessions'), - method='POST', - body='{"id":"any","notebook":{"path":"anywhere"},"kernel":{"name":"python"}}' - ) - self.assertEqual(response.code, 201) + response = await jp_fetch("api", "sessions", method="POST", + body='{"id":"any","notebook":{"path":"anywhere"},"kernel":{"name":"python"}}') + assert response.code == 201 session = json_decode(response.body) # Check the list again - response = yield self.http_client.fetch( - self.get_url('/api/sessions') - ) - self.assertEqual(response.code, 200) + response = await jp_fetch("api", "sessions", method="GET") + assert response.code == 200 sessions = json_decode(response.body) - self.assertEqual(len(sessions), 1) - self.assertEqual(sessions[0]['id'], session['id']) + assert len(sessions) == 1 + assert sessions[0]['id'] == session['id'] # Delete the session - response = yield self.http_client.fetch( - self.get_url('/api/sessions/'+session['id']), - method='DELETE' - ) - self.assertEqual(response.code, 204) + response = await jp_fetch("api", "sessions", session['id'], method="DELETE") + assert response.code == 204 # Make sure the list is empty - response = yield self.http_client.fetch( - self.get_url('/api/sessions') - ) - self.assertEqual(response.code, 200) + response = await jp_fetch("api", "sessions", method="GET") + assert response.code == 200 sessions = json_decode(response.body) - self.assertEqual(len(sessions), 0) + assert len(sessions) == 0 - @gen_test - def test_json_errors(self): + async def test_json_errors(self, jp_fetch): """Handlers should always return JSON errors.""" # A handler that we override - response = yield self.http_client.fetch( - self.get_url('/api/kernels'), - raise_error=False - ) - body = json_decode(response.body) - self.assertEqual(response.code, 403) - self.assertEqual(body['reason'], 'Forbidden') + with pytest.raises(HTTPClientError) as e: + await jp_fetch("api", "kernels", method="GET") + assert e.value.response.code == 403 + + body = json_decode(e.value.response.body) + assert body['reason'] == 'Forbidden' # A handler from the notebook base - response = yield self.http_client.fetch( - self.get_url('/api/kernels/1-2-3-4-5'), - raise_error=False - ) - body = json_decode(response.body) - self.assertEqual(response.code, 404) - # Base handler json_errors decorator does not capture reason properly - # self.assertEqual(body['reason'], 'Not Found') - self.assertIn('1-2-3-4-5', body['message']) + with pytest.raises(HTTPClientError) as e: + await jp_fetch("api", "kernels", "1-2-3-4-5", method="GET") + assert e.value.response.code == 404 + + body = json_decode(e.value.response.body) + assert "1-2-3-4-5" in body['message'] # The last resort not found handler - response = yield self.http_client.fetch( - self.get_url('/fake-endpoint'), - raise_error=False - ) - body = json_decode(response.body) - self.assertEqual(response.code, 404) - self.assertEqual(body['reason'], 'Not Found') - - @gen_test - def test_kernel_env(self): + with pytest.raises(HTTPClientError) as e: + await jp_fetch("fake-endpoint", method="GET") + assert e.value.response.code == 404 + + body = json_decode(e.value.response.body) + assert body["reason"] == "Not Found" + + @pytest.mark.parametrize("jp_argv", + (["--JupyterWebsocketPersonality.env_whitelist=TEST_VAR"],)) + async def test_kernel_env(self, spawn_kernel, jp_argv): """Kernel should start with environment vars defined in the request.""" - self.app.personality.env_whitelist = ['TEST_VAR'] + kernel_body = json.dumps({ - 'name': 'python', - 'env': { - 'KERNEL_FOO': 'kernel-foo-value', - 'NOT_KERNEL': 'ignored', - 'KERNEL_GATEWAY': 'overridden', - 'TEST_VAR': 'allowed' + "name": "python", + "env": { + "KERNEL_FOO": "kernel-foo-value", + "NOT_KERNEL": "ignored", + "KERNEL_GATEWAY": "overridden", + "TEST_VAR": "allowed" } }) - ws = yield self.spawn_kernel(kernel_body) - req = self.execute_request('import os; print(os.getenv("KERNEL_FOO"), os.getenv("NOT_KERNEL"), os.getenv("KERNEL_GATEWAY"), os.getenv("TEST_VAR"))') - ws.write_message(json_encode(req)) - content = yield self.await_stream(ws) - self.assertEqual(content['name'], 'stdout') - self.assertIn('kernel-foo-value', content['text']) - self.assertNotIn('ignored', content['text']) - self.assertNotIn('overridden', content['text']) - self.assertIn('allowed', content['text']) - + ws = await spawn_kernel(kernel_body) + req = get_execute_request('import os; print(os.getenv("KERNEL_FOO"), os.getenv("NOT_KERNEL"), ' + 'os.getenv("KERNEL_GATEWAY"), os.getenv("TEST_VAR"))') + + await ws.write_message(json_encode(req)) + content = await await_stream(ws) + + assert content["name"] == "stdout" + assert "kernel-foo-value" in content["text"] + assert "ignored" not in content["text"] + assert "overridden" not in content["text"] + assert "allowed" in content["text"] ws.close() - @gen_test - def test_get_swagger_yaml_spec(self): + async def test_get_swagger_yaml_spec(self, jp_fetch): """Getting the swagger.yaml spec should be ok""" - response = yield self.http_client.fetch(self.get_url('/api/swagger.yaml')) - self.assertEqual(response.code, 200) + response = await jp_fetch("api", "swagger.yaml", method="GET") + assert response.code == 200 - @gen_test - def test_get_swagger_json_spec(self): + async def test_get_swagger_json_spec(self, jp_fetch): """Getting the swagger.json spec should be ok""" - response = yield self.http_client.fetch(self.get_url('/api/swagger.json')) - self.assertEqual(response.code, 200) + response = await jp_fetch("api", "swagger.json", method="GET") + assert response.code == 200 - @gen_test - def test_kernel_env_auth_token(self): + async def test_kernel_env_auth_token(self, monkeypatch, spawn_kernel): """Kernel should not have KG_AUTH_TOKEN in its environment.""" - os.environ['KG_AUTH_TOKEN'] = 'fake-secret' + monkeypatch.setenv("KG_AUTH_TOKEN", "fake-secret") + ws = None try: - ws = yield self.spawn_kernel() - req = self.execute_request('import os; print(os.getenv("KG_AUTH_TOKEN"))') - ws.write_message(json_encode(req)) - content = yield self.await_stream(ws) - self.assertNotIn('fake-secret', content['text']) + ws = await spawn_kernel() + req = get_execute_request("import os; print(os.getenv('KG_AUTH_TOKEN'))") + await ws.write_message(json_encode(req)) + content = await await_stream(ws) + assert "fake-secret" not in content["text"] + assert "None" in content["text"] # ensure None was printed finally: - del os.environ['KG_AUTH_TOKEN'] - ws.close() + if ws is not None: + ws.close() -class TestCustomDefaultKernel(TestJupyterWebsocket): +@pytest.mark.parametrize("jp_argv", + ([f"--KernelGatewayApp.default_kernel_name=fake-kernel"],)) +class TestCustomDefaultKernel: """Tests gateway behavior when setting a custom default kernelspec.""" - def setup_app(self): - self.app.default_kernel_name = 'fake-kernel' - - @gen_test - def test_default_kernel_name(self): + async def test_default_kernel_name(self, jp_argv, jp_fetch): """The default kernel name should be used on empty requests.""" - # Request without an explicit kernel name - response = yield self.http_client.fetch( - self.get_url('/api/kernels'), - method='POST', - body='', - raise_error=False - ) - self.assertEqual(response.code, 500) - self.assertTrue('raise NoSuchKernel' in str(response.body)) - - -class TestForceKernel(TestJupyterWebsocket): + with pytest.raises(HTTPClientError) as e: + await jp_fetch("api", "kernels", method="POST", body='') + assert e.value.response.code == 500 + assert "raise NoSuchKernel" in str(e.value.response.body) + + +@pytest.mark.parametrize("jp_argv", + (["--KernelGatewayApp.prespawn_count=2", + f"--KernelGatewayApp.seed_uri={os.path.join(RESOURCES, 'zen.ipynb')}", + "--KernelGatewayApp.force_kernel_name=python3"],)) +class TestForceKernel: """Tests gateway behavior when forcing a kernelspec.""" - def setup_app(self): - self.app.prespawn_count = 2 - self.app.seed_uri = os.path.join(RESOURCES, - 'zen.ipynb') - self.app.force_kernel_name = 'python3' - - @gen_test - def test_force_kernel_name(self): + async def test_force_kernel_name(self, jp_argv, jp_fetch): """Should create a Python kernel.""" - response = yield self.http_client.fetch( - self.get_url('/api/kernels'), - method='POST', - body='{"name": "fake-kernel"}', - raise_error=False - ) - self.assertEqual(response.code, 201) + response = await jp_fetch("api", "kernels", method="POST", body='{"name": "fake-kernel"}') + assert response.code == 201 + kernel = json_decode(response.body) + assert kernel["name"] == "python3" -class TestEnableDiscovery(TestJupyterWebsocket): +class TestEnableDiscovery: """Tests gateway behavior with kernel listing enabled.""" - def setup_configurables(self): - """Enables kernel listing for all tests.""" - self.app.personality.list_kernels = True - - @gen_test - def test_enable_kernel_list(self): + @pytest.mark.parametrize("jp_argv", (["--JupyterWebsocketPersonality.list_kernels=True"],)) + async def test_enable_kernel_list(self, jp_fetch, jp_argv): """The list of kernels, sessions, and activities should be available.""" - response = yield self.http_client.fetch( - self.get_url('/api/kernels'), - ) - self.assertEqual(response.code, 200) - self.assertTrue('[]' in str(response.body)) - response = yield self.http_client.fetch( - self.get_url('/api/sessions'), - ) - self.assertEqual(response.code, 200) - self.assertTrue('[]' in str(response.body)) - -class TestPrespawnKernels(TestJupyterWebsocket): - """Tests gateway behavior when kernels are spawned at startup.""" - def setup_app(self): - """Always prespawn 2 kernels.""" - self.app.prespawn_count = 2 - @gen_test(timeout=10) - def test_prespawn_count(self): + response = await jp_fetch("api", "kernels", method="GET") + assert response.code == 200 + assert '[]' in str(response.body) + + response = await jp_fetch("api", "sessions", method="GET") + assert response.code == 200 + assert '[]' in str(response.body) + + +class TestPrespawnKernels: + """Tests gateway behavior when kernels are spawned at startup.""" + @pytest.mark.parametrize("jp_argv", (["--KernelGatewayApp.prespawn_count=2"],)) + async def test_prespawn_count(self, jp_fetch, jp_web_app, jp_argv): """Server should launch the given number of kernels on startup.""" - self.app.web_app.settings['kg_list_kernels'] = True - response = yield self.http_client.fetch( - self.get_url('/api/kernels') - ) - self.assertEqual(response.code, 200) + jp_web_app.settings['kg_list_kernels'] = True + await sleep(0.5) + response = await jp_fetch("api", "kernels", method="GET") + assert response.code == 200 + kernels = json_decode(response.body) - self.assertEqual(len(kernels), 2) + assert len(kernels) == 2 def test_prespawn_max_conflict(self): - """Server should error if prespawn count is greater than max allowed - kernels. - """ + """Server should error if prespawn count is greater than max allowed kernels.""" app = KernelGatewayApp() app.prespawn_count = 3 app.max_kernels = 2 - self.assertRaises(RuntimeError, app.init_configurables) + with pytest.raises(RuntimeError): + app.init_configurables() -class TestBaseURL(TestJupyterWebsocket): +class TestBaseURL: """Tests gateway behavior when a custom base URL is configured.""" - def setup_app(self): - """Sets the custom base URL and enables kernel listing.""" - self.app.base_url = '/fake/path' - - def setup_configurables(self): - """Enables kernel listing for all tests.""" - self.app.personality.list_kernels = True - - @gen_test - def test_base_url(self): + @pytest.mark.parametrize("jp_argv", (["--JupyterWebsocketPersonality.list_kernels=True"],)) + @pytest.mark.parametrize("jp_base_url", ("/fake/path",)) + async def test_base_url(self, jp_base_url, jp_argv, jp_fetch): """Server should mount resources under configured base.""" - # Should not exist at root - response = yield self.http_client.fetch( - self.get_url('/api/kernels'), - method='GET', - raise_error=False - ) - self.assertEqual(response.code, 404) - # Should exist under path - response = yield self.http_client.fetch( - self.get_url('/fake/path/api/kernels'), - method='GET' - ) - self.assertEqual(response.code, 200) + response = await jp_fetch("api", "kernels", method="GET") + assert response.code == 200 + assert "/fake/path/api/kernels" in response.effective_url -class TestRelativeBaseURL(TestJupyterWebsocket): +class TestRelativeBaseURL: """Tests gateway behavior when a relative base URL is configured.""" - def setup_app(self): - """Sets the custom base URL as a relative path.""" - self.app.base_url = 'fake/path' - - @gen_test - def test_base_url(self): + @pytest.mark.parametrize("jp_argv", (["--JupyterWebsocketPersonality.list_kernels=True"],)) + @pytest.mark.parametrize("jp_base_url", ("/fake/path",)) + async def test_base_url(self, jp_base_url, jp_argv, jp_fetch): """Server should mount resources under fixed base.""" - self.app.web_app.settings['kg_list_kernels'] = True # Should exist under path - response = yield self.http_client.fetch( - self.get_url('/fake/path/api/kernels'), - method='GET' - ) - self.assertEqual(response.code, 200) - + response = await jp_fetch("api", "kernels", method="GET") + assert response.code == 200 + assert "/fake/path/api/kernels" in response.effective_url -class TestSeedURI(TestJupyterWebsocket): - """Tests gateway behavior when a seeding kernel memory with code from a - notebook.""" - def setup_app(self): - self.app.seed_uri = os.path.join(RESOURCES, - 'zen.ipynb') - @gen_test - def test_seed(self): - """Kernel should have variables preseeded from the notebook.""" - ws = yield self.spawn_kernel() +class TestSeedURI: + """Tests gateway behavior when a seeding kernel memory with code from a notebook.""" + @pytest.mark.parametrize("jp_argv", + ([f"--KernelGatewayApp.seed_uri={os.path.join(RESOURCES, 'zen.ipynb')}"],)) + async def test_seed(self, jp_argv, spawn_kernel): + """Kernel should have variables pre-seeded from the notebook.""" + ws = await spawn_kernel() # Print the encoded "zen of python" string, the kernel should have # it imported - req = self.execute_request('print(this.s)') - ws.write_message(json_encode(req)) - content = yield self.await_stream(ws) - self.assertEqual(content['name'], 'stdout') - self.assertIn('Gur Mra bs Clguba', content['text']) + req = get_execute_request("print(this.s)") + await ws.write_message(json_encode(req)) + content = await await_stream(ws) + assert content["name"] == "stdout" + assert "Gur Mra bs Clguba" in content["text"] ws.close() -class TestRemoteSeedURI(TestSeedURI): - """Tests gateway behavior when a seeding kernel memory with code from a - remote notebook. - """ - def setup_app(self): - """Sets the seed notebook to a remote notebook.""" - self.app.seed_uri = 'https://gist.githubusercontent.com/parente/ccd36bd7db2f617d58ce/raw/zen3.ipynb' +class TestRemoteSeedURI: + """Tests gateway behavior when a seeding kernel memory with code from a remote notebook.""" + @pytest.mark.parametrize("jp_argv", + ([f"--KernelGatewayApp.seed_uri=" + f"https://gist.githubusercontent.com/parente/ccd36bd7db2f617d58ce/raw/zen3.ipynb"],)) + async def test_seed(self, jp_argv, spawn_kernel): + """Kernel should have variables pre-seeded from the notebook.""" + ws = await spawn_kernel() + + # Print the encoded "zen of python" string, the kernel should have + # it imported + req = get_execute_request("print(this.s)") + await ws.write_message(json_encode(req)) + content = await await_stream(ws) + assert content["name"] == "stdout" + assert "Gur Mra bs Clguba" in content["text"] + ws.close() -class TestBadSeedURI(TestJupyterWebsocket): - """Tests gateway behavior when seeding kernel memory with notebook code - that fails. - """ - def setup_app(self): - """Sets the seed notebook to one of the test resources.""" - self.app.seed_uri = os.path.join(RESOURCES, - 'failing_code.ipynb') - @gen_test - def test_seed_error(self): +class TestBadSeedURI: + """Tests gateway behavior when seeding kernel memory with notebook code that fails.""" + @pytest.mark.parametrize("jp_argv", + ([f"--KernelGatewayApp.seed_uri={os.path.join(RESOURCES, 'failing_code.ipynb')}", + "--JupyterWebsocketPersonality.list_kernels=True"],)) + async def test_seed_error(self, jp_argv, jp_fetch): """ Server should shutdown kernel and respond with error when seed notebook has an execution error. """ - self.app.web_app.settings['kg_list_kernels'] = True # Request a kernel - response = yield self.http_client.fetch( - self.get_url('/api/kernels'), - method='POST', - body='{}', - raise_error=False - ) - self.assertEqual(response.code, 500) + with pytest.raises(HTTPClientError) as e: + await jp_fetch("api", "kernels", method='POST', body='{}') + assert e.value.response.code == 500 # No kernels should be running - response = yield self.http_client.fetch( - self.get_url('/api/kernels'), - method='GET' - ) + response = await jp_fetch("api", "kernels", method="GET") + assert response.code == 200 kernels = json_decode(response.body) - self.assertEqual(len(kernels), 0) + assert len(kernels) == 0 - def test_seed_kernel_failing(self): + async def test_seed_kernel_not_available(self): """ - Server should error because seed notebook requires a kernel that is not - installed. - """ - app = KernelGatewayApp() - app.prespawn_count = 1 - app.seed_uri = os.path.join(RESOURCES, 'failing_code.ipynb') - self.assertRaises(RuntimeError, app.init_configurables) - - def test_seed_kernel_not_available(self): - """ - Server should error because seed notebook requires a kernel that is not - installed. + Server should error because seed notebook requires a kernel that is not installed. """ app = KernelGatewayApp() app.seed_uri = os.path.join(RESOURCES, 'unknown_kernel.ipynb') - self.assertRaises(NoSuchKernel, app.init_configurables) + with pytest.raises(NoSuchKernel): + app.init_configurables() -class TestKernelLanguageSupport(TestJupyterWebsocket): +@pytest.mark.parametrize("jp_argv", + ([f"--KernelGatewayApp.seed_uri={os.path.join(RESOURCES, 'zen.ipynb')}", + "--KernelGatewayApp.prespawn_count=1"],)) +class TestKernelLanguageSupport: """Tests gateway behavior when a client requests a specific kernel spec.""" - def setup_app(self): - """Sets the app to prespawn one kernel and preseed it with one of the - test notebooks. - """ - self.app.prespawn_count = 1 - self.app.seed_uri = os.path.join(RESOURCES, - 'zen.ipynb') - - @coroutine - def spawn_kernel(self): - """Override the base class spawn utility method to set the Python kernel - version number when spawning. - """ - kernel_body = json.dumps({"name":"python3"}) - ws = yield super(TestKernelLanguageSupport, self).spawn_kernel(kernel_body) - raise Return(ws) - - @gen_test - def test_seed_language_support(self): - """Kernel should have variables preseeded from notebook.""" - ws = yield self.spawn_kernel() + async def test_seed_language_support(self, jp_argv, spawn_kernel): + """Kernel should have variables pre-seeded from notebook.""" + ws = await spawn_kernel(body=json.dumps({"name": "python3"})) code = 'print(this.s)' - # Print the encoded "zen of python" string, the kernel should have - # it imported - req = self.execute_request(code) - ws.write_message(json_encode(req)) - content = yield self.await_stream(ws) - self.assertEqual(content['name'], 'stdout') - self.assertIn('Gur Mra bs Clguba', content['text']) + # Print the encoded "zen of python" string, the kernel should have it imported + req = get_execute_request(code) + await ws.write_message(json_encode(req)) + content = await await_stream(ws) + assert content['name'] == 'stdout' + assert 'Gur Mra bs Clguba' in content['text'] ws.close() + + +class TestSessionApi: + """Test session object API to improve coverage.""" + + async def test_session_api(self, tmp_path, jp_environ): + + # Create the manager instances + akm = AsyncMappingKernelManager() + sm = SessionManager(akm) + + row_model = await sm.create_session(path=str(tmp_path), kernel_name="python3") + assert "id" in row_model + assert "kernel" in row_model + assert row_model["notebook"]["path"] == str(tmp_path) + + session_id = row_model["id"] + kernel_id = row_model["kernel"]["id"] + + # Perform some get_session tests + with pytest.raises(TypeError): + sm.get_session() # no kwargs + + with pytest.raises(TypeError): + kwargs = {"bogus_column": 1} + sm.get_session(**kwargs) # bad column + + non_existent_session_id = uuid.uuid4() + with pytest.raises(HTTPError) as e: + kwargs = {"session_id": str(non_existent_session_id)} + sm.get_session(**kwargs) # bad session id + assert e.value.status_code == 404 + + # Perform some update_session tests + sm.update_session(session_id) # no kwargs - success expected + + with pytest.raises(KeyError): + kwargs = {"kernel_id": kernel_id} + sm.update_session(str(non_existent_session_id), **kwargs) # bad session id + + kwargs = {"path": "/tmp"} + sm.update_session(session_id, **kwargs) # update path of session + + # confirm update + kwargs = {"session_id": session_id} + row_model = sm.get_session(**kwargs) + assert row_model["notebook"]["path"] == "/tmp" + + kwargs = {"kernel_id": str(uuid.uuid4())} + with pytest.raises(KeyError): + sm.update_session(session_id, **kwargs) # bad kernel_id + + await sm.delete_session(session_id) + + with pytest.raises(HTTPError) as e: + kwargs = {"session_id": session_id} + sm.get_session(**kwargs) + assert e.value.status_code == 404 diff --git a/kernel_gateway/tests/test_mixins.py b/kernel_gateway/tests/test_mixins.py index 33f0d58..3f41b98 100644 --- a/kernel_gateway/tests/test_mixins.py +++ b/kernel_gateway/tests/test_mixins.py @@ -3,17 +3,13 @@ """Tests for handler mixins.""" import json -import unittest - -try: - from unittest.mock import Mock, MagicMock -except ImportError: - # Python 2.7: use backport - from mock import Mock, MagicMock - +import pytest +from unittest.mock import Mock from tornado import web + from kernel_gateway.mixins import TokenAuthorizationMixin, JSONErrorsMixin + class SuperTokenAuthHandler(object): """Super class for the handler using TokenAuthorizationMixin.""" is_prepared = False @@ -22,10 +18,11 @@ def prepare(self): # called by the mixin when authentication succeeds self.is_prepared = True -class TestableTokenAuthHandler(TokenAuthorizationMixin, SuperTokenAuthHandler): + +class CustomTokenAuthHandler(TokenAuthorizationMixin, SuperTokenAuthHandler): """Implementation that uses the TokenAuthorizationMixin for testing.""" - def __init__(self, token=''): - self.settings = { 'kg_auth_token': token } + def __init__(self, token=""): + self.settings = {"kg_auth_token": token} self.arguments = {} self.response = None self.status_code = None @@ -33,105 +30,96 @@ def __init__(self, token=''): def send_error(self, status_code): self.status_code = status_code - def get_argument(self, name, default=''): + def get_argument(self, name, default=""): return self.arguments.get(name, default) -class TestTokenAuthMixin(unittest.TestCase): - """Unit tests the Token authorization mixin.""" - def setUp(self): - """Creates a handler that uses the mixin.""" - self.mixin = TestableTokenAuthHandler('YouKnowMe') +@pytest.fixture +def auth_mixin(): + auth_mixin_instance = CustomTokenAuthHandler("YouKnowMe") + yield auth_mixin_instance - def test_no_token_required(self): + +class TestTokenAuthMixin: + """Unit tests the Token authorization mixin.""" + def test_no_token_required(self, auth_mixin): """Status should be None.""" - self.mixin.request = Mock({}) - self.mixin.settings['kg_auth_token'] = '' - self.mixin.prepare() - self.assertEqual(self.mixin.is_prepared, True) - self.assertEqual(self.mixin.status_code, None) + auth_mixin.request = Mock({}) + auth_mixin.settings["kg_auth_token"] = "" + auth_mixin.prepare() + assert auth_mixin.is_prepared is True + assert auth_mixin.status_code is None - def test_missing_token(self): + def test_missing_token(self, auth_mixin): """Status should be 'unauthorized'.""" - attrs = { 'headers' : { - } } - self.mixin.request = Mock(**attrs) - self.mixin.prepare() - self.assertEqual(self.mixin.is_prepared, False) - self.assertEqual(self.mixin.status_code, 401) - - def test_valid_header_token(self): + attrs = {"headers": {}} + auth_mixin.request = Mock(**attrs) + auth_mixin.prepare() + assert auth_mixin.is_prepared is False + assert auth_mixin.status_code == 401 + + def test_valid_header_token(self, auth_mixin): """Status should be None.""" - attrs = { 'headers' : { - 'Authorization' : 'token YouKnowMe' - } } - self.mixin.request = Mock(**attrs) - self.mixin.prepare() - self.assertEqual(self.mixin.is_prepared, True) - self.assertEqual(self.mixin.status_code, None) - - def test_wrong_header_token(self): + attrs = {"headers": {"Authorization": "token YouKnowMe"}} + auth_mixin.request = Mock(**attrs) + auth_mixin.prepare() + assert auth_mixin.is_prepared is True + assert auth_mixin.status_code is None + + def test_wrong_header_token(self, auth_mixin): """Status should be 'unauthorized'.""" - attrs = { 'headers' : { - 'Authorization' : 'token NeverHeardOf' - } } - self.mixin.request = Mock(**attrs) - self.mixin.prepare() - self.assertEqual(self.mixin.is_prepared, False) - self.assertEqual(self.mixin.status_code, 401) - - def test_valid_url_token(self): + attrs = {"headers": {"Authorization": "token NeverHeardOf"}} + auth_mixin.request = Mock(**attrs) + auth_mixin.prepare() + assert auth_mixin.is_prepared is False + assert auth_mixin.status_code == 401 + + def test_valid_url_token(self, auth_mixin): """Status should be None.""" - self.mixin.arguments['token'] = 'YouKnowMe' - attrs = { 'headers' : { - } } - self.mixin.request = Mock(**attrs) - self.mixin.prepare() - self.assertEqual(self.mixin.is_prepared, True) - self.assertEqual(self.mixin.status_code, None) - - def test_wrong_url_token(self): + auth_mixin.arguments["token"] = "YouKnowMe" + attrs = {"headers": {}} + auth_mixin.request = Mock(**attrs) + auth_mixin.prepare() + assert auth_mixin.is_prepared is True + assert auth_mixin.status_code is None + + def test_wrong_url_token(self, auth_mixin): """Status should be 'unauthorized'.""" - self.mixin.arguments['token'] = 'NeverHeardOf' - attrs = { 'headers' : { - } } - self.mixin.request = Mock(**attrs) - self.mixin.prepare() - self.assertEqual(self.mixin.is_prepared, False) - self.assertEqual(self.mixin.status_code, 401) - - def test_differing_tokens_valid_url(self): + auth_mixin.arguments["token"] = "NeverHeardOf" + attrs = {"headers": {}} + auth_mixin.request = Mock(**attrs) + auth_mixin.prepare() + assert auth_mixin.is_prepared is False + assert auth_mixin.status_code == 401 + + def test_differing_tokens_valid_url(self, auth_mixin): """Status should be None, URL token takes precedence""" - self.mixin.arguments['token'] = 'YouKnowMe' - attrs = { 'headers' : { - 'Authorization' : 'token NeverHeardOf' - } } - self.mixin.request = Mock(**attrs) - self.mixin.prepare() - self.assertEqual(self.mixin.is_prepared, True) - self.assertEqual(self.mixin.status_code, None) - - def test_differing_tokens_wrong_url(self): + auth_mixin.arguments["token"] = "YouKnowMe" + attrs = {"headers": {"Authorization": "token NeverHeardOf"}} + auth_mixin.request = Mock(**attrs) + auth_mixin.prepare() + assert auth_mixin.is_prepared is True + assert auth_mixin.status_code is None + + def test_differing_tokens_wrong_url(self, auth_mixin): """Status should be 'unauthorized', URL token takes precedence""" - attrs = { 'headers' : { - 'Authorization' : 'token YouKnowMe' - } } - self.mixin.request = Mock(**attrs) - self.mixin.arguments['token'] = 'NeverHeardOf' - self.mixin.prepare() - self.assertEqual(self.mixin.is_prepared, False) - self.assertEqual(self.mixin.status_code, 401) - - def test_unset_client_token_with_options(self): + attrs = {"headers": {"Authorization": "token YouKnowMe"}} + auth_mixin.request = Mock(**attrs) + auth_mixin.arguments["token"] = "NeverHeardOf" + auth_mixin.prepare() + assert auth_mixin.is_prepared is False + assert auth_mixin.status_code == 401 + + def test_unset_client_token_with_options(self, auth_mixin): """No token is needed for an OPTIONS request. Enables CORS.""" - attrs = { 'method' : 'OPTIONS' } - self.mixin.request = Mock(**attrs) - self.mixin.prepare() - self.assertEqual(self.mixin.is_prepared, True) - self.assertEqual(self.mixin.status_code, None) + attrs = {"method": "OPTIONS"} + auth_mixin.request = Mock(**attrs) + auth_mixin.prepare() + assert auth_mixin.is_prepared is True + assert auth_mixin.status_code is None -class TestableJSONErrorsHandler(JSONErrorsMixin): +class CustomJSONErrorsHandler(JSONErrorsMixin): """Implementation that uses the JSONErrorsMixin for testing.""" def __init__(self): self.headers = {} @@ -153,36 +141,39 @@ def set_status(self, status_code, reason=None): def set_header(self, name, value): self.headers[name] = value -class TestJSONErrorsMixin(unittest.TestCase): - """Unit tests the JSON errors mixin.""" - def setUp(self): - """Creates a handler that uses the mixin.""" - self.mixin = TestableJSONErrorsHandler() - def test_status(self): +@pytest.fixture +def errors_mixin(): + errors_mixin_instance = CustomJSONErrorsHandler() + yield errors_mixin_instance + + +class TestJSONErrorsMixin: + """Unit tests the JSON errors mixin.""" + def test_status(self, errors_mixin): """Status should be set on the response.""" - self.mixin.write_error(404) - response = json.loads(self.mixin.response) - self.assertEqual(self.mixin.status_code, 404) - self.assertEqual(response['reason'], 'Not Found') - self.assertEqual(response['message'], '') + errors_mixin.write_error(404) + response = json.loads(errors_mixin.response) + assert errors_mixin.status_code == 404 + assert response["reason"] == "Not Found" + assert response["message"] == "" - def test_custom_status(self): + def test_custom_status(self, errors_mixin): """Custom reason from exeception should be set in the response.""" - exc = web.HTTPError(500, reason='fake-reason') - self.mixin.write_error(500, exc_info=[None, exc]) + exc = web.HTTPError(500, reason="fake-reason") + errors_mixin.write_error(500, exc_info=[None, exc]) - response = json.loads(self.mixin.response) - self.assertEqual(self.mixin.status_code, 500) - self.assertEqual(response['reason'], 'fake-reason') - self.assertEqual(response['message'], '') + response = json.loads(errors_mixin.response) + assert errors_mixin.status_code == 500 + assert response["reason"] == "fake-reason" + assert response["message"] == "" - def test_log_message(self): + def test_log_message(self, errors_mixin): """Custom message from exeception should be set in the response.""" - exc = web.HTTPError(410, log_message='fake-message') - self.mixin.write_error(410, exc_info=[None, exc]) + exc = web.HTTPError(410, log_message="fake-message") + errors_mixin.write_error(410, exc_info=[None, exc]) - response = json.loads(self.mixin.response) - self.assertEqual(self.mixin.status_code, 410) - self.assertEqual(response['reason'], 'Gone') - self.assertEqual(response['message'], 'fake-message') + response = json.loads(errors_mixin.response) + assert errors_mixin.status_code == 410 + assert response["reason"] == "Gone" + assert response["message"] == "fake-message" diff --git a/kernel_gateway/tests/test_notebook_http.py b/kernel_gateway/tests/test_notebook_http.py index 8dd327f..03b701e 100644 --- a/kernel_gateway/tests/test_notebook_http.py +++ b/kernel_gateway/tests/test_notebook_http.py @@ -2,434 +2,266 @@ # Distributed under the terms of the Modified BSD License. """Tests for notebook-http mode.""" +import asyncio import os import json +import pytest + +from tornado.httpclient import HTTPClientError +from traitlets.config import Config -from .test_gatewayapp import TestGatewayAppBase, RESOURCES from ..notebook_http.swagger.handlers import SwaggerSpecHandler -from tornado.testing import gen_test +from .test_gatewayapp import RESOURCES + + +@pytest.fixture +def jp_server_config(): + """Allows tests to setup their specific configuration values.""" + config = { + "KernelGatewayApp": { + "api": "kernel_gateway.notebook_http", + "seed_uri": os.path.join(RESOURCES, "kernel_api.ipynb"), + } + } + return Config(config) -class TestDefaults(TestGatewayAppBase): + +class TestDefaults: """Tests gateway behavior.""" - def setup_app(self): - """Sets the notebook-http mode and points to a local test notebook as - the basis for the API. - """ - self.app.api = 'kernel_gateway.notebook_http' - self.app.seed_uri = os.path.join(RESOURCES, - 'kernel_api.ipynb') - @gen_test - def test_api_get_endpoint(self): + async def test_api_get_endpoint(self, jp_fetch): """GET HTTP method should be callable""" - response = yield self.http_client.fetch( - self.get_url('/hello'), - method='GET', - raise_error=False - ) - self.assertEqual(response.code, 200, 'GET endpoint did not return 200.') - self.assertEqual(response.body, b'hello world\n', 'Unexpected body in response to GET.') - - @gen_test - def test_api_get_endpoint_with_path_param(self): + response = await jp_fetch("hello", method="GET") + assert response.code == 200, "GET endpoint did not return 200." + assert response.body == b"hello world\n", "Unexpected body in response to GET." + + async def test_api_get_endpoint_with_path_param(self, jp_fetch): """GET HTTP method should be callable with a path param""" - response = yield self.http_client.fetch( - self.get_url('/hello/governor'), - method='GET', - raise_error=False - ) - self.assertEqual(response.code, 200, 'GET endpoint did not return 200.') - self.assertEqual(response.body, b'hello governor\n', 'Unexpected body in response to GET.') - - @gen_test - def test_api_get_endpoint_with_query_param(self): + response = await jp_fetch("hello", "governor", method="GET") + assert response.code == 200, "GET endpoint did not return 200." + assert response.body == b"hello governor\n", "Unexpected body in response to GET." + + async def test_api_get_endpoint_with_query_param(self, jp_fetch): """GET HTTP method should be callable with a query param""" - response = yield self.http_client.fetch( - self.get_url('/hello/person?person=governor'), - method='GET', - raise_error=False - ) - self.assertEqual(response.code, 200, 'GET endpoint did not return 200.') - self.assertEqual(response.body, b'hello governor\n', 'Unexpected body in response to GET.') - - @gen_test - def test_api_get_endpoint_with_multiple_query_params(self): + response = await jp_fetch("hello", "person", params={"person": "governor"}, method="GET") + assert response.code == 200, "GET endpoint did not return 200." + print(f"response.body = '{response.body}'") + assert response.body == b"hello governor\n", "Unexpected body in response to GET." + + async def test_api_get_endpoint_with_multiple_query_params(self, jp_fetch): """GET HTTP method should be callable with multiple query params""" - response = yield self.http_client.fetch( - self.get_url('/hello/persons?person=governor&person=rick'), - method='GET', - raise_error=False - ) - self.assertEqual(response.code, 200, 'GET endpoint did not return 200.') - self.assertEqual(response.body, b'hello governor, rick\n', 'Unexpected body in response to GET.') - - @gen_test - def test_api_put_endpoint(self): + response = await jp_fetch("hello", "persons", params={"person": "governor, rick"}, method="GET") + assert response.code == 200, "GET endpoint did not return 200." + assert response.body == b"hello governor, rick\n", "Unexpected body in response to GET." + + async def test_api_put_endpoint(self, jp_fetch): """PUT HTTP method should be callable""" - response = yield self.http_client.fetch( - self.get_url('/message'), - method='PUT', - body='hola {}', - raise_error=False - ) - self.assertEqual(response.code, 200, 'PUT endpoint did not return 200.') - - response = yield self.http_client.fetch( - self.get_url('/message'), - method='GET', - raise_error=False - ) - self.assertEqual(response.code, 200, 'GET endpoint did not return 200.') - self.assertEqual(response.body, b'hola {}\n', 'Unexpected body in response to GET after performing PUT.') - - @gen_test - def test_api_post_endpoint(self): + response = await jp_fetch("message", method="PUT", body="hola {}") + assert response.code == 200, "PUT endpoint did not return 200." + + response = await jp_fetch("message", method="GET") + assert response.code == 200, "GET endpoint did not return 200." + assert response.body == b"hola {}\n", "Unexpected body in response to GET after performing PUT." + + async def test_api_post_endpoint(self, jp_fetch): """POST endpoint should be callable""" expected = b'["Rick", "Maggie", "Glenn", "Carol", "Daryl"]\n' - response = yield self.http_client.fetch( - self.get_url('/people'), - method='POST', - body=expected.decode('UTF-8'), - raise_error=False, - headers={'Content-Type': 'application/json'} - ) - self.assertEqual(response.code, 200, 'POST endpoint did not return 200.') - self.assertEqual(response.body, expected, 'Unexpected body in response to POST.') - - @gen_test - def test_api_delete_endpoint(self): + response = await jp_fetch("people", method="POST", body=expected.decode("UTF-8"), headers={"Content-Type": "application/json"}) + assert response.code == 200, "POST endpoint did not return 200." + assert response.body == expected, "Unexpected body in response to POST." + + async def test_api_delete_endpoint(self, jp_fetch): """DELETE HTTP method should be callable""" expected = b'["Rick", "Maggie", "Glenn", "Carol", "Daryl"]\n' - response = yield self.http_client.fetch( - self.get_url('/people'), - method='POST', - body=expected.decode('UTF-8'), - raise_error=False, - headers={'Content-Type': 'application/json'} - ) - response = yield self.http_client.fetch( - self.get_url('/people/2'), - method='DELETE', - raise_error=False, - ) - self.assertEqual(response.code, 200, 'DELETE endpoint did not return 200.') - self.assertEqual(response.body, b'["Rick", "Maggie", "Carol", "Daryl"]\n', 'Unexpected body in response to DELETE.') - - @gen_test - def test_api_error_endpoint(self): + response = await jp_fetch("people", method="POST", body=expected.decode("UTF-8"), headers={"Content-Type": "application/json"}) + response = await jp_fetch("people", "2", method="DELETE") + assert response.code == 200, "DELETE endpoint did not return 200." + assert response.body == b'["Rick", "Maggie", "Carol", "Daryl"]\n', "Unexpected body in response to DELETE." + + async def test_api_error_endpoint(self, jp_fetch): """Error in a cell should cause 500 HTTP status""" - response = yield self.http_client.fetch( - self.get_url('/error'), - method='GET', - raise_error=False - ) - self.assertEqual(response.code, 500, 'Cell with error did not return 500 status code.') - - @gen_test - def test_api_stderr_endpoint(self): + with pytest.raises(HTTPClientError) as e: + await jp_fetch("error", method="GET") + assert e.value.code == 500, "Cell with error did not return 500 status code." + + async def test_api_stderr_endpoint(self, jp_fetch): """stderr output in a cell should be dropped""" - response = yield self.http_client.fetch( - self.get_url('/stderr'), - method='GET', - raise_error=False - ) - self.assertEqual(response.body, b'I am text on stdout\n', 'Unexpected text in response') - - @gen_test - def test_api_unsupported_method(self): + response = await jp_fetch("stderr", method="GET") + assert response.body == b"I am text on stdout\n", "Unexpected text in response" + + async def test_api_unsupported_method(self, jp_fetch): """Endpoints which do no support an HTTP verb should respond with 405. """ - response = yield self.http_client.fetch( - self.get_url('/message'), - method='DELETE', - raise_error=False - ) - self.assertEqual(response.code, 405, 'Endpoint which exists, but does not support DELETE, did not return 405 status code.') - - @gen_test - def test_api_undefined(self): + with pytest.raises(HTTPClientError) as e: + await jp_fetch("message", method="DELETE") + assert e.value.code == 405, "Endpoint which exists, but does not support DELETE, did not return 405 status code." + + async def test_api_undefined(self, jp_fetch): """Endpoints which are not registered at all should respond with 404. """ - response = yield self.http_client.fetch( - self.get_url('/not/an/endpoint'), - method='GET', - raise_error=False - ) - body = json.loads(response.body.decode('UTF-8')) - self.assertEqual(response.code, 404, 'Endpoint which should not exist did not return 404 status code.') - self.assertEqual(body['reason'], 'Not Found') - - @gen_test - def test_api_access_http_header(self): + with pytest.raises(HTTPClientError) as e: + await jp_fetch("not", "an", "endpoint", method="GET") + + assert e.value.code == 404, "Endpoint which should not exist did not return 404 status code." + body = json.loads(e.value.response.body.decode("UTF-8")) + assert body["reason"] == "Not Found" + + async def test_api_access_http_header(self, jp_fetch): """HTTP endpoints should be able to access request headers""" - content_types = ['text/plain', 'application/json', 'application/atom+xml', 'foo'] + content_types = ["text/plain", "application/json", "application/atom+xml", "foo"] for content_type in content_types: - response = yield self.http_client.fetch( - self.get_url('/content-type'), - method='GET', - raise_error=False, - headers={'Content-Type': content_type} - ) - self.assertEqual(response.code, 200, 'GET endpoint did not return 200.') - self.assertEqual(response.body.decode(encoding='UTF-8'), '{}\n'.format(content_type), 'Unexpected value in response') - - @gen_test - def test_format_request_code_escaped_integration(self): + response = await jp_fetch("content-type", method="GET", headers={"Content-Type": content_type}) + assert response.code == 200, "GET endpoint did not return 200." + assert response.body.decode(encoding="UTF-8") == f"{content_type}\n", "Unexpected value in response" + + async def test_format_request_code_escaped_integration(self, jp_fetch): """Quotes should be properly escaped in request headers.""" - #Test query with escaping of arguements and headers with multiple escaped quotes - response = yield self.http_client.fetch( - self.get_url('/hello/person?person=governor'), - method='GET', - headers={'If-None-Match': '\"\"9a28a9262f954494a8de7442c63d6d0715ce0998\"\"'}, - raise_error=False - ) - self.assertEqual(response.code, 200, 'GET endpoint did not return 200.') - self.assertEqual(response.body, b'hello governor\n', 'Unexpected body in response to GET.') - - @gen_test - def test_blocked_download_notebook_source(self): + # Test query with escaping of arguments and headers with multiple escaped quotes + response = await jp_fetch("hello", "person", params={"person": "governor"}, method="GET", + headers={"If-None-Match": '\"\"9a28a9262f954494a8de7442c63d6d0715ce0998\"\"'}) + assert response.code == 200, "GET endpoint did not return 200." + assert response.body == b"hello governor\n", "Unexpected body in response to GET." + + async def test_blocked_download_notebook_source(self, jp_fetch): """Notebook source should not exist under the path /_api/source when `allow_notebook_download` is False or not configured. """ - response = yield self.http_client.fetch( - self.get_url('/_api/source'), - method='GET', - raise_error=False - ) - self.assertEqual(response.code, 404, "/_api/source found when allow_notebook_download is false") - - @gen_test - def test_blocked_public(self): + with pytest.raises(HTTPClientError) as e: + await jp_fetch("_api", "source", method="GET") + assert e.value.code == 404, "/_api/source found when allow_notebook_download is false" + + async def test_blocked_public(self, jp_fetch): """Public static assets should not exist under the path /public when `static_path` is False or not configured. """ - response = yield self.http_client.fetch( - self.get_url('/public'), - method='GET', - raise_error=False - ) - self.assertEqual(response.code, 404, "/public found when static_path is false") - - @gen_test - def test_api_returns_execute_result(self): + with pytest.raises(HTTPClientError) as e: + await jp_fetch("public", method="GET") + assert e.value.code == 404, "/public found when static_path is false" + + async def test_api_returns_execute_result(self, jp_fetch): """GET HTTP method should return the result of cell execution""" - response = yield self.http_client.fetch( - self.get_url('/execute_result'), - method='GET', - raise_error=False - ) - self.assertEqual(response.code, 200, 'GET endpoint did not return 200.') - self.assertEqual(response.body, b'{"text/plain": "2"}', 'Unexpected body in response to GET.') - - @gen_test - def test_cells_concatenate(self): + response = await jp_fetch("execute_result", method="GET") + assert response.code == 200, "GET endpoint did not return 200." + assert response.body == b'{"text/plain": "2"}', "Unexpected body in response to GET." + + async def test_cells_concatenate(self, jp_fetch): """Multiple cells with the same verb and path should concatenate.""" - response = yield self.http_client.fetch( - self.get_url('/multi'), - method='GET', - raise_error=False - ) - self.assertEqual(response.code, 200, 'GET endpoint did not return 200.') - self.assertEqual(response.body, b'x is 1\n', 'Unexpected body in response to GET.') - - @gen_test - def test_kernel_gateway_environment_set(self): + response = await jp_fetch("multi", method="GET") + assert response.code == 200, "GET endpoint did not return 200." + assert response.body == b"x is 1\n", "Unexpected body in response to GET." + + async def test_kernel_gateway_environment_set(self, jp_fetch): """GET HTTP method should be callable with multiple query params""" - response = yield self.http_client.fetch( - self.get_url('/env_kernel_gateway'), - method='GET', - raise_error=False - ) - self.assertEqual(response.code, 200, 'GET endpoint did not return 200.') - self.assertEqual(response.body, b'KERNEL_GATEWAY is 1\n', 'Unexpected body in response to GET.') - -class TestPublicStatic(TestGatewayAppBase): - """Tests gateway behavior when public static assets are enabled.""" - def setup_app(self): - """Sets the notebook-http mode and points to a local test notebook as - the basis for the API. - """ - self.app.api = 'kernel_gateway.notebook_http' - self.app.seed_uri = os.path.join(RESOURCES, - 'kernel_api.ipynb') + response = await jp_fetch("env_kernel_gateway", method="GET") + assert response.code == 200, "GET endpoint did not return 200." + assert response.body == b"KERNEL_GATEWAY is 1\n", "Unexpected body in response to GET." - def setup_configurables(self): - """Configures the static path at the root of the resources/public folder.""" - self.app.personality.static_path = os.path.join(RESOURCES, 'public') - @gen_test - def test_get_public(self): +@pytest.mark.parametrize("jp_argv", + ([f"--NotebookHTTPPersonality.static_path={os.path.join(RESOURCES, 'public')}"],)) +class TestPublicStatic: + """Tests gateway behavior when public static assets are enabled.""" + + async def test_get_public(self, jp_fetch, jp_argv): """index.html should exist under `/public/index.html`.""" - response = yield self.http_client.fetch( - self.get_url('/public/index.html'), - method='GET', - raise_error=False - ) - self.assertEqual(response.code, 200) - self.assertEqual(response.headers.get('Content-Type'), 'text/html') - -class TestSourceDownload(TestGatewayAppBase): - """Tests gateway behavior when notebook download is allowed.""" - def setup_app(self): - """Sets the notebook-http mode, points to a local test notebook as - the basis for the API, and enables downloads of that notebook. - """ - self.app.api = 'kernel_gateway.notebook_http' - self.app.seed_uri = os.path.join(RESOURCES, - 'kernel_api.ipynb') + response = await jp_fetch("public", "index.html", method="GET") + assert response.code == 200 + assert response.headers.get("Content-Type") == "text/html" - def setup_configurables(self): - self.app.personality.allow_notebook_download = True - @gen_test - def test_download_notebook_source(self): +@pytest.mark.parametrize("jp_argv", + (["--NotebookHTTPPersonality.allow_notebook_download=True"],)) +class TestSourceDownload: + """Tests gateway behavior when notebook download is allowed.""" + + async def test_download_notebook_source(self, jp_fetch, jp_argv): """Notebook source should exist under the path `/_api/source`.""" - response = yield self.http_client.fetch( - self.get_url('/_api/source'), - method='GET', - raise_error=False - ) - self.assertEqual(response.code, 200, "/_api/source did not correctly return the downloaded notebook") - -class TestCustomResponse(TestGatewayAppBase): + response = await jp_fetch("_api", "source", method="GET") + assert response.code == 200, "/_api/source did not correctly return the downloaded notebook" + nb = json.loads(response.body) + for key in ["cells", "metadata", "nbformat", "nbformat_minor"]: + assert key in nb + + +@pytest.mark.parametrize("jp_argv", + ([f"--KernelGatewayApp.seed_uri={os.path.join(RESOURCES, 'responses.ipynb')}"],)) +class TestCustomResponse: """Tests gateway behavior when the notebook contains ResponseInfo cells.""" - def setup_app(self): - """Sets the notebook-http mode and points to a local test notebook as - the basis for the API. - """ - self.app.api = 'kernel_gateway.notebook_http' - self.app.seed_uri = os.path.join(RESOURCES, - 'responses.ipynb') - @gen_test - def test_setting_content_type(self): + async def test_setting_content_type(self, jp_fetch, jp_argv): """A response cell should allow the content type to be set""" - response = yield self.http_client.fetch( - self.get_url('/json'), - method='GET', - raise_error=False - ) - result = json.loads(response.body.decode('UTF-8')) - self.assertEqual(response.code, 200, 'Response status was not 200') - self.assertEqual(response.headers['Content-Type'], 'application/json', 'Incorrect mime type was set on response') - self.assertEqual(result, {'hello' : 'world'}, 'Incorrect response value.') - - @gen_test - def test_setting_response_status_code(self): + response = await jp_fetch("json", method="GET") + result = json.loads(response.body.decode("UTF-8")) + assert response.code == 200, "Response status was not 200" + assert response.headers["Content-Type"] == "application/json", "Incorrect mime type was set on response" + assert result == {"hello": "world"}, "Incorrect response value." + + async def test_setting_response_status_code(self, jp_fetch, jp_argv): """A response cell should allow the response status code to be set""" - response = yield self.http_client.fetch( - self.get_url('/nocontent'), - method='GET', - raise_error=False - ) - self.assertEqual(response.code, 204, 'Response status was not 204') - self.assertEqual(response.body, b'', 'Incorrect response value.') - - @gen_test - def test_setting_etag_header(self): + response = await jp_fetch("nocontent", method="GET") + assert response.code == 204, "Response status was not 204" + assert response.body == b"", "Incorrect response value." + + async def test_setting_etag_header(self, jp_fetch, jp_argv): """A response cell should allow the etag header to be set""" - response = yield self.http_client.fetch( - self.get_url('/etag'), - method='GET', - raise_error=False - ) - result = json.loads(response.body.decode('UTF-8')) - self.assertEqual(response.code, 200, 'Response status was not 200') - self.assertEqual(response.headers['Content-Type'], 'application/json', 'Incorrect mime type was set on response') - self.assertEqual(result, {'hello' : 'world'}, 'Incorrect response value.') - self.assertEqual(response.headers['Etag'], '1234567890', 'Incorrect Etag header value.') - -class TestKernelPool(TestGatewayAppBase): - """Tests gateway behavior with more than one kernel in the kernel pool.""" - def setup_app(self): - """Sets the notebook-http mode, points to a local test notebook as - the basis for the API, and spawns 3 kernels to service requests. - """ - self.app.prespawn_count = 3 - self.app.api = 'kernel_gateway.notebook_http' - self.app.seed_uri = os.path.join(RESOURCES, - 'kernel_api.ipynb') + response = await jp_fetch("etag", method="GET") + result = json.loads(response.body.decode("UTF-8")) + assert response.code == 200, "Response status was not 200" + assert response.headers["Content-Type"] == "application/json", "Incorrect mime type was set on response" + assert result, {"hello" : "world"} == "Incorrect response value." + assert response.headers["Etag"] == "1234567890", "Incorrect Etag header value." + + +@pytest.mark.parametrize("jp_argv", (["--KernelGatewayApp.prespawn_count=3"],)) +class TestKernelPool: - @gen_test - def test_should_cycle_through_kernels(self): + async def test_should_cycle_through_kernels(self, jp_fetch, jp_argv): """Requests should cycle through kernels""" - response = yield self.http_client.fetch( - self.get_url('/message'), - method='PUT', - body='hola {}', - raise_error=False - ) - self.assertEqual(response.code, 200, 'PUT endpoint did not return 200.') - - for i in range(self.app.prespawn_count): - response = yield self.http_client.fetch( - self.get_url('/message'), - method='GET', - raise_error=False - ) - - if i != self.app.prespawn_count-1: - self.assertEqual(response.body, b'hello {}\n', 'Unexpected body in response to GET after performing PUT.') + response = await jp_fetch("message", method="PUT", body='hola {}') + assert response.code == 200, 'PUT endpoint did not return 200.' + + for i in range(3): + response = await jp_fetch("message", method="GET") + + if i != 2: + assert response.body == b"hello {}\n", "Unexpected body in response to GET after performing PUT." else: - self.assertEqual(response.body, b'hola {}\n', 'Unexpected body in response to GET after performing PUT.') - @gen_test - def test_concurrent_request_should_not_be_blocked(self): + assert response.body == b"hola {}\n", "Unexpected body in response to GET after performing PUT." + + @pytest.mark.timeout(10) + async def test_concurrent_request_should_not_be_blocked(self, jp_fetch, jp_argv): """Concurrent requests should not be blocked""" - response_long_running = self.http_client.fetch( - self.get_url('/sleep/6'), - method='GET', - raise_error=False - ) - if callable(getattr(response_long_running, 'done', "")): - # Tornado 5 - self.assertFalse(response_long_running.done(), 'Long HTTP Request is not running') - else: - # Tornado 4 - self.assertTrue(response_long_running.running(), 'Long HTTP Request is not running') - - response_short_running = yield self.http_client.fetch( - self.get_url('/sleep/3'), - method='GET', - raise_error=False - ) - if callable(getattr(response_long_running, 'done', "")): - # Tornado 5 - self.assertFalse(response_long_running.done(), 'Long HTTP Request is not running') - else: - # Tornado 4 - self.assertTrue(response_long_running.running(), 'Long HTTP Request is not running') - - self.assertEqual(response_short_running.code, 200, 'Short HTTP Request did not return proper status code of 200') - - @gen_test - def test_locking_semaphore_of_kernel_resources(self): + response_long_running = jp_fetch("sleep", "6", method="GET") + assert response_long_running.done() is False, "Long HTTP Request is not running" + + response_short_running = await jp_fetch("sleep", "3", method="GET") + assert response_short_running.code == 200, "Short HTTP Request did not return proper status code of 200" + assert response_long_running.done() is False, "Long HTTP Request is not running" + while not response_long_running.done(): + await asyncio.sleep(0.3) # let the long request complete + + async def test_locking_semaphore_of_kernel_resources(self, jp_fetch, jp_argv): """Kernel pool should prevent more than one request from running on a kernel at a time. """ futures = [] - for _ in range(self.app.prespawn_count*2+1): - futures.append(self.http_client.fetch( - self.get_url('/sleep/1'), - method='GET', - raise_error=False - )) + for _ in range(7): + futures.append(jp_fetch("sleep", "1", method="GET")) count = 0 for future in futures: - yield future + await future count += 1 - if count >= self.app.prespawn_count + 1: + if count >= 4: break -class TestSwaggerSpec(TestGatewayAppBase): - """Tests gateway behavior when generating a custom base URL is configured.""" - def setup_app(self): - """Sets a different notebook for testing the swagger generation.""" - self.app.api = 'kernel_gateway.notebook_http' - self.app.seed_uri = os.path.join(RESOURCES, - 'simple_api.ipynb') + for future in futures: + await future + - @gen_test - def test_generation_of_swagger_spec(self): +@pytest.mark.parametrize("jp_argv", + ([f"--KernelGatewayApp.seed_uri={os.path.join(RESOURCES, 'simple_api.ipynb')}"],)) +class TestSwaggerSpec: + async def test_generation_of_swagger_spec(self, jp_fetch, jp_argv): """Server should expose a swagger specification of its notebook-defined API. """ @@ -447,78 +279,18 @@ def test_generation_of_swagger_spec(self): "swagger": "2.0" } - response = yield self.http_client.fetch( - self.get_url('/_api/spec/swagger.json'), - method='GET', - raise_error=False - ) - result = json.loads(response.body.decode('UTF-8')) - self.assertEqual(response.code, 200, "Swagger spec endpoint did not return the correct status code") - self.assertEqual(result, expected_response, "Swagger spec endpoint did not return the correct value") - self.assertIsNotNone(SwaggerSpecHandler.output, "Swagger spec output wasn't cached for later requests") - -class TestBaseURL(TestGatewayAppBase): - """Tests gateway behavior when a custom base URL is configured.""" - def setup_app(self): - """Sets the custom base URL and enables the notebook-defined API.""" - self.app.base_url = '/fake/path' - self.app.api = 'kernel_gateway.notebook_http' - self.app.seed_uri = os.path.join(RESOURCES, - 'kernel_api.ipynb') - - def setup_configurables(self): - self.app.personality.allow_notebook_download = True - - @gen_test - def test_base_url(self): - """Server should mount resources under the configured base.""" - # Should not exist at root - response = yield self.http_client.fetch( - self.get_url('/hello'), - method='GET', - raise_error=False - ) - self.assertEqual(response.code, 404) - - response = yield self.http_client.fetch( - self.get_url('/_api/spec/swagger.json'), - method='GET', - raise_error=False - ) - self.assertEqual(response.code, 404) - - # Should exist under path - response = yield self.http_client.fetch( - self.get_url('/fake/path/hello'), - method='GET', - raise_error=False - ) - self.assertEqual(response.code, 200) - - response = yield self.http_client.fetch( - self.get_url('/fake/path/_api/spec/swagger.json'), - method='GET', - raise_error=False - ) - self.assertEqual(response.code, 200) - - -class TestForceKernel(TestGatewayAppBase): - """Tests gateway behavior when forcing a kernel spec.""" - def setup_app(self): - """Sets the notebook-http mode, points to a local test notebook as - the basis for the API, and forces a Python kernel. - """ - self.app.api = 'kernel_gateway.notebook_http' - self.app.seed_uri = os.path.join(RESOURCES, 'unknown_kernel.ipynb') - self.app.force_kernel_name = 'python3' + response = await jp_fetch("_api", "spec", "swagger.json", method="GET") + result = json.loads(response.body.decode("UTF-8")) + assert response.code == 200, "Swagger spec endpoint did not return the correct status code" + assert result == expected_response, "Swagger spec endpoint did not return the correct value" + assert SwaggerSpecHandler.output is not None, "Swagger spec output wasn't cached for later requests" + - @gen_test - def test_force_kernel_spec(self): +@pytest.mark.parametrize("jp_argv", + ([f"--KernelGatewayApp.seed_uri={os.path.join(RESOURCES, 'unknown_kernel.ipynb')}", + "--KernelGatewayApp.force_kernel_name=python3"],)) +class TestForceKernel: + async def test_force_kernel_spec(self, jp_fetch, jp_argv): """Should start properly..""" - response = yield self.http_client.fetch( - self.get_url('/_api/spec/swagger.json'), - method='GET', - raise_error=False - ) - self.assertEqual(response.code, 200) + response = await jp_fetch("_api", "spec", "swagger.json", method="GET") + assert response.code == 200 diff --git a/requirements-test.txt b/requirements-test.txt index 33f4945..6284dfd 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,2 +1,6 @@ coverage -nose +pytest +pytest-cov +pytest_jupyter +pytest-timeout +ipykernel diff --git a/requirements.txt b/requirements.txt index 6c8b02b..f5e545d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ -jupyter_core>=4.4.0 -jupyter_client>=5.2.0 -notebook>=5.7.6,<7.0 -traitlets>=4.2.0 -tornado>=4.2.0 +jupyter_core>=4.12 +jupyter_client>=7.4.4 +jupyter_server>=2.12.2 +traitlets>=5.6.0 +tornado>=6.2.0 requests>=2.7,<3.0 diff --git a/setup.py b/setup.py index 9640f1c..afdc0e0 100644 --- a/setup.py +++ b/setup.py @@ -53,14 +53,14 @@ 'scripts/jupyter-kernelgateway' ], install_requires=[ - 'jupyter_core>=4.4.0', - 'jupyter_client>=5.2.0,<8.0', # Pin < 8 until we update packages in #377 - 'notebook>=5.7.6,<7.0', - 'traitlets>=4.2.0', - 'tornado>=4.2.0', + 'jupyter_client>=7.4.4', + 'jupyter_core>=4.12,!=5.0.*', + 'jupyter_server>=2.0', + 'traitlets>=5.6.0', + 'tornado>=6.2.0', 'requests>=2.7,<3.0' ], - python_requires='>=3.7', + python_requires='>=3.8', classifiers=[ 'Intended Audience :: Developers', 'Intended Audience :: System Administrators', @@ -69,9 +69,10 @@ 'Operating System :: OS Independent', 'Programming Language :: Python', 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', ], include_package_data=True, )