diff --git a/.dockerignore b/.dockerignore index 5af9064..e145e98 100644 --- a/.dockerignore +++ b/.dockerignore @@ -3,3 +3,5 @@ ./.travis.yml ./.env ./docker-compose.yml +*.log +*.db diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 991ab31..6399338 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -1,4 +1,4 @@ -name: Release +name: Publish on: release: @@ -16,19 +16,21 @@ jobs: steps: - uses: actions/checkout@v4 with: + # Fetch all commits and tags to make sure the version gets set right by hatch-vcs. + # This should be temporary; once these changes are merged into the primary branch it can + # probably be removed. fetch-depth: 0 fetch-tags: true - name: Set up Python uses: actions/setup-python@v5 with: - python-version: "3.9" - - name: Install dependencies + python-version: "3.10" + - name: Install hatch run: | - python -m pip install -U pip build - python -m pip install -r requirements.txt . + python -m pip install hatch - name: Build dist run: | - python -m build + hatch build - name: Store the distribution packages uses: actions/upload-artifact@v4 with: diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml new file mode 100644 index 0000000..e2772f5 --- /dev/null +++ b/.github/workflows/run-tests.yml @@ -0,0 +1,41 @@ +name: Run tests + +on: + push: + branches: + - main + - develop + pull_request: + branches: + - main + - develop + +jobs: + test: + name: Run tests + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10", "3.11"] + + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + fetch-tags: true + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install hatch + run: | + python -m pip install hatch + - name: Lint + run: | + hatch fmt + - name: Test + run: | + hatch run +py=${{ matrix.python-version }} test:pytest + - name: Build dist + run: | + hatch build diff --git a/.gitignore b/.gitignore index e45e603..5b8211f 100644 --- a/.gitignore +++ b/.gitignore @@ -31,6 +31,9 @@ wheels/ .installed.cfg *.egg +# Auto generated during builds +aardvark/__version__.py + # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. @@ -106,4 +109,5 @@ ENV/ .mypy_cache/ # config -config.py +.secrets.yaml +settings.local.yaml.bak diff --git a/Dockerfile b/Dockerfile index ba074e8..e4e7b40 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.8 +FROM python:3.10 RUN apt-get update -y \ && apt-get upgrade -y \ @@ -15,10 +15,12 @@ WORKDIR /etc/aardvark ENV AARDVARK_DATA_DIR=/data \ AARDVARK_ROLE=Aardvark \ ARN_PARTITION=aws \ - AWS_DEFAULT_REGION=us-east-1 + AWS_DEFAULT_REGION=us-east-1 \ + FLASK_APP=aardvark EXPOSE 5000 +COPY ./settings.yaml . COPY ./entrypoint.sh /etc/aardvark/entrypoint.sh ENTRYPOINT [ "/etc/aardvark/entrypoint.sh" ] diff --git a/MANIFEST.in b/MANIFEST.in index bab913d..8348f77 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,3 +1,3 @@ -include setup.py README.md MANIFEST.in LICENSE +include pyproject.toml requirements.txt README.md MANIFEST.in LICENSE aardvark/config_default.yaml recursive-include aardvark *.js global-exclude *~ \ No newline at end of file diff --git a/README.md b/README.md index b589a14..48de534 100644 --- a/README.md +++ b/README.md @@ -3,46 +3,71 @@ Aardvark - Multi-Account AWS IAM Access Advisor API [![NetflixOSS Lifecycle](https://img.shields.io/osslifecycle/Netflix/osstracker.svg)]() [![Discord chat](https://img.shields.io/discord/754080763070382130?logo=discord)](https://discord.gg/9kwMWa6) -Aardvark Logo +![Aardvark Logo](docs/images/aardvark_logo_small.png) Aardvark is a multi-account AWS IAM Access Advisor API (and caching layer). -## Install: +## New in `v1.0.0` -Ensure that you have Python 3.6 or later. Python 2 is no longer supported. +⚠️ Breaking change -```bash +✨ Enhancement + +- ⚠️ Upgrade to Python 3.10+ +- ⚠️ New configuration format +- ✨ Pluggable persistence layer +- ✨ Pluggable retrievers + +## Install + +Ensure that you have Python 3.10 or later. + +Use pip install Aardvark: + +```shell +pip install aardvark +``` + +Alternatively, clone the repository and install a development version: + +```shell git clone https://github.com/Netflix-Skunkworks/aardvark.git cd aardvark -python3 -m venv env -. env/bin/activate -python setup.py develop +python3 -m venv venv +source venv/bin/activate +pip install -e . ``` -### Known Dependencies - - libpq-dev +To run the test suite, you'll need to install the test requirements: + +```shell +pip install -r requirements-test.txt +pytest test/ +``` ## Configure Aardvark The Aardvark config wizard will guide you through the setup. -``` -% aardvark config +```shell +❯ aardvark config -Aardvark can use SWAG to look up accounts. https://github.com/Netflix-Skunkworks/swag-client -Do you use SWAG to track accounts? [yN]: no -ROLENAME: Aardvark -DATABASE [sqlite:////home/github/aardvark/aardvark.db]: -# Threads [5]: +Aardvark can use SWAG to look up accounts. See https://github.com/Netflix-Skunkworks/swag-client +Do you use SWAG to track accounts? [yN]: N +Role Name [Aardvark]: Aardvark +Database URI [sqlite:///aardvark.db]: +Worker Count [5]: 5 +Config file location [settings.yaml]: settings.local.yaml ->> Writing to config.py +writing config file to settings.local.yaml ``` - Whether to use [SWAG](https://github.com/Netflix-Skunkworks/swag-client) to enumerate your AWS accounts. (Optional, but useful when you have many accounts.) - The name of the IAM Role to assume into in each account. - The Database connection string. (Defaults to sqlite in the current working directory. Use RDS Postgres for production.) +- The number of workers to create. ## Create the DB tables -``` +```shell aardvark create_db ``` @@ -50,27 +75,95 @@ aardvark create_db Aardvark needs an IAM Role in each account that will be queried. Additionally, Aardvark needs to be launched with a role or user which can `sts:AssumeRole` into the different account roles. -AardvarkInstanceProfile: +### Hub role (`AardvarkInstanceProfile`): + - Only create one. -- Needs the ability to call `sts:AssumeRole` into all of the AardvarkRole's +- Needs the ability to call `sts:AssumeRole` into all of the `AardvarkRole`s + +Inline policy example: + +```json +{ + "Version": "2012-10-17", + "Statement": [ + { + "Sid": "AssumeSpokeRoles", + "Effect": "Allow", + "Action": [ + "sts:assumerole" + ], + "Resource": [ + "arn:aws:iam::*:role/AardvarkRole" + ] + } + ] +} +``` + +### Spoke roles (`AardvarkRole`): -AardvarkRole: - Must exist in every account to be monitored. - Must have a trust policy allowing `AardvarkInstanceProfile`. - Has these permissions: + ``` iam:GenerateServiceLastAccessedDetails iam:GetServiceLastAccessedDetails -iam:listrolepolicies -iam:listroles +iam:ListRolePolicies +iam:ListRoles iam:ListUsers iam:ListPolicies iam:ListGroups ``` +Assume role policy document example (be sure to replace the account ID with a real one): + +```json +{ + "Version": "2012-10-17", + "Statement": [ + { + "Sid": "AllowHubRoleAssume", + "Effect": "Allow", + "Principal": { + "AWS": [ + "arn:aws:iam::111111111111:role/AardvarkInstanceProfile" + ] + }, + "Action": "sts:AssumeRole" + } + ] +} +``` + +Inline policy example: + +```json +{ + "Version": "2012-10-17", + "Statement": [ + { + "Sid": "IAMAccess", + "Effect": "Allow", + "Action": [ + "iam:GenerateServiceLastAccessedDetails", + "iam:GetServiceLastAccessedDetails", + "iam:ListRolePolicies", + "iam:ListRoles", + "iam:ListUsers", + "iam:ListPolicies", + "iam:ListGroups" + ], + "Resource": [ + "*" + ] + } + ] +} +``` -So if you are monitoring `n` accounts, you will always need `n+1` roles. (`n` AardvarkRoles and `1` AardvarkInstanceProfile). +So if you are monitoring `n` accounts, you will always need `n+1` roles. (one `AardvarkInstanceProfile` and n `AardvarkRole`s). -Note: For locally running aardvark, you don't have to take care of the AardvarkInstanceProfile. Instead, just attach a policy which contains "sts:AssumeRole" to the user you are using on the AWS CLI to assume Aardvark Role. Also, the same user should be mentioned in the trust policy of Aardvark Role for proper assignment of the privileges. +Note: For locally running aardvark, you don't have to take care of the AardvarkInstanceProfile. Instead, just attach a policy which contains `sts:AssumeRole` to the user you are using on the AWS CLI to assume Aardvark Role. Also, the same user should be mentioned in the trust policy of Aardvark Role for proper assignment of the privileges. ## Gather Access Advisor Data @@ -80,24 +173,30 @@ You'll likely want to refresh the Access Advisor data regularly. We recommend r If you don't have SWAG you can pass comma separated account numbers: - aardvark update -a 123456789012,210987654321 + aardvark update -a 123456789012 -a 210987654321 #### With SWAG: Aardvark can use [SWAG](https://github.com/Netflix-Skunkworks/swag-client) to look up accounts, so you can run against all with: - aardvark update +```shell +aardvark update +``` or by account name/tag with: - aardvark update -a dev,test,prod +```shell +aardvark update -a dev -a test -a prod +``` ## API ### Start the API - aardvark start_api -b 0.0.0.0:5000 +```shell +FLASK_APP=aardvark flask run -b 0.0.0.0:5000 +``` In production, you'll likely want to have something like supervisor starting the API for you. @@ -106,7 +205,7 @@ In production, you'll likely want to have something like supervisor starting the Swagger is available for the API at `/apidocs/#!`. Aardvark responds to get/post requests. All results are paginated and pagination can be controlled by passing `count` and/or `page` arguments. Here are a few example queries: -```bash +```shell curl localhost:5000/api/1/advisors curl localhost:5000/api/1/advisors?phrase=SecurityMonkey curl localhost:5000/api/1/advisors?arn=arn:aws:iam::000000000000:role/SecurityMonkey&arn=arn:aws:iam::111111111111:role/SecurityMonkey @@ -143,7 +242,7 @@ Once this file is created, then build the containers and start the services. Aar - API Server - This is the HTTP webserver will serve the data. By default, this is listening on [http://localhost:5000/apidocs/#!](http://localhost:5000/apidocs/#!). - Collector - This is a daemon that will fetch and cache the data in the local SQL database. This should be run periodically. -```bash +```shell # build the containers docker-compose build @@ -153,7 +252,7 @@ docker-compose up Finally, to clean up the environment -```bash +```shell # bring down the containers docker-compose down @@ -209,7 +308,7 @@ if __name__ == "__main__": This file can now be invoked in the same way as `manage.py`: -```bash +```shell python signals_example.py update -a cool_account ``` @@ -230,7 +329,3 @@ INFO: Thread #1 FINISHED persisting data for account 123456789012 |-------|---------| | `manage.UpdateAccountThread` | `on_ready`, `on_complete`, `on_failure` | | `updater.AccountToUpdate` | `on_ready`, `on_complete`, `on_error`, `on_failure` | - -## TODO: - -See [TODO](TODO.md) diff --git a/TODO.md b/TODO.md deleted file mode 100644 index d4e3e88..0000000 --- a/TODO.md +++ /dev/null @@ -1,3 +0,0 @@ -# TODO: -- Unit tests -- Better docs diff --git a/aardvark/__version__.py b/aardvark/__version__.py new file mode 100644 index 0000000..61a989f --- /dev/null +++ b/aardvark/__version__.py @@ -0,0 +1,8 @@ +# this is a placeholder and will be overwritten at build time +try: + from setuptools_scm import get_version + __version__ = get_version(root="..", relative_to=__file__) +except (ImportError, OSError): + # ImportError: setuptools_scm isn't installed + # OSError: git isn't installed + __version__ = "0.0.0.dev0+placeholder" diff --git a/aardvark/__version__.py.bak b/aardvark/__version__.py.bak new file mode 100644 index 0000000..61a989f --- /dev/null +++ b/aardvark/__version__.py.bak @@ -0,0 +1,8 @@ +# this is a placeholder and will be overwritten at build time +try: + from setuptools_scm import get_version + __version__ = get_version(root="..", relative_to=__file__) +except (ImportError, OSError): + # ImportError: setuptools_scm isn't installed + # OSError: git isn't installed + __version__ = "0.0.0.dev0+placeholder" diff --git a/aardvark/_config.py b/aardvark/_config.py deleted file mode 100644 index e424bf2..0000000 --- a/aardvark/_config.py +++ /dev/null @@ -1,8 +0,0 @@ -SQLALCHEMY_DATABASE_URI = "sqlite:///:memory:" -SQLALCHEMY_TRACK_MODIFICATIONS = False - -# Use a set to store ARNs that are constantly failing. -# Aardvark will only log these errors at the INFO level -# instead of the ERROR level -FAILING_ARNS = set() -# FAILING_ARNS = {'ASDF', 'DEFG'} diff --git a/aardvark/advisors.py b/aardvark/advisors.py new file mode 100644 index 0000000..6fb33cd --- /dev/null +++ b/aardvark/advisors.py @@ -0,0 +1,110 @@ +from flask import Blueprint, abort, jsonify, request + +from aardvark.persistence.sqlalchemy import SQLAlchemyPersistence + +advisor_bp = Blueprint("advisor", __name__) + + +@advisor_bp.route("/advisors", methods=["GET", "POST"]) +def post(): + """Get access advisor data for role(s) + Returns access advisor information for role(s) that match filters + --- + consumes: + - 'application/json' + produces: + - 'application/json' + + parameters: + - name: page + in: query + type: integer + description: return results from given page of total results + required: false + - name: count + in: query + type: integer + description: specifies how many results should be return per page + required: false + - name: combine + in: query + type: boolean + description: combine access advisor data for all results [Default False] + required: false + - name: phrase + in: query + type: string + description: TODO + required: false + - name: regex + in: query + type: string + description: TODO + required: false + - name: arn + in: query + type: string + description: TODO + required: false + + definitions: + AdvisorData: + type: object + properties: + lastAuthenticated: + type: number + lastAuthenticatedEntity: + type: string + lastUpdated: + type: string + serviceName: + type: string + serviceNamespace: + type: string + totalAuthenticatedEntities: + type: number + QueryBody: + type: object + properties: + phrase: + type: string + regex: + type: string + arn: + type: array + items: string + Results: + type: array + items: + $ref: '#/definitions/AdvisorData' + + responses: + 200: + description: Query successful, results in body + schema: + $ref: '#/definitions/AdvisorData' + 400: + description: Bad request - error message in body + """ + try: + page = int(request.args.get("page", 1)) + count = int(request.args.get("count", 30)) + combine = request.args.get("combine", type=str, default="false") + combine = combine.lower() == "true" + phrase = request.args.get("phrase") + regex = request.args.get("regex", default=None) + arns = request.args.get("arn") + arns = arns.split(",") if arns else [] + except Exception as e: + raise abort(400, str(e)) from e + + values = SQLAlchemyPersistence().get_role_data( + page=page, + count=count, + combine=combine, + phrase=phrase, + arns=arns, + regex=regex, + ) + + return jsonify(values) diff --git a/aardvark/app.py b/aardvark/app.py index ddf7274..29b8e60 100644 --- a/aardvark/app.py +++ b/aardvark/app.py @@ -22,11 +22,11 @@ advisor_bp ] -API_VERSION = '1' +API_VERSION = "1" def create_app(config_override: Config = None): - app = Flask(__name__, static_url_path='/static') + app = Flask(__name__, static_url_path="/static") Swagger(app) if config_override: @@ -34,13 +34,13 @@ def create_app(config_override: Config = None): else: path = _find_config() if not path: - print('No config') - app.config.from_pyfile('_config.py') + print("No config") + app.config.from_pyfile("_config.py") else: app.config.from_pyfile(path) # For ELB and/or Eureka - @app.route('/healthcheck') + @app.route("/healthcheck") def healthcheck(): """Healthcheck Simple healthcheck that indicates the services is up @@ -49,11 +49,11 @@ def healthcheck(): 200: description: service is up """ - return 'ok' + return "ok" # Blueprints for bp in BLUEPRINTS: - app.register_blueprint(bp, url_prefix="/api/{0}".format(API_VERSION)) + app.register_blueprint(bp, url_prefix=f"/api/{API_VERSION}") # Extensions: db.init_app(app) @@ -64,10 +64,12 @@ def healthcheck(): def _find_config(): """Search for config.py in order of preference and return path if it exists, else None""" - CONFIG_PATHS = [os.path.join(os.getcwd(), 'config.py'), - '/etc/aardvark/config.py', - '/apps/aardvark/config.py'] - for path in CONFIG_PATHS: + config_paths = [ + os.path.join(os.getcwd(), "config.py"), + "/etc/aardvark/config.py", + "/apps/aardvark/config.py", + ] + for path in config_paths: if os.path.exists(path): return path return None @@ -75,16 +77,13 @@ def _find_config(): def setup_logging(app): if not app.debug: - if app.config.get('LOG_CFG'): + if app.config.get("LOG_CFG"): # initialize the Flask logger (removes all handlers) - app.logger - dictConfig(app.config.get('LOG_CFG')) + dictConfig(app.config.get("LOG_CFG")) app.logger = logging.getLogger(__name__) else: handler = StreamHandler(stream=sys.stderr) - handler.setFormatter(Formatter( - '%(asctime)s %(levelname)s: %(message)s ' - '[in %(pathname)s:%(lineno)d]')) - app.logger.setLevel(app.config.get('LOG_LEVEL', DEBUG)) + handler.setFormatter(Formatter("%(asctime)s %(levelname)s: %(message)s [in %(pathname)s:%(lineno)d]")) + app.logger.setLevel(app.config.get("LOG_LEVEL", DEBUG)) app.logger.addHandler(handler) diff --git a/aardvark/config.py b/aardvark/config.py new file mode 100644 index 0000000..780a4da --- /dev/null +++ b/aardvark/config.py @@ -0,0 +1,143 @@ +import contextlib +import logging +import os + +from dynaconf import Dynaconf, Validator + +cwd_path = os.path.join(os.getcwd(), "settings.yaml") + +settings = Dynaconf( + envvar_prefix="AARDVARK", + settings_files=[ + "settings.yaml", + ".secrets.yaml", + cwd_path, + "/etc/aardvark/settings.yaml", + ], + env_switcher="AARDVARK_ENV", + environments=True, + validators=[ + Validator("AWS_ARN_PARTITION", default="aws"), + Validator("AWS_REGION", default="us-east-1"), + Validator("AWS_ARN_PARTITION", default="aws"), + Validator("SQLALCHEMY_DATABASE_URI", default="sqlite:///aardvark.db"), + Validator("UPDATER_NUM_THREADS", default=1), + ], +) + +log = logging.getLogger(__name__) + + +def create_config( + *, + aardvark_role: str = "", + swag_bucket: str = "", + swag_filter: str = "", + swag_service_enabled_requirement: str = "", + arn_partition: str = "", + sqlalchemy_database_uri: str = "", + sqlalchemy_track_modifications: bool = False, + num_threads: int = 5, + region: str = "", + filename: str = "settings.yaml", + environment: str = "default", +): + if aardvark_role: + settings.set("aws_rolename", aardvark_role) + if arn_partition: + settings.set("aws_arn_partition", arn_partition) + if region: + settings.set("aws_region", region) + if swag_bucket: + settings.set("swag.bucket", swag_bucket) + if swag_filter: + settings.set("swag.filter", swag_filter) + if swag_service_enabled_requirement: + settings.set("swag.service_enabled_requirement", swag_service_enabled_requirement) + if sqlalchemy_database_uri: + settings.set("sqlalchemy_database_uri", sqlalchemy_database_uri) + if sqlalchemy_track_modifications: + settings.set("sqlalchemy_track_modifications", sqlalchemy_track_modifications) + if num_threads: + settings.set("updater_num_threads", num_threads) + write_config(filename, environment=environment) + + +def find_legacy_config(): + """Search for config.py in order of preference and return path if it exists, else None""" + config_paths = [ + os.path.join(os.getcwd(), "config.py"), + "/etc/aardvark/config.py", + "/apps/aardvark/config.py", + ] + for path in config_paths: + if os.path.exists(path): + return path + return None + + +def convert_config( + filename: str, + *, + write: bool = False, + output_filename: str = "settings.yaml", + environment: str = "default", +): + """Convert a pre-1.0 config to a YAML config file""" + import importlib.util + + spec = importlib.util.spec_from_file_location("aardvark.config.legacy", filename) + old_config = importlib.util.module_from_spec(spec) + spec.loader.exec_module(old_config) + + with contextlib.suppress(AttributeError): + settings.set("aws_rolename", old_config.ROLENAME) + + with contextlib.suppress(AttributeError): + settings.set("aws_region", old_config.REGION) + + with contextlib.suppress(AttributeError): + settings.set("aws_arn_partition", old_config.ARN_PARTITION) + + with contextlib.suppress(AttributeError): + settings.set("sqlalchemy_database_uri", old_config.SQLALCHEMY_DATABASE_URI) + + with contextlib.suppress(AttributeError): + settings.set("sqlalchemy_track_modifications", old_config.SQLALCHEMY_TRACK_MODIFICATIONS) + + with contextlib.suppress(AttributeError): + settings.set("swag.bucket", old_config.SWAG_BUCKET) + + with contextlib.suppress(AttributeError): + settings.set("swag.opts", old_config.SWAG_OPTS) + + with contextlib.suppress(AttributeError): + settings.set("swag.filter", old_config.SWAG_FILTER) + + with contextlib.suppress(AttributeError): + settings.set( + "swag.service_enabled_requirement", + old_config.SWAG_SERVICE_ENABLED_REQUIREMENT, + ) + + with contextlib.suppress(AttributeError): + settings.set("updater_failing_arns", old_config.FAILING_ARNS) + + with contextlib.suppress(AttributeError): + settings.set("updater_num_threads", old_config.NUM_THREADS) + + if write: + write_config(output_filename, environment=environment) + + +def open_config(filepath: str): + settings.load_file(filepath) + + +def write_config(filename: str = "settings.yaml", environment: str = "default"): + from dynaconf import loaders + from dynaconf.utils.boxing import DynaBox + + data = settings.as_dict() + log.info("writing config file to %s", filename) + loaders.write(filename, DynaBox(data).to_dict(), env=environment) diff --git a/aardvark/exceptions.py b/aardvark/exceptions.py new file mode 100644 index 0000000..ef53475 --- /dev/null +++ b/aardvark/exceptions.py @@ -0,0 +1,18 @@ +class AardvarkError(Exception): + pass + + +class AccessAdvisorError(AardvarkError): + pass + + +class CombineError(AardvarkError): + pass + + +class DatabaseError(AardvarkError): + pass + + +class RetrieverError(AardvarkError): + pass diff --git a/aardvark/manage.py b/aardvark/manage.py index 97ebbe0..875601a 100644 --- a/aardvark/manage.py +++ b/aardvark/manage.py @@ -1,141 +1,60 @@ +from __future__ import annotations + +import asyncio +import logging import os import queue -import re +import sys import threading -from blinker import Signal -from bunch import Bunch -from flask import current_app -from flask_script import Manager, Command, Option -from swag_client.backend import SWAGManager -from swag_client.exceptions import InvalidSWAGDataException -from swag_client.util import parse_swag_config_options +import click -from aardvark.app import create_app, db -from aardvark.updater import AccountToUpdate +from aardvark import create_app, init_logging +from aardvark.config import convert_config, create_config, find_legacy_config +from aardvark.exceptions import AardvarkError +from aardvark.persistence.sqlalchemy import SQLAlchemyPersistence +from aardvark.retrievers.runner import RetrieverRunner -manager = Manager(create_app) +APP = None +log = logging.getLogger("aardvark") ACCOUNT_QUEUE = queue.Queue() -DB_LOCK = threading.Lock() QUEUE_LOCK = threading.Lock() UPDATE_DONE = False -SWAG_REPO_URL = 'https://github.com/Netflix-Skunkworks/swag-client' +SWAG_REPO_URL = "https://github.com/Netflix-Skunkworks/swag-client" -LOCALDB = 'sqlite' +LOCALDB = "sqlite" +DEFAULT_LOCALDB_FILENAME = "aardvark.db" # Configuration default values. -DEFAULT_LOCALDB_FILENAME = 'aardvark.db' -DEFAULT_SWAG_BUCKET = 'swag-data' -DEFAULT_AARDVARK_ROLE = 'Aardvark' -DEFAULT_NUM_THREADS = 5 # testing shows problems with more than 6 threads - - -class UpdateAccountThread(threading.Thread): - global ACCOUNT_QUEUE, DB_LOCK, QUEUE_LOCK, UPDATE_DONE - on_ready = Signal() - on_complete = Signal() - on_failure = Signal() - - def __init__(self, thread_ID): - self.thread_ID = thread_ID - threading.Thread.__init__(self) - self.app = current_app._get_current_object() - - def run(self): - while not UPDATE_DONE: - self.on_ready.send(self) - - QUEUE_LOCK.acquire() - - if not ACCOUNT_QUEUE.empty(): - (account_num, role_name, arns) = ACCOUNT_QUEUE.get() +DEFAULT_SWAG_BUCKET = "swag-data" +DEFAULT_AARDVARK_ROLE = "Aardvark" +DEFAULT_NUM_THREADS = 5 - self.app.logger.info("Thread #{} updating account {} with {} arns".format( - self.thread_ID, account_num, 'all' if arns[0] == 'all' else len(arns))) - self.app.logger.debug(f"ACCOUNT_QUEUE depth now ~ {ACCOUNT_QUEUE.qsize()}") +def get_app(): + global APP # noqa: PLW0603 + if not APP: + APP = create_app() + return APP - QUEUE_LOCK.release() - try: - account = AccountToUpdate(self.app, account_num, role_name, arns) - ret_code, aa_data = account.update_account() - except Exception as e: - self.on_failure.send(self, error=e) - self.app.logger.exception(f"Thread #{self.thread_ID} caught exception - {e} - while attempting to update account {account_num}. Continuing.") - # Assume that whatever went wrong isn't transient; to avoid an - # endless loop we don't put the account back on the queue. - continue - - if ret_code != 0: # retrieve wasn't successful, put back on queue - self.on_failure.send(self) - QUEUE_LOCK.acquire() - ACCOUNT_QUEUE.put((account_num, role_name, arns)) - QUEUE_LOCK.release() - - self.app.logger.info("Thread #{} persisting data for account {}".format(self.thread_ID, account_num)) - - DB_LOCK.acquire() - persist_aa_data(self.app, aa_data) - DB_LOCK.release() - - self.on_complete.send(self) - self.app.logger.info("Thread #{} FINISHED persisting data for account {}".format(self.thread_ID, account_num)) - else: - QUEUE_LOCK.release() - - -def persist_aa_data(app, aa_data): - """ - Reads access advisor JSON file & persists to our database - """ - from aardvark.model import AWSIAMObject, AdvisorData - - with app.app_context(): - if not aa_data: - app.logger.warn('Cannot persist Access Advisor Data as no data was collected.') - return - - arn_cache = {} - for arn, data in aa_data.items(): - if arn in arn_cache: - item = arn_cache[arn] - else: - item = AWSIAMObject.get_or_create(arn) - arn_cache[arn] = item - for service in data: - AdvisorData.create_or_update(item.id, - service['LastAuthenticated'], - service['ServiceName'], - service['ServiceNamespace'], - service.get('LastAuthenticatedEntity'), - service['TotalAuthenticatedEntities']) - db.session.commit() - - -@manager.command -def drop_db(): - """ Drops the database. """ - db.drop_all() - - -@manager.command -def create_db(): - """ Creates the database. """ - db.create_all() +@click.group() +def cli(): + init_logging() # All of these default to None rather than the corresponding DEFAULT_* values # so we can tell whether they were passed or not. We don't prompt for any of # the options that were passed as parameters. -@manager.option('-a', '--aardvark-role', dest='aardvark_role_param', type=str) -@manager.option('-b', '--swag-bucket', dest='bucket_param', type=str) -@manager.option('-d', '--db-uri', dest='db_uri_param', type=str) -@manager.option('--num-threads', dest='num_threads_param', type=int) -@manager.option('--no-prompt', dest='no_prompt', action='store_true', default=False) -def config(aardvark_role_param, bucket_param, db_uri_param, num_threads_param, no_prompt): +@cli.command("config") +@click.option("--aardvark-role", "-a", type=str) +@click.option("--swag-bucket", "-b", type=str) +@click.option("--db-uri", "-d", type=str) +@click.option("--num-threads", type=int) +@click.option("--no-prompt", is_flag=True, default=False) +def config(aardvark_role, swag_bucket, db_uri, num_threads, no_prompt): """ Creates a config.py configuration file from user input or default values. @@ -163,237 +82,101 @@ def config(aardvark_role_param, bucket_param, db_uri_param, num_threads_param, n LOG_CFG = {...} """ # We don't set these until runtime. - default_db_uri = '{localdb}:///{path}/{filename}'.format( - localdb=LOCALDB, path=os.getcwd(), filename=DEFAULT_LOCALDB_FILENAME - ) + default_db_uri = f"{LOCALDB}:///{os.getcwd()}/{DEFAULT_LOCALDB_FILENAME}" + default_save_file = "settings.local.yaml" if no_prompt: # Just take the parameters as currently constituted. - aardvark_role = aardvark_role_param or DEFAULT_AARDVARK_ROLE - num_threads = num_threads_param or DEFAULT_NUM_THREADS - db_uri = db_uri_param or default_db_uri + aardvark_role = aardvark_role or DEFAULT_AARDVARK_ROLE + num_threads = num_threads or DEFAULT_NUM_THREADS + db_uri = db_uri or default_db_uri # If a swag bucket was specified we set write_swag here so it gets # written out to the config file below. - write_swag = bool(bucket_param) - bucket = bucket_param or DEFAULT_SWAG_BUCKET - + bucket = swag_bucket or DEFAULT_SWAG_BUCKET + save_file = default_save_file else: # This is essentially the same "param, or input, or default" # structure as the additional parameters below. - if bucket_param: - bucket = bucket_param - write_swag = True + if swag_bucket: + bucket = swag_bucket else: - print('\nAardvark can use SWAG to look up accounts. See {repo_url}'.format(repo_url=SWAG_REPO_URL)) - use_swag = input('Do you use SWAG to track accounts? [yN]: ') - if len(use_swag) > 0 and 'yes'.startswith(use_swag.lower()): - bucket_prompt = 'SWAG_BUCKET [{default}]: '.format(default=DEFAULT_SWAG_BUCKET) + print(f"\nAardvark can use SWAG to look up accounts. See {SWAG_REPO_URL}") # noqa: T201 + use_swag = input("Do you use SWAG to track accounts? [yN]: ") + if len(use_swag) > 0 and "yes".startswith(use_swag.lower()): + bucket_prompt = f"SWAG bucket [{DEFAULT_SWAG_BUCKET}]: " bucket = input(bucket_prompt) or DEFAULT_SWAG_BUCKET - write_swag = True else: - write_swag = False - - aardvark_role_prompt = 'ROLENAME [{default}]: '.format(default=DEFAULT_AARDVARK_ROLE) - db_uri_prompt = 'DATABASE URI [{default}]: '.format(default=default_db_uri) - num_threads_prompt = '# THREADS [{default}]: '.format(default=DEFAULT_NUM_THREADS) - - aardvark_role = aardvark_role_param or input(aardvark_role_prompt) or DEFAULT_AARDVARK_ROLE - db_uri = db_uri_param or input(db_uri_prompt) or default_db_uri - num_threads = num_threads_param or input(num_threads_prompt) or DEFAULT_NUM_THREADS - - log = """LOG_CFG = { - 'version': 1, - 'disable_existing_loggers': False, - 'formatters': { - 'standard': { - 'format': '%(asctime)s %(levelname)s: %(message)s ' - '[in %(pathname)s:%(lineno)d]' - } - }, - 'handlers': { - 'file': { - 'class': 'logging.handlers.RotatingFileHandler', - 'level': 'DEBUG', - 'formatter': 'standard', - 'filename': 'aardvark.log', - 'maxBytes': 10485760, - 'backupCount': 100, - 'encoding': 'utf8' - }, - 'console': { - 'class': 'logging.StreamHandler', - 'level': 'DEBUG', - 'formatter': 'standard', - 'stream': 'ext://sys.stdout' - } - }, - 'loggers': { - 'aardvark': { - 'handlers': ['file', 'console'], - 'level': 'DEBUG' - } - } -}""" - - with open('config.py', 'w') as filedata: - print('\n>> Writing to config.py') - filedata.write('# Autogenerated config file\n') - if write_swag: - filedata.write("SWAG_OPTS = {{'swag.type': 's3', 'swag.bucket_name': '{bucket}'}}\n".format(bucket=bucket)) - filedata.write("SWAG_FILTER = None\n") - filedata.write("SWAG_SERVICE_ENABLED_REQUIREMENT = None\n") - filedata.write('ROLENAME = "{role}"\n'.format(role=aardvark_role)) - filedata.write('REGION = "us-east-1"\n') - filedata.write('ARN_PARTITION = "aws"\n') - filedata.write('SQLALCHEMY_DATABASE_URI = "{uri}"\n'.format(uri=db_uri)) - filedata.write('SQLALCHEMY_TRACK_MODIFICATIONS = False\n') - filedata.write('NUM_THREADS = {num_threads}\n'.format(num_threads=num_threads)) - filedata.write(log) - - -@manager.option('-a', '--accounts', dest='accounts', type=str, default='all') -@manager.option('-r', '--arns', dest='arns', type=str, default='all') -def update(accounts, arns): - """ - Asks AWS for new Access Advisor information. - """ - accounts = _prep_accounts(accounts) - arns = arns.split(',') - app = create_app() - - global ACCOUNT_QUEUE, QUEUE_LOCK, UPDATE_DONE - - role_name = app.config.get('ROLENAME') - num_threads = app.config.get('NUM_THREADS') or 5 - - if num_threads > 6: - current_app.logger.warn('Greater than 6 threads seems to cause problems') - - QUEUE_LOCK.acquire() - for account_number in accounts: - ACCOUNT_QUEUE.put((account_number, role_name, arns)) - current_app.logger.debug(f"Starting update operation for {ACCOUNT_QUEUE.qsize()} accounts using {num_threads} threads.") - QUEUE_LOCK.release() - - threads = [] - for thread_num in range(num_threads): - thread = UpdateAccountThread(thread_num + 1) - thread.start() - threads.append(thread) - - while not ACCOUNT_QUEUE.empty(): - pass - UPDATE_DONE = True - current_app.logger.debug("Queue is empty; no more accounts to process.") + bucket = "" + + aardvark_role_prompt = f"Role Name [{DEFAULT_AARDVARK_ROLE}]: " + db_uri_prompt = f"Database URI [{default_db_uri}]: " + num_threads_prompt = f"Worker Count [{DEFAULT_NUM_THREADS}]: " + save_file_prompt = f"Config file location [{default_save_file}]: " + + aardvark_role = aardvark_role or input(aardvark_role_prompt) or DEFAULT_AARDVARK_ROLE + db_uri = db_uri or input(db_uri_prompt) or default_db_uri + num_threads = num_threads or input(num_threads_prompt) or DEFAULT_NUM_THREADS + save_file = input(save_file_prompt) or default_save_file + + create_config( + aardvark_role=aardvark_role, + swag_bucket=bucket or "", + swag_filter="", + swag_service_enabled_requirement="", + sqlalchemy_database_uri=db_uri, + sqlalchemy_track_modifications=False, + num_threads=num_threads, + region="us-east-1", + filename=save_file, + ) -def _prep_accounts(account_names): +@cli.command("update") +@click.option("--account", "-a", type=str, default=[], multiple=True) +@click.option("--arn", "-r", type=str, default=[], multiple=True) +def update(account: list[str], arn: list[str]): """ - Convert CLI provided account names into list of accounts from SWAG. - Considers account aliases as well as account names. - Returns a list of account numbers + Asks AWS for new Access Advisor information. """ - matching_accounts = list() - account_names = account_names.split(',') - account_names = {name.lower().strip() for name in account_names} - - # create a new copy of the account_names list so we can remove accounts as needed - for account in list(account_names): - if re.match('\d{12}', account): - account_names.remove(account) - matching_accounts.append(account) - - if not account_names: - return matching_accounts - + accounts = list(account) + arns = list(arn) + r = RetrieverRunner() try: - current_app.logger.info('getting bucket {}'.format( - current_app.config.get('SWAG_BUCKET'))) - - swag = SWAGManager(**parse_swag_config_options(current_app.config.get('SWAG_OPTS'))) - - all_accounts = swag.get_all(current_app.config.get('SWAG_FILTER')) - - service_enabled_requirement = current_app.config.get('SWAG_SERVICE_ENABLED_REQUIREMENT', None) - if service_enabled_requirement: - all_accounts = swag.get_service_enabled(service_enabled_requirement, accounts_list=all_accounts) - - except (KeyError, InvalidSWAGDataException, Exception) as e: - current_app.logger.error('Account names passed but SWAG not configured or unavailable: {}'.format(e)) - - if 'all' in account_names: - return [account['id'] for account in all_accounts] + asyncio.run(r.run(accounts=accounts, arns=arns)) + except KeyboardInterrupt: + r.cancel() + except AardvarkError: + log.exception() + sys.exit(1) - lookup = {account['name']: Bunch(account) for account in all_accounts} - for account in all_accounts: - # get the right key, depending on whether we're using swag v1 or v2 - alias_key = 'aliases' if account['schemaVersion'] == '2' else 'alias' - for alias in account[alias_key]: - lookup[alias] = Bunch(account) - - for name in account_names: - if name not in lookup: - current_app.logger.warn('Could not find an account named %s' - % name) - continue +@cli.command("drop_db") +def drop_db(): + """Drops the database.""" + SQLAlchemyPersistence().teardown_db() - account_number = lookup[name].get('id', None) - if account_number: - matching_accounts.append(account_number) - return matching_accounts +@cli.command("create_db") +def create_db(): + """Creates the database.""" + SQLAlchemyPersistence().init_db() + + +@cli.command("migrate_config") +@click.option("--environment", "-e", type=str, default="default") +@click.option("--config-file", "-c", type=str) +@click.option("--write/--no-write", type=bool, default=True) +@click.option("--output-file", "-o", type=str, default="settings.yaml") +def migrate_config(environment, config_file, write, output_file): + if not config_file: + config_file = find_legacy_config() + convert_config( + config_file, + write=write, + output_filename=output_file, + environment=environment, + ) -class GunicornServer(Command): - """ - This is the main GunicornServer server, it runs the flask app with gunicorn and - uses any configuration options passed to it. - You can pass all standard gunicorn flags to this command as if you were - running gunicorn itself. - For example: - aardvark start_api -w 4 -b 127.0.0.0:8002 - Will start gunicorn with 4 workers bound to 127.0.0.0:8002 - """ - description = 'Run the app within Gunicorn' - - def get_options(self): - options = [] - try: - from gunicorn.config import make_settings - except ImportError: - # Gunicorn does not yet support Windows. - # See issue #524. https://github.com/benoitc/gunicorn/issues/524 - # For dev on Windows, make this an optional import. - print('Could not import gunicorn, skipping.') - return options - - settings = make_settings() - for setting, klass in settings.items(): - if klass.cli: - if klass.action: - if klass.action == 'store_const': - options.append(Option(*klass.cli, const=klass.const, action=klass.action)) - else: - options.append(Option(*klass.cli, action=klass.action)) - else: - options.append(Option(*klass.cli)) - return options - - def run(self, *args, **kwargs): - from gunicorn.app.wsgiapp import WSGIApplication - - app = WSGIApplication() - - app.app_uri = 'aardvark.app:create_app()' - return app.run() - - -def main(): - manager.add_command("start_api", GunicornServer()) - manager.run() - - -if __name__ == '__main__': - manager.add_command("start_api", GunicornServer()) - manager.run() +if __name__ == "__main__": + cli() diff --git a/aardvark/model.py b/aardvark/model.py index d182796..e69de29 100644 --- a/aardvark/model.py +++ b/aardvark/model.py @@ -1,138 +0,0 @@ -from __future__ import annotations - -import datetime - -from flask import current_app -from sqlalchemy import BigInteger, Column, Integer, Text, TIMESTAMP -import sqlalchemy.exc -from sqlalchemy.orm import relationship -from sqlalchemy.schema import ForeignKey - -from aardvark.app import db -from aardvark.utils.sqla_regex import String - - -class AWSIAMObject(db.Model): - """ - Meant to model AWS IAM Object Access Advisor. - """ - __tablename__ = "aws_iam_object" - id = Column(Integer, primary_key=True) - arn = Column(String(2048), nullable=True, index=True, unique=True) - lastUpdated = Column(TIMESTAMP) - usage = relationship("AdvisorData", backref="item", cascade="all, delete, delete-orphan", - foreign_keys="AdvisorData.item_id") - - @staticmethod - def get_or_create(arn): - item = AWSIAMObject.query.filter(AWSIAMObject.arn == arn).scalar() - - added = False - try: - item = AWSIAMObject.query.filter(AWSIAMObject.arn == arn).scalar() - except sqlalchemy.exc.SQLAlchemyException as e: - current_app.logger.error('Database exception: {}'.format(e.message)) - - if not item: - item = AWSIAMObject(arn=arn, lastUpdated=datetime.datetime.utcnow()) - added = True - else: - item.lastUpdated = datetime.datetime.utcnow() - db.session.add(item) - - # we only need a refresh if the object was created - if added: - db.session.commit() - db.session.refresh(item) - return item - - -class AdvisorData(db.Model): - """ - Models certain IAM Access Advisor Data fields. - - { - "totalAuthenticatedEntities": 1, - "lastAuthenticatedEntity": "arn:aws:iam::XXXXXXXX:role/name", - "serviceName": "Amazon Simple Systems Manager", - "lastAuthenticated": 1489176000000, - "serviceNamespace": "ssm" - } - """ - __tablename__ = "advisor_data" - id = Column(Integer, primary_key=True) - item_id = Column(Integer, ForeignKey("aws_iam_object.id"), nullable=False, index=True) - lastAuthenticated = Column(BigInteger) - serviceName = Column(String(128), index=True) - serviceNamespace = Column(String(64), index=True) - lastAuthenticatedEntity = Column(Text) - totalAuthenticatedEntities = Column(Integer) - - @staticmethod - def create_or_update(item_id, lastAuthenticated, serviceName, serviceNamespace, lastAuthenticatedEntity, - totalAuthenticatedEntities): - # Truncate service name and namespace to make sure they fit in our DB fields - serviceName = serviceName[:128] - serviceNamespace = serviceNamespace[:64] - - # Query the database for an existing entry that matches this item ID and service namespace. If there is none, - # instantiate an empty AdvisorData - item: AdvisorData | None = None - try: - item = db.session.query(AdvisorData).filter( - AdvisorData.item_id == item_id, - AdvisorData.serviceNamespace == serviceNamespace, - ).scalar() - except sqlalchemy.exc.SQLAlchemyError as e: - current_app.logger.error( - 'Database error: %s item_id: %s serviceNamespace: %s', - str(e), - item_id, - serviceNamespace - ) - - if not item: - item = AdvisorData() - - # Save existing lastAuthenticated timestamp for later comparison - existingLastAuthenticated = item.lastAuthenticated or 0 - - # Set all fields to the provided values. SQLAlchemy will only mark the model instance as modified if the actual - # values have changed, so this will be a no-op if the values are all the same. - item.item_id = item_id - item.lastAuthenticated = lastAuthenticated - item.lastAuthenticatedEntity = lastAuthenticatedEntity - item.serviceName = serviceName - item.serviceNamespace = serviceNamespace - item.totalAuthenticatedEntities = totalAuthenticatedEntities - - # When there is no AA data about a service, the lastAuthenticated key is missing from the returned dictionary. - # This is perfectly valid, either because the service in question was not accessed in the past 365 days or - # the entity granting access to it was created recently enough that no AA data is available yet (it can take - # up to 4 hours for this to happen). - # - # When this happens, the AccountToUpdate._get_job_results() method will set lastAuthenticated to 0. Usually - # we don't want to persist such an entity, with one exception: there's already a recorded, non-zero - # lastAuthenticated timestamp persisted for this item. That means the service was accessed at some point in - # time, but now more than 365 passed since the last access, so AA no longer returns a timestamp for it. - if lastAuthenticated < existingLastAuthenticated: - if lastAuthenticated == 0: - current_app.logger.info( - 'Previously seen object not accessed in the past 365 days (got null lastAuthenticated from AA). ' - 'Setting to 0. Object %s service %s previous timestamp %d', - item.item_id, - item.serviceName, - item.lastAuthenticated - ) - else: - current_app.logger.warning( - "Received an older time than previously seen for object %s service %s (%d < %d)!", - item.item_id, - item.serviceName, - lastAuthenticated, - existingLastAuthenticated - ) - item.lastAuthenticated = existingLastAuthenticated - - # Add the updated item to the session so it gets committed with everything else - db.session.add(item) diff --git a/aardvark/persistence/__init__.py b/aardvark/persistence/__init__.py new file mode 100644 index 0000000..1c0cded --- /dev/null +++ b/aardvark/persistence/__init__.py @@ -0,0 +1,3 @@ +from aardvark.persistence.plugin import PersistencePlugin + +__all__ = ["PersistencePlugin"] diff --git a/aardvark/persistence/plugin.py b/aardvark/persistence/plugin.py new file mode 100644 index 0000000..5ea4bdb --- /dev/null +++ b/aardvark/persistence/plugin.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from aardvark.plugins import AardvarkPlugin + +if TYPE_CHECKING: + from dynaconf.utils import DynaconfDict + + +class PersistencePlugin(AardvarkPlugin): + def __init__(self, alternative_config: DynaconfDict | None = None): + super().__init__(alternative_config=alternative_config) + + def init_db(self): + raise NotImplementedError + + def teardown_db(self): + raise NotImplementedError + + def get_role_data( + self, + *, + page: int = 0, + count: int = 0, + combine: bool = False, + phrase: str = "", + arns: list[str] | None = None, + regex: str = "", + ) -> dict[str, any]: + raise NotImplementedError + + def store_role_data(self, access_advisor_data: dict[str, any]) -> None: + raise NotImplementedError + + def _combine_results(self, access_advisor_data: dict[str, any]) -> dict[str, any]: + raise NotImplementedError diff --git a/aardvark/persistence/sqlalchemy/__init__.py b/aardvark/persistence/sqlalchemy/__init__.py new file mode 100644 index 0000000..3f50522 --- /dev/null +++ b/aardvark/persistence/sqlalchemy/__init__.py @@ -0,0 +1,3 @@ +from aardvark.persistence.sqlalchemy.sa_persistence import SQLAlchemyPersistence + +__all__ = ["SQLAlchemyPersistence"] diff --git a/aardvark/persistence/sqlalchemy/models.py b/aardvark/persistence/sqlalchemy/models.py new file mode 100644 index 0000000..1a155c9 --- /dev/null +++ b/aardvark/persistence/sqlalchemy/models.py @@ -0,0 +1,68 @@ +# ruff: noqa: N815 +from sqlalchemy import TIMESTAMP, BigInteger, Column, ForeignKey, Integer, Text +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import relationship + +from aardvark.utils.sqla_regex import String + +Base = declarative_base() + + +class AdvisorData(Base): + """ + Models certain IAM Access Advisor Data fields. + + { + "totalAuthenticatedEntities": 1, + "lastAuthenticatedEntity": "arn:aws:iam::XXXXXXXX:role/name", + "serviceName": "Amazon Simple Systems Manager", + "lastAuthenticated": 1489176000000, + "serviceNamespace": "ssm" + } + """ + + __tablename__ = "advisor_data" + id = Column(Integer, primary_key=True) + item_id = Column(Integer, ForeignKey("aws_iam_object.id"), nullable=False, index=True) + lastAuthenticated = Column(BigInteger) + serviceName = Column(String(128), index=True) + serviceNamespace = Column(String(64), index=True) + lastAuthenticatedEntity = Column(Text) + totalAuthenticatedEntities = Column(Integer) + + +class AWSIAMObject(Base): + """ + Meant to model AWS IAM Object Access Advisor. + """ + + __tablename__ = "aws_iam_object" + id = Column(Integer, primary_key=True) + arn = Column(String(2048), nullable=True, index=True, unique=True) + lastUpdated = Column(TIMESTAMP) + usage = relationship( + "AdvisorData", + backref="item", + cascade="all, delete, delete-orphan", + foreign_keys="AdvisorData.item_id", + ) + action_usage = relationship( + "ActionData", + backref="item", + cascade="all, delete, delete-orphan", + foreign_keys="ActionData.item_id", + ) + + +class ActionData(Base): + """ + Models action-specific data from sources other than Access Advisor, + such as CloudTrail. + """ + + __tablename__ = "action_data" + id = Column(Integer, primary_key=True) + item_id = Column(Integer, ForeignKey("aws_iam_object.id"), nullable=False, index=True) + lastAuthenticated = Column(BigInteger) + serviceName = Column(String(128), index=True) + serviceNamespace = Column(String(64), index=True) diff --git a/aardvark/persistence/sqlalchemy/sa_persistence.py b/aardvark/persistence/sqlalchemy/sa_persistence.py new file mode 100644 index 0000000..c0fec2c --- /dev/null +++ b/aardvark/persistence/sqlalchemy/sa_persistence.py @@ -0,0 +1,325 @@ +# ruff: noqa: DTZ003, DTZ007 +from __future__ import annotations + +import datetime +import logging +from contextlib import contextmanager +from typing import TYPE_CHECKING + +from sqlalchemy import create_engine, engine +from sqlalchemy import func as sa_func +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import Session, scoped_session, sessionmaker + +from aardvark.exceptions import CombineError, DatabaseError +from aardvark.persistence import PersistencePlugin +from aardvark.persistence.sqlalchemy.models import AdvisorData, AWSIAMObject, Base + +if TYPE_CHECKING: + from dynaconf.utils import DynaconfDict + +log = logging.getLogger("aardvark") +session_type = scoped_session | Session + + +class SQLAlchemyPersistence(PersistencePlugin): + sa_engine: engine + session_factory: sessionmaker + session: session_type + + def __init__(self, *, alternative_config: DynaconfDict | None = None, initialize: bool = True): + super().__init__(alternative_config=alternative_config) + if initialize: + self.init_db() + + def init_db(self): + self.sa_engine = create_engine(self.config.get("sqlalchemy_database_uri")) + self.session_factory = sessionmaker( + autocommit=False, + autoflush=False, + bind=self.sa_engine, + expire_on_commit=False, + ) + self.session = scoped_session(self.session_factory) + Base.query = self.session.query_property() + + Base.metadata.create_all(bind=self.sa_engine) + + def _create_session(self) -> scoped_session: + return self.session() + + def teardown_db(self): + Base.metadata.drop_all(bind=self.sa_engine) + + def create_iam_object( + self, arn: str, last_updated: datetime.datetime, session: session_type = None + ) -> AWSIAMObject: + with self.session_scope(session) as session: + item = AWSIAMObject(arn=arn, lastUpdated=last_updated) + session.add(item) + return item + + @contextmanager + def session_scope(self, session: session_type = None) -> Session: + """Provide a transactional scope around a series of operations.""" + if not session: + log.debug("creating new SQLAlchemy DB session") + session: session_type = self._create_session() + close_session = True + else: + log.debug("using provided SQLAlchemy DB session") + close_session = False + try: + yield session + log.debug("committing SQLAlchemy DB session") + session.commit() + except Exception as e: + log.warning("exception caught, rolling back session: %s", e, exc_info=True) + session.rollback() + raise + finally: + if close_session: + log.debug("closing SQLAlchemy DB session") + session.close() + self.session.remove() + else: + log.debug("not closing SQLAlchemy DB session") + + def _combine_results(self, access_advisor_data: dict[str, any]) -> dict[str, any]: + access_advisor_data.pop("page") + access_advisor_data.pop("count") + access_advisor_data.pop("total") + usage: dict[str, dict] = {} + for services in access_advisor_data.values(): + for service in services: + namespace = service.get("serviceNamespace") + last_authenticated = service.get("lastAuthenticated") + if namespace not in usage: + usage[namespace] = service + else: + count_entities = ( + usage[namespace]["totalAuthenticatedEntities"] + service["totalAuthenticatedEntities"] + ) + if last_authenticated > usage[namespace]["lastAuthenticated"]: + usage[namespace] = service + usage[namespace]["totalAuthenticatedEntities"] = count_entities + + for namespace, service in usage.items(): + last_authenticated = service["lastAuthenticated"] + if isinstance(last_authenticated, int): + dt_last_authenticated = datetime.datetime.fromtimestamp( + last_authenticated / 1e3, tz=datetime.timezone.utc + ) + elif isinstance(last_authenticated, str): + dt_last_authenticated = datetime.datetime.strptime(last_authenticated, "%Y-%m-%d %H:%M:%S.%f") + else: + dt_last_authenticated = last_authenticated + + dt_starting = datetime.datetime.utcnow() - datetime.timedelta(days=90) + usage[namespace]["USED_LAST_90_DAYS"] = dt_last_authenticated > dt_starting + + return usage + + def store_role_data(self, access_advisor_data: dict[str, any], session: session_type = None): + with self.session_scope(session) as session: + if not access_advisor_data: + log.warning("Cannot persist Access Advisor Data as no data was collected.") + return + + arn_cache = {} + for arn, data in access_advisor_data.items(): + if arn in arn_cache: + item = arn_cache[arn] + else: + item = self.get_or_create_iam_object(arn, session=session) + arn_cache[arn] = item + for service in data: + self.create_or_update_advisor_data( + item.id, + service["LastAuthenticated"], + service["ServiceName"], + service["ServiceNamespace"], + service.get("LastAuthenticatedEntity"), + service["TotalAuthenticatedEntities"], + session=session, + ) + + def get_role_data( + self, + *, + page: int = 0, + count: int = 0, + combine: bool = False, + phrase: str = "", + arns: list[str] | None = None, + regex: str = "", + session: session_type = None, + ) -> dict[str, any]: + offset = (page - 1) * count if page else 0 + limit = count + with self.session_scope(session) as session: + # default unfiltered query + query = session.query(AWSIAMObject) + + try: + if phrase: + query = query.filter(AWSIAMObject.arn.ilike("%" + phrase + "%")) + + if arns: + query = query.filter(sa_func.lower(AWSIAMObject.arn).in_([arn.lower() for arn in arns])) + + if regex: + query = query.filter(AWSIAMObject.arn.regexp(regex)) + + total = query.count() + + if offset: + query = query.offset(offset) + + if limit: + query = query.limit(limit) + + items = query.all() + except Exception as e: + raise DatabaseError from e + + if not items: + items = session.query(AWSIAMObject).offset(offset).limit(limit).all() + + values = { + "page": page, + "total": total, + "count": len(items), + } + for item in items: + values[item.arn] = [ + { + "lastAuthenticated": advisor_data.lastAuthenticated, + "serviceName": advisor_data.serviceName, + "serviceNamespace": advisor_data.serviceNamespace, + "lastAuthenticatedEntity": advisor_data.lastAuthenticatedEntity, + "totalAuthenticatedEntities": advisor_data.totalAuthenticatedEntities, + "lastUpdated": item.lastUpdated, + } + for advisor_data in item.usage + ] + + if combine and total > len(items): + message = f"Error: Please specify a count of at least {total}." + raise CombineError(message) + + if combine: + return self._combine_results(values) + + return values + + def create_or_update_advisor_data( + self, + item_id: int, + last_authenticated: int, + service_name: str, + service_namespace: str, + last_authenticated_entity: str, + total_authenticated_entities: int, + session: session_type = None, + ): + with self.session_scope(session) as session: + service_name = service_name[:128] + service_namespace = service_namespace[:64] + item: AdvisorData | None = None + try: + item = ( + session.query(AdvisorData) + .filter( + AdvisorData.item_id == item_id, + AdvisorData.serviceNamespace == service_namespace, + ) + .scalar() + ) + except SQLAlchemyError: + log.exception( + "Database error: item_id: %s serviceNamespace: %s", + item_id, + service_namespace, + ) + raise + + if not item: + item = AdvisorData() + + # Save existing lastAuthenticated timestamp for later comparison + # sqlite will return a string for item.lastAuthenticated, so we parse that into a datetime + if isinstance(item.lastAuthenticated, str): + existing_last_authenticated = datetime.datetime.strptime(item.lastAuthenticated, "%Y-%m-%d %H:%M:%S.%f") + else: + existing_last_authenticated = item.lastAuthenticated + + # Set all fields to the provided values. SQLAlchemy will only mark the model instance as modified if the actual + # values have changed, so this will be a no-op if the values are all the same. + item.item_id = item_id + item.lastAuthenticated = last_authenticated + item.lastAuthenticatedEntity = last_authenticated_entity + item.serviceName = service_name + item.serviceNamespace = service_namespace + item.totalAuthenticatedEntities = total_authenticated_entities + + # When there is no AA data about a service, the lastAuthenticated key is missing from the returned data. + # This is perfectly valid, either because the service in question was not accessed in the past 365 days or + # the entity granting access to it was created recently enough that no AA data is available yet (it can + # take up to 4 hours for this to happen). + # + # When this happens, the AccountToUpdate._get_job_results() method will set lastAuthenticated to 0. Usually + # we don't want to persist such an entity, with one exception: there's already a recorded, non-zero + # lastAuthenticated timestamp persisted for this item. That means the service was accessed at some point in + # time, but now more than 365 passed since the last access, so AA no longer returns a timestamp for it. + if existing_last_authenticated is not None and existing_last_authenticated > last_authenticated: + if last_authenticated == 0: + log.info( + "Previously seen object not accessed in the past 365 days (got null lastAuthenticated from " + "AA). Setting to 0. Object %s service %s previous timestamp %d", + item.item_id, + item.serviceName, + item.lastAuthenticated, + ) + else: + log.warning( + "Received an older time than previously seen for object %s service %s (%d < %d). Not updating!", + item.item_id, + item.serviceName, + last_authenticated, + existing_last_authenticated, + ) + item.lastAuthenticated = existing_last_authenticated + + session.add(item) + + def get_or_create_iam_object(self, arn: str, session: session_type = None): + with self.session_scope(session) as session: + try: + item = session.query(AWSIAMObject).filter(AWSIAMObject.arn == arn).scalar() + except SQLAlchemyError: + log.exception("failed to retrieve IAM object") + raise + + added = False + if not item: + item = AWSIAMObject(arn=arn, lastUpdated=datetime.datetime.utcnow()) + added = True + else: + item.lastUpdated = datetime.datetime.utcnow() + + try: + session.add(item) + except SQLAlchemyError: + log.exception("failed to add AWSIAMObject item to session") + raise + + # we only need a refresh if the object was created + if added: + try: + session.commit() + session.refresh(item) + except SQLAlchemyError: + log.exception("failed to create IAM object") + raise + return item diff --git a/aardvark/plugins/__init__.py b/aardvark/plugins/__init__.py new file mode 100644 index 0000000..8a204e4 --- /dev/null +++ b/aardvark/plugins/__init__.py @@ -0,0 +1,3 @@ +from aardvark.plugins.plugin import AardvarkPlugin + +__all__ = ["AardvarkPlugin"] diff --git a/aardvark/plugins/plugin.py b/aardvark/plugins/plugin.py new file mode 100644 index 0000000..db8ac47 --- /dev/null +++ b/aardvark/plugins/plugin.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from aardvark.config import settings + +if TYPE_CHECKING: + from dynaconf.utils import DynaconfDict + + +class AardvarkPlugin: + def __init__(self, alternative_config: DynaconfDict | None = None): + if alternative_config: + self.config = alternative_config + else: + self.config = settings diff --git a/aardvark/retrievers/__init__.py b/aardvark/retrievers/__init__.py new file mode 100644 index 0000000..ed57277 --- /dev/null +++ b/aardvark/retrievers/__init__.py @@ -0,0 +1,25 @@ +r""" + ,-~~~~-, + .-~~~; ;~~~-. + / / \ \ + { .'{ O O }'. } + `~` { .-~~~~-. } `~` + ;/ \; + /'._ () _.'\ + / `~~~~` \ + ; ; + { } + { } { } + { } { } + / \ / \ + { { { }~~{ } } } + `~~~~~` `~~~~~` + (`"======="`) + (_.=======._) + +Good boy. +""" + +from aardvark.retrievers.plugin import RetrieverPlugin + +__all__ = ["RetrieverPlugin"] diff --git a/aardvark/retrievers/access_advisor/__init__.py b/aardvark/retrievers/access_advisor/__init__.py new file mode 100644 index 0000000..1c26467 --- /dev/null +++ b/aardvark/retrievers/access_advisor/__init__.py @@ -0,0 +1,3 @@ +from aardvark.retrievers.access_advisor.retriever import AccessAdvisorRetriever + +__all__ = ["AccessAdvisorRetriever"] diff --git a/aardvark/retrievers/access_advisor/retriever.py b/aardvark/retrievers/access_advisor/retriever.py new file mode 100644 index 0000000..59895bb --- /dev/null +++ b/aardvark/retrievers/access_advisor/retriever.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +import asyncio +import logging +from typing import TYPE_CHECKING, Any + +from asgiref.sync import sync_to_async +from cloudaux.aws.sts import boto3_cached_conn + +from aardvark.exceptions import AccessAdvisorError +from aardvark.retrievers import RetrieverPlugin + +if TYPE_CHECKING: + import datetime + + from dynaconf.utils import DynaconfDict + +log = logging.getLogger("aardvark") + + +class AccessAdvisorRetriever(RetrieverPlugin): + def __init__(self, alternative_config: DynaconfDict | None = None): + super().__init__("access_advisor", alternative_config=alternative_config) + self.max_retries = self.config.get("retrievers.access_advisor.max_retries", 10) + self.backoff_base = self.config.get("retrievers.access_advisor.retry_backoff_base", 2) + + async def _generate_service_last_accessed_details(self, iam_client, arn): + """Call IAM API to create an Access Advisor job.""" + result = await sync_to_async(iam_client.generate_service_last_accessed_details)(Arn=arn) + return result["JobId"] + + async def _get_service_last_accessed_details(self, iam_client, job_id): + """Retrieve Access Advisor job results. Do an exponential backoff if the job is not complete.""" + attempts = 0 + while attempts < self.max_retries: + details = await sync_to_async(iam_client.get_service_last_accessed_details)(JobId=job_id) + match details.get("JobStatus"): + case "COMPLETED": + return details + case "IN_PROGRESS": + # backoff sleep and try again + await asyncio.sleep(self.backoff_base**attempts) + attempts += 1 + continue + case _: + message = f"Access Advisor job failed: {details.get('Error') or 'no error details provided'}" + raise AccessAdvisorError(message) + message = "Access Advisor job failed: exceeded max retries" + raise AccessAdvisorError(message) + + @staticmethod + def _get_account_from_arn(arn: str) -> str: + """Return the AWS account ID from an ARN.""" + return arn.split(":")[4] + + @staticmethod + def _transform_result(service_last_accessed: dict[str, str | int | datetime.datetime]) -> dict[str, str | int]: + """'Transform' Access Advisor result, which really just means convert the datetime to a timestamp.""" + last_authenticated = service_last_accessed.get("LastAuthenticated") + + # Convert from datetime to timestamp, defaulting to zero if there isn't one + last_authenticated = int(last_authenticated.timestamp() * 1000) if last_authenticated else 0 + + service_last_accessed["LastAuthenticated"] = last_authenticated + return service_last_accessed + + async def run(self, arn: str, data: dict[str, Any]) -> dict[str, Any]: + """Retrieve Access Advisor data for the given ARN and add the results to `data["access_advisor"]`.""" + log.debug("running %s for %s", self, arn) + account = self._get_account_from_arn(arn) + conn_details: dict[str, str] = { + "account_number": account, + "assume_role": self.config.get("aws_rolename"), + "session_name": "aardvark", + "region": self.config.get("aws_region", "us-east-1"), + "arn_partition": self.config("aws_arn_partition", "aws"), + } + iam_client = boto3_cached_conn("iam", **conn_details) + try: + job_id = await self._generate_service_last_accessed_details(iam_client, arn) + except iam_client.exceptions.NoSuchEntityException: + log.info("ARN %s no longer exists in AWS IAM", arn) + return data + + aa_details = await self._get_service_last_accessed_details(iam_client, job_id) + result = map(self._transform_result, aa_details["ServicesLastAccessed"]) + data["access_advisor"] = list(result) + return data diff --git a/aardvark/retrievers/plugin.py b/aardvark/retrievers/plugin.py new file mode 100644 index 0000000..8da8a29 --- /dev/null +++ b/aardvark/retrievers/plugin.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any + +from aardvark.plugins import AardvarkPlugin + +if TYPE_CHECKING: + from dynaconf.utils import DynaconfDict + +log = logging.getLogger("aardvark") + + +class RetrieverPlugin(AardvarkPlugin): + _name: str + + def __init__(self, name: str, alternative_config: DynaconfDict = None): + super().__init__(alternative_config=alternative_config) + self._name = name + + async def run(self, arn: str, data: dict[str, Any]) -> dict[str, Any]: + raise NotImplementedError + + @property + def name(self): + return self._name + + def __str__(self): + return f"Retriever({self.name})" diff --git a/aardvark/retrievers/runner.py b/aardvark/retrievers/runner.py new file mode 100644 index 0000000..0644927 --- /dev/null +++ b/aardvark/retrievers/runner.py @@ -0,0 +1,275 @@ +import asyncio +import logging +import re +from copy import copy +from typing import Any + +from asgiref.sync import sync_to_async +from botocore.exceptions import ClientError +from cloudaux.aws.iam import list_roles, list_users +from cloudaux.aws.sts import boto3_cached_conn +from dynaconf import Dynaconf +from swag_client import InvalidSWAGDataException +from swag_client.backend import SWAGManager +from swag_client.util import parse_swag_config_options + +from aardvark.exceptions import RetrieverError +from aardvark.persistence.sqlalchemy import SQLAlchemyPersistence +from aardvark.plugins import AardvarkPlugin +from aardvark.retrievers import RetrieverPlugin +from aardvark.retrievers.access_advisor import AccessAdvisorRetriever + +log = logging.getLogger("aardvark") +re_account_id = re.compile(r"\d{12}") +EMPTY_QUEUE_DELAY = 1 +EMPTY_QUEUE_RETRIES = 5 + + +class RetrieverRunner(AardvarkPlugin): + """Scheduling and execution for data retrieval tasks.""" + + retrievers: list[RetrieverPlugin] + account_queue: asyncio.Queue + arn_queue: asyncio.Queue + results_queue: asyncio.Queue + failure_queue: asyncio.Queue + failed_arns: list[str] + tasks: list[asyncio.Future] + num_workers: int + swag: SWAGManager + swag_config: dict[str, str] + accounts_complete: bool + persistence: SQLAlchemyPersistence + + def __init__( + self, + alternative_config: Dynaconf = None, + ): + super().__init__(alternative_config=alternative_config) + self.tasks = [] + self.retrievers = [] + self.failed_arns = [] + self.num_workers = self.config.get("updater_num_threads") + self.swag_config = self.config.get("swag") + swag_opts = parse_swag_config_options(self.swag_config["opts"]) + self.swag = SWAGManager(**swag_opts) + self.accounts_complete = False + self.persistence = SQLAlchemyPersistence(alternative_config=alternative_config) + + def register_retriever(self, r: RetrieverPlugin): + """Add a retriever instance to be called during the run process.""" + self.retrievers.append(r) + + async def _run_retrievers(self, arn: str) -> dict[str, Any]: + """Run retriever plugins for a given ARN. + + Retriever plugins are executed in the order in which they are registered. Each retriever + is passed the result from the previous one, starting with a dict containing an `arn` element. + + Note: The data from all previous retriever plugins is mutable by subsequent ones. + """ + data = { + "arn": arn, + } + # Iterate through retrievers, passing the results from the previous to the next. + for r in self.retrievers: + try: + data = await r.run(arn, data) + except Exception as e: + log.exception("failed to run %s on ARN %s", r, arn) + raise RetrieverError from e + return data + + async def _retriever_loop(self, name: str): + """Loop to consume from self.arn_queue and call the retriever runner function.""" + log.debug("creating %s", name) + while True: + log.debug("getting arn from queue") + arn = await self.arn_queue.get() + log.debug("%s retrieving data for %s", name, arn) + try: + data = await self._run_retrievers(arn) + except Exception: + log.exception("failed to run retriever on ARN %s", arn) + self.failed_arns.append(arn) + await self.failure_queue.put(arn) + self.arn_queue.task_done() + continue + # TODO: handle nested data from retrievers in persistence layer + await self.results_queue.put(data) + self.arn_queue.task_done() + + async def _results_loop(self, name: str): + """Loop to consume from self.results_queue and handle results.""" + log.debug("creating %s", name) + while True: + data = await self.results_queue.get() + log.debug("%s storing results for %s", name, data["arn"]) + try: + await sync_to_async(self.persistence.store_role_data)({data["arn"]: data["access_advisor"]}) + except Exception: + log.exception("exception occurred in results loop") + await self.failure_queue.put(data) + self.results_queue.task_done() + + async def _get_arns_for_account(self, account: str): + """Retrieve ARNs for roles, users, policies, and groups in an account and add them to the ARN queue.""" + conn_details: dict[str, str] = { + "account_number": account, + "assume_role": self.config.get("aws_rolename"), + "session_name": "aardvark", + "region": self.config.get("aws_region", "us-east-1"), + "arn_partition": self.config.get("aws_arn_partition", "aws"), + } + client = await sync_to_async(boto3_cached_conn)("iam", service_type="client", **conn_details) + + for role in await sync_to_async(list_roles)(**conn_details): + await self.arn_queue.put(role["Arn"]) + + for user in await sync_to_async(list_users)(**conn_details): + await self.arn_queue.put(user["Arn"]) + + for page in await sync_to_async(client.get_paginator("list_policies").paginate)(Scope="Local"): + for policy in page["Policies"]: + await self.arn_queue.put(policy["Arn"]) + + for page in await sync_to_async(client.get_paginator("list_groups").paginate)(): + for group in page["Groups"]: + await self.arn_queue.put(group["Arn"]) + + async def _arn_lookup_loop(self, name: str): + """Loop to consume from self.account_queue to retrieve and enqueue ARNs for each account.""" + log.debug("creating %s", name) + while True: + log.debug("getting account from queue") + account = await self.account_queue.get() + log.debug("%s retrieving ARNs for %s", name, account) + try: + await self._get_arns_for_account(account) + except ClientError: + log.exception("exception occurred in arn lookup loop: %s") + await self.failure_queue.put(account) + self.account_queue.task_done() + + async def _get_swag_accounts(self) -> list[dict]: + """Retrieve AWS accounts from SWAG based on the SWAG options in the application configuration.""" + log.debug("getting accounts from SWAG") + try: + all_accounts: list[dict] = self.swag.get_all(self.swag_config["filter"]) + swag_service = self.swag_config["service_enabled_requirement"] + if swag_service: + all_accounts = await sync_to_async(self.swag.get_service_enabled)( + swag_service, accounts_list=all_accounts + ) + else: + all_accounts = await sync_to_async(self.swag.get_all)(search_filter=self.swag_config["filter"]) + except (KeyError, InvalidSWAGDataException, ClientError) as e: + log.exception("account names passed but SWAG not configured or unavailable") + message = "could not retrieve SWAG data" + raise RetrieverError(message) from e + + return all_accounts + + async def _queue_all_accounts(self): + """Add all accounts to the account queue. + + Perform a SWAG lookup and add all returned accounts to `self.account_queue`.""" + for account in await self._get_swag_accounts(): + await self.account_queue.put(account["id"]) + + async def _queue_arns(self, arns: list[str]): + """Add a list of ARNs to the ARN queue.""" + for arn in arns: + await self.arn_queue.put(arn) + + async def _queue_accounts(self, account_names: list[str]): + """Add requested accounts to the account queue. + + Given a list of account names and/or IDs, use SWAG to look up account numbers where needed + and add each account number to `self.account_queue`.""" + accounts = copy(account_names) + for account in accounts: + if re_account_id.match(account): + await self.account_queue.put(account) + accounts.remove(account) + + all_accounts = await self._get_swag_accounts() + + # TODO(psanders): Consider refactoring. This could be expensive for organizations + # with many accounts and many aliases. + for account in all_accounts: + # Check if the account name matches one we want. If so, queue it and carry on. + if account.get("name") in accounts: + await self.account_queue.put(account["id"]) + continue + # Now check the account's aliases to see if one matches. + alias_key = "aliases" if account["schemaVersion"] == "2" else "alias" + for alias in account.get(alias_key, []): + if alias in accounts: + await self.account_queue.put(account["id"]) + continue + + def cancel(self): + """Send a cancel signal to all running workers.""" + log.info("stopping runner tasks") + for task in self.tasks: + task.cancel() + log.info("task %s canceled", task) + + async def run(self, accounts: list[str] | None = None, arns: list[str] | None = None): + """Prep account queue and kick off ARN lookup, retriever, and results workers. + + Populate ARN queue with ARNs if provided. Otherwise, use SWAG to look up account + numbers and put those in the account queue. + + After that, we start `updater.num_threads` workers for each queue. Workers will NOT + be started for the account queue if ARNs are provided since there will be no accounts + in the queue. + """ + self.register_retriever(AccessAdvisorRetriever()) + log.debug("starting retriever") + + self.arn_queue = asyncio.Queue() + self.account_queue = asyncio.Queue() + self.results_queue = asyncio.Queue() + self.failure_queue = asyncio.Queue() + + lookup_accounts = True + if arns: + await self._queue_arns(arns) + lookup_accounts = False + + # We only need to do account lookups if ARNs were not provided. + if lookup_accounts: + if accounts: + await self._queue_accounts(accounts) + else: + await self._queue_all_accounts() + + for i in range(self.num_workers): + name = f"arn-lookup-worker-{i}" + task = asyncio.create_task(self._arn_lookup_loop(name)) + self.tasks.append(task) + + for i in range(self.num_workers): + name = f"retriever-worker-{i}" + task = asyncio.create_task(self._retriever_loop(name)) + self.tasks.append(task) + + for i in range(self.num_workers): + name = f"results-worker-{i}" + task = asyncio.create_task(self._results_loop(name)) + self.tasks.append(task) + + await self.account_queue.join() + await self.arn_queue.join() + await self.results_queue.join() + + # Clean up our workers + self.cancel() + + while not self.failure_queue.empty(): + failure = await self.failure_queue.get() + log.error("failure: %s", failure) + + await asyncio.gather(*self.tasks, return_exceptions=True) diff --git a/aardvark/updater/__init__.py b/aardvark/updater/__init__.py index 70c77b3..e69de29 100644 --- a/aardvark/updater/__init__.py +++ b/aardvark/updater/__init__.py @@ -1,254 +0,0 @@ -import time - -from blinker import Signal -from cloudaux.aws.iam import list_roles, list_users -from cloudaux.aws.sts import boto3_cached_conn -from cloudaux.aws.decorators import rate_limited - - -class JobNotComplete(Exception): - pass - - -class JobFailed(Exception): - pass - - -class AccountToUpdate(object): - on_ready = Signal() - on_complete = Signal() - on_error = Signal() - on_failure = Signal() - - def __init__(self, current_app, account_number, role_name, arns_list): - self.current_app = current_app - self.account_number = account_number - self.role_name = role_name - self.arn_list = arns_list - self.conn_details = { - 'account_number': account_number, - 'assume_role': role_name, - 'session_name': 'aardvark', - 'region': self.current_app.config.get('REGION') or 'us-east-1', - 'arn_partition': self.current_app.config.get('ARN_PARTITION') or 'aws' - } - self.max_access_advisor_job_wait = 5 * 60 # Wait 5 minutes before giving up on jobs - - def update_account(self): - """ - Updates Access Advisor data for a given AWS account. - 1) Gets list of IAM Role ARNs in target account. - 2) Gets IAM credentials in target account. - 3) Calls GenerateServiceLastAccessedDetails for each role - 4) Calls GetServiceLastAccessedDetails for each role to retrieve data - - :return: Return code and JSON Access Advisor data for given account - """ - self.on_ready.send(self) - arns = self._get_arns() - - if not arns: - self.current_app.logger.warn("Zero ARNs collected. Exiting") - exit(-1) - - client = self._get_client() - try: - details = self._call_access_advisor(client, list(arns)) - except Exception as e: - self.on_failure.send(self, error=e) - self.current_app.logger.exception('Failed to call access advisor', exc_info=True) - return 255, None - else: - self.on_complete.send(self) - return 0, details - - def _get_arns(self): - """ - Gets a list of all Role ARNs in a given account, optionally limited by - class property ARN filter - :return: list of role ARNs - """ - client = self._get_client() - - account_arns = set() - - for role in list_roles(**self.conn_details): - account_arns.add(role['Arn']) - - for user in list_users(**self.conn_details): - account_arns.add(user['Arn']) - - for page in client.get_paginator('list_policies').paginate(Scope='Local'): - for policy in page['Policies']: - account_arns.add(policy['Arn']) - - for page in client.get_paginator('list_groups').paginate(): - for group in page['Groups']: - account_arns.add(group['Arn']) - - result_arns = set() - for arn in self.arn_list: - if arn.lower() == 'all': - return account_arns - - if arn not in account_arns: - self.current_app.logger.warn("Provided ARN {arn} not found in account.".format(arn=arn)) - continue - - result_arns.add(arn) - - self.current_app.logger.debug("got %d arns", len(result_arns)) - return list(result_arns) - - def _get_client(self): - """ - Assumes into the target account and obtains IAM client - - :return: boto3 IAM client in target account & role - """ - try: - client = boto3_cached_conn( - 'iam', **self.conn_details) - - if not client: - raise ValueError(f"boto3_cached_conn returned null IAM client for {self.account_number}") - - return client - - except Exception as e: - self.on_failure.send(self, error=e) - self.current_app.logger.exception(f"Failed to obtain boto3 IAM client for account {self.account_number}.", exc_info=False) - raise e - - def _call_access_advisor(self, iam, arns): - jobs = self._generate_job_ids(iam, arns) - details = self._process_jobs(iam, jobs) - if arns and not details: - self.current_app.logger.error("Didn't get any results from Access Advisor") - return details - - @rate_limited() - def _generate_service_last_accessed_details(self, iam, arn): - """ Wrapping the actual AWS API calls for rate limiting protection. """ - self.current_app.logger.debug('generating last accessed details for role %s', arn) - return iam.generate_service_last_accessed_details(Arn=arn)['JobId'] - - @rate_limited() - def _get_service_last_accessed_details(self, iam, job_id, marker=None): - """ Wrapping the actual AWS API calls for rate limiting protection. """ - self.current_app.logger.debug('getting last accessed details for job %s', job_id) - params = { - 'JobId': job_id, - } - if marker: - params['Marker'] = marker - return iam.get_service_last_accessed_details(**params) - - def _generate_job_ids(self, iam, arns): - jobs = {} - for role_arn in arns: - try: - job_id = self._generate_service_last_accessed_details(iam, role_arn) - jobs[job_id] = role_arn - except iam.exceptions.NoSuchEntityException: - """ We're here because this ARN disappeared since the call to self._get_arns(). Log the missing ARN and move along. """ - self.current_app.logger.info('ARN {arn} found gone when fetching details'.format(arn=role_arn)) - except Exception as e: - self.on_error.send(self, error=e) - self.current_app.logger.error('Could not gather data from {0}.'.format(role_arn), exc_info=True) - return jobs - - def _get_job_results(self, iam, job_id, role_arn): - last_accessed_details = [] - marker = None # Marker is used for pagination - while True: - try: - response = self._get_service_last_accessed_details(iam, job_id, marker=marker) - except Exception as e: - self.on_error.send(self, error=e) - self.current_app.logger.error(f'Could not gather data for role {role_arn}.', exc_info=True) - raise - - # Check job status. Possible values are IN_PROGRESS, COMPLETED, and FAILED. - if response['JobStatus'] == 'IN_PROGRESS': - raise JobNotComplete() - elif response['JobStatus'] == 'FAILED': - message = response.get("Error", {}).get("Message", "Unknown error") - raise JobFailed(message) - - # Status should only be COMPLETED if we've made it this far. - if response['JobStatus'] != 'COMPLETED': - raise Exception(f"Unknown job status {response['JobStatus']}") - - # Add results to list - last_accessed_details.extend(response.get('ServicesLastAccessed', [])) - - # Check for pagination token, save it to marker if it exists - if response.get('IsTruncated', False): - marker = response.get('Marker') - else: - break - return last_accessed_details - - def _process_jobs(self, iam, jobs): - access_details = {} - job_queue = list(jobs.keys()) - last_job_completion_time = time.time() - - while job_queue: - - # Check for timeout - now = time.time() - if now - last_job_completion_time > self.max_access_advisor_job_wait: - # We ran out of time, some jobs are unfinished - self._log_unfinished_jobs(job_queue, jobs) - break - - # Pull next job ID - job_id = job_queue.pop() - role_arn = jobs[job_id] - try: - last_accessed_details = self._get_job_results(iam, job_id, role_arn) - except JobNotComplete: - job_queue.append(job_id) - continue - except JobFailed as e: - log_str = f"Job {job_id} for ARN {role_arn} failed: {e}" - - failing_arns = self.current_app.config.get('FAILING_ARNS', {}) - if role_arn in failing_arns: - self.current_app.logger.info(log_str) - else: - self.current_app.logger.error(log_str) - continue - except Exception as e: - self.on_error.send(self, error=e) - self.current_app.logger.error('Could not gather data from {0}.'.format(role_arn), exc_info=True) - continue - - # Job status must be COMPLETED. Save result. - last_job_completion_time = time.time() - updated_list = [] - - for detail in last_accessed_details: - # AWS gives a datetime, convert to epoch - last_auth = detail.get('LastAuthenticated') - if last_auth: - last_auth = int(time.mktime(last_auth.timetuple()) * 1000) - else: - last_auth = 0 - - detail['LastAuthenticated'] = last_auth - updated_list.append(detail) - - access_details[role_arn] = updated_list - - return access_details - - def _log_unfinished_jobs(self, job_queue, job_details): - for job_id in job_queue: - role_arn = job_details[job_id] - self.current_app.logger.error("Job {job_id} for ARN {arn} didn't finish".format( - job_id=job_id, - arn=role_arn, - )) diff --git a/aardvark/utils/sqla_regex.py b/aardvark/utils/sqla_regex.py index 8efd384..f98acb2 100644 --- a/aardvark/utils/sqla_regex.py +++ b/aardvark/utils/sqla_regex.py @@ -8,13 +8,14 @@ import re import sqlite3 -from sqlalchemy import String as _String, event, exc +from sqlalchemy import String as _String +from sqlalchemy import event, exc from sqlalchemy.engine import Engine from sqlalchemy.ext.compiler import compiles from sqlalchemy.sql.expression import BinaryExpression, func, literal from sqlalchemy.sql.operators import custom_op -__all__ = ['String'] +__all__ = ["String"] class String(_String): @@ -23,29 +24,31 @@ class String(_String): Supports additional operators that can be used while constructing filter expressions. """ - class comparator_factory(_String.comparator_factory): + + class ComparatorFactory(_String.comparator_factory): """Contains implementation of :class:`String` operators related to regular expressions. """ + def regexp(self, other): - return RegexMatchExpression(self.expr, literal(other), custom_op('~')) + return RegexMatchExpression(self.expr, literal(other), custom_op("~")) def iregexp(self, other): - return RegexMatchExpression(self.expr, literal(other), custom_op('~*')) + return RegexMatchExpression(self.expr, literal(other), custom_op("~*")) def not_regexp(self, other): - return RegexMatchExpression(self.expr, literal(other), custom_op('!~')) + return RegexMatchExpression(self.expr, literal(other), custom_op("!~")) def not_iregexp(self, other): - return RegexMatchExpression(self.expr, literal(other), custom_op('!~*')) + return RegexMatchExpression(self.expr, literal(other), custom_op("!~*")) class RegexMatchExpression(BinaryExpression): """Represents matching of a column againsts a regular expression.""" -@compiles(RegexMatchExpression, 'sqlite') -def sqlite_regex_match(element, compiler, **kw): +@compiles(RegexMatchExpression, "sqlite") +def sqlite_regex_match(element, compiler, **_): """Compile the SQL expression representing a regular expression match for the SQLite engine. """ @@ -54,12 +57,14 @@ def sqlite_regex_match(element, compiler, **kw): try: func_name, _ = SQLITE_REGEX_FUNCTIONS[operator] except (KeyError, ValueError) as e: - would_be_sql_string = ' '.join((compiler.process(element.left), - operator, - compiler.process(element.right))) + would_be_sql_string = " ".join((compiler.process(element.left), operator, compiler.process(element.right))) + message = f"unknown regular expression match operator: {operator}" raise exc.StatementError( - "unknown regular expression match operator: %s" % operator, - would_be_sql_string, None, e) + message, + would_be_sql_string, + None, + e, + ) from e # compile the expression as an invocation of the custom function regex_func = getattr(func, func_name) @@ -67,8 +72,8 @@ def sqlite_regex_match(element, compiler, **kw): return compiler.process(regex_func_call) -@event.listens_for(Engine, 'connect') -def sqlite_engine_connect(dbapi_connection, connection_record): +@event.listens_for(Engine, "connect") +def sqlite_engine_connect(dbapi_connection, _): """Listener for the event of establishing connection to a SQLite database. Creates the functions handling regular expression operators @@ -84,12 +89,11 @@ def sqlite_engine_connect(dbapi_connection, connection_record): # Mapping from the regular expression matching operators # to named Python functions that implement them for SQLite. SQLITE_REGEX_FUNCTIONS = { - '~': ('REGEXP', - lambda value, regex: bool(re.match(regex, value))), - '~*': ('IREGEXP', - lambda value, regex: bool(re.match(regex, value, re.IGNORECASE))), - '!~': ('NOT_REGEXP', - lambda value, regex: not re.match(regex, value)), - '!~*': ('NOT_IREGEXP', - lambda value, regex: not re.match(regex, value, re.IGNORECASE)), + "~": ("REGEXP", lambda value, regex: bool(re.match(regex, value))), + "~*": ("IREGEXP", lambda value, regex: bool(re.match(regex, value, re.IGNORECASE))), + "!~": ("NOT_REGEXP", lambda value, regex: not re.match(regex, value)), + "!~*": ( + "NOT_IREGEXP", + lambda value, regex: not re.match(regex, value, re.IGNORECASE), + ), } diff --git a/aardvark/view.py b/aardvark/view.py index af5f3d0..e69de29 100644 --- a/aardvark/view.py +++ b/aardvark/view.py @@ -1,200 +0,0 @@ -import datetime - -from flask import abort, jsonify -from flask import Blueprint -from flask_restful import Api, Resource, reqparse -from flask import Flask -import sqlalchemy as sa - -from aardvark.model import AWSIAMObject - - -mod = Blueprint('advisor', __name__) -api = Api(mod) -app = Flask(__name__) - - -class RoleSearch(Resource): - """ - Search for roles by phrase, regex, or by ARN. - """ - def __init__(self): - super(RoleSearch, self).__init__() - self.reqparse = reqparse.RequestParser() - - def combine(self, aa): - del aa['count'] - del aa['page'] - del aa['total'] - - usage = dict() - for arn, services in aa.items(): - for service in services: - namespace = service.get('serviceNamespace') - last_authenticated = service.get('lastAuthenticated') - if namespace not in usage: - usage[namespace] = service - else: - count_entities = usage[namespace]['totalAuthenticatedEntities'] + service['totalAuthenticatedEntities'] - if last_authenticated > usage[namespace]['lastAuthenticated']: - usage[namespace] = service - usage[namespace]['totalAuthenticatedEntities'] = count_entities - - for namespace, service in usage.items(): - last_authenticated = service['lastAuthenticated'] - dt_last_authenticated = datetime.datetime.fromtimestamp(last_authenticated / 1e3) - dt_starting = datetime.datetime.utcnow() - datetime.timedelta(days=90) - usage[namespace]['USED_LAST_90_DAYS'] = dt_last_authenticated > dt_starting - - return jsonify(usage) - - # undocumented convenience pass-through so we can query directly from browser - @app.route('/advisors') - def get(self): - return(self.post()) - - @app.route('/advisors') - def post(self): - """Get access advisor data for role(s) - Returns access advisor information for role(s) that match filters - --- - consumes: - - 'application/json' - produces: - - 'application/json' - - parameters: - - name: page - in: query - type: integer - description: return results from given page of total results - required: false - - name: count - in: query - type: integer - description: specifies how many results should be return per page - required: false - - name: combine - in: query - type: boolean - description: combine access advisor data for all results [Default False] - required: false - - name: query - in: body - schema: - $ref: '#/definitions/QueryBody' - description: | - one or more query parameters in a JSON blob. Filter - parameters build on eachother. - - Options are: - - 1) arn list - a list of one or more specific arns - - 2) phrase matching - search for ARNs like the one supplied - - 3) regex - match a supplied regular expression. - - definitions: - AdvisorData: - type: object - properties: - lastAuthenticated: - type: number - lastAuthenticatedEntity: - type: string - lastUpdated: - type: string - serviceName: - type: string - serviceNamespace: - type: string - totalAuthenticatedEntities: - type: number - QueryBody: - type: object - properties: - phrase: - type: string - regex: - type: string - arn: - type: array - items: string - Results: - type: array - items: - $ref: '#/definitions/AdvisorData' - - responses: - 200: - description: Query successful, results in body - schema: - $ref: '#/definitions/AdvisorData' - 400: - description: Bad request - error message in body - """ - self.reqparse.add_argument('page', type=int, default=1) - self.reqparse.add_argument('count', type=int, default=30) - self.reqparse.add_argument('combine', type=str, default='false') - self.reqparse.add_argument('phrase', default=None) - self.reqparse.add_argument('regex', default=None) - self.reqparse.add_argument('arn', default=None, action='append') - try: - args = self.reqparse.parse_args() - except Exception as e: - abort(400, str(e)) - - page = args.pop('page') - count = args.pop('count') - combine = args.pop('combine', 'false') - combine = combine.lower() == 'true' - phrase = args.pop('phrase', '') - arns = args.pop('arn', []) - regex = args.pop('regex', '') - items = None - - # default unfiltered query - query = AWSIAMObject.query - - try: - if phrase: - query = query.filter(AWSIAMObject.arn.ilike('%' + phrase + '%')) - - if arns: - query = query.filter( - sa.func.lower(AWSIAMObject.arn).in_([arn.lower() for arn in arns])) - - if regex: - query = query.filter(AWSIAMObject.arn.regexp(regex)) - - items = query.paginate(page, count) - except Exception as e: - abort(400, str(e)) - - if not items: - items = AWSIAMObject.query.paginate(page, count) - - values = dict(page=items.page, total=items.total, count=len(items.items)) - for item in items.items: - item_values = [] - for advisor_data in item.usage: - item_values.append(dict( - lastAuthenticated=advisor_data.lastAuthenticated, - serviceName=advisor_data.serviceName, - serviceNamespace=advisor_data.serviceNamespace, - lastAuthenticatedEntity=advisor_data.lastAuthenticatedEntity, - totalAuthenticatedEntities=advisor_data.totalAuthenticatedEntities, - lastUpdated=item.lastUpdated - )) - values[item.arn] = item_values - - if combine and items.total > len(items.items): - abort(400, "Error: Please specify a count of at least {}.".format(items.total)) - elif combine: - return self.combine(values) - - return jsonify(values) - - -api.add_resource(RoleSearch, '/advisors') diff --git a/docker-compose.yml b/docker-compose.yml index 9fc52db..7ad25f7 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -7,6 +7,7 @@ services: image: netflixoss/aardvark:latest volumes: - data:/data + - ./settings.local.yaml.bak:/etc/aardvark/settings.local.yaml.bak env_file: .env command: [ "aardvark", "create_db" ] @@ -19,8 +20,9 @@ services: - 5000:5000 volumes: - data:/data + - ./settings.local.yaml.bak:/etc/aardvark/settings.local.yaml.bak env_file: .env - command: [ "aardvark", "start_api", "-b", "0.0.0.0:5000" ] + command: [ "flask", "run", "-h", "0.0.0.0", "-p", "5000" ] collector: build: . @@ -30,8 +32,11 @@ services: restart: always volumes: - data:/data + - ./settings.local.yaml.bak:/etc/aardvark/settings.local.yaml.bak env_file: .env - command: [ "aardvark", "update", "-a", "$AARDVARK_ACCOUNTS" ] + command: [ "aardvark", "update" ] + # If you're not using SWAG, you'll need to specify one or more accounts here: + # command: [ "aardvark", "update", "-a", "123456789012", "-a", "234567890123" ] volumes: data: diff --git a/docs/images/aardvark_logo.jpg b/docs/images/aardvark_logo.jpg deleted file mode 100644 index 2b39fb6..0000000 Binary files a/docs/images/aardvark_logo.jpg and /dev/null differ diff --git a/docs/images/aardvark_logo_small.png b/docs/images/aardvark_logo_small.png new file mode 100644 index 0000000..608e121 Binary files /dev/null and b/docs/images/aardvark_logo_small.png differ diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..7264da4 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,67 @@ +[build-system] +requires = ["hatchling", "hatch-vcs", "hatch-build-scripts"] +build-backend = "hatchling.build" + +[project] +name = "aardvark" +description = "A multi-account AWS IAM Access Advisor API" +readme = "README.md" +authors = [{name = "Netflix", email = "aardvark-maintainers@netflix.com"}] +license = {file = "LICENSE"} +requires-python = ">=3.10" +dependencies = [ + "asgiref", + "blinker", + "click", + "cloudaux", + "confuse", + "dynaconf", + "flasgger", + "flask", + "SQLAlchemy", + "setuptools", + "swag_client", + "toml", +] +dynamic = ["version"] + +[project.scripts] +aardvark = "aardvark.manage:cli" + +[project.urls] +repository = "https://github.com/Netflix-Skunkworks/aardvark" + +[tool.hatch.build] +include = ["aardvark", "README.md", "LICENSE", "aardvark/config_default.yaml"] + +[[tool.hatch.build.hooks.build-scripts.scripts]] +commands = [ + # the version file is updated as part of the build. this command tells git to ignore it. + # if you actually need to update the version file, you'll need to run the inverse of this command: + # git update-index --no-skip-worktree aardvark/__version__.py + "git update-index --skip-worktree aardvark/__version__.py", +] +artifacts = [] + +[tool.hatch.build.hooks.vcs] +version-file = "aardvark/__version__.py" + +[tool.hatch.envs.test] +extra-dependencies = [ + "pytest", + "pytest-asyncio", +] + +[[tool.hatch.envs.test.matrix]] +python = ["3.10", "3.11"] + +[tool.hatch.version] +source = "vcs" +scheme = "standard" + +[tool.hatch.version.raw-options] +version_scheme = "python-simplified-semver" +local_scheme = "no-local-version" + +[tool.ruff.lint.per-file-ignores] +"**/tests/*" = ["ARG001", "ARG002", "PLR2004", "S101", "S105", "SLF001"] diff --git a/requirements.txt b/requirements.txt index 037bf30..e69de29 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,61 +0,0 @@ -blinker -boto3==1.20.43 -botocore==1.23.43 -bunch==1.0.1 -cloudaux>=1.8.0 # pinned -Flask==1.0.2 -Flask-RESTful==0.3.5 -Flask-Script==2.0.5 -Flask-SQLAlchemy>=2.5 # pinned -SQLAlchemy==1.3.10 -swag-client==0.4.6 - -# indirect -aniso8601==8.0.0 -astroid==2.4.2 -attrs==19.3.0 -certifi==2023.7.22 -chardet==3.0.4 -Click==7.0 -click-log==0.3.2 -decorator==4.4.0 -deepdiff==3.3.0 -defusedxml==0.6.0 -docutils==0.15.2 -dogpile.cache==0.8.0 -flagpole==1.1.1 -flasgger==0.9.5 -gunicorn==19.7.1 -idna==2.8 -importlib-metadata -inflection==0.3.1 -isort==4.3.21 -itsdangerous==1.1.0 -Jinja2==3.0.3 -jmespath==0.9.4 -joblib==0.14.0 -jsonpickle==1.2 -jsonschema==3.1.1 -lazy-object-proxy==1.4.2 -MarkupSafe -marshmallow==2.20.5 -mccabe==0.6.1 -mistune==0.8.4 -more-itertools==7.2.0 -pexpect==4.7.0 -psycopg2-binary==2.9.3 -ptyprocess==0.6.0 -pylint==2.6.0 -pyrsistent==0.15.4 -python-dateutil==2.8.0 -pytz==2017.2 -PyYAML -requests==2.31.0 -retrying==1.3.3 -simplejson==3.16.0 -six==1.12.0 -tabulate==0.8.5 -tqdm==4.40.0 -Werkzeug==0.16.0 -wrapt==1.11.2 -zipp==0.6.0 diff --git a/settings.yaml b/settings.yaml new file mode 100644 index 0000000..e8dac13 --- /dev/null +++ b/settings.yaml @@ -0,0 +1,35 @@ +--- +default: + AWS_ARN_PARTITION: aws + AWS_REGION: us-east-1 + AWS_ROLENAME: Aardvark + SQLALCHEMY_DATABASE_URI: "sqlite:///:memory:" + SQLALCHEMY_TRACK_MODIFICATIONS: false + UPDATER_NUM_THREADS: 1 + SWAG: + bucket: "swag-bucket" + filter: "" + service_enabled_requirement: "" + opts: + swag.schema_version: 2 + swag.type: "s3" + swag.bucket_name: "swag-bucket" + swag.data_file: "v2/accounts.json" + swag.region: "us-east-1" +testing: + AWS_ARN_PARTITION: test + AWS_REGION: test + AWS_ROLENAME: test + SQLALCHEMY_DATABASE_URI: "sqlite:///:memory:" + SQLALCHEMY_TRACK_MODIFICATIONS: false + UPDATER_NUM_THREADS: 1 + SWAG: + bucket: "swag-bucket" + filter: "mock swag filter" + service_enabled_requirement: "glowcloud" + opts: + swag.schema_version: 2 + swag.type: "s3" + swag.bucket_name: "swag-bucket" + swag.data_file: "v2/accounts.json" + swag.region: "us-east-1" diff --git a/setup.py b/setup.py index a373418..e69de29 100644 --- a/setup.py +++ b/setup.py @@ -1,20 +0,0 @@ -from setuptools import setup - - -setup( - name="aardvark", - author="Patrick Kelley, Travis McPeak, Patrick Sanders", - author_email="aardvark-maintainers@netflix.com", - url="https://github.com/Netflix-Skunkworks/aardvark", - setup_requires="setupmeta", - versioning="dev", - extras_require={ - 'tests': ['pexpect>=4.2.1'], - }, - entry_points={ - 'console_scripts': [ - 'aardvark = aardvark.manage:main', - ], - }, - python_requires=">=3.8,<3.10", -) diff --git a/test/test_config.py b/test/test_config.py deleted file mode 100644 index 3c80fa9..0000000 --- a/test/test_config.py +++ /dev/null @@ -1,567 +0,0 @@ -'''Test cases for the manage.config() function. - -We test the following for the configurable parameters: -- Exactly the expected parameters appear in the config file. -- The expected parameters have the expected values in the config file. - - -In the event of test failures it may be helpful to examine the command -used to launch aardvark config, any responses to prompts and the -resulting configuration file. To this end, these artifacts will be -saved in locations controlled by the environment variable -SWAG_CONFIG_TEST_ARCHIVE_DIR. Set this environment variable to the -absolute path to the directory to use to archive test artifacts. The -default is '/tmp'. - -By default, artifacts will only be saved in the event of test -failures. To force archiving of test artifacts regardless of test -status, set the SWAG_CONFIG_TEST_ALWAYS_ARCHIVE to 1 (or anything -truthy). - -Archiving occurs in tearDown(). The config file, the command line call -to "aardvark config" and a transcript of the command line interaction -will be saved in the archive directory in files named "command.*" and -"config.py.*". The wildcard will be replaced by default with the name -of the "test_*" function that was running when the failure occcured. - - -The command lines used to call aardvark config in each test case will -always be archived. Set the SWAG_CONFIG_TEST_COMMAND_ARCHIVE_DIR -environment variable to the absolute path to the directory to use to -archive these commands. The default is to use the same value as -SWAG_CONFIG_TEST_ARCHIVE_DIR. This will create a record of all the -command lines executed in execution of these test cases, in files -names "commands.[TestClassName]". - -''' - -#adding for py3 support -from __future__ import absolute_import - -import inspect -import os -import shutil -import tempfile - -import unittest - -from aardvark import manage -import pexpect - -# These are fast command line script interactions, eight seconds is forever. -EXPECT_TIMEOUT = 8 - -CONFIG_FILENAME = 'config.py' - -ALWAYS_ARCHIVE = os.environ.get('SWAG_CONFIG_TEST_ALWAYS_ARCHIVE') - -# Locations where we will archive test artifacts. -DEFAULT_ARTIFACT_ARCHIVE_DIR = '/tmp' -ARTIFACT_ARCHIVE_DIR = ( - os.environ.get('SWAG_CONFIG_TEST_ARCHIVE_DIR') or - DEFAULT_ARTIFACT_ARCHIVE_DIR - ) -COMMAND_ARCHIVE_DIR = ( - os.environ.get('SWAG_CONFIG_TEST_COMMAND_ARCHIVE_DIR') or - ARTIFACT_ARCHIVE_DIR - ) - -DEFAULT_LOCALDB_FILENAME = 'aardvark.db' -DEFAULT_AARDVARK_ROLE = 'Aardvark' -DEFAULT_NUM_THREADS = 5 - -# Specification of option names, default values, methods of extracting -# from config file, etc. The keys here are what we use as the 'handle' -# for each configurable option throughout this test file. -CONFIG_OPTIONS = { - 'swag_bucket': { - 'short': '-b', - 'long': '--swag-bucket', - 'config_key': 'SWAG_OPTS', - 'config_prompt': r'(?i).*SWAG.*BUCKET.*:', - 'getval': lambda x: x.get('swag.bucket_name') if x else None, - 'default': manage.DEFAULT_SWAG_BUCKET - }, - 'aardvark_role': { - 'short': '-a', - 'long': '--aardvark-role', - 'config_key': 'ROLENAME', - 'config_prompt': r'(?i).*ROLE.*NAME.*:', - 'getval': lambda x: x, - 'default': manage.DEFAULT_AARDVARK_ROLE - }, - 'db_uri': { - 'short': '-d', - 'long': '--db-uri', - 'config_key': 'SQLALCHEMY_DATABASE_URI', - 'getval': lambda x: x, - 'config_prompt': r'(?i).*DATABASE.*URI.*:', - 'default': None # need to be in tmpdir. - }, - 'num_threads': { - 'short': None, - 'long': '--num-threads', - 'config_key': 'NUM_THREADS', - 'config_prompt': r'(?i).*THREADS.*:', - 'getval': lambda x: x, - 'default': manage.DEFAULT_NUM_THREADS - }, - } - -# Syntax sugar for getting default parameters for each option. Note -# that we reset the db_uri value after we change the working directory -# in setUpClass(). -DEFAULT_PARAMETERS = dict([ - (k, v['default']) for k, v in CONFIG_OPTIONS.items() - ]) -DEFAULT_PARAMETERS_NO_SWAG = dict(DEFAULT_PARAMETERS, **{'swag_bucket': None}) - -# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -# Uncomment to show lower level logging statements. -# import logging -# logger = logging.getLogger() -# logger.setLevel(logging.DEBUG) -# shandler = logging.StreamHandler() -# shandler.setLevel(logging.INFO) # Pick one. -# -# formatter = logging.Formatter( -# '%(asctime)s - %(name)s - %(levelname)s - %(message)s' -# ) -# shandler.setFormatter(formatter) -# logger.addHandler(shandler) - - -# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -def default_db_uri(): - '''Return the default db_uri value at runtime.''' - return '{localdb}:///{path}/{filename}'.format( - localdb=manage.LOCALDB, - path=os.getcwd(), - filename=manage.DEFAULT_LOCALDB_FILENAME - ) - - -# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -def get_config_option_string(cmdline_option_spec, short_flags=True): - '''Construct the options string for a call to aardvark config.''' - - option_substrings = [] - - for param, value in cmdline_option_spec.items(): - flag = ( - CONFIG_OPTIONS[param]['short'] - if short_flags and CONFIG_OPTIONS[param]['short'] - else CONFIG_OPTIONS[param]['long'] - ) - option_substrings.append('{} {}'.format(flag, value)) - - return ' '.join(option_substrings) - - -# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -def load_configfile(cmdline_option_spec): - '''Evaluate the config values for the fields in cmdline_option_spec.''' - - all_config = {} - with open(CONFIG_FILENAME) as in_file: - exec(in_file.read(), all_config) - # print all_config.keys() - found_config = dict([ - (k, v['getval'](all_config.get(v['config_key']))) - for (k, v) in CONFIG_OPTIONS.items() - ]) - return found_config - - -# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -def get_expected_config(option_spec): - '''Return a dict with the values that should be set by a config file.''' - - include_swag = ('swag_bucket' in option_spec) - default_parameters = ( - DEFAULT_PARAMETERS if include_swag else DEFAULT_PARAMETERS_NO_SWAG - ) - expected_config = dict(default_parameters) - expected_config.update(option_spec) - - return expected_config - - -# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -class TestConfigBase(unittest.TestCase): - '''Base class for config test cases.''' - - # Throughout, the dicts cmdline_option_spec and config_option_spec - # are defined with keys matching the keys in CONFIG_SPEC and the - # values defining the value for the corresponding parameter, to be - # delivered via a command line parameter to 'aardvark config' or - # via entry after the appropriate prompt interactively. - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - @classmethod - def setUpClass(cls): - '''Test case class common fixture setup.''' - - cls.tmpdir = tempfile.mkdtemp() - cls.original_working_dir = os.getcwd() - os.chdir(cls.tmpdir) - - cls.commands_issued = [] - - # These depend on the current working directory set above. - CONFIG_OPTIONS['db_uri']['default'] = default_db_uri() - DEFAULT_PARAMETERS['db_uri'] = CONFIG_OPTIONS['db_uri']['default'] - DEFAULT_PARAMETERS_NO_SWAG['db_uri'] = ( - CONFIG_OPTIONS['db_uri']['default'] - ) - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - @classmethod - def tearDownClass(cls): - '''Test case class common fixture teardown.''' - - os.chdir(cls.original_working_dir) - cls.clean_tmpdir() - os.rmdir(cls.tmpdir) - - command_archive_filename = '.'.join(['commands', cls.__name__]) - command_archive_path = os.path.join( - COMMAND_ARCHIVE_DIR, command_archive_filename - ) - - with open(command_archive_path, 'w') as fptr: - fptr.write('\n'.join(cls.commands_issued) + '\n') - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - @classmethod - def clean_tmpdir(cls): - '''Remove all content from cls.tmpdir.''' - for root, dirs, files in os.walk(cls.tmpdir, topdown=False): - for name in files: - os.remove(os.path.join(root, name)) - for name in dirs: - os.rmdir(os.path.join(root, name)) - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - def setUp(self): - '''Test case common fixture setup.''' - self.clean_tmpdir() - self.assertFalse(os.path.exists(CONFIG_FILENAME)) - self.last_transcript = [] - self.archive_case_artifacts_as = None - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - def tearDown(self): - '''Test case common fixture teardown.''' - - # Archive the last command and config file created, if indicated. - if self.archive_case_artifacts_as: - - command_archive_path = self.archive_path( - 'command', self.archive_case_artifacts_as - ) - config_archive_path = self.archive_path( - CONFIG_FILENAME, self.archive_case_artifacts_as - ) - - with open(command_archive_path, 'w') as fptr: - fptr.write(self.last_config_command + '\n') - if self.last_transcript: - fptr.write( - '\n'.join( - map(lambda x: str(x), self.last_transcript) - ) + '\n' - ) - - if os.path.exists(CONFIG_FILENAME): - shutil.copyfile(CONFIG_FILENAME, config_archive_path) - else: - with open(config_archive_path, 'w') as fptr: - fptr.write( - '(no {} file found in {})\n'.format( - CONFIG_FILENAME, os.getcwd() - ) - ) - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - def archive_path(self, filename, suffix): - '''Return the path to an archive file.''' - archive_filename = '.'.join([filename, suffix]) - archive_path = os.path.join(ARTIFACT_ARCHIVE_DIR, archive_filename) - return archive_path - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - def call_aardvark_config( - self, - cmdline_option_spec=None, - input_option_spec=None, - prompt=True, - short_flags=False - ): - '''Call aardvark config and interact as necessary.''' - - cmdline_option_spec = cmdline_option_spec or {} - input_option_spec = input_option_spec or {} - - command = 'aardvark config' + ('' if prompt else ' --no-prompt') - self.last_config_command = '{} {}'.format( - command, - get_config_option_string( - cmdline_option_spec, short_flags=short_flags - ) - ) - - self.commands_issued.append(self.last_config_command) - spawn_config = pexpect.spawn(self.last_config_command) - - self.conduct_config_prompt_sequence( - spawn_config, input_option_spec - ) - - # If we didn't wrap up the session, something's amiss. - self.assertFalse(spawn_config.isalive()) - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - def conduct_config_prompt_sequence(self, spawned, input_option_spec): - '''Carry out the steps in the config prompt sequence.''' - - # The order is all that tells us which of these match in a pexpect - # call, so we can't use a dict here. - control_prompts = [ - (pexpect.EOF, 'eof'), - (pexpect.TIMEOUT, 'timeout') - ] - config_option_prompts = [ - (v['config_prompt'], k) - for k, v in CONFIG_OPTIONS.items() - ] - expect_prompts = [ - (r'(?i).*Do you use SWAG.*:', 'use_swag'), - ] - - expect_prompts.extend(config_option_prompts) - expect_prompts.extend(control_prompts) - - response_spec = input_option_spec - response_spec['use_swag'] = ( - 'y' if 'swag_bucket' in input_option_spec else 'N' - ) - - while spawned.isalive(): - - prompt_index = spawned.expect( - [x[0] for x in expect_prompts], timeout=EXPECT_TIMEOUT - ) - self.last_transcript.append(spawned.after) - - prompt_received = expect_prompts[prompt_index][1] - if prompt_received in [x[1] for x in control_prompts]: - return - - response = response_spec.get(prompt_received) - response = '' if response is None else response - spawned.sendline(str(response)) - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - def case_worker( - self, - cmdline_option_spec=None, - input_option_spec=None, - prompt=False, - short_flags=False, - expect_config_file=True, - archive_as=None - ): - '''Carry out common test steps. - - Parameters: - - cmdline_option_spec (dict or None): - A dictionary specifying the options and values to pass - to aardvark config as command line flags. - - input_option_spec (dict or None): - A dictionary specifying the options and values to pass - to aardvark config as prompted interactive input. - - prompt (bool, default: False): - If False, set the --no-prompt option when calling - aardvark config. - - short_flags (bool, default: False): - If True, set the --short-flags option when calling - aardvark config. - - expect_config_file (bool, default: True): - If True, test for the presence and correctness of a - config file after calling aardvark config. If false, - test for the absence of a config file. - - archive_as (str or None): - The "unique" string to use when constructing a - filename for archiving artifacts of this test case. If - None, the name of the nearest calling function in the - call stack named "test_*" will be used, if one can be - found; otherwise a possibly non-unique string will be - used. - - ''' - - cmdline_option_spec = cmdline_option_spec or {} - input_option_spec = input_option_spec or {} - # Combined, for validation. - option_spec = dict(cmdline_option_spec, **input_option_spec) - - if not archive_as: - # Get the calling test case function's name, for - # archiving. We'll take the first caller in the stack - # whose name starts with 'test_'. - caller_names = [ - inspect.getframeinfo(frame[0]).function - for frame in inspect.stack() - if inspect.getframeinfo(frame[0]).function.startswith('test_') - ] - archive_as = caller_names[0] if caller_names else 'unknown_test' - - # Turn on failed-case archive. - self.archive_case_artifacts_as = archive_as - - self.assertFalse(os.path.exists(CONFIG_FILENAME)) - self.call_aardvark_config( - cmdline_option_spec=cmdline_option_spec, - input_option_spec=input_option_spec, - prompt=prompt, - short_flags=short_flags - ) - - if expect_config_file: - self.assertTrue(os.path.exists(CONFIG_FILENAME)) - - found_config = load_configfile(cmdline_option_spec) - expected_config = get_expected_config(option_spec) - - self.assertCountEqual(expected_config.keys(), found_config.keys()) - for k, v in found_config.items(): - self.assertEqual((k, v), (k, expected_config[k])) - - else: - self.assertFalse(os.path.exists(CONFIG_FILENAME)) - - # Turn off failed-case archive unless we're forcing archiving. - if not ALWAYS_ARCHIVE: - self.archive_case_artifacts_as = None - - -# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -class TestConfigNoPrompt(TestConfigBase): - '''Test cases for config --no-prompt.''' - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - def test_no_prompt_defaults(self): - '''Test with no-prompt and all default arguments.''' - - self.case_worker() - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - def test_no_prompt_all_parameters(self): - '''Test with no-prompt and all parameters.''' - - cmdline_option_spec = { - 'swag_bucket': 'bucket_123', - 'aardvark_role': 'role_123', - 'db_uri': 'db_uri_123', - 'num_threads': 4 - } - - self.case_worker( - cmdline_option_spec=cmdline_option_spec, - ) - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - def test_no_prompt_all_parameters_short(self): - '''Test with no-prompt and short parameters.''' - - cmdline_option_spec = { - 'swag_bucket': 'bucket_123', - 'aardvark_role': 'role_123', - 'db_uri': 'db_uri_123', - 'num_threads': 4 - } - - self.case_worker( - cmdline_option_spec=cmdline_option_spec, - short_flags=True - ) - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - def test_no_prompt_no_swag(self): - '''Test with no-prompt and all non-swag parameters.''' - - cmdline_option_spec = { - 'aardvark_role': 'role_123', - 'db_uri': 'db_uri_123', - 'num_threads': 4 - } - - self.case_worker( - cmdline_option_spec=cmdline_option_spec, - ) - - -# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -class TestConfigPrompt(TestConfigBase): - '''Test cases for config with prompting.''' - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - def test_prompted_defaults(self): - '''Test with no parameters specified.''' - - self.case_worker( - prompt=True - ) - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - def test_prompted_all_cmdline_parameters(self): - '''Test with all parameters passed as options.''' - - cmdline_option_spec = { - 'swag_bucket': 'bucket_123', - 'aardvark_role': 'role_123', - 'db_uri': 'db_uri_123', - 'num_threads': 4 - } - - self.case_worker( - cmdline_option_spec=cmdline_option_spec, - prompt=True - ) - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - def test_prompted_no_swag(self): - '''Test with all non-swag parameters interactively.''' - - input_option_spec = { - 'aardvark_role': 'role_123', - 'db_uri': 'db_uri_123', - 'num_threads': 4 - } - - self.case_worker( - input_option_spec=input_option_spec, - prompt=True - ) - -# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -# Define test suites. -# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -load_case = unittest.TestLoader().loadTestsFromTestCase -all_suites = { - 'testconfignoprompt': load_case(TestConfigNoPrompt), - 'testconfigprompt': load_case(TestConfigPrompt) - } - -master_suite = unittest.TestSuite(all_suites.values()) - -# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -if __name__ == '__main__': - unittest.main() diff --git a/test/test_docker.py b/test/test_docker.py deleted file mode 100644 index 80f72fe..0000000 --- a/test/test_docker.py +++ /dev/null @@ -1,979 +0,0 @@ -''' -Test cases for docker container creation. -''' - -#adding for py3 support -from __future__ import absolute_import - -import logging -import os -import random -import re -import shutil -import tempfile - -import unittest - -import pexpect - - -# Configure logging. Troubleshooting the pexpect interactions in -# particular needs a lot of tracing. -FILENAME = os.path.split(__file__)[-1] -shandler = logging.StreamHandler() -sformatter = logging.Formatter( - '%(asctime)s - %(name)s - %(levelname)s - %(message)s' - ) -shandler.setFormatter(sformatter) -logger = logging.getLogger() -logger.setLevel(logging.WARNING) -# logger.setLevel(logging.INFO) -# logger.setLevel(logging.DEBUG) -logger.addHandler(shandler) - -# These need to be removed from the test run environment if present -# before configuring the environment for the pexpect call to the make -# process. -BUILD_CONTROL_ENV_VARIABLES = [ - 'AARDVARK_ROLE', - 'AARDVARK_DB_URI', - 'SWAG_BUCKET', - 'AARDVARK_IMAGES_TAG' - ] - -# We'll copy the docker directory contents to the temporary working -# directory each time. -DOCKER_PATH = os.path.join( - os.path.dirname(os.path.dirname(os.path.realpath(__file__))), - "docker" - ) - -# An index of possible docker images and their pseudo-artifacts. -# TODO: Preserved for reference purposes; the DOCKER_IMAGES variable -# isn't used as of this commenting. -DOCKER_IMAGES = { - 'aardvark-base': 'aardvark-base-docker-build', - 'aardvark-data-volume': 'aardvark-data-docker-build', - 'aardvark-data-volume': 'aardvark-data-docker-run', - 'aardvark-apiserver': 'aardvark-apiserver-docker-build', - 'aardvark-collector': 'aardvark-collector-docker-build', - } - -# The subdirectory of the working directory where pseudo-artifacts -# are created. -ARTIFACT_DIRECTORY = 'artifacts' - -# A few constants that are checked when testing container -# config settings. -CONTAINER_CONFIG_PATH = '/etc/aardvark/config.py' - -EXPECTED_SQLITE_DB_URI = 'sqlite:////usr/share/aardvark-data/aardvark.db' -EXPECTED_SQL_TRACK_MODS = False - -# Making targets can take some time, and depends on network connection -# speed. Set the NETWORK_SPEED_FACTOR environment variable to increase -# if necessary. -PEXPECT_TIMEOUTS = { - 'default': 30, - 'container_command': 2, - 'aardvark': 30, - 'aardvark-all': 240, - 'aardvark-base': 180, - 'aardvark-sqlite': 300, - } -NETWORK_SPEED_FACTOR = 1.0 - -# The key that will uniquely identify a docker construct of the -# indicated type. -UID_KEY = { - 'image': 'id', - 'container': 'id', - 'volume': 'name' - } - -# A message for unittest.skipIf. -SKIP_NO_IMAGES_TAG_MSG = ( - "Remove aardvark-*:latest images and the aardvark-data volume as desired" - " and set the RUN_AARDVARK_DOCKER_TESTS_NO_IMAGES_TAG environment variable" - " to 1 to enable tests that don't set the AARDVARK_IMAGES_TAG." - ) - - -# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -def sqlite_db_uri(path): - '''Return the default db_uri value at runtime.''' - return 'sqlite:///{}/aardvark.db'.format(path) - - -# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -def copy_recipes(src, dest): - '''Copy Make- and Dockerfiles from src to dest''' - [ - shutil.copyfile(os.path.join(src, f), os.path.join(dest, f)) - for f in os.listdir(src) - if (f.startswith("Dockerfile") or f.startswith("Makefile")) - ] - - -# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -def interact( - spawn, - command, - response_filter=None, - timeout=PEXPECT_TIMEOUTS['container_command'] - ): # pylint: disable=bad-continuation - '''Send command to spawn and return filtered response.''' - - def default_response_filter(response): - '''Define a default method for the response filter.''' - return response - - if not response_filter: - response_filter = default_response_filter - - shell_prompt = r'root@\w+:[\S]+# ' - expect_prompt = [shell_prompt, pexpect.EOF, pexpect.TIMEOUT] - - eof_position = expect_prompt.index(pexpect.EOF) - timeout_position = expect_prompt.index(pexpect.TIMEOUT) - - eof_msg_template = 'unexpected EOF in pexpect connection to {}' - responses = [] - result = None - - logger.info('COMMAND:\t%s', command) - spawn.sendline(command) - - # Read until we time out, indicating no more lines are pending; - # watch for the response that has our command echoed on one line - # and our response on the next. - while(True): - - prompt_index = spawn.expect(expect_prompt, timeout=timeout) - - import json - if prompt_index < timeout_position: - responses.append( - spawn.before.decode('utf-8').replace('\r\n', '\n') - ) - # print '--------' - logger.debug(' BEFORE: %s', json.dumps(responses[-1])) - logger.debug( - ' AFTER: %s', - json.dumps(spawn.after.decode("utf-8").replace('\r\n', '\n')) - ) - if command and responses[-1].startswith(command + '\n'): - result = response_filter( - responses[-1].replace(command + '\n', '', 1).strip() - ) - logger.info(" result found: %s", result) - # print '--------\n' - - elif prompt_index == timeout_position: - logger.debug(' TIMEOUT') - break - - elif prompt_index == eof_position: - raise RuntimeError(eof_msg_template.format(command)) - - logger.debug('Responses:\n%s', '\n'.join(responses)) - logger.debug('') - - return result - - -# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -class TestDockerBase(unittest.TestCase): - '''Base class for docker container construction test cases.''' - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - @classmethod - def setUpClass(cls): - '''Test case class common fixture setup.''' - - # These will record the constructs at class start; we'll log - # any discrepancy at class termination. - cls.constructs = {'original': {}, 'current': {}} - - for construct_type in ['image', 'container', 'volume']: - cls.constructs['original'][construct_type] = ( - cls.get_docker_constructs(construct_type) - ) - cls.constructs['current'][construct_type] = list( - cls.constructs['original'][construct_type] - ) - - cls.tmpdir = tempfile.mkdtemp() - - cls.original_working_dir = os.getcwd() - os.chdir(cls.tmpdir) - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - @classmethod - def tearDownClass(cls): - '''Test case class common fixture teardown.''' - - os.chdir(cls.original_working_dir) - cls.clean_tmpdir() - os.rmdir(cls.tmpdir) - - cls.warn_on_hanging_constructs() - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - @classmethod - def warn_on_hanging_constructs(cls): - '''Log warnings for undeleted docker constructs created by tests.''' - - construct_fields = { - 'container': ['id', 'name'], - 'image': ['id', 'name', 'tag'], - 'volume': ['name'], - } - - hanging_constructs = {} - - for construct_type in construct_fields.keys(): - - hanging_constructs[construct_type] = [ - c for c in cls.constructs['current'][construct_type] - if c[UID_KEY[construct_type]] not in [ - x[UID_KEY[construct_type]] - for x in cls.constructs['original'][construct_type] - ] - ] - - for construct in hanging_constructs[construct_type]: - logger.warning( - 'a failed test case left behind %s:\t%s', - construct_type, - '\t'.join([ - construct[field] - for field in construct_fields[construct_type] - ]) - ) - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - @classmethod - def clean_tmpdir(cls): - '''Remove all content from cls.tmpdir.''' - - for root, dirs, files in os.walk(cls.tmpdir, topdown=False): - - for name in files: - - path = os.path.join(root, name) - os.remove(path) - - for name in dirs: - - path = os.path.join(root, name) - if os.path.islink(path): - os.remove(path) - else: - os.rmdir(path) - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - def setUp(self): - '''Test case common fixture setup.''' - - self.clean_tmpdir() - - # This copies the Makefile and the Dockerfiles. - copy_recipes(DOCKER_PATH, self.tmpdir) - - # We track to make sure we don't disturb any files that were - # present in the working directory. - self.initial_contents = os.listdir(self.tmpdir) - - # If true, we will stop containers and delete images created - # during each test case. - self.delete_artifacts = True - - # A (almost certainly) unique string for unique test case - # artifact names. - self.testcase_tag = ( - 'test{:08X}'.format(random.randrange(16 ** 8)).lower() - ) - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - def tearDown(self): - '''Test case common fixture teardown.''' - - logger.info("======= tearDown: %s", self.delete_artifacts) - - self.log_docker_constructs() - - new_containers = self.new_constructs('container') - new_images = self.new_constructs('image') - new_volumes = self.new_constructs('volume') - - if self.delete_artifacts: - # Every case should clean up its docker images and containers. - - for container in new_containers: - # These should have been launched with the --rm flag, - # so they should be removed once stopped. - logger.info("REMOVING %s", container['id']) - pexpect.run('docker stop {}'.format(container['id'])) - - for image in new_images: - logger.info("REMOVING %s", image['id']) - pexpect.run('docker rmi {}'.format(image['id'])) - - for volume in new_volumes: - logger.info("REMOVING %s", volume['name']) - pexpect.run('docker volume rm {}'.format(volume['name'])) - - else: - # We'll leave behind any new docker constructs, so we need - # to update the "original docker volumes". - self.constructs['current']['container'].extend(new_containers) - self.constructs['current']['image'].extend(new_images) - self.constructs['current']['volume'].extend(new_volumes) - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - def log_docker_constructs(self, **kwargs): - '''Log docker construct status.''' - - def log_listing(caption, records, *fields): - '''Helper method for log message construction.''' - - return '{}:\n '.format(caption) + '\n '.join([ - '\t'.join( - [record[field] for field in fields] - ) for record in records - ]) - - construct_fields = { - 'container': ['id', 'name'], - 'image': ['id', 'name', 'tag'], - 'volume': ['name'], - } - - for construct_type in construct_fields.keys(): - - logger.info("------------%ss:", construct_type) - logger.info(log_listing( - 'Original', - self.constructs['original'][construct_type], - *construct_fields[construct_type] - )) - logger.info(log_listing( - 'Current', - self.get_docker_constructs(construct_type), - *construct_fields[construct_type] - )) - logger.info(log_listing( - 'New', - self.new_constructs(construct_type), - *construct_fields[construct_type] - )) - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - @classmethod - def get_docker_constructs(cls, construct_type, *expected_constructs): - '''Get a list of images, containers or volumes.''' - return { - 'image': cls.get_docker_image_list, - 'container': cls.get_docker_container_list, - 'volume': cls.get_docker_volume_list - }[construct_type](*expected_constructs) - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - @staticmethod - def get_docker_image_list(*expected_images): - '''Get the output from "docker image ls" for specified image names.''' - - image_listing_pattern = ( - r'(?P[^\s]+)\s+' - r'(?P[^\s]+)\s+' - r'(?P[0-9a-f]+)\s+' - r'(?P.+ago)\s+' - r'(?P[^\s]+)' - r'\s*$' - ) - image_listing_re = re.compile(image_listing_pattern) - - docker_images_response = pexpect.run('docker image ls') - - image_list = [] - expected_image_nametag_pairs = [ - (x.split(':') + ['latest'])[0:2] for x in expected_images - ] if expected_images else None - - docker_images_response_l = docker_images_response.decode('utf-8').split('\n') - - for line in docker_images_response_l: - match = image_listing_re.match(line) - if ( - match and ( - not expected_images or [ - match.groupdict()['name'], match.groupdict()['tag'] - ] in expected_image_nametag_pairs - ) - ): - image_list.append(match.groupdict()) - - return image_list - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - @staticmethod - def get_docker_container_list(*expected_containers): - '''Get the output from "docker ps -a" for specified container names.''' - - container_listing_pattern = ( - r'(?P[0-9a-f]+)\s+' - r'(?P[^\s]+)\s+' - r'(?P"[^"]+")\s+' - r'(?P.+ago)\s+' - r'(?P(Created|Exited.*ago|Up \d+ \S+))\s+' - r'(?P[^\s]+)?\s+' - r'(?P[a-z]+_[a-z]+)' - # r'\s*$' - ) - container_listing_re = re.compile(container_listing_pattern) - - docker_containers_response = pexpect.run('docker ps -a') - - container_list = [] - # expected_container_nametag_pairs = [ - # (x.split(':') + ['latest'])[0:2] for x in expected_containers - # ] if expected_containers else [] - - docker_containers_response_l = docker_containers_response.decode('utf-8').split('\n') - - for line in docker_containers_response_l: - match = container_listing_re.match(line) - if match: - container_list.append(match.groupdict()) - - return container_list - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - @staticmethod - def get_docker_volume_list(*expected_volumes): - '''Get the output from "docker volume ls" for specified volumes.''' - - volume_listing_pattern = ( - r'(?P\S+)\s+' - r'(?P\S+)' - # r'\s*$' - ) - volume_listing_re = re.compile(volume_listing_pattern) - - docker_volumes_response = pexpect.run('docker volume ls') - - docker_volumes_response_l = docker_volumes_response.decode('utf-8').split('\n') - - volume_list = [] - - for line in docker_volumes_response_l: - match = volume_listing_re.match(line) - if match: - volume_list.append(match.groupdict()) - - return volume_list - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - def require_filenames_in_directory(self, patterns=None, directory='.'): - '''Check that filenames are found in the indicated directory. - - Each pattern in the list of patterns must match exactly one - file in the indicated directory. - - ''' - - failure_string_template = ( - 'Unexpected or missing filename match result in {}' - ' for pattern r\'{}\':\n{}\n' - 'Directory contents:\n{}' - ) - - if patterns: - self.assertTrue(os.path.exists(directory)) - all_filenames = os.listdir(directory) - for pattern in patterns: - matching_files = [ - x for x in all_filenames - if re.match(pattern, x) - ] - self.assertTrue( - len(matching_files) == 1, - failure_string_template.format( - directory, - pattern, - matching_files, - os.listdir(directory) - ) - ) - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - def new_constructs(self, construct_type): - '''Get a list of images, containers or volumes not already recorded.''' - return [ - c for c in self.get_docker_constructs(construct_type) - if c[UID_KEY[construct_type]] not in [ - o[UID_KEY[construct_type]] - for o in self.constructs['original'][construct_type] - ] - ] - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - def get_container_details(self, image): - '''Run a docker container shell and retrieve several details.''' - - # (detail name, command, result filter) for extracting details - # from container command lines. - shell_commands = ( - ('pwd', 'pwd', None), - ('config_file', 'ls {}'.format(CONTAINER_CONFIG_PATH), None), - ('config_contents', 'cat {}'.format(CONTAINER_CONFIG_PATH), None), - ) - - command = "docker run --rm -it {} bash".format(image) - logger.info('IMAGE: %s', image) - logger.info('CONTAINER LAUNCH COMMAND: %s', command) - spawn = pexpect.spawn(command) - - container_details = {} - - for field, shell_command, response_filter in shell_commands: - container_details[field] = interact( - spawn, shell_command, response_filter - ) - - # Exit the container. - spawn.sendcontrol('d') - - # "Expand" the config records if we found a config file. - if container_details['config_file'] == CONTAINER_CONFIG_PATH: - try: - exec(container_details['config_contents'], container_details) - except SyntaxError: - pass - # The '__builtins__' are noise: - if '__builtins__' in container_details: - del container_details['__builtins__'] - - return container_details - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - def check_container_details(self, image, expected_details=None): - '''Validate docker container details.''' - - # A helper method to retrieve container details of interest, - # necessary for configuration items that aren't simple config - # file values; e.g. "SWAG_BUCKET". - # As written, this won't allow tests to catch *missing* - # entries if the *expected* value is None. - def get_detail_value(data, detail): - '''A helper method to retrieve container details of interest.''' - - def default_get_detail_value_method(data): - '''Define a default method for get_detail_value.''' - return data.get(detail) - - method = { - 'SWAG_BUCKET': lambda data: ( - data.get('SWAG_OPTS') or {} - ).get('swag.bucket_name') - }.get(detail) - - if method is None: - method = default_get_detail_value_method - - return method(data) - - image_name, image_tag = image.split(':') - - expected_details = expected_details or {} - assert '_common' in expected_details - - expected_container_details = dict(expected_details['_common']) - if image_name in expected_details: - expected_container_details.update(expected_details[image_name]) - - container_details = self.get_container_details(image) - - clean_comparison = { - 'image': image, - 'missing': {}, - 'incorrect': {} - } - comparison = { - 'image': image, - 'missing': {}, - 'incorrect': {} - } - - for k, v in expected_container_details.items(): - - logger.info(' -- checking configuration item %s...', k) - logger.info(' expected: %s', v) - - actual = get_detail_value(container_details, k) - - # if k not in container_details: - # TODO: Note that this fails if we're *expecting* None. - if actual is None: - comparison['missing'][k] = v - logger.info(' actual: -') - - elif actual != v: - comparison['incorrect'][k] = { - 'expected': v, - 'actual': actual - } - logger.info(comparison['incorrect'][k]) - - else: - logger.info(' actual: %s', actual) - - logger.info("comparing %s", image) - logger.info(comparison) - logger.info('') - self.assertEqual(comparison, clean_comparison) - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - def case_worker( - self, - target, - expected_artifacts=None, - expected_docker_images=None, - expected_details=None, - expect_aardvark=True, - add_env=None, - set_images_tag=True, - ): - '''Carry out common test steps. - ''' - - logger.info(' -' * 8 + ' working case: %s' + ' -' * 8, target) - - # Unless we finish without a failure Exception, tell tearDown - # not to clean up artifacts. We reset this below. - self.delete_artifacts = False - - # A unique string to add to certain test case artifact names - # to avoid clobbering/colliding. - if set_images_tag: - images_tag = self.testcase_tag - logger.info('Test case images tag is %s', images_tag) - else: - images_tag = 'latest' - logger.info('Default test case images tag will be "latest"') - - expected_artifacts = expected_artifacts or [] - expected_docker_images = expected_docker_images or [] - tagged_expected_docker_images = [ - x + ':{}'.format(images_tag) - for x in expected_docker_images - ] - - # expected_details is a two level dict so a straightforward - # dict update isn't possible. - expected_details = expected_details or {} - expected_details['_common'] = expected_details.get('_common') or {} - expected_details['_common']['config_file'] = CONTAINER_CONFIG_PATH - - # Environment variables to add to the pexpect interaction with - # containers. - add_env = dict(add_env or {}) - if set_images_tag: - add_env = dict(add_env, AARDVARK_IMAGES_TAG=images_tag) - - # Fetch the default environment settings so we can update - # those with specific case settings. - spawn_env = dict( - map( - lambda x: x.strip().split('=', 1), - pexpect.run('env').strip().decode("utf-8").split("\n") - ) - ) - - # Remove any build control variables we inherit from the test - # environment - we want complete control over which are - # visible to the make process in the pexpect call. - spawn_env = { - k: v for (k, v) in spawn_env.items() - if k not in BUILD_CONTROL_ENV_VARIABLES - } - - command = 'make {}'.format(target) - logger.info('COMMAND: %s', command) - - # TODO: A sort of halfhearted attempt at adjusting for network - # conditions. Need some kind of not-too-slow way to check for - # network speed, say a sample download or something. - (result, exitstatus) = pexpect.run( - command, - timeout=( - PEXPECT_TIMEOUTS.get(target) or PEXPECT_TIMEOUTS['default'] * - NETWORK_SPEED_FACTOR - ), - withexitstatus=True, - env=dict(spawn_env, **add_env) - ) - - self.assertEqual( - exitstatus, 0, - 'command "{}" exited with exit status {}'.format( - command, exitstatus - ) - ) - - # Sanity check - we didn't delete the Makefile or any of the - # Dockerfiles. - self.assertEqual( - [x for x in self.initial_contents if x not in os.listdir('.')], - [] - ) - - if expected_docker_images: - self.assertCountEqual( - [ - [x['name'], x['tag']] - for x in self.get_docker_image_list( - *tagged_expected_docker_images - ) - ], - [x.split(':') for x in tagged_expected_docker_images] - ) - - for image in tagged_expected_docker_images: - self.check_container_details(image, expected_details) - - if expect_aardvark: - self.require_filenames_in_directory([r'aardvark$']) - - if expected_artifacts: - self.require_filenames_in_directory( - expected_artifacts, - directory=ARTIFACT_DIRECTORY - ) - - # We made it through, tell tearDown we can clean up artifacts. - self.delete_artifacts = True - - -# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -class TestDockerContainerConstruction(TestDockerBase): - '''Test cases for docker container construction.''' - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - def test_make_aardvark(self): - '''Test "make aardvark".''' - - self.case_worker( - target='aardvark', - expect_aardvark=True, - ) - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - def test_make_aardvark_base(self): - '''Test "make aardvark-base".''' - - self.case_worker( - target='aardvark-base', - expected_docker_images=[ - 'aardvark-base' - ], - expected_details={ - '_common': { - 'pwd': '/etc/aardvark', - 'NUM_THREADS': 5, - 'ROLENAME': 'Aardvark', - 'SQLALCHEMY_DATABASE_URI': EXPECTED_SQLITE_DB_URI, - 'SQLALCHEMY_TRACK_MODIFICATIONS': EXPECTED_SQL_TRACK_MODS, - } - }, - expected_artifacts=[ - 'aardvark-base-docker-build' - ], - ) - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - def test_make_aardvark_base_set_env_variables(self): - '''Test "make aardvark-base" with build time environment variables.''' - - aardvark_db_uri = 'blort://blah.bleh.bloo/bing/bang/bong' - aardvark_role = 'llama' - swag_bucket = 'ponponpon' - - self.case_worker( - target='aardvark-base', - expected_docker_images=[ - 'aardvark-base' - ], - expected_details={ - '_common': { - 'pwd': '/etc/aardvark', - 'NUM_THREADS': 5, - 'ROLENAME': aardvark_role, - 'SQLALCHEMY_DATABASE_URI': aardvark_db_uri, - 'SQLALCHEMY_TRACK_MODIFICATIONS': EXPECTED_SQL_TRACK_MODS, - 'SWAG_BUCKET': swag_bucket, - } - }, - expected_artifacts=[ - 'aardvark-base-docker-build' - ], - add_env={ - 'AARDVARK_DB_URI': aardvark_db_uri, - 'AARDVARK_ROLE': aardvark_role, - 'SWAG_BUCKET': swag_bucket, - } - ) - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - def test_make_aardvark_all(self): - '''Test "make aardvark-all".''' - - self.case_worker( - target='aardvark-all', - expected_docker_images=[ - 'aardvark-base', - 'aardvark-collector', - 'aardvark-apiserver', - ], - expected_details={ - '_common': { - 'pwd': '/usr/share/aardvark-data', - 'NUM_THREADS': 5, - 'ROLENAME': 'Aardvark', - 'SQLALCHEMY_DATABASE_URI': EXPECTED_SQLITE_DB_URI, - 'SQLALCHEMY_TRACK_MODIFICATIONS': EXPECTED_SQL_TRACK_MODS, - }, - 'aardvark-base': { - 'pwd': '/etc/aardvark', - }, - }, - expected_artifacts=[ - 'aardvark-base-docker-build', - 'aardvark-apiserver-docker-build', - 'aardvark-collector-docker-build', - ], - ) - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - @unittest.skipIf( - not os.environ.get('RUN_AARDVARK_DOCKER_TESTS_NO_IMAGES_TAG'), - SKIP_NO_IMAGES_TAG_MSG - ) - def test_make_aardvark_all_no_images_tag(self): - '''Test "make aardvark-all" without specifying the images tag.''' - - self.case_worker( - target='aardvark-all', - expected_docker_images=[ - 'aardvark-base', - 'aardvark-collector', - 'aardvark-apiserver', - ], - expected_details={ - '_common': { - 'pwd': '/usr/share/aardvark-data', - 'NUM_THREADS': 5, - 'ROLENAME': 'Aardvark', - 'SQLALCHEMY_DATABASE_URI': EXPECTED_SQLITE_DB_URI, - 'SQLALCHEMY_TRACK_MODIFICATIONS': EXPECTED_SQL_TRACK_MODS, - }, - 'aardvark-base': { - 'pwd': '/etc/aardvark', - }, - }, - expected_artifacts=[ - 'aardvark-base-docker-build', - 'aardvark-apiserver-docker-build', - 'aardvark-collector-docker-build', - ], - set_images_tag=False - ) - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - def test_make_aardvark_sqlite(self): - '''Test "make aardvark-sqlite".''' - - self.case_worker( - target='aardvark-sqlite', - expected_docker_images=[ - 'aardvark-base', - 'aardvark-data-init', - 'aardvark-collector', - 'aardvark-apiserver', - ], - expected_details={ - '_common': { - 'pwd': '/usr/share/aardvark-data', - 'NUM_THREADS': 5, - 'ROLENAME': 'Aardvark', - 'SQLALCHEMY_DATABASE_URI': EXPECTED_SQLITE_DB_URI, - 'SQLALCHEMY_TRACK_MODIFICATIONS': EXPECTED_SQL_TRACK_MODS, - }, - 'aardvark-base': { - 'pwd': '/etc/aardvark', - }, - }, - expected_artifacts=[ - 'aardvark-base-docker-build', - 'aardvark-data-docker-build', - 'aardvark-data-docker-run', - 'aardvark-apiserver-docker-build', - 'aardvark-collector-docker-build', - ], - ) - - # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - @unittest.skipIf( - not os.environ.get('RUN_AARDVARK_DOCKER_TESTS_NO_IMAGES_TAG'), - SKIP_NO_IMAGES_TAG_MSG - ) - def test_make_aardvark_sqlite_no_images_tag(self): - '''Test "make aardvark-sqlite" without specifying the images tag.''' - - self.case_worker( - target='aardvark-sqlite', - expected_docker_images=[ - 'aardvark-base', - 'aardvark-data-init', - 'aardvark-collector', - 'aardvark-apiserver', - ], - expected_details={ - '_common': { - 'pwd': '/usr/share/aardvark-data', - 'NUM_THREADS': 5, - 'ROLENAME': 'Aardvark', - 'SQLALCHEMY_DATABASE_URI': EXPECTED_SQLITE_DB_URI, - 'SQLALCHEMY_TRACK_MODIFICATIONS': EXPECTED_SQL_TRACK_MODS, - }, - 'aardvark-base': { - 'pwd': '/etc/aardvark', - }, - }, - expected_artifacts=[ - 'aardvark-base-docker-build', - 'aardvark-data-docker-build', - 'aardvark-data-docker-run', - 'aardvark-apiserver-docker-build', - 'aardvark-collector-docker-build', - ], - set_images_tag=False - ) - - -# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -# Define test suites. -# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -load_case = unittest.TestLoader().loadTestsFromTestCase -all_suites = { - 'testdockercontainerconstruction': load_case( - TestDockerContainerConstruction - ), - } - -master_suite = unittest.TestSuite(all_suites.values()) - -# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..dddffef --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,39 @@ +import pytest +from dynaconf import Dynaconf + +import aardvark.config +from aardvark import init_logging +from aardvark.config import settings + +init_logging() + + +@pytest.fixture(scope="session", autouse=True) +def set_test_settings(): + settings.configure(FORCE_ENV_FOR_DYNACONF="testing") + + +@pytest.fixture +def temp_config_file(tmp_path): + config_path = tmp_path / "settings.yaml" + return str(config_path) + + +@pytest.fixture(autouse=True) +def patch_config(monkeypatch, tmp_path): + db_path = tmp_path / "aardvark-test.db" + db_uri = f"sqlite:///{db_path}" + + config = Dynaconf( + envvar_prefix="AARDVARK", + settings_files=[ + "tests/settings.yaml", + ], + environments=True, + ) + config.configure(FORCE_ENV_FOR_DYNACONF="testing") + config.set("sqlalchemy_database_uri", db_uri) + + # Monkeypatch the actual config object so we don't poison it for future tests + monkeypatch.setattr(aardvark.config, "settings", config) + return config diff --git a/tests/persistence/__init__.py b/tests/persistence/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/persistence/test_sqlalchemy.py b/tests/persistence/test_sqlalchemy.py new file mode 100644 index 0000000..797d480 --- /dev/null +++ b/tests/persistence/test_sqlalchemy.py @@ -0,0 +1,200 @@ +# ruff: noqa: DTZ005 +import datetime + +import pytest +from dynaconf.utils import DynaconfDict +from sqlalchemy.exc import OperationalError + +from aardvark.persistence import PersistencePlugin +from aardvark.persistence.sqlalchemy import SQLAlchemyPersistence +from aardvark.plugins import AardvarkPlugin + +TIMESTAMP = datetime.datetime.now() +ADVISOR_DATA = { + "arn:aws:iam::123456789012:role/SpongebobSquarepants": [ + { + "LastAuthenticated": TIMESTAMP - datetime.timedelta(days=45), + "ServiceName": "Krabby Patty", + "ServiceNamespace": "krbpty", + "LastAuthenticatedEntity": "arn:aws:iam::123456789012:role/SpongebobSquarepants", + "TotalAuthenticatedEntities": 1, + }, + ], + "arn:aws:iam::123456789012:role/SheldonJPlankton": [ + { + "LastAuthenticated": TIMESTAMP - datetime.timedelta(days=100), + "ServiceName": "Chum Bucket", + "ServiceNamespace": "chb", + "LastAuthenticatedEntity": "arn:aws:iam::123456789012:role/SheldonJPlankton", + "TotalAuthenticatedEntities": 1, + }, + ], +} + + +@pytest.fixture +def temp_sqlite_db_config(): + db_uri = "sqlite:///:memory:" + return DynaconfDict( + { + "sqlalchemy_database_uri": db_uri, + } + ) + + +def test_sqlalchemypersistence(): + sap = SQLAlchemyPersistence() + assert isinstance(sap, AardvarkPlugin) + assert isinstance(sap, PersistencePlugin) + assert sap.config + + +def test_sqlalchemypersistence_custom_config(): + custom_config = DynaconfDict({"test_key": "test_value"}) + custom_config["test_key"] = "test_value" + sap = SQLAlchemyPersistence(alternative_config=custom_config, initialize=False) + assert isinstance(sap, AardvarkPlugin) + assert isinstance(sap, PersistencePlugin) + assert sap.config + assert sap.config["test_key"] == "test_value" + + +def test_init_db(temp_sqlite_db_config): + sap = SQLAlchemyPersistence(alternative_config=temp_sqlite_db_config, initialize=False) + sap.init_db() + assert sap.sa_engine + assert sap.session_factory + from aardvark.persistence.sqlalchemy.models import AdvisorData, AWSIAMObject + + with sap.session_scope() as session: + session.query(AdvisorData).all() + session.query(AWSIAMObject).all() + + +def test_teardown_db(temp_sqlite_db_config): + sap = SQLAlchemyPersistence(alternative_config=temp_sqlite_db_config, initialize=False) + sap.init_db() + sap.teardown_db() + from aardvark.persistence.sqlalchemy.models import AdvisorData, AWSIAMObject + + with ( + sap.session_scope() as session, + pytest.raises(OperationalError), + ): + session.query(AdvisorData).all() + + with ( + sap.session_scope() as session, + pytest.raises(OperationalError), + ): + session.query(AWSIAMObject).all() + + +def test_create_iam_object(temp_sqlite_db_config): + sap = SQLAlchemyPersistence(alternative_config=temp_sqlite_db_config) + iam_object = sap.create_iam_object("arn:aws:iam::123456789012:role/SpongebobSquarepants", datetime.datetime.now()) + assert iam_object.id + assert iam_object.arn == "arn:aws:iam::123456789012:role/SpongebobSquarepants" + + +def test_create_or_update_advisor_data(temp_sqlite_db_config): + from aardvark.persistence.sqlalchemy.models import AdvisorData + + sap = SQLAlchemyPersistence(alternative_config=temp_sqlite_db_config) + now = datetime.datetime.now() + # 10 days ago + original_timestamp = int((now - datetime.timedelta(days=10)).timestamp() * 1000) + # 5 days ago + update_timestamp = int((now - datetime.timedelta(days=5)).timestamp() * 1000) + + # Create advisor data record + with sap.session_scope() as session: + sap.create_or_update_advisor_data( + 1, + original_timestamp, + "Aardvark Test", + "adv", + "arn:aws:iam::123456789012:role/PatrickStar", + 999, + session=session, + ) + + with sap.session_scope() as session: + record: AdvisorData = session.query(AdvisorData).filter(AdvisorData.id == 1).scalar() + + assert record + assert record.item_id == 1 + assert record.lastAuthenticated == int((now - datetime.timedelta(days=10)).timestamp() * 1000) + assert record.lastAuthenticatedEntity == "arn:aws:iam::123456789012:role/PatrickStar" + assert record.serviceName == "Aardvark Test" + assert record.serviceNamespace == "adv" + assert record.totalAuthenticatedEntities == 999 + + # Update advisor data record with new timestamp, plus update service name, last authenticated entity, and total + # authenticated entities + with sap.session_scope() as session: + sap.create_or_update_advisor_data( + 1, + update_timestamp, + "Aardvark Test v2", + "adv", + "arn:aws:iam::123456789012:role/SquidwardTentacles", + 1000, + session=session, + ) + + with sap.session_scope() as session: + record: AdvisorData = session.query(AdvisorData).filter(AdvisorData.id == 1).scalar() + + assert record + assert record.item_id == 1 + assert record.lastAuthenticated == int((now - datetime.timedelta(days=5)).timestamp() * 1000) + assert record.lastAuthenticatedEntity == "arn:aws:iam::123456789012:role/SquidwardTentacles" + assert record.serviceName == "Aardvark Test v2" + assert record.serviceNamespace == "adv" + assert record.totalAuthenticatedEntities == 1000 + + +def test_get_or_create_iam_object(temp_sqlite_db_config): + sap = SQLAlchemyPersistence(alternative_config=temp_sqlite_db_config) + + # create a new IAM object + new_object = sap.get_or_create_iam_object("arn:aws:iam::123456789012:role/SquidwardTentacles") + assert new_object.id + assert new_object.arn == "arn:aws:iam::123456789012:role/SquidwardTentacles" + object_id = new_object.id + object_arn = new_object.arn + + # make the same call and make sure we get the same entry we created before + retrieved_object = sap.get_or_create_iam_object("arn:aws:iam::123456789012:role/SquidwardTentacles") + assert retrieved_object.id + assert retrieved_object.arn == "arn:aws:iam::123456789012:role/SquidwardTentacles" + assert retrieved_object.id == object_id + assert retrieved_object.arn == object_arn + + +def test_store_role_data(temp_sqlite_db_config): + sap = SQLAlchemyPersistence(alternative_config=temp_sqlite_db_config) + sap.store_role_data(ADVISOR_DATA) + + +def test_get_role_data(temp_sqlite_db_config): + sap = SQLAlchemyPersistence(alternative_config=temp_sqlite_db_config) + sap.store_role_data(ADVISOR_DATA) + role_data = sap.get_role_data() + assert role_data["arn:aws:iam::123456789012:role/SpongebobSquarepants"] + assert len(role_data["arn:aws:iam::123456789012:role/SpongebobSquarepants"]) == 1 + assert role_data["arn:aws:iam::123456789012:role/SheldonJPlankton"] + assert len(role_data["arn:aws:iam::123456789012:role/SheldonJPlankton"]) == 1 + + +def test_get_role_data_combine(temp_sqlite_db_config): + sap = SQLAlchemyPersistence(alternative_config=temp_sqlite_db_config) + sap.store_role_data(ADVISOR_DATA) + role_data = sap.get_role_data(combine=True) + assert role_data["krbpty"] + assert role_data["krbpty"]["USED_LAST_90_DAYS"] + assert role_data["krbpty"]["serviceName"] == "Krabby Patty" + assert role_data["chb"] + assert not role_data["chb"]["USED_LAST_90_DAYS"] + assert role_data["chb"]["serviceName"] == "Chum Bucket" diff --git a/tests/retrievers/__init__.py b/tests/retrievers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/retrievers/conftest.py b/tests/retrievers/conftest.py new file mode 100644 index 0000000..3b3ffe8 --- /dev/null +++ b/tests/retrievers/conftest.py @@ -0,0 +1,49 @@ +import os +from typing import Any + +import pytest +from dynaconf import Dynaconf + +from aardvark.retrievers import RetrieverPlugin +from aardvark.retrievers.runner import RetrieverRunner + + +class RetrieverStub(RetrieverPlugin): + def __init__(self, alternative_config: Dynaconf = None): + super().__init__("retriever_stub", alternative_config=alternative_config) + + async def run(self, arn: str, data: dict[str, Any]) -> dict[str, Any]: + data["retriever_stub"] = {"success": True} + return data + + +class FailingRetriever(RetrieverPlugin): + def __init__(self, alternative_config: Dynaconf = None): + super().__init__("retriever_stub", alternative_config=alternative_config) + + async def run(self, arn: str, data: dict[str, Any]) -> dict[str, Any]: + raise Exception("Oh no! Retriever failed") # noqa + + +@pytest.fixture +def mock_retriever(): + return RetrieverStub() + + +@pytest.fixture +def mock_failing_retriever(): + return FailingRetriever() + + +@pytest.fixture +def runner(): + return RetrieverRunner() + + +@pytest.fixture +def aws_credentials(): + """Mocked AWS Credentials for moto.""" + os.environ["AWS_ACCESS_KEY_ID"] = "testing" + os.environ["AWS_SECRET_ACCESS_KEY"] = "testing" + os.environ["AWS_SECURITY_TOKEN"] = "testing" + os.environ["AWS_SESSION_TOKEN"] = "testing" diff --git a/tests/retrievers/test_access_advisor.py b/tests/retrievers/test_access_advisor.py new file mode 100644 index 0000000..2a02426 --- /dev/null +++ b/tests/retrievers/test_access_advisor.py @@ -0,0 +1,207 @@ +import datetime +from unittest.mock import MagicMock, patch + +import pytest + +from aardvark.exceptions import AccessAdvisorError +from aardvark.retrievers.access_advisor import AccessAdvisorRetriever + + +def test_generate_service_last_accessed_details(event_loop): + iam_client = MagicMock() + iam_client.generate_service_last_accessed_details.return_value = {"JobId": "abc123"} + aar = AccessAdvisorRetriever() + job_id = event_loop.run_until_complete(aar._generate_service_last_accessed_details(iam_client, "abc123")) + assert job_id == "abc123" + + +def test_get_service_last_accessed_details(event_loop): + iam_client = MagicMock() + iam_client.get_service_last_accessed_details.side_effect = [ + {"JobStatus": "IN_PROGRESS"}, + {"JobStatus": "IN_PROGRESS"}, + { + "JobStatus": "COMPLETED", + "ServicesLastAccessed": [ + { + "ServiceName": "AWS Lambda", + "LastAuthenticated": datetime.datetime(2020, 4, 12, 15, 30, tzinfo=datetime.timezone.utc), + "ServiceNamespace": "lambda", + "LastAuthenticatedEntity": "arn:aws:iam::123456789012:user/admin", + "TotalAuthenticatedEntities": 6, + }, + ], + }, + ] + aar = AccessAdvisorRetriever() + aar.backoff_base = 0.1 + aa_data = event_loop.run_until_complete(aar._get_service_last_accessed_details(iam_client, "abc123")) + assert aa_data["ServicesLastAccessed"][0]["ServiceName"] == "AWS Lambda" + assert aa_data["ServicesLastAccessed"][0]["LastAuthenticatedEntity"] == "arn:aws:iam::123456789012:user/admin" + + +def test_get_service_last_accessed_details_failure(event_loop): + iam_client = MagicMock() + iam_client.get_service_last_accessed_details.side_effect = [ + {"JobStatus": "IN_PROGRESS"}, + {"JobStatus": "FAILED", "Error": "Oh no!"}, + ] + aar = AccessAdvisorRetriever() + aar.backoff_base = 0.1 + with pytest.raises(AccessAdvisorError) as e: + _ = event_loop.run_until_complete(aar._get_service_last_accessed_details(iam_client, "abc123")) + assert str(e.value) == "Access Advisor job failed: Oh no!" + + +def test_get_service_last_accessed_details_too_many_retries(event_loop): + iam_client = MagicMock() + iam_client.get_service_last_accessed_details.side_effect = [ + {"JobStatus": "IN_PROGRESS"}, + {"JobStatus": "IN_PROGRESS"}, + {"JobStatus": "IN_PROGRESS"}, + {"JobStatus": "IN_PROGRESS"}, + {"JobStatus": "IN_PROGRESS"}, + {"JobStatus": "IN_PROGRESS"}, + ] + aar = AccessAdvisorRetriever() + aar.max_retries = 5 + aar.backoff_base = 0.1 + with pytest.raises(AccessAdvisorError) as e: + _ = event_loop.run_until_complete(aar._get_service_last_accessed_details(iam_client, "abc123")) + assert str(e.value) == "Access Advisor job failed: exceeded max retries" + + +@pytest.mark.parametrize( + ("arn", "expected"), + [ + ("arn:aws:iam::123456789012:role/roleName", "123456789012"), # Role ARN + ( + "arn:aws:iam::123456789012:role/thisIsAPath/roleName", + "123456789012", + ), # Role ARN with path + ("arn:aws:iam::223456789012:policy/policyName", "223456789012"), # Policy ARN + ("arn:aws:iam::323456789012:user/userName", "323456789012"), # User ARN + ], +) +def test_get_account_from_arn(arn, expected): + result = AccessAdvisorRetriever._get_account_from_arn(arn) + assert result == expected + + +@pytest.mark.parametrize( + ("service_last_accessed", "expected"), + [ + ( + # datetime object for LastAuthenticated + { + "ServiceName": "AWS Lambda", + "LastAuthenticated": datetime.datetime(2020, 4, 12, 15, 30, tzinfo=datetime.timezone.utc), + "ServiceNamespace": "lambda", + "LastAuthenticatedEntity": "arn:aws:iam::123456789012:user/admin", + "TotalAuthenticatedEntities": 6, + }, + { + "ServiceName": "AWS Lambda", + "LastAuthenticated": 1586705400000, + "ServiceNamespace": "lambda", + "LastAuthenticatedEntity": "arn:aws:iam::123456789012:user/admin", + "TotalAuthenticatedEntities": 6, + }, + ), + ( + # empty string for LastAuthenticated + { + "ServiceName": "AWS Lambda", + "LastAuthenticated": "", + "ServiceNamespace": "lambda", + "LastAuthenticatedEntity": "", + "TotalAuthenticatedEntities": 0, + }, + { + "ServiceName": "AWS Lambda", + "LastAuthenticated": 0, + "ServiceNamespace": "lambda", + "LastAuthenticatedEntity": "", + "TotalAuthenticatedEntities": 0, + }, + ), + ], +) +def test_transform_result(service_last_accessed, expected): + result = AccessAdvisorRetriever._transform_result(service_last_accessed) + assert result == expected + + +@pytest.mark.parametrize( + ("arn", "data", "expected"), + [ + # Empty input data + ( + "arn:aws:iam::123456789012:user/admin", + {}, + { + "access_advisor": [ + { + "LastAuthenticated": 1586705400000, + "LastAuthenticatedEntity": "arn:aws:iam::123456789012:user/admin", + "ServiceName": "AWS Lambda", + "ServiceNamespace": "lambda", + "TotalAuthenticatedEntities": 6, + } + ] + }, + ), + # Non-empty input data + ( + "arn:aws:iam::123456789012:user/admin", + {"data_from_other_retrievers": "hello"}, + { + "access_advisor": [ + { + "LastAuthenticated": 1586705400000, + "LastAuthenticatedEntity": "arn:aws:iam::123456789012:user/admin", + "ServiceName": "AWS Lambda", + "ServiceNamespace": "lambda", + "TotalAuthenticatedEntities": 6, + } + ], + "data_from_other_retrievers": "hello", + }, + ), + ], +) +@patch("aardvark.retrievers.access_advisor.retriever.boto3_cached_conn") +def test_run(mock_boto3_cached_conn, event_loop, arn, data, expected): + mock_iam_client = MagicMock() + mock_iam_client.generate_service_last_accessed_details.return_value = {"JobId": "abc123"} + mock_iam_client.get_service_last_accessed_details.return_value = { + "JobStatus": "COMPLETED", + "ServicesLastAccessed": [ + { + "ServiceName": "AWS Lambda", + "LastAuthenticated": datetime.datetime(2020, 4, 12, 15, 30, tzinfo=datetime.timezone.utc), + "ServiceNamespace": "lambda", + "LastAuthenticatedEntity": "arn:aws:iam::123456789012:user/admin", + "TotalAuthenticatedEntities": 6, + }, + ], + } + mock_boto3_cached_conn.return_value = mock_iam_client + aar = AccessAdvisorRetriever() + result = event_loop.run_until_complete(aar.run("arn:aws:iam::123456789012:user/admin", data)) + assert result["access_advisor"] + assert result == expected + + +@pytest.mark.parametrize(("arn", "data", "expected"), [("arn", {}, {})]) +@patch("aardvark.retrievers.access_advisor.retriever.boto3_cached_conn") +def test_run_missing_arn(mock_boto3_cached_conn, event_loop, arn, data, expected): + mock_iam_client = MagicMock() + mock_iam_client.exceptions.NoSuchEntityException = Exception + mock_iam_client.generate_service_last_accessed_details.side_effect = ( + mock_iam_client.exceptions.NoSuchEntityException() + ) + mock_boto3_cached_conn.return_value = mock_iam_client + aar = AccessAdvisorRetriever() + result = event_loop.run_until_complete(aar.run("arn:aws:iam::123456789012:user/admin", {})) + assert result == expected diff --git a/tests/retrievers/test_retriever_plugin.py b/tests/retrievers/test_retriever_plugin.py new file mode 100644 index 0000000..71b9791 --- /dev/null +++ b/tests/retrievers/test_retriever_plugin.py @@ -0,0 +1,11 @@ +import pytest + +from aardvark.retrievers import RetrieverPlugin + + +def test_retriever_plugin(event_loop): + retriever = RetrieverPlugin("test_retriever") + assert retriever.name == "test_retriever" + assert str(retriever) == "Retriever(test_retriever)" + with pytest.raises(NotImplementedError): + event_loop.run_until_complete(retriever.run("arn:foo:bar", {"data": "information"})) diff --git a/tests/retrievers/test_runner.py b/tests/retrievers/test_runner.py new file mode 100644 index 0000000..3488f1f --- /dev/null +++ b/tests/retrievers/test_runner.py @@ -0,0 +1,272 @@ +import asyncio +from unittest.mock import AsyncMock, MagicMock, call, patch + +import pytest +from swag_client.exceptions import InvalidSWAGDataException + +from aardvark.exceptions import RetrieverError +from aardvark.retrievers.runner import RetrieverRunner + + +def test_register_retriever(runner, mock_retriever): + runner.register_retriever(mock_retriever) + assert len(runner.retrievers) == 1 + assert runner.retrievers[0].name == "retriever_stub" + + +@pytest.mark.asyncio +async def test_run_retrievers(runner, mock_retriever): + runner.register_retriever(mock_retriever) + result = await runner._run_retrievers("abc123") + assert result + assert result["arn"] == "abc123" + assert result["retriever_stub"]["success"] + + +@pytest.mark.asyncio +async def test_run_retrievers_failure(runner, mock_failing_retriever): + runner.register_retriever(mock_failing_retriever) + with pytest.raises(RetrieverError): + await runner._run_retrievers("abc123") + + +@pytest.mark.asyncio +async def test_retriever_loop(runner, mock_retriever): + runner.register_retriever(mock_retriever) + arn_queue = asyncio.Queue() + await arn_queue.put("abc123") + assert not arn_queue.empty() + runner.arn_queue = arn_queue + results_queue = asyncio.Queue() + runner.results_queue = results_queue + task = asyncio.create_task(runner._retriever_loop("")) + await arn_queue.join() + task.cancel() + result = await runner.results_queue.get() + assert result + assert result["arn"] == "abc123" + assert result["retriever_stub"]["success"] + + +@pytest.mark.asyncio +async def test_retriever_loop_failure(runner, mock_failing_retriever): + runner.register_retriever(mock_failing_retriever) + arn_queue = asyncio.Queue() + await arn_queue.put("abc123") + assert not arn_queue.empty() + runner.arn_queue = arn_queue + results_queue = asyncio.Queue() + runner.results_queue = results_queue + failure_queue = asyncio.Queue() + runner.failure_queue = failure_queue + task = asyncio.create_task(runner._retriever_loop("")) + await arn_queue.join() + task.cancel() + assert len(runner.failed_arns) == 1 + assert runner.failed_arns[0] == "abc123" + assert runner.results_queue.empty() + assert not runner.failure_queue.empty() + failed = await runner.failure_queue.get() + assert failed == "abc123" + + +@pytest.mark.asyncio +async def test_results_loop(runner, mock_retriever): + runner.register_retriever(mock_retriever) + results_queue = asyncio.Queue() + await results_queue.put({"arn": "abc123", "access_advisor": {"access": "advised"}}) + runner.results_queue = results_queue + expected = {"abc123": {"access": "advised"}} + runner.persistence.store_role_data = MagicMock() + task = asyncio.create_task(runner._results_loop("")) + await runner.results_queue.join() + task.cancel() + runner.persistence.store_role_data.assert_called() + runner.persistence.store_role_data.assert_called_with(expected) + + +@patch("aardvark.retrievers.runner.boto3_cached_conn") +@patch( + "aardvark.retrievers.runner.list_roles", + return_value=[{"Arn": "role1"}, {"Arn": "role2"}], +) +@patch( + "aardvark.retrievers.runner.list_users", + return_value=[{"Arn": "user1"}, {"Arn": "user2"}], +) +@pytest.mark.asyncio +async def test_get_arns_for_account(mock_list_users, mock_list_roles, mock_boto3_cached_conn, runner): + paginator = MagicMock() + paginator.paginate.side_effect = ( + [{"Policies": [{"Arn": "policy1"}]}, {"Policies": [{"Arn": "policy2"}]}], + [{"Groups": [{"Arn": "group1"}]}, {"Groups": [{"Arn": "group2"}]}], + ) + mock_iam_client = MagicMock() + mock_iam_client.get_paginator.return_value = paginator + mock_boto3_cached_conn.return_value = mock_iam_client + runner.arn_queue = asyncio.Queue() + await runner._get_arns_for_account("012345678901") + assert not runner.arn_queue.empty() + expected = [ + "role1", + "role2", + "user1", + "user2", + "policy1", + "policy2", + "group1", + "group2", + ] + for arn in expected: + assert runner.arn_queue.get_nowait() == arn + + +@patch("aardvark.retrievers.runner.RetrieverRunner._get_arns_for_account") +@pytest.mark.asyncio +async def test_arn_lookup_loop(mock_get_arns_for_account, runner): + account_queue = asyncio.Queue() + account_queue.put_nowait("123456789012") + account_queue.put_nowait("223456789012") + runner.account_queue = account_queue + task = asyncio.create_task(runner._arn_lookup_loop("")) + await account_queue.join() + task.cancel() + assert mock_get_arns_for_account.call_args_list == [ + call("123456789012"), + call("223456789012"), + ] + + +@pytest.mark.asyncio +async def test_get_swag_accounts(): + swag_response = {"foo": "bar"} + runner = RetrieverRunner() + runner.swag = MagicMock() + runner.swag.get_all.return_value = swag_response + runner.swag.get_service_enabled.return_value = swag_response + result = await runner._get_swag_accounts() + assert result == swag_response + runner.swag.get_all.assert_called_with("mock swag filter") + runner.swag.get_service_enabled.assert_called_with("glowcloud", accounts_list={"foo": "bar"}) + + +@pytest.mark.asyncio +async def test_get_swag_accounts_failure(): + swag_response = {"foo": "bar"} + runner = RetrieverRunner() + runner.swag = MagicMock() + runner.swag.get_all.side_effect = InvalidSWAGDataException + runner.swag.get_service_enabled.return_value = swag_response + with pytest.raises(RetrieverError): + await runner._get_swag_accounts() + + +@pytest.mark.asyncio +async def test_queue_all_accounts(runner): + expected_account_ids = ["123456789012", "223456789012", "323456789012"] + account_queue = asyncio.Queue() + runner.account_queue = account_queue + runner._get_swag_accounts = AsyncMock() + runner._get_swag_accounts.return_value = [{"id": account_id} for account_id in expected_account_ids] + await runner._queue_all_accounts() + for account_id in expected_account_ids: + assert account_queue.get_nowait() == account_id + assert account_queue.empty() + + +@pytest.mark.asyncio +async def test_queue_accounts(runner): + swag_accounts = [ + { + "schemaVersion": "2", + "id": "123456789012", + "name": "test", + }, + { + "schemaVersion": "2", + "id": "223456789012", + "name": "staging", + "aliases": ["stage"], + }, + { + "schemaVersion": "2", + "id": "323456789012", + "name": "prod", + }, + ] + expected_account_ids = ["423456789012", "123456789012", "223456789012"] + account_queue = asyncio.Queue() + runner.account_queue = account_queue + runner._get_swag_accounts = AsyncMock() + runner._get_swag_accounts.return_value = swag_accounts + await runner._queue_accounts(["test", "stage", "423456789012"]) + for account_id in expected_account_ids: + assert account_queue.get_nowait() == account_id + assert account_queue.empty() + + +@pytest.mark.asyncio +async def test_queue_arns(runner): + arn_queue = asyncio.Queue() + runner.arn_queue = arn_queue + arns = ["arn1", "arn2"] + await runner._queue_arns(arns) + for arn in arns: + assert arn_queue.get_nowait() == arn + + +@pytest.mark.asyncio +async def test_run(): + runner = RetrieverRunner() + runner._queue_accounts = AsyncMock() + runner._queue_arns = AsyncMock() + runner._queue_all_accounts = AsyncMock() + runner._arn_lookup_loop = AsyncMock() + runner._retriever_loop = AsyncMock() + runner._results_loop = AsyncMock() + await runner.run() + runner._queue_accounts.assert_not_called() + runner._queue_arns.assert_not_called() + runner._queue_all_accounts.assert_called() + runner._arn_lookup_loop.assert_called_with("arn-lookup-worker-0") + runner._retriever_loop.assert_called_with("retriever-worker-0") + runner._results_loop.assert_called_with("results-worker-0") + assert len(runner.tasks) == 3 + + +@pytest.mark.asyncio +async def test_run_with_accounts(): + runner = RetrieverRunner() + runner._queue_accounts = AsyncMock() + runner._queue_arns = AsyncMock() + runner._queue_all_accounts = AsyncMock() + runner._arn_lookup_loop = AsyncMock() + runner._retriever_loop = AsyncMock() + runner._results_loop = AsyncMock() + await runner.run(accounts=["test", "prod"]) + runner._queue_accounts.assert_called_with(["test", "prod"]) + runner._queue_arns.assert_not_called() + runner._queue_all_accounts.assert_not_called() + runner._arn_lookup_loop.assert_called_with("arn-lookup-worker-0") + runner._retriever_loop.assert_called_with("retriever-worker-0") + runner._results_loop.assert_called_with("results-worker-0") + assert len(runner.tasks) == 3 + + +@pytest.mark.asyncio +async def test_run_with_arns(): + runner = RetrieverRunner() + runner._queue_accounts = AsyncMock() + runner._queue_arns = AsyncMock() + runner._queue_all_accounts = AsyncMock() + runner._arn_lookup_loop = AsyncMock() + runner._retriever_loop = AsyncMock() + runner._results_loop = AsyncMock() + await runner.run(arns=["arn1", "arn2"]) + runner._queue_accounts.assert_not_called() + runner._queue_arns.assert_called_with(["arn1", "arn2"]) + runner._queue_all_accounts.assert_not_called() + runner._arn_lookup_loop.assert_not_called() + runner._retriever_loop.assert_called_with("retriever-worker-0") + runner._results_loop.assert_called_with("results-worker-0") + assert len(runner.tasks) == 2 diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..a095a55 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,32 @@ +import yaml + +from aardvark.config import create_config + + +def test_create_config(temp_config_file): + create_config( + aardvark_role="role", + swag_bucket="bucket", + swag_filter="filter", + swag_service_enabled_requirement="service", + arn_partition="aws", + sqlalchemy_database_uri="sqlite://////////////hi.db", + sqlalchemy_track_modifications=True, + num_threads=99, + region="us-underground-5", + filename=temp_config_file, + environment="testtesttest", + ) + + with open(temp_config_file) as f: + file_data = yaml.safe_load(f) + + assert file_data["testtesttest"]["AWS_ROLENAME"] == "role" + assert file_data["testtesttest"]["AWS_REGION"] == "us-underground-5" + assert file_data["testtesttest"]["AWS_ARN_PARTITION"] == "aws" + assert file_data["testtesttest"]["SWAG"]["bucket"] == "bucket" + assert file_data["testtesttest"]["SWAG"]["filter"] == "filter" + assert file_data["testtesttest"]["SWAG"]["service_enabled_requirement"] == "service" + assert file_data["testtesttest"]["UPDATER_NUM_THREADS"] == 99 + assert file_data["testtesttest"]["SQLALCHEMY_DATABASE_URI"] == "sqlite://////////////hi.db" + assert file_data["testtesttest"]["SQLALCHEMY_TRACK_MODIFICATIONS"] is True