diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 26e7695014f..f386f469cf6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -84,12 +84,6 @@ repos: ^boefjes/tools | ^keiko/templates | ^mula/whitelist\.py$ | - ^octopoes/bits | - ^octopoes/octopoes/core | - ^octopoes/octopoes/events | - ^octopoes/octopoes/models | - ^octopoes/octopoes/repositories | - ^octopoes/octopoes/xtdb | ^octopoes/tools | ^rocky/whitelist\.py$ | /tests/ | diff --git a/octopoes/.ci/docker-compose.yml b/octopoes/.ci/docker-compose.yml index 8818557e11a..b4f234f4749 100644 --- a/octopoes/.ci/docker-compose.yml +++ b/octopoes/.ci/docker-compose.yml @@ -20,7 +20,7 @@ services: args: ENVIRONMENT: dev context: . - command: pytest -x tests/integration --timeout=300 + command: pytest tests/integration --timeout=300 depends_on: - xtdb - ci_octopoes diff --git a/octopoes/bits/check_csp_header/check_csp_header.py b/octopoes/bits/check_csp_header/check_csp_header.py index 143039f3f5a..aafe0466abc 100644 --- a/octopoes/bits/check_csp_header/check_csp_header.py +++ b/octopoes/bits/check_csp_header/check_csp_header.py @@ -14,7 +14,7 @@ def run(input_ooi: HTTPHeader, additional_oois: list, config: dict[str, str]) -> if header.key.lower() != "content-security-policy": return - findings: [str] = [] + findings: list[str] = [] if "http://" in header.value: findings.append("Http should not be used in the CSP settings of an HTTP Header.") @@ -100,10 +100,10 @@ def run(input_ooi: HTTPHeader, additional_oois: list, config: dict[str, str]) -> def _ip_valid(source: str) -> bool: "Check if there are IP's in this source, return False if the address found was to be non global. Ignores non ips" - ip = NON_DECIMAL_FILTER.sub("", source) - if ip: + ip_str = NON_DECIMAL_FILTER.sub("", source) + if ip_str: try: - ip = ipaddress.ip_address(ip) + ip = ipaddress.ip_address(ip_str) if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_multicast or ip.is_reserved: return False except ValueError: @@ -121,7 +121,7 @@ def _create_kat_finding(header: Reference, kat_id: str, description: str) -> Ite ) -def _source_valid(policy: [str]) -> bool: +def _source_valid(policy: list[str]) -> bool: for value in policy: if not ( re.search(r"\S+\.\S{2,3}([\s]+|$|;|:[0-9]+)", value) diff --git a/octopoes/bits/check_hsts_header/check_hsts_header.py b/octopoes/bits/check_hsts_header/check_hsts_header.py index 73e9811b51c..fa6038c9291 100644 --- a/octopoes/bits/check_hsts_header/check_hsts_header.py +++ b/octopoes/bits/check_hsts_header/check_hsts_header.py @@ -14,7 +14,7 @@ def run(input_ooi: HTTPHeader, additional_oois: list, config: dict[str, str]) -> one_year = datetime.timedelta(days=365).total_seconds() max_age = int(config.get("max-age", one_year)) if config else one_year - findings: [str] = [] + findings: list[str] = [] headervalue = header.value.lower() if "includesubdomains" not in headervalue: diff --git a/octopoes/bits/runner.py b/octopoes/bits/runner.py index c0f69be51cf..f19773cd6a6 100644 --- a/octopoes/bits/runner.py +++ b/octopoes/bits/runner.py @@ -1,7 +1,7 @@ from collections.abc import Iterator from importlib import import_module from inspect import isfunction, signature -from typing import Any, Protocol, cast +from typing import Any, Protocol from bits.definitions import BitDefinition from octopoes.models import OOI @@ -22,7 +22,6 @@ def __init__(self, bit_definition: BitDefinition): def run(self, *args, **kwargs) -> list[OOI]: module = import_module(self.module) - module = cast(Runnable, module) if not hasattr(module, "run") or not isfunction(module.run): raise ModuleException(f"Module {module} does not define a run function") @@ -38,7 +37,7 @@ def __str__(self): def _bit_run_signature(input_ooi: OOI, additional_oois: list[OOI], config: dict[str, str]) -> Iterator[OOI]: - ... + yield input_ooi BIT_SIGNATURE = signature(_bit_run_signature) diff --git a/octopoes/bits/spf_discovery/internetnl_spf_parser.py b/octopoes/bits/spf_discovery/internetnl_spf_parser.py index 11e84f082c8..711da223c63 100644 --- a/octopoes/bits/spf_discovery/internetnl_spf_parser.py +++ b/octopoes/bits/spf_discovery/internetnl_spf_parser.py @@ -1,6 +1,5 @@ # Copyright: 2022, ECP, NLnet Labs and the Internet.nl contributors # SPDX-License-Identifier: Apache-2.0 -import contextlib import ipaddress from pyparsing import ( @@ -40,16 +39,13 @@ def _parse_ipv6(tokens): """ match = str(tokens[0]) - ipv6 = None try: - ipv6 = ipaddress.IPv6Address(match) + return str(ipaddress.IPv6Address(match)) except ipaddress.AddressValueError: - with contextlib.suppress(ipaddress.AddressValueError, ipaddress.NetmaskValueError): - ipv6 = ipaddress.IPv6Network(match, strict=False) - - if not ipv6: - raise ParseException("Non valid IPv6 address/network.") - return str(ipv6) + try: + return str(ipaddress.IPv6Network(match, strict=False)) + except (ipaddress.AddressValueError, ipaddress.NetmaskValueError) as e: + raise ParseException("Non valid IPv6 address/network.") from e SP = White(ws=" ", exact=1).suppress() diff --git a/octopoes/bits/spf_discovery/spf_discovery.py b/octopoes/bits/spf_discovery/spf_discovery.py index acfa9e1412b..74cdd131781 100644 --- a/octopoes/bits/spf_discovery/spf_discovery.py +++ b/octopoes/bits/spf_discovery/spf_discovery.py @@ -47,7 +47,7 @@ def run(input_ooi: DNSTXTRecord, additional_oois, config: dict[str, str]) -> Ite yield Finding(finding_type=ft.reference, ooi=input_ooi.reference, description="This SPF record is invalid") -def parse_ip_qualifiers(mechanism: str, input_ooi: DNSTXTRecord, spf_record: DNSSPFRecord) -> Iterator[str]: +def parse_ip_qualifiers(mechanism: str, input_ooi: DNSTXTRecord, spf_record: DNSSPFRecord) -> Iterator[OOI]: # split mechanism into qualifier and ip qualifier, ip = mechanism.split(":", 1) ip = mechanism[4:] @@ -72,7 +72,7 @@ def parse_ip_qualifiers(mechanism: str, input_ooi: DNSTXTRecord, spf_record: DNS ) -def parse_a_mx_qualifiers(mechanism: str, input_ooi: DNSTXTRecord, spf_record: DNSSPFRecord) -> Iterator[str]: +def parse_a_mx_qualifiers(mechanism: str, input_ooi: DNSTXTRecord, spf_record: DNSSPFRecord) -> Iterator[OOI]: if mechanism == "a" or mechanism == "mx": yield DNSSPFMechanismHostname(spf_record=spf_record.reference, hostname=input_ooi.hostname, mechanism=mechanism) else: @@ -86,7 +86,7 @@ def parse_a_mx_qualifiers(mechanism: str, input_ooi: DNSTXTRecord, spf_record: D spf_record=spf_record.reference, hostname=hostname.reference, mechanism=mechanism_type ) if mechanism.startswith("a/") or mechanism.startswith("mx/"): - mechanism_type, domain = mechanism.split("/", 1)[1] + mechanism_type, domain = mechanism.split("/", 1) # TODO: fix prefix lengths domain = domain.split("/")[0] hostname = Hostname(name=domain, network=Network(name=input_ooi.hostname.tokenized.network.name).reference) @@ -98,7 +98,7 @@ def parse_a_mx_qualifiers(mechanism: str, input_ooi: DNSTXTRecord, spf_record: D def parse_ptr_exists_include_mechanism( mechanism: str, input_ooi: DNSTXTRecord, spf_record: DNSSPFRecord -) -> Iterator[str]: +) -> Iterator[OOI]: if mechanism == "ptr": yield DNSSPFMechanismHostname(spf_record=spf_record.reference, hostname=input_ooi.hostname, mechanism="ptr") else: @@ -113,7 +113,7 @@ def parse_ptr_exists_include_mechanism( ) -def parse_redirect_mechanism(mechanism: str, input_ooi: DNSTXTRecord, spf_record: DNSSPFRecord) -> Iterator[str]: +def parse_redirect_mechanism(mechanism: str, input_ooi: DNSTXTRecord, spf_record: DNSSPFRecord) -> Iterator[OOI]: mechanism_type, domain = mechanism.split("=", 1) # currently, the model only supports hostnames and not domains if domain.startswith("_"): diff --git a/octopoes/octopoes/api/models.py b/octopoes/octopoes/api/models.py index d1172fc526e..81cbd842b12 100644 --- a/octopoes/octopoes/api/models.py +++ b/octopoes/octopoes/api/models.py @@ -1,11 +1,11 @@ import uuid from datetime import datetime -from typing import Any +from typing import Annotated, Any from pydantic import AwareDatetime, BaseModel, Field from octopoes.models import Reference -from octopoes.models.types import OOIType +from octopoes.models.types import ConcreteOOIType, OOIType class ServiceHealth(BaseModel): @@ -59,18 +59,21 @@ class ScanProfileDeclaration(BaseModel): valid_time: datetime +ValidatedOOIType = Annotated[ConcreteOOIType, Field(discriminator="object_type")] + + # API models (timezone validation and pydantic parsing) class ValidatedObservation(_BaseObservation): """Used by Octopoes API to validate and parse correctly""" - result: list[OOIType] + result: list[ValidatedOOIType] valid_time: AwareDatetime class ValidatedDeclaration(BaseModel): """Used by Octopoes API to validate and parse correctly""" - ooi: OOIType + ooi: ValidatedOOIType valid_time: AwareDatetime method: str | None = "manual" task_id: uuid.UUID | None = Field(default_factory=uuid.uuid4) @@ -79,7 +82,7 @@ class ValidatedDeclaration(BaseModel): class ValidatedAffirmation(BaseModel): """Used by Octopoes API to validate and parse correctly""" - ooi: OOIType + ooi: ValidatedOOIType valid_time: AwareDatetime method: str | None = "hydration" task_id: uuid.UUID | None = Field(default_factory=uuid.uuid4) diff --git a/octopoes/octopoes/api/router.py b/octopoes/octopoes/api/router.py index 522fb564731..5c6faa49fa7 100644 --- a/octopoes/octopoes/api/router.py +++ b/octopoes/octopoes/api/router.py @@ -368,10 +368,13 @@ def get_scan_profile_inheritance( reference: Reference = Depends(extract_reference), ) -> list[InheritanceSection]: ooi = octopoes.get_ooi(reference, valid_time) + if not ooi.scan_profile: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="OOI does not have a scanprofile") + start = InheritanceSection( reference=ooi.reference, level=ooi.scan_profile.level, scan_profile_type=ooi.scan_profile.scan_profile_type ) - if ooi.scan_profile.scan_profile_type == ScanProfileType.DECLARED: + if ooi.scan_profile.scan_profile_type == ScanProfileType.DECLARED.value: return [start] return octopoes.get_scan_profile_inheritance(reference, valid_time, [start]) diff --git a/octopoes/octopoes/connector/octopoes.py b/octopoes/octopoes/connector/octopoes.py index f7648e36576..eeb038ae357 100644 --- a/octopoes/octopoes/connector/octopoes.py +++ b/octopoes/octopoes/connector/octopoes.py @@ -97,7 +97,7 @@ def list_objects( scan_level: set[ScanLevel] = DEFAULT_SCAN_LEVEL_FILTER, scan_profile_type: set[ScanProfileType] = DEFAULT_SCAN_PROFILE_TYPE_FILTER, ) -> Paginated[OOIType]: - params: dict[str, str | int | list[str] | set[str]] = { + params: dict[str, str | int | list[str] | set[str | int]] = { "types": [t.__name__ for t in types], "valid_time": str(valid_time), "offset": offset, @@ -169,7 +169,7 @@ def list_origins( "source": source, "result": result, "task_id": str(task_id) if task_id else None, - "origin_type": origin_type, + "origin_type": str(origin_type) if origin_type else None, }, ) diff --git a/octopoes/octopoes/core/service.py b/octopoes/octopoes/core/service.py index 368e7b301b5..fc64b97046a 100644 --- a/octopoes/octopoes/core/service.py +++ b/octopoes/octopoes/core/service.py @@ -1,8 +1,9 @@ import json from collections import Counter -from collections.abc import Callable +from collections.abc import Callable, ValuesView from datetime import datetime, timezone from logging import getLogger +from typing import overload from bits.definitions import get_bit_definitions from bits.runner import BitRunner @@ -77,7 +78,17 @@ def __init__( self.origin_parameter_repository = origin_parameter_repository self.scan_profile_repository = scan_profile_repository + @overload + def _populate_scan_profiles(self, oois: ValuesView[OOI], valid_time: datetime) -> ValuesView[OOI]: + ... + + @overload def _populate_scan_profiles(self, oois: list[OOI], valid_time: datetime) -> list[OOI]: + ... + + def _populate_scan_profiles( + self, oois: list[OOI] | ValuesView[OOI], valid_time: datetime + ) -> list[OOI] | ValuesView[OOI]: logger.debug("Populating scan profiles for %s oois", len(oois)) ooi_cache: dict[str, OOI] = {str(ooi.reference): ooi for ooi in oois} @@ -132,7 +143,7 @@ def get_ooi_tree( reference: Reference, valid_time: datetime, search_types: set[type[OOI]] | None = None, - depth: int | None = 1, + depth: int = 1, ): tree = self.ooi_repository.get_tree(reference, valid_time, search_types, depth) self._populate_scan_profiles(tree.store.values(), valid_time) @@ -154,7 +165,7 @@ def save_origin(self, origin: Origin, oois: list[OOI], valid_time: datetime) -> try: self.ooi_repository.get(origin.source, valid_time) except ObjectNotFoundException: - return + raise ValueError("Origin source of observation does not exist") for ooi in oois: self.ooi_repository.save(ooi, valid_time=valid_time) @@ -176,7 +187,7 @@ def _run_inference(self, origin: Origin, valid_time: datetime) -> None: return try: - level = self.scan_profile_repository.get(origin.source, valid_time).level + level = self.scan_profile_repository.get(origin.source, valid_time).level.value except ObjectNotFoundException: level = 0 @@ -343,6 +354,9 @@ def process_event(self, event: DBEvent): # OOI events def _on_create_ooi(self, event: OOIDBEvent) -> None: + if event.new_data is None: + raise ValueError("Create event new_data should not be None") + ooi = event.new_data # keep old scan profile, or create new scan profile @@ -393,6 +407,9 @@ def _on_create_ooi(self, event: OOIDBEvent) -> None: self.origin_parameter_repository.save(origin_parameter, event.valid_time) def _on_update_ooi(self, event: OOIDBEvent) -> None: + if event.new_data is None: + raise ValueError("Update event new_data should not be None") + inference_origins = self.origin_repository.list_origins(event.valid_time, source=event.new_data.reference) inference_params = self.origin_parameter_repository.list_by_reference( event.new_data.reference, valid_time=event.valid_time @@ -405,6 +422,9 @@ def _on_update_ooi(self, event: OOIDBEvent) -> None: self._run_inference(inference_origin, event.valid_time) def _on_delete_ooi(self, event: OOIDBEvent) -> None: + if event.old_data is None: + raise ValueError("Update event old_data should not be None") + reference = event.old_data.reference # delete related origins to which it is a source @@ -426,38 +446,53 @@ def _on_delete_ooi(self, event: OOIDBEvent) -> None: # Origin events def _on_create_origin(self, event: OriginDBEvent) -> None: + if event.new_data is None: + raise ValueError("Create event new_data should not be None") + if event.new_data.origin_type == OriginType.INFERENCE: self._run_inference(event.new_data, event.valid_time) def _on_update_origin(self, event: OriginDBEvent) -> None: + if event.new_data is None or event.old_data is None: + raise ValueError("Update event new_data and old_data should not be None") + dereferenced_oois = event.old_data - event.new_data for reference in dereferenced_oois: self._delete_ooi(reference, event.valid_time) def _on_delete_origin(self, event: OriginDBEvent) -> None: + if event.old_data is None: + raise ValueError("Delete event old_data should not be None") + for reference in event.old_data.result: self._delete_ooi(reference, event.valid_time) # Origin parameter events def _on_create_origin_parameter(self, event: OriginParameterDBEvent) -> None: + if event.new_data is None: + raise ValueError("Create event new_data should not be None") + # Run the bit/origin try: origin = self.origin_repository.get(event.new_data.origin_id, event.valid_time) self._run_inference(origin, event.valid_time) except ObjectNotFoundException: - return + pass def _on_update_origin_parameter(self, event: OriginParameterDBEvent) -> None: # update of origin_parameter is not possible, since both fields are unique ... def _on_delete_origin_parameter(self, event: OriginParameterDBEvent) -> None: + if event.old_data is None: + raise ValueError("Delete event old_data should not be None") + # Run the bit/origin try: origin = self.origin_repository.get(event.old_data.origin_id, event.valid_time) self._run_inference(origin, event.valid_time) except ObjectNotFoundException: - return + pass def _run_inferences(self, event: ScanProfileDBEvent) -> None: inference_origins = self.origin_repository.list_origins(event.valid_time, source=event.reference) @@ -505,14 +540,16 @@ def get_scan_profile_inheritance( segment = path.segments[0] for neighbour in neighbours: segment_inheritance = get_max_scan_level_inheritance(segment) - if ( - segment_inheritance is None - or neighbour.reference in visited - or neighbour.scan_profile.level < last_inheritance_level - ): + if segment_inheritance is None or neighbour.reference in visited: + continue + + if neighbour.scan_profile is None: + raise ValueError("neighbour scan_profile is None") + + if neighbour.scan_profile.level < last_inheritance_level: continue - inherited_level = min(get_max_scan_level_inheritance(segment), neighbour.scan_profile.level) + inherited_level = min(get_max_scan_level_inheritance(segment) or 0, neighbour.scan_profile.level) inheritances.append( InheritanceSection( segment=str(segment), @@ -549,7 +586,7 @@ def get_scan_profile_inheritance( def recalculate_bits(self) -> int: valid_time = datetime.now(timezone.utc) - bit_counter = Counter() + bit_counter: Counter[str] = Counter() # loop over all bit definitions and add origins and origin params bit_definitions = get_bit_definitions() diff --git a/octopoes/octopoes/events/events.py b/octopoes/octopoes/events/events.py index 60480e1cf7d..1e1d23cc3e8 100644 --- a/octopoes/octopoes/events/events.py +++ b/octopoes/octopoes/events/events.py @@ -19,7 +19,7 @@ class DBEvent(BaseModel): entity_type: str operation_type: OperationType valid_time: datetime - client: str | None = None + client: str @property def primary_key(self) -> str: @@ -33,7 +33,9 @@ class OOIDBEvent(DBEvent): @property def primary_key(self) -> str: - return self.new_data.primary_key if self.new_data else self.old_data.primary_key + # There doesn't seem to be an easy way to tell mypy that if new_data is + # None then old_data is never None. + return self.new_data.primary_key if self.new_data else self.old_data.primary_key # type: ignore[union-attr] class OriginDBEvent(DBEvent): @@ -43,7 +45,9 @@ class OriginDBEvent(DBEvent): @property def primary_key(self) -> str: - return self.new_data.id if self.new_data else self.old_data.id + # There doesn't seem to be an easy way to tell mypy that if new_data is + # None then old_data is never None. + return self.new_data.id if self.new_data else self.old_data.id # type: ignore[union-attr] class OriginParameterDBEvent(DBEvent): @@ -53,7 +57,9 @@ class OriginParameterDBEvent(DBEvent): @property def primary_key(self) -> str: - return self.new_data.id if self.new_data else self.old_data.id + # There doesn't seem to be an easy way to tell mypy that if new_data is + # None then old_data is never None. + return self.new_data.id if self.new_data else self.old_data.id # type: ignore[union-attr] class ScanProfileDBEvent(DBEvent): diff --git a/octopoes/octopoes/events/manager.py b/octopoes/octopoes/events/manager.py index cf09350e246..e6c4f16ca85 100644 --- a/octopoes/octopoes/events/manager.py +++ b/octopoes/octopoes/events/manager.py @@ -78,8 +78,6 @@ def publish(self, event: DBEvent) -> None: raise def _publish(self, event: DBEvent) -> None: - event.client = self.client - # schedule celery event processor self.celery_app.send_task( "octopoes.tasks.tasks.handle_event", @@ -98,10 +96,11 @@ def _publish(self, event: DBEvent) -> None: if not isinstance(event, ScanProfileDBEvent): return - incremented = (event.operation_type == OperationType.CREATE and event.new_data.level > 0) or ( + # There doesn't seem to be an easy way to tell mypy when old_data or new_data is None. + incremented = (event.operation_type == OperationType.CREATE and event.new_data.level > 0) or ( # type: ignore[union-attr] event.operation_type == OperationType.UPDATE and event.old_data - and event.new_data.level > event.old_data.level + and event.new_data.level > event.old_data.level # type: ignore[union-attr] ) if incremented: @@ -109,7 +108,7 @@ def _publish(self, event: DBEvent) -> None: { "primary_key": event.reference, "object_type": event.reference.class_, - "scan_profile": event.new_data.dict(), + "scan_profile": event.new_data.dict(), # type: ignore[union-attr] } ) @@ -123,7 +122,7 @@ def _publish(self, event: DBEvent) -> None: logger.debug( "Published scan_profile_increment [primary_key=%s] [level=%s]", format_id_short(event.primary_key), - event.new_data.level, + event.new_data.level, # type: ignore[union-attr] ) # publish mutations @@ -131,8 +130,8 @@ def _publish(self, event: DBEvent) -> None: if event.operation_type != OperationType.DELETE: mutation.value = AbstractOOI( - primary_key=event.new_data.reference, - object_type=event.new_data.reference.class_, + primary_key=event.new_data.reference, # type: ignore[union-attr] + object_type=event.new_data.reference.class_, # type: ignore[union-attr] scan_profile=event.new_data, ) diff --git a/octopoes/octopoes/models/__init__.py b/octopoes/octopoes/models/__init__.py index ad582b5c709..bb8f88a1891 100644 --- a/octopoes/octopoes/models/__init__.py +++ b/octopoes/octopoes/models/__init__.py @@ -1,6 +1,5 @@ from __future__ import annotations -import abc from enum import Enum, IntEnum from typing import ( Any, @@ -79,7 +78,7 @@ class ScanProfileType(Enum): EMPTY = "empty" -class ScanProfileBase(BaseModel, abc.ABC): +class ScanProfileBase(BaseModel): scan_profile_type: str reference: Reference level: ScanLevel @@ -113,14 +112,14 @@ class InheritedScanProfile(ScanProfileBase): ScanProfile = EmptyScanProfile | InheritedScanProfile | DeclaredScanProfile -class OOI(BaseModel, abc.ABC): - object_type: Literal["OOI"] +class OOI(BaseModel): + object_type: str scan_profile: ScanProfile | None = None - _natural_key_attrs: list[str] = [] - _reverse_relation_names: dict[str, str] = {} - _information_value: list[str] = [] + _natural_key_attrs: ClassVar[list[str]] = [] + _reverse_relation_names: ClassVar[dict[str, str]] = {} + _information_value: ClassVar[list[str]] = [] _traversable: ClassVar[bool] = True primary_key: str = "" @@ -192,7 +191,7 @@ def reference(self) -> Reference: @classmethod def get_reverse_relation_name(cls, attr: str) -> str: - return cls._reverse_relation_names.default.get(attr, f"{cls.get_object_type()}_{attr}") + return cls._reverse_relation_names.get(attr, f"{cls.get_object_type()}_{attr}") @classmethod def get_tokenized_primary_key(cls, natural_key: str): @@ -234,10 +233,10 @@ def format_id_short(id_: str) -> str: class PrimaryKeyToken(RootModel): root: dict[str, str | PrimaryKeyToken] - def __getattr__(self, item) -> str | PrimaryKeyToken: + def __getattr__(self, item) -> Any: return self.root[item] - def __getitem__(self, item) -> str | PrimaryKeyToken: + def __getitem__(self, item) -> Any: return self.root[item] @@ -251,12 +250,11 @@ def get_leaf_subclasses(cls: type[OOI]) -> set[type[OOI]]: return set().union(*child_sets) -def build_token_tree(ooi_class: type[OOI]) -> dict: - tokens = {} +def build_token_tree(ooi_class: type[OOI]) -> dict[str, dict | str]: + tokens: dict[str, dict | str] = {} - for attribute in ooi_class._natural_key_attrs.default: + for attribute in ooi_class._natural_key_attrs: field = ooi_class.model_fields[attribute] - value = "" if field.annotation in (Reference, Reference | None): from octopoes.models.types import related_object_type @@ -265,7 +263,7 @@ def build_token_tree(ooi_class: type[OOI]) -> dict: trees = [build_token_tree(related_class) for related_class in get_leaf_subclasses(related_class)] # combine trees - value = {key: value_ for tree in trees for key, value_ in tree.items()} - - tokens[attribute] = value + tokens[attribute] = {key: value_ for tree in trees for key, value_ in tree.items()} + else: + tokens[attribute] = "" return tokens diff --git a/octopoes/octopoes/models/ooi/dns/records.py b/octopoes/octopoes/models/ooi/dns/records.py index 9fb835ec602..dc510f46000 100644 --- a/octopoes/octopoes/models/ooi/dns/records.py +++ b/octopoes/octopoes/models/ooi/dns/records.py @@ -1,4 +1,3 @@ -import abc import hashlib from enum import Enum from typing import Literal @@ -9,7 +8,7 @@ from octopoes.models.persistence import ReferenceField -class DNSRecord(OOI, abc.ABC): +class DNSRecord(OOI): hostname: Reference = ReferenceField(Hostname, max_issue_scan_level=0, max_inherit_scan_level=2) dns_record_type: Literal["A", "AAAA", "CAA", "CNAME", "MX", "NS", "PTR", "SOA", "SRV", "TXT"] value: str diff --git a/octopoes/octopoes/models/ooi/web.py b/octopoes/octopoes/models/ooi/web.py index 81be2c9a1a2..4e8f65101ba 100644 --- a/octopoes/octopoes/models/ooi/web.py +++ b/octopoes/octopoes/models/ooi/web.py @@ -1,4 +1,3 @@ -from abc import ABC from enum import Enum from typing import Literal @@ -47,7 +46,7 @@ class WebScheme(Enum): HTTPS = "https" -class WebURL(OOI, ABC): +class WebURL(OOI): network: Reference = ReferenceField(Network) scheme: WebScheme diff --git a/octopoes/octopoes/models/origin.py b/octopoes/octopoes/models/origin.py index 767d8ecd334..8b87d332810 100644 --- a/octopoes/octopoes/models/origin.py +++ b/octopoes/octopoes/models/origin.py @@ -17,7 +17,7 @@ class Origin(BaseModel): origin_type: OriginType method: str source: Reference - result: list[Reference] | None = Field(default_factory=list) + result: list[Reference] = Field(default_factory=list) task_id: UUID | None = None def __sub__(self, other) -> set[Reference]: diff --git a/octopoes/octopoes/models/path.py b/octopoes/octopoes/models/path.py index be03804ce50..b1b0a40dade 100644 --- a/octopoes/octopoes/models/path.py +++ b/octopoes/octopoes/models/path.py @@ -61,7 +61,10 @@ def reverse(self) -> Segment: self.source_type, ) - def __eq__(self, other: Segment) -> bool: + def __eq__(self, other: object) -> bool: + if not isinstance(other, Segment): + return NotImplemented + return ( self.source_type == other.source_type and self.direction == other.direction @@ -100,7 +103,10 @@ def __str__(self) -> str: segments = ".".join(map(str, self.segments)) return f"{start_type}.{segments}" - def __eq__(self, other: Path): + def __eq__(self, other: object) -> bool: + if not isinstance(other, Path | str): + return NotImplemented + return str(self) == str(other) def __lt__(self, other): diff --git a/octopoes/octopoes/models/tree.py b/octopoes/octopoes/models/tree.py index 118918ad1ea..e6efdc3a3e2 100644 --- a/octopoes/octopoes/models/tree.py +++ b/octopoes/octopoes/models/tree.py @@ -26,12 +26,12 @@ def filter_children(self, filter_fn: Callable[[ReferenceNode], bool]): return filter_fn(self) def collect_references(self) -> set[Reference]: - child_references = set() + child_references: set[Reference] = set() for child_name, children in self.children.items(): child_references_ = [child.collect_references() for child in children] # merge list of sets - child_references_ = set().union(*child_references_) - child_references = child_references.union(child_references_) + child_references_merged = set().union(*child_references_) + child_references = child_references.union(child_references_merged) return {self.reference}.union(child_references) diff --git a/octopoes/octopoes/models/types.py b/octopoes/octopoes/models/types.py index 0df693c6f3d..41d1f911f1f 100644 --- a/octopoes/octopoes/models/types.py +++ b/octopoes/octopoes/models/types.py @@ -93,9 +93,8 @@ | ResolvedHostname | NXDOMAIN ) -FindingTypeType = ( - FindingType - | ADRFindingType +ConcreteFindingTypeType = ( + ADRFindingType | KATFindingType | CVEFindingType | RetireJSFindingType @@ -103,7 +102,9 @@ | CAPECFindingType | SnykFindingType ) -NetworkType = Network | IPAddress | IPAddressV4 | IPAddressV6 | AutonomousSystem | IPV4NetBlock | IPV6NetBlock | IPPort +FindingTypeType = FindingType | ConcreteFindingTypeType +ConcreteNetworkType = Network | IPAddressV4 | IPAddressV6 | AutonomousSystem | IPV4NetBlock | IPV6NetBlock | IPPort +NetworkType = ConcreteNetworkType | IPAddress ServiceType = Service | IPService | TLSCipher SoftwareType = Software | SoftwareInstance WebType = ( @@ -135,11 +136,11 @@ ConfigType = Config ReportsType = ReportData -OOIType = ( +ConcreteOOIType = ( CertificateType | DnsType | DnsRecordType - | NetworkType + | ConcreteNetworkType | ServiceType | SoftwareType | WebType @@ -151,12 +152,14 @@ | EmailSecurityType | Finding | MutedFinding - | FindingTypeType + | ConcreteFindingTypeType | ConfigType | Question | ReportsType ) +OOIType = ConcreteOOIType | NetworkType | FindingTypeType + def get_all_types(cls_: type[OOI]) -> Iterator[type[OOI]]: yield cls_ diff --git a/octopoes/octopoes/repositories/ooi_repository.py b/octopoes/octopoes/repositories/ooi_repository.py index 668e6db2954..c9e27537a1e 100644 --- a/octopoes/octopoes/repositories/ooi_repository.py +++ b/octopoes/octopoes/repositories/ooi_repository.py @@ -119,7 +119,7 @@ def get_tree( reference: Reference, valid_time: datetime, search_types: set[type[OOI]] | None = None, - depth: int | None = 1, + depth: int = 1, ) -> ReferenceTree: raise NotImplementedError @@ -133,6 +133,7 @@ def list_findings( self, severities, exclude_muted, + only_muted, offset, limit, valid_time, @@ -158,7 +159,7 @@ def to_reference_node(self, pk_prefix: str) -> ReferenceNode | None: # Apparently relations can be joined to Null values..?!? if pk_prefix not in self.root: return None - reference = Reference.from_str(self.root.pop(pk_prefix)) + reference = Reference.from_str(cast(str, self.root.pop(pk_prefix))) children = {} for name, value in self.root.items(): if isinstance(value, XTDBReferenceNode): @@ -229,11 +230,14 @@ def deserialize(cls, data: dict[str, Any]) -> OOI: def get(self, reference: Reference, valid_time: datetime) -> OOI: try: res = self.session.client.get_entity(str(reference), valid_time) - return self.deserialize(res) except HTTPError as e: if e.response.status_code == HTTPStatus.NOT_FOUND: raise ObjectNotFoundException(str(reference)) + raise + + return self.deserialize(res) + def get_history( self, reference: Reference, @@ -259,6 +263,8 @@ def get_history( if e.response.status_code == HTTPStatus.NOT_FOUND: raise ObjectNotFoundException(str(reference)) + raise + def load_bulk(self, references: set[Reference], valid_time: datetime) -> dict[str, OOI]: ids = list(map(str, references)) query = generate_pull_query(FieldSet.ALL_FIELDS, {self.pk_prefix: ids}) @@ -363,11 +369,11 @@ def get_tree( reference: Reference, valid_time: datetime, search_types: set[type[OOI]] | None = None, - depth: int | None = 1, + depth: int = 1, ) -> ReferenceTree: if search_types is None: search_types = {OOI} - search_types = to_concrete(search_types) + concrete_search_types = to_concrete(search_types) results = self._get_tree_level({reference}, depth=depth, valid_time=valid_time) @@ -376,7 +382,7 @@ def get_tree( except IndexError: raise ObjectNotFoundException(str(reference)) - reference_node.filter_children(lambda child_node: child_node.reference.class_type in search_types) + reference_node.filter_children(lambda child_node: child_node.reference.class_type in concrete_search_types) store = self.load_bulk(reference_node.collect_references(), valid_time) return ReferenceTree(root=reference_node, store=store) @@ -401,7 +407,7 @@ def _get_related_objects(self, references: set[Reference], valid_time: datetime def _get_tree_level( self, references: set[Reference], - depth: int | None = 1, + depth: int = 1, exclude: set[Reference] | None = None, valid_time: datetime | None = None, ) -> list[ReferenceNode]: @@ -514,7 +520,7 @@ def construct_neighbour_query_multi(cls, references: set[Reference], paths: set[ return query def get_neighbours( - self, reference: Reference, valid_time: datetime, paths: set[Path] = None + self, reference: Reference, valid_time: datetime, paths: set[Path] | None = None ) -> dict[Path, list[OOI]]: query = self.construct_neighbour_query(reference, paths) @@ -586,6 +592,7 @@ def save(self, ooi: OOI, valid_time: datetime, end_valid_time: datetime | None = valid_time=valid_time, old_data=old_ooi, new_data=new_ooi, + client=self.event_manager.client, ) # After transaction, send event @@ -604,6 +611,7 @@ def delete(self, reference: Reference, valid_time: datetime) -> None: operation_type=OperationType.DELETE, valid_time=valid_time, old_data=ooi, + client=self.event_manager.client, ) self.session.listen_post_commit(lambda: self.event_manager.publish(event)) diff --git a/octopoes/octopoes/repositories/origin_parameter_repository.py b/octopoes/octopoes/repositories/origin_parameter_repository.py index 13d5c3719d9..4326ab728b3 100644 --- a/octopoes/octopoes/repositories/origin_parameter_repository.py +++ b/octopoes/octopoes/repositories/origin_parameter_repository.py @@ -107,6 +107,7 @@ def save(self, origin_parameter: OriginParameter, valid_time: datetime) -> None: valid_time=valid_time, old_data=old_origin_parameter, new_data=origin_parameter, + client=self.event_manager.client, ) self.session.listen_post_commit(lambda: self.event_manager.publish(event)) @@ -117,5 +118,6 @@ def delete(self, origin_parameter: OriginParameter, valid_time: datetime) -> Non operation_type=OperationType.DELETE, valid_time=valid_time, old_data=origin_parameter, + client=self.event_manager.client, ) self.session.listen_post_commit(lambda: self.event_manager.publish(event)) diff --git a/octopoes/octopoes/repositories/origin_repository.py b/octopoes/octopoes/repositories/origin_repository.py index 109ef4ab925..805d0eb2713 100644 --- a/octopoes/octopoes/repositories/origin_repository.py +++ b/octopoes/octopoes/repositories/origin_repository.py @@ -134,6 +134,7 @@ def save(self, origin: Origin, valid_time: datetime) -> None: valid_time=valid_time, old_data=old_origin, new_data=origin, + client=self.event_manager.client, ) self.session.listen_post_commit(lambda: self.event_manager.publish(event)) @@ -144,5 +145,6 @@ def delete(self, origin: Origin, valid_time: datetime) -> None: operation_type=OperationType.DELETE, valid_time=valid_time, old_data=origin, + client=self.event_manager.client, ) self.session.listen_post_commit(lambda: self.event_manager.publish(event)) diff --git a/octopoes/octopoes/repositories/scan_profile_repository.py b/octopoes/octopoes/repositories/scan_profile_repository.py index 7ed63b4d9e5..4a5aabd6b72 100644 --- a/octopoes/octopoes/repositories/scan_profile_repository.py +++ b/octopoes/octopoes/repositories/scan_profile_repository.py @@ -105,6 +105,7 @@ def save( reference=new_scan_profile.reference, old_data=old_scan_profile, new_data=new_scan_profile, + client=self.event_manager.client, ) self.session.listen_post_commit(lambda: self.event_manager.publish(event)) @@ -116,6 +117,7 @@ def delete(self, scan_profile: ScanProfileBase, valid_time: datetime) -> None: reference=scan_profile.reference, valid_time=valid_time, old_data=scan_profile, + client=self.event_manager.client, ) self.session.listen_post_commit(lambda: self.event_manager.publish(event)) diff --git a/octopoes/octopoes/xtdb/client.py b/octopoes/octopoes/xtdb/client.py index dd2ce2cc9dc..7471e6fe4ad 100644 --- a/octopoes/octopoes/xtdb/client.py +++ b/octopoes/octopoes/xtdb/client.py @@ -40,8 +40,8 @@ def __init__(self, base_url: str): self._base_url = base_url self.headers["Accept"] = "application/json" - def request(self, method: str, url: str | bytes, **kwargs) -> requests.Response: - return super().request(method, self._base_url + str(url), **kwargs) + def request(self, method: str | bytes, url: str | bytes, *args, **kwargs) -> requests.Response: + return super().request(method, self._base_url + str(url), *args, **kwargs) class XTDBStatus(BaseModel): @@ -187,8 +187,8 @@ class XTDBSession: def __init__(self, client: XTDBHTTPClient): self.client = client - self._operations = [] - self.post_commit_callbacks = [] + self._operations: list[Operation] = [] + self.post_commit_callbacks: list[Callable[[], None]] = [] def __enter__(self): return self diff --git a/octopoes/octopoes/xtdb/query.py b/octopoes/octopoes/xtdb/query.py index 1253b5e4410..f7e59bbb923 100644 --- a/octopoes/octopoes/xtdb/query.py +++ b/octopoes/octopoes/xtdb/query.py @@ -43,7 +43,6 @@ class Aliased: Ref = type[OOI] | Aliased -A = Aliased @dataclass @@ -92,9 +91,12 @@ def from_path(cls, path: Path) -> "Query": ooi_type = path.segments[-1].target_type query = cls(ooi_type) - target_ref = None + target_ref: Ref alias_map: dict[str, Ref] = {} + if not path.segments: + return query + for segment in path.segments: source_ref = alias_map.get(segment.source_type.get_object_type(), segment.source_type) @@ -103,9 +105,9 @@ def from_path(cls, path: Path) -> "Query": if segment.target_type.get_object_type() not in alias_map: target_ref = segment.target_type - alias_map[target_ref.get_object_type()] = target_ref + alias_map[segment.target_type.get_object_type()] = target_ref else: - target_ref = A(segment.target_type) + target_ref = Aliased(segment.target_type) alias_map[segment.target_type.get_object_type()] = target_ref if segment.direction is Direction.OUTGOING: @@ -113,8 +115,8 @@ def from_path(cls, path: Path) -> "Query": else: query = query.where(target_ref, **{segment.property_name: source_ref}) - if target_ref: # Make sure we use the last reference in the path as a target - query.result_type = target_ref + # Make sure we use the last reference in the path as a target + query.result_type = target_ref return query @@ -265,5 +267,8 @@ def _get_object_alias(self, object_type: Ref) -> str: def __str__(self) -> str: return self._compile() - def __eq__(self, other: "Query"): + def __eq__(self, other: object): + if not isinstance(other, Query): + return NotImplemented + return str(self) == str(other) diff --git a/octopoes/octopoes/xtdb/query_builder.py b/octopoes/octopoes/xtdb/query_builder.py index 97de6f301d7..172cfc745c2 100644 --- a/octopoes/octopoes/xtdb/query_builder.py +++ b/octopoes/octopoes/xtdb/query_builder.py @@ -1,13 +1,11 @@ import re -from collections.abc import Iterator +from collections.abc import Iterable, Mapping +from typing import Any -from octopoes.xtdb.related_field_generator import ( - FieldSet, - RelatedFieldNode, -) +from octopoes.xtdb.related_field_generator import FieldSet, RelatedFieldNode -def join_csv(values: Iterator[any]) -> str: +def join_csv(values: Iterable[Any]) -> str: return " ".join(values) @@ -19,8 +17,8 @@ def str_val(val): def generate_pull_query( - field_set: FieldSet | None = FieldSet.ALL_FIELDS, - where: dict[str, str | int | list[str | int] | set[str | int]] | None = None, + field_set: FieldSet = FieldSet.ALL_FIELDS, + where: Mapping[str, str | int | list[str] | list[int] | set[str] | set[int]] | None = None, offset: int | None = None, limit: int | None = None, field_node: RelatedFieldNode | None = None, diff --git a/octopoes/octopoes/xtdb/related_field_generator.py b/octopoes/octopoes/xtdb/related_field_generator.py index 2b7975b0992..ac48a8f0e88 100644 --- a/octopoes/octopoes/xtdb/related_field_generator.py +++ b/octopoes/octopoes/xtdb/related_field_generator.py @@ -6,7 +6,7 @@ def __init__( self, data_model: Datamodel, object_types: set[str], - path: tuple[ForeignKey, ...] | None = (), + path: tuple[ForeignKey, ...] = (), ): self.data_model = data_model self.object_types = object_types @@ -89,15 +89,15 @@ def generate_field(self, field_set: FieldSet, pk_prefix: str): # Loop outgoing QueryNodes fields = [f"{queried_fields}"] - for key, node in self.relations_out.items(): - cls, attr_name = key + for key_out, node in self.relations_out.items(): + cls, attr_name = key_out deeper_fields = node.generate_field(field_set, pk_prefix) field_query = f"{{(:{cls}/{attr_name} {{:as {attr_name}}}) {deeper_fields}}}" fields.append(field_query) # Loop incoming QueryNodes - for key, node in self.relations_in.items(): - foreign_cls, attr_name, reverse_name = key + for key_in, node in self.relations_in.items(): + foreign_cls, attr_name, reverse_name = key_in deeper_fields = node.generate_field(field_set, pk_prefix) field_query = f"{{(:{foreign_cls}/_{attr_name} {{:as {reverse_name}}}) {deeper_fields}}}" fields.append(field_query) @@ -144,9 +144,9 @@ def to_dict(self): """ d = {} if self.relations_out: - for p, v in self.relations_out.items(): - d[f"{p[0]}/{p[1]}"] = v.to_dict() + for key_out, node in self.relations_out.items(): + d[f"{key_out[0]}/{key_out[1]}"] = node.to_dict() if self.relations_in: - for p, v in self.relations_in.items(): - d[f"{p[0]}/_{p[1]} as {p[0]}/_{p[1]}"] = v.to_dict() + for key_in, node in self.relations_in.items(): + d[f"{key_in[0]}/_{key_in[1]} as {key_in[0]}/_{key_in[1]}"] = node.to_dict() return d diff --git a/octopoes/tests/conftest.py b/octopoes/tests/conftest.py index 8c8864349ad..0793c13463e 100644 --- a/octopoes/tests/conftest.py +++ b/octopoes/tests/conftest.py @@ -243,6 +243,7 @@ class MockEventManager: def __init__(self): self.queue = [] self.processed = [0] + self.client = "test" def publish(self, event) -> None: self.queue.append(event) @@ -314,7 +315,7 @@ def mock_xtdb_session(): @pytest.fixture def origin_repository(mock_xtdb_session): - yield XTDBOriginRepository(Mock(spec=EventManager), mock_xtdb_session) + yield XTDBOriginRepository(Mock(spec=EventManager, client="test"), mock_xtdb_session) def seed_system(xtdb_ooi_repository: XTDBOOIRepository, xtdb_origin_repository: XTDBOriginRepository, valid_time): diff --git a/octopoes/tests/robot/03_deletion_propagation.robot b/octopoes/tests/robot/03_deletion_propagation.robot index 1bdd468ff6e..1b54eb04295 100644 --- a/octopoes/tests/robot/03_deletion_propagation.robot +++ b/octopoes/tests/robot/03_deletion_propagation.robot @@ -3,12 +3,10 @@ Resource robot.resource Test Setup Setup Test Test Teardown Teardown Test - - -*** Test Cases *** -Propagate Deletion - Insert Empty Normalizer Output - Await Sync +# *** Test Cases *** +# Propagate Deletion +# Insert Empty Normalizer Output +# Await Sync # This test fails because of circular origins proving eachother's existence # Object List Should Be Empty diff --git a/octopoes/tests/test_event_manager.py b/octopoes/tests/test_event_manager.py index 229508cf18b..133aca06ef7 100644 --- a/octopoes/tests/test_event_manager.py +++ b/octopoes/tests/test_event_manager.py @@ -13,7 +13,9 @@ def test_event_manager_create_ooi(mocker, network): mocker.patch.object(uuid, "uuid4", return_value="1754a4c8-f0b8-42c8-b294-5706ce23a47d") manager = EventManager("test", "amqp://test-queue-uri", celery_mock, "queue", lambda x: channel_mock) - event = OOIDBEvent(operation_type=OperationType.CREATE, valid_time=datetime(2023, 1, 1), new_data=network) + event = OOIDBEvent( + operation_type=OperationType.CREATE, valid_time=datetime(2023, 1, 1), new_data=network, client="test" + ) manager.publish(event) celery_mock.send_task.assert_called_once_with( @@ -51,6 +53,7 @@ def test_event_manager_create_empty_scan_profile(mocker, empty_scan_profile): valid_time=datetime(2023, 1, 1), new_data=empty_scan_profile, reference="test_reference", + client="test", ) manager.publish(event) @@ -92,6 +95,7 @@ def test_event_manager_create_declared_scan_profile(mocker, declared_scan_profil valid_time=datetime(2023, 1, 1), new_data=declared_scan_profile, reference="test_reference", + client="test", ) manager.publish(event) @@ -144,6 +148,7 @@ def test_event_manager_delete_empty_scan_profile(mocker, empty_scan_profile): valid_time=datetime(2023, 1, 1), old_data=empty_scan_profile, reference="test_reference", + client="test", ) manager.publish(event) diff --git a/octopoes/tests/test_octopoes_service.py b/octopoes/tests/test_octopoes_service.py index eb452ce64be..5eac373b197 100644 --- a/octopoes/tests/test_octopoes_service.py +++ b/octopoes/tests/test_octopoes_service.py @@ -6,7 +6,7 @@ from bits.definitions import BitDefinition from octopoes.events.events import OOIDBEvent, OperationType, OriginDBEvent, ScanProfileDBEvent -from octopoes.models import EmptyScanProfile, Reference +from octopoes.models import EmptyScanProfile, Reference, ScanLevel from octopoes.models.ooi.dns.zone import Hostname from octopoes.models.ooi.network import IPAddress, IPAddressV4, Network from octopoes.models.origin import Origin, OriginType @@ -112,7 +112,7 @@ def test_on_create_scan_profile(octopoes_service, new_data, old_data, bit_runner source=Reference.from_str("Hostname|internet|example.com"), ) ] - octopoes_service.scan_profile_repository.get.return_value = Mock(level=2) + octopoes_service.scan_profile_repository.get.return_value = Mock(level=ScanLevel.L2) octopoes_service.ooi_repository.get.return_value = Mock() octopoes_service.origin_parameter_repository.list_by_origin.return_value = {} octopoes_service.ooi_repository.load_bulk.return_value = {} @@ -127,6 +127,7 @@ def test_on_create_scan_profile(octopoes_service, new_data, old_data, bit_runner old_data=old_data, new_data=new_data, reference="test_reference", + client="_dev", ) octopoes_service.process_event(event) diff --git a/octopoes/tests/test_query.py b/octopoes/tests/test_query.py index 0469831afb8..dfe545db486 100644 --- a/octopoes/tests/test_query.py +++ b/octopoes/tests/test_query.py @@ -10,7 +10,7 @@ from octopoes.models.ooi.service import IPService, Service from octopoes.models.ooi.web import Website from octopoes.models.path import Path -from octopoes.xtdb.query import A, InvalidField, Query +from octopoes.xtdb.query import Aliased, InvalidField, Query def test_basic_field_where_clause(): @@ -201,8 +201,8 @@ def test_create_query_from_path_abstract(): def test_aliased_query(): - h1 = A(Hostname, UUID("4b4afa7e-5b76-4506-a373-069216b051c2")) - h2 = A(Hostname, UUID("98076f7a-7606-47ac-85b7-b511ee21ae42")) + h1 = Aliased(Hostname, UUID("4b4afa7e-5b76-4506-a373-069216b051c2")) + h2 = Aliased(Hostname, UUID("98076f7a-7606-47ac-85b7-b511ee21ae42")) query = ( Query(DNSAAAARecord) .where(DNSAAAARecord, hostname=h1) @@ -269,8 +269,8 @@ def test_build_system_query_with_path_segments(mocker): uuid_mock = mocker.patch("octopoes.xtdb.query.uuid4") uuid_mock.side_effect = uuid_batch - resolved_hostname_alias = A(ResolvedHostname) - hostname_alias = A(Hostname) + resolved_hostname_alias = Aliased(ResolvedHostname) + hostname_alias = Aliased(Hostname) query = ( Query(hostname_alias) diff --git a/rocky/katalogus/views/mixins.py b/rocky/katalogus/views/mixins.py index de171ad336f..5e75ae2b8c6 100644 --- a/rocky/katalogus/views/mixins.py +++ b/rocky/katalogus/views/mixins.py @@ -112,7 +112,7 @@ def run_boefje_for_oois( self.run_boefje(boefje, None) for ooi in oois: - if ooi.scan_profile.level < boefje.scan_level: + if ooi.scan_profile and ooi.scan_profile.level < boefje.scan_level: try: self.raise_clearance_level(ooi.reference, boefje.scan_level) except IndemnificationNotPresentException: diff --git a/rocky/katalogus/views/plugin_detail.py b/rocky/katalogus/views/plugin_detail.py index d761534d984..9649738826c 100644 --- a/rocky/katalogus/views/plugin_detail.py +++ b/rocky/katalogus/views/plugin_detail.py @@ -237,7 +237,7 @@ def get_oois(self, selected_oois: list[str]) -> dict[str, Any]: oois_without_clearance = [] for ooi in selected_oois: ooi_object = self.get_single_ooi(pk=ooi) - if ooi_object.scan_profile.level >= self.plugin.scan_level.value: + if ooi_object.scan_profile and ooi_object.scan_profile.level >= self.plugin.scan_level.value: oois_with_clearance.append(ooi_object) else: oois_without_clearance.append(ooi_object.primary_key) diff --git a/rocky/reports/forms.py b/rocky/reports/forms.py index ad399ad07a2..ed34728cf30 100644 --- a/rocky/reports/forms.py +++ b/rocky/reports/forms.py @@ -2,7 +2,6 @@ from django.utils.translation import gettext_lazy as _ from tools.forms.base import BaseRockyForm -from octopoes.models.types import OOIType from reports.report_types.definitions import Report @@ -13,7 +12,7 @@ class OOITypeMultiCheckboxForReportForm(BaseRockyForm): widget=forms.CheckboxSelectMultiple, ) - def __init__(self, ooi_types: list[OOIType], *args, **kwargs): + def __init__(self, ooi_types: list[str], *args, **kwargs): super().__init__(*args, **kwargs) self.fields["ooi_type"].choices = ((ooi_type, ooi_type) for ooi_type in ooi_types) diff --git a/rocky/reports/report_types/definitions.py b/rocky/reports/report_types/definitions.py index 4d36332f775..ec4fc7c399b 100644 --- a/rocky/reports/report_types/definitions.py +++ b/rocky/reports/report_types/definitions.py @@ -4,7 +4,7 @@ from typing import Any, TypedDict from octopoes.connector.octopoes import OctopoesAPIConnector -from octopoes.models.types import OOIType +from octopoes.models import OOI REPORTS_DIR = Path(__file__).parent logger = getLogger(__name__) @@ -27,7 +27,7 @@ def __init__(self, octopoes_api_connector: OctopoesAPIConnector): class Report(BaseReport): plugins: ReportPlugins - input_ooi_types: set[OOIType] + input_ooi_types: set[type[OOI]] def generate_data(self, input_ooi: str, valid_time: datetime) -> dict[str, Any]: raise NotImplementedError @@ -46,7 +46,7 @@ def class_attributes(cls) -> dict[str, Any]: class MultiReport(BaseReport): plugins: ReportPlugins - input_ooi_types: set[OOIType] + input_ooi_types: set[type[OOI]] def post_process_data(self, data: dict[str, Any]) -> dict[str, Any]: raise NotImplementedError diff --git a/rocky/reports/report_types/dns_report/report.py b/rocky/reports/report_types/dns_report/report.py index 6fbfbef699b..00e8123c215 100644 --- a/rocky/reports/report_types/dns_report/report.py +++ b/rocky/reports/report_types/dns_report/report.py @@ -42,7 +42,7 @@ def generate_data(self, input_ooi: str, valid_time: datetime) -> dict[str, Any]: records.append( { "type": ooi.dns_record_type, - "ttl": round(ooi.ttl / 60), + "ttl": round(ooi.ttl / 60) if ooi.ttl else "", "name": ooi.hostname.tokenized.name, "content": ooi.value, } diff --git a/rocky/reports/report_types/ipv6_report/report.py b/rocky/reports/report_types/ipv6_report/report.py index 0ceedb1d1cd..a1e191f9515 100644 --- a/rocky/reports/report_types/ipv6_report/report.py +++ b/rocky/reports/report_types/ipv6_report/report.py @@ -36,9 +36,9 @@ def generate_data(self, input_ooi: str, valid_time: datetime) -> dict[str, Any]: """ try: ooi = self.octopoes_api_connector.get(Reference.from_str(input_ooi), valid_time) - except ObjectNotFoundException as e: - logger.error("No data found for OOI '%s' on date %s.", str(e), str(valid_time)) - raise ObjectNotFoundException(e) + except ObjectNotFoundException: + logger.error("No data found for OOI '%s' on date %s.", str(ooi), str(valid_time)) + raise if ooi.reference.class_type == IPAddressV4 or ooi.reference.class_type == IPAddressV6: path = Path.parse("IPAddress.
dict[str, Any]: try: ooi = self.octopoes_api_connector.get(Reference.from_str(input_ooi), valid_time) - except ObjectNotFoundException as e: - logger.error("No data found for OOI '%s' on date %s.", str(e), str(valid_time)) - raise ObjectNotFoundException(e) + except ObjectNotFoundException: + logger.error("No data found for OOI '%s' on date %s.", str(ooi), str(valid_time)) + raise if ooi.reference.class_type == Hostname: hostnames = [ooi] diff --git a/rocky/reports/report_types/name_server_report/report.py b/rocky/reports/report_types/name_server_report/report.py index bcba4e0481d..904c0ce3770 100644 --- a/rocky/reports/report_types/name_server_report/report.py +++ b/rocky/reports/report_types/name_server_report/report.py @@ -71,9 +71,9 @@ def generate_data(self, input_ooi: str, valid_time: datetime) -> dict[str, Any]: try: ooi = self.octopoes_api_connector.get(Reference.from_str(input_ooi), valid_time) - except ObjectNotFoundException as e: - logger.error("No data found for OOI '%s' on date %s.", str(e), str(valid_time)) - raise ObjectNotFoundException(e) + except ObjectNotFoundException: + logger.error("No data found for OOI '%s' on date %s.", str(ooi), str(valid_time)) + raise if ooi.reference.class_type == Hostname: hostnames = [ooi] diff --git a/rocky/reports/report_types/open_ports_report/report.py b/rocky/reports/report_types/open_ports_report/report.py index 39ab58ae74f..2289b42aa0b 100644 --- a/rocky/reports/report_types/open_ports_report/report.py +++ b/rocky/reports/report_types/open_ports_report/report.py @@ -25,9 +25,9 @@ class OpenPortsReport(Report): def generate_data(self, input_ooi: str, valid_time: datetime) -> dict[str, Any]: try: ooi = self.octopoes_api_connector.get(Reference.from_str(input_ooi), valid_time) - except ObjectNotFoundException as e: - logger.error("No data found for OOI '%s' on date %s.", str(e), str(valid_time)) - raise ObjectNotFoundException(e) + except ObjectNotFoundException: + logger.error("No data found for OOI '%s' on date %s.", str(ooi), str(valid_time)) + raise if ooi.reference.class_type == Hostname: path = Path.parse("Hostname. dict[str, Any]: ) ] - results[ref] = {"ports": port_numbers, "hostnames": hostnames, "services": services} + results[str(ref)] = {"ports": port_numbers, "hostnames": hostnames, "services": services} return results diff --git a/rocky/reports/report_types/rpki_report/report.py b/rocky/reports/report_types/rpki_report/report.py index 359bc9ab67c..7f377e45c83 100644 --- a/rocky/reports/report_types/rpki_report/report.py +++ b/rocky/reports/report_types/rpki_report/report.py @@ -32,9 +32,9 @@ class RPKIReport(Report): def generate_data(self, input_ooi: str, valid_time: datetime) -> dict[str, Any]: try: ooi = self.octopoes_api_connector.get(Reference.from_str(input_ooi), valid_time) - except ObjectNotFoundException as e: - logger.error("No data found for OOI '%s' on date %s.", str(e), str(valid_time)) - raise ObjectNotFoundException(e) + except ObjectNotFoundException: + logger.error("No data found for OOI '%s' on date %s.", str(ooi), str(valid_time)) + raise if ooi.reference.class_type == Hostname: ips = self.octopoes_api_connector.query( diff --git a/rocky/reports/report_types/safe_connections_report/report.py b/rocky/reports/report_types/safe_connections_report/report.py index 91ea4b35b30..7099d0992c0 100644 --- a/rocky/reports/report_types/safe_connections_report/report.py +++ b/rocky/reports/report_types/safe_connections_report/report.py @@ -26,9 +26,9 @@ class SafeConnectionsReport(Report): def generate_data(self, input_ooi: str, valid_time: datetime) -> dict[str, Any]: try: ooi = self.octopoes_api_connector.get(Reference.from_str(input_ooi), valid_time) - except ObjectNotFoundException as e: - logger.error("No data found for OOI '%s' on date %s.", str(e), str(valid_time)) - raise ObjectNotFoundException(e) + except ObjectNotFoundException: + logger.error("No data found for OOI '%s' on date %s.", str(ooi), str(valid_time)) + raise if ooi.reference.class_type == Hostname: ips = self.octopoes_api_connector.query( diff --git a/rocky/reports/report_types/systems_report/report.py b/rocky/reports/report_types/systems_report/report.py index b95a299d0e0..fd14a5eaceb 100644 --- a/rocky/reports/report_types/systems_report/report.py +++ b/rocky/reports/report_types/systems_report/report.py @@ -42,9 +42,9 @@ def generate_data(self, input_ooi: str, valid_time: datetime) -> dict[str, Any]: try: ooi = self.octopoes_api_connector.get(Reference.from_str(input_ooi), valid_time) - except ObjectNotFoundException as e: - logger.error("No data found for OOI '%s' on date %s.", str(e), str(valid_time)) - raise ObjectNotFoundException(e) + except ObjectNotFoundException: + logger.error("No data found for OOI '%s' on date %s.", str(ooi), str(valid_time)) + raise if ooi.reference.class_type == Hostname: ips = self.octopoes_api_connector.query( @@ -53,7 +53,7 @@ def generate_data(self, input_ooi: str, valid_time: datetime) -> dict[str, Any]: elif ooi.reference.class_type in (IPAddressV4, IPAddressV6): ips = [ooi] - ip_services = {} + ip_services: dict[Reference, dict[str, list[Reference | str]]] = {} service_mapping = { "http": SystemType.WEB, diff --git a/rocky/reports/report_types/tls_report/report.py b/rocky/reports/report_types/tls_report/report.py index a2d06411f4c..199416dcee1 100644 --- a/rocky/reports/report_types/tls_report/report.py +++ b/rocky/reports/report_types/tls_report/report.py @@ -24,8 +24,8 @@ class TLSReport(Report): template_path = "tls_report/report.html" def generate_data(self, input_ooi: str, valid_time: datetime) -> dict[str, Any]: - suites = {} - findings = [] + suites: dict = {} + findings: list[Finding] = [] suites_with_findings = [] ref = Reference.from_str(input_ooi) tree = self.octopoes_api_connector.get_tree( @@ -40,7 +40,7 @@ def generate_data(self, input_ooi: str, valid_time: datetime) -> dict[str, Any]: for protocol, cipher_suites in suites.items(): for suite in cipher_suites: for finding in findings: - if suite["cipher_suite_name"] in finding.description: + if finding.description and suite["cipher_suite_name"] in finding.description: suites_with_findings.append(suite["cipher_suite_name"]) return { diff --git a/rocky/reports/report_types/vulnerability_report/report.py b/rocky/reports/report_types/vulnerability_report/report.py index c7e8ad9ee9f..854d72c3686 100644 --- a/rocky/reports/report_types/vulnerability_report/report.py +++ b/rocky/reports/report_types/vulnerability_report/report.py @@ -31,7 +31,7 @@ class VulnerabilityReport(Report): input_ooi_types = {Hostname, IPAddressV4, IPAddressV6} template_path = "vulnerability_report/report.html" - def get_finding_valid_time_history(self, reference: str) -> list[datetime]: + def get_finding_valid_time_history(self, reference: Reference) -> list[datetime]: transaction_record = self.octopoes_api_connector.get_history(reference=reference) valid_time_history = [transaction.valid_time for transaction in transaction_record] return valid_time_history @@ -43,9 +43,9 @@ def get_findings(self, input_ooi: str, valid_time: datetime) -> dict[str, Findin try: ooi = self.octopoes_api_connector.get(Reference.from_str(input_ooi), valid_time) - except ObjectNotFoundException as e: - logger.error("No data found for OOI '%s' on date %s.", str(e), str(valid_time)) - raise ObjectNotFoundException(e) + except ObjectNotFoundException: + logger.error("No data found for OOI '%s' on date %s.", str(ooi), str(valid_time)) + raise if ooi.reference.class_type == Hostname: ips = self.octopoes_api_connector.query( @@ -108,7 +108,7 @@ def generate_data(self, input_ooi: str, valid_time: datetime) -> dict[str, dict[ for finding in findings: if finding.finding_type.tokenized.id == finding_type.id: - time_history = self.get_finding_valid_time_history(finding.primary_key) + time_history = self.get_finding_valid_time_history(finding.reference) if time_history: first_seen = str(time_history[0]) @@ -127,11 +127,13 @@ def generate_data(self, input_ooi: str, valid_time: datetime) -> dict[str, dict[ vulnerabilities[finding_type.id] = { "cvss": { - "class": str(finding_type.risk_severity.value).lower(), + "class": str(finding_type.risk_severity.value).lower() if finding_type.risk_severity else "-", "score": finding_type.risk_score, "risk_level": " ".join( [str(finding_type.risk_severity.value).title(), str(finding_type.risk_score)] - ), + ) + if finding_type.risk_severity + else "-", }, "occurrences": occurrences[finding_type.id], "description": finding_type.description or "-", @@ -139,7 +141,7 @@ def generate_data(self, input_ooi: str, valid_time: datetime) -> dict[str, dict[ "findings": filtered_findings, } - if finding_type.risk_severity.value == critical: + if finding_type.risk_severity and finding_type.risk_severity.value == critical: total_criticals += 1 if finding_type.recommendation: diff --git a/rocky/reports/report_types/web_system_report/report.py b/rocky/reports/report_types/web_system_report/report.py index bb766c74d6e..4f80de26f65 100644 --- a/rocky/reports/report_types/web_system_report/report.py +++ b/rocky/reports/report_types/web_system_report/report.py @@ -115,9 +115,9 @@ def generate_data(self, input_ooi: str, valid_time: datetime) -> dict[str, Any]: try: ooi = self.octopoes_api_connector.get(Reference.from_str(input_ooi), valid_time) - except ObjectNotFoundException as e: - logger.error("No data found for OOI '%s' on date %s.", str(e), str(valid_time)) - raise ObjectNotFoundException(e) + except ObjectNotFoundException: + logger.error("No data found for OOI '%s' on date %s.", str(ooi), str(valid_time)) + raise if ooi.reference.class_type == Hostname: hostnames = [ooi] diff --git a/rocky/reports/views/base.py b/rocky/reports/views/base.py index 973cb0ce1ce..290174e9f17 100644 --- a/rocky/reports/views/base.py +++ b/rocky/reports/views/base.py @@ -1,4 +1,4 @@ -from collections.abc import Sequence +from collections.abc import Iterable, Sequence from datetime import datetime from logging import getLogger from typing import Any @@ -16,7 +16,6 @@ from tools.view_helpers import BreadcrumbsMixin from octopoes.models import OOI -from octopoes.models.types import OOIType from reports.forms import OOITypeMultiCheckboxForReportForm from reports.report_types.definitions import MultiReport, Report, ReportType from reports.report_types.helpers import get_plugins_for_report_ids, get_report_by_id @@ -107,7 +106,7 @@ def get_oois(self) -> list[OOI]: logger.warning("No data could be found for '%s' ", ooi_id) return oois - def get_ooi_filter_forms(self, ooi_types: set[OOIType]) -> dict[str, Form]: + def get_ooi_filter_forms(self, ooi_types: Iterable[type[OOI]]) -> dict[str, Form]: return { "ooi_type_form": OOITypeMultiCheckboxForReportForm( sorted([ooi_class.get_ooi_type() for ooi_class in ooi_types]), self.request.GET diff --git a/rocky/rocky/bytes_client.py b/rocky/rocky/bytes_client.py index 6714074059f..e2e11e69652 100644 --- a/rocky/rocky/bytes_client.py +++ b/rocky/rocky/bytes_client.py @@ -139,7 +139,7 @@ def get_raw_metas(self, boefje_meta_id: uuid.UUID, organization_code: str) -> li return metas - def get_normalizer_meta(self, normalizer_meta_id: str) -> dict: + def get_normalizer_meta(self, normalizer_meta_id: uuid.UUID) -> dict: # Note: we assume organization permissions are handled before requesting raw data. response = self.session.get(f"{self.base_url}/bytes/normalizer_meta/{normalizer_meta_id}") diff --git a/rocky/rocky/views/finding_list.py b/rocky/rocky/views/finding_list.py index 6e9ee4e5dd6..042fd67a250 100644 --- a/rocky/rocky/views/finding_list.py +++ b/rocky/rocky/views/finding_list.py @@ -35,13 +35,13 @@ def generate_findings_metadata( for finding in findings[: FindingList.HARD_LIMIT]: finding_type = finding.finding_type - if not severity_filter or finding_type.risk_severity in severity_filter: + if not severity_filter or (finding_type.risk_severity and finding_type.risk_severity in severity_filter): findings_meta.append( { "finding_number": 0, "finding": finding, "finding_type": finding_type, - "severity": finding_type.risk_severity.name, + "severity": finding_type.risk_severity.name if finding_type.risk_severity else "", "risk_level_score": finding_type.risk_score, } ) diff --git a/rocky/rocky/views/mixins.py b/rocky/rocky/views/mixins.py index 6e6c41872f4..adf78a3f60a 100644 --- a/rocky/rocky/views/mixins.py +++ b/rocky/rocky/views/mixins.py @@ -87,11 +87,13 @@ class OctopoesView(ObservedAtMixin, OrganizationView): def get_single_ooi(self, pk: str) -> OOI: try: ref = Reference.from_str(pk) - return self.octopoes_api_connector.get(ref, valid_time=self.observed_at) + ooi = self.octopoes_api_connector.get(ref, valid_time=self.observed_at) except Exception as e: # TODO: raise the exception but let the handling be done by the method that implements "get_single_ooi" self.handle_connector_exception(e) + return ooi + def get_origins( self, reference: Reference, @@ -102,7 +104,7 @@ def get_origins( origin_data = [OriginData(origin=origin) for origin in origins] for origin in origin_data: - if origin.origin.origin_type != OriginType.OBSERVATION: + if origin.origin.origin_type != OriginType.OBSERVATION or not origin.origin.task_id: continue try: diff --git a/rocky/rocky/views/ooi_detail_related_object.py b/rocky/rocky/views/ooi_detail_related_object.py index 89f1373972a..82c2e3eb115 100644 --- a/rocky/rocky/views/ooi_detail_related_object.py +++ b/rocky/rocky/views/ooi_detail_related_object.py @@ -47,7 +47,7 @@ def count_findings_per_severity(self) -> Counter: counter = Counter({severity: 0 for severity in RiskLevelSeverity}) for finding in self.get_findings(): finding_type: FindingType | None = self.tree.store.get(str(finding.finding_type), None) - if finding_type is not None: + if finding_type is not None and finding_type.risk_severity is not None: counter.update([finding_type.risk_severity]) else: counter.update([RiskLevelSeverity.UNKNOWN]) @@ -55,7 +55,7 @@ def count_findings_per_severity(self) -> Counter: def get_finding_details_sorted_by_score_desc(self) -> list[tuple[Finding, FindingType]]: finding_details = self.get_finding_details() - return list(sorted(finding_details, key=lambda x: x[1].risk_score, reverse=True)) + return list(sorted(finding_details, key=lambda x: x[1].risk_score or 0, reverse=True)) def get_finding_details(self) -> list[tuple[Finding, FindingType]]: return [(finding, self.tree.store[str(finding.finding_type)]) for finding in self.get_findings()] diff --git a/rocky/rocky/views/ooi_mute.py b/rocky/rocky/views/ooi_mute.py index 7fe6ce95b46..a95759ffe71 100644 --- a/rocky/rocky/views/ooi_mute.py +++ b/rocky/rocky/views/ooi_mute.py @@ -47,7 +47,7 @@ def post(self, request, *args, **kwargs): messages.add_message(self.request, messages.WARNING, _("Please select at least one finding.")) return redirect(reverse("finding_list", kwargs={"organization_code": self.organization.code})) if unmute: - mutes_finding_refs = [MutedFinding(finding=finding) for finding in selected_findings] + mutes_finding_refs = [MutedFinding(finding=finding).reference for finding in selected_findings] self.octopoes_api_connector.delete_many(mutes_finding_refs, datetime.now(timezone.utc)) messages.add_message(self.request, messages.SUCCESS, _("Finding(s) successfully unmuted.")) diff --git a/rocky/rocky/views/scan_profile.py b/rocky/rocky/views/scan_profile.py index f7d98a407fd..1c3f3205508 100644 --- a/rocky/rocky/views/scan_profile.py +++ b/rocky/rocky/views/scan_profile.py @@ -105,7 +105,7 @@ def post(self, request, *args, **kwargs): def get_initial(self): initial = super().get_initial() - if not isinstance(self.ooi.scan_profile, InheritedScanProfile): + if self.ooi.scan_profile and not isinstance(self.ooi.scan_profile, InheritedScanProfile): initial["level"] = self.ooi.scan_profile.level return initial @@ -116,7 +116,7 @@ class ScanProfileResetView(OOIDetailView): def get(self, request, *args, **kwargs): result = super().get(request, *args, **kwargs) - if self.ooi.scan_profile.scan_profile_type != "declared": + if not self.ooi.scan_profile or self.ooi.scan_profile.scan_profile_type != "declared": messages.add_message( self.request, messages.WARNING,