Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix octopoes typing #2555

Merged
merged 12 commits into from
Mar 1, 2024
6 changes: 0 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,6 @@ repos:
^boefjes/tools |
^keiko/templates |
^mula/whitelist\.py$ |
^octopoes/bits |
^octopoes/octopoes/core |
^octopoes/octopoes/events |
^octopoes/octopoes/models |
^octopoes/octopoes/repositories |
^octopoes/octopoes/xtdb |
^octopoes/tools |
^rocky/whitelist\.py$ |
/tests/ |
Expand Down
2 changes: 1 addition & 1 deletion octopoes/.ci/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ services:
args:
ENVIRONMENT: dev
context: .
command: pytest -x tests/integration --timeout=300
command: pytest tests/integration --timeout=300
underdarknl marked this conversation as resolved.
Show resolved Hide resolved
depends_on:
- xtdb
- ci_octopoes
Expand Down
10 changes: 5 additions & 5 deletions octopoes/bits/check_csp_header/check_csp_header.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def run(input_ooi: HTTPHeader, additional_oois: list, config: dict[str, str]) ->
if header.key.lower() != "content-security-policy":
return

findings: [str] = []
findings: list[str] = []

if "http://" in header.value:
findings.append("Http should not be used in the CSP settings of an HTTP Header.")
Expand Down Expand Up @@ -100,10 +100,10 @@ def run(input_ooi: HTTPHeader, additional_oois: list, config: dict[str, str]) ->

def _ip_valid(source: str) -> bool:
"Check if there are IP's in this source, return False if the address found was to be non global. Ignores non ips"
ip = NON_DECIMAL_FILTER.sub("", source)
if ip:
ip_str = NON_DECIMAL_FILTER.sub("", source)
if ip_str:
try:
ip = ipaddress.ip_address(ip)
ip = ipaddress.ip_address(ip_str)
if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_multicast or ip.is_reserved:
return False
except ValueError:
Expand All @@ -121,7 +121,7 @@ def _create_kat_finding(header: Reference, kat_id: str, description: str) -> Ite
)


def _source_valid(policy: [str]) -> bool:
def _source_valid(policy: list[str]) -> bool:
for value in policy:
if not (
re.search(r"\S+\.\S{2,3}([\s]+|$|;|:[0-9]+)", value)
Expand Down
2 changes: 1 addition & 1 deletion octopoes/bits/check_hsts_header/check_hsts_header.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def run(input_ooi: HTTPHeader, additional_oois: list, config: dict[str, str]) ->
one_year = datetime.timedelta(days=365).total_seconds()

max_age = int(config.get("max-age", one_year)) if config else one_year
findings: [str] = []
findings: list[str] = []

headervalue = header.value.lower()
if "includesubdomains" not in headervalue:
Expand Down
5 changes: 2 additions & 3 deletions octopoes/bits/runner.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections.abc import Iterator
from importlib import import_module
from inspect import isfunction, signature
from typing import Any, Protocol, cast
from typing import Any, Protocol

from bits.definitions import BitDefinition
from octopoes.models import OOI
Expand All @@ -22,7 +22,6 @@ def __init__(self, bit_definition: BitDefinition):

def run(self, *args, **kwargs) -> list[OOI]:
module = import_module(self.module)
module = cast(Runnable, module)

if not hasattr(module, "run") or not isfunction(module.run):
raise ModuleException(f"Module {module} does not define a run function")
Expand All @@ -38,7 +37,7 @@ def __str__(self):


def _bit_run_signature(input_ooi: OOI, additional_oois: list[OOI], config: dict[str, str]) -> Iterator[OOI]:
...
yield input_ooi


BIT_SIGNATURE = signature(_bit_run_signature)
14 changes: 5 additions & 9 deletions octopoes/bits/spf_discovery/internetnl_spf_parser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright: 2022, ECP, NLnet Labs and the Internet.nl contributors
# SPDX-License-Identifier: Apache-2.0
import contextlib
import ipaddress

from pyparsing import (
Expand Down Expand Up @@ -40,16 +39,13 @@ def _parse_ipv6(tokens):

"""
match = str(tokens[0])
ipv6 = None
try:
ipv6 = ipaddress.IPv6Address(match)
return str(ipaddress.IPv6Address(match))
except ipaddress.AddressValueError:
with contextlib.suppress(ipaddress.AddressValueError, ipaddress.NetmaskValueError):
ipv6 = ipaddress.IPv6Network(match, strict=False)

if not ipv6:
raise ParseException("Non valid IPv6 address/network.")
return str(ipv6)
try:
return str(ipaddress.IPv6Network(match, strict=False))
except (ipaddress.AddressValueError, ipaddress.NetmaskValueError) as e:
raise ParseException("Non valid IPv6 address/network.") from e


SP = White(ws=" ", exact=1).suppress()
Expand Down
10 changes: 5 additions & 5 deletions octopoes/bits/spf_discovery/spf_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def run(input_ooi: DNSTXTRecord, additional_oois, config: dict[str, str]) -> Ite
yield Finding(finding_type=ft.reference, ooi=input_ooi.reference, description="This SPF record is invalid")


def parse_ip_qualifiers(mechanism: str, input_ooi: DNSTXTRecord, spf_record: DNSSPFRecord) -> Iterator[str]:
def parse_ip_qualifiers(mechanism: str, input_ooi: DNSTXTRecord, spf_record: DNSSPFRecord) -> Iterator[OOI]:
# split mechanism into qualifier and ip
qualifier, ip = mechanism.split(":", 1)
ip = mechanism[4:]
Expand All @@ -72,7 +72,7 @@ def parse_ip_qualifiers(mechanism: str, input_ooi: DNSTXTRecord, spf_record: DNS
)


def parse_a_mx_qualifiers(mechanism: str, input_ooi: DNSTXTRecord, spf_record: DNSSPFRecord) -> Iterator[str]:
def parse_a_mx_qualifiers(mechanism: str, input_ooi: DNSTXTRecord, spf_record: DNSSPFRecord) -> Iterator[OOI]:
if mechanism == "a" or mechanism == "mx":
yield DNSSPFMechanismHostname(spf_record=spf_record.reference, hostname=input_ooi.hostname, mechanism=mechanism)
else:
Expand All @@ -86,7 +86,7 @@ def parse_a_mx_qualifiers(mechanism: str, input_ooi: DNSTXTRecord, spf_record: D
spf_record=spf_record.reference, hostname=hostname.reference, mechanism=mechanism_type
)
if mechanism.startswith("a/") or mechanism.startswith("mx/"):
mechanism_type, domain = mechanism.split("/", 1)[1]
mechanism_type, domain = mechanism.split("/", 1)
underdarknl marked this conversation as resolved.
Show resolved Hide resolved
# TODO: fix prefix lengths
domain = domain.split("/")[0]
hostname = Hostname(name=domain, network=Network(name=input_ooi.hostname.tokenized.network.name).reference)
Expand All @@ -98,7 +98,7 @@ def parse_a_mx_qualifiers(mechanism: str, input_ooi: DNSTXTRecord, spf_record: D

def parse_ptr_exists_include_mechanism(
mechanism: str, input_ooi: DNSTXTRecord, spf_record: DNSSPFRecord
) -> Iterator[str]:
) -> Iterator[OOI]:
if mechanism == "ptr":
yield DNSSPFMechanismHostname(spf_record=spf_record.reference, hostname=input_ooi.hostname, mechanism="ptr")
else:
Expand All @@ -113,7 +113,7 @@ def parse_ptr_exists_include_mechanism(
)


def parse_redirect_mechanism(mechanism: str, input_ooi: DNSTXTRecord, spf_record: DNSSPFRecord) -> Iterator[str]:
def parse_redirect_mechanism(mechanism: str, input_ooi: DNSTXTRecord, spf_record: DNSSPFRecord) -> Iterator[OOI]:
mechanism_type, domain = mechanism.split("=", 1)
# currently, the model only supports hostnames and not domains
if domain.startswith("_"):
Expand Down
13 changes: 8 additions & 5 deletions octopoes/octopoes/api/models.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import uuid
from datetime import datetime
from typing import Any
from typing import Annotated, Any

from pydantic import AwareDatetime, BaseModel, Field

from octopoes.models import Reference
from octopoes.models.types import OOIType
from octopoes.models.types import ConcreteOOIType, OOIType


class ServiceHealth(BaseModel):
Expand Down Expand Up @@ -59,18 +59,21 @@ class ScanProfileDeclaration(BaseModel):
valid_time: datetime


ValidatedOOIType = Annotated[ConcreteOOIType, Field(discriminator="object_type")]
underdarknl marked this conversation as resolved.
Show resolved Hide resolved


# API models (timezone validation and pydantic parsing)
class ValidatedObservation(_BaseObservation):
"""Used by Octopoes API to validate and parse correctly"""

result: list[OOIType]
result: list[ValidatedOOIType]
valid_time: AwareDatetime


class ValidatedDeclaration(BaseModel):
"""Used by Octopoes API to validate and parse correctly"""

ooi: OOIType
ooi: ValidatedOOIType
valid_time: AwareDatetime
method: str | None = "manual"
task_id: uuid.UUID | None = Field(default_factory=uuid.uuid4)
Expand All @@ -79,7 +82,7 @@ class ValidatedDeclaration(BaseModel):
class ValidatedAffirmation(BaseModel):
"""Used by Octopoes API to validate and parse correctly"""

ooi: OOIType
ooi: ValidatedOOIType
valid_time: AwareDatetime
method: str | None = "hydration"
task_id: uuid.UUID | None = Field(default_factory=uuid.uuid4)
13 changes: 8 additions & 5 deletions octopoes/octopoes/api/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from octopoes.version import __version__
from octopoes.xtdb.client import XTDBSession
from octopoes.xtdb.exceptions import XTDBException
from octopoes.xtdb.query import A
from octopoes.xtdb.query import Aliased
from octopoes.xtdb.query import Query as XTDBQuery

logger = getLogger(__name__)
Expand Down Expand Up @@ -144,7 +144,7 @@ def query(
@router.get("/query-many", tags=["Objects"])
def query_many(
path: str,
sources: list[Reference] = Query(),
sources: list[str] = Query(),
octopoes: OctopoesService = Depends(octopoes_service),
valid_time: datetime = Depends(extract_valid_time),
):
Expand Down Expand Up @@ -175,7 +175,7 @@ def query_many(
raise HTTPException(status_code=400, detail="No path components provided.")

q = XTDBQuery.from_path(object_path)
source_alias = A(object_path.segments[0].source_type, field="primary_key")
source_alias = Aliased(object_path.segments[0].source_type, field="primary_key")

return octopoes.ooi_repository.query(
q.find(source_alias)
Expand Down Expand Up @@ -407,10 +407,13 @@ def get_scan_profile_inheritance(
reference: Reference = Depends(extract_reference),
) -> list[InheritanceSection]:
ooi = octopoes.get_ooi(reference, valid_time)
if not ooi.scan_profile:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="OOI does not have a scanprofile")

start = InheritanceSection(
reference=ooi.reference, level=ooi.scan_profile.level, scan_profile_type=ooi.scan_profile.scan_profile_type
)
if ooi.scan_profile.scan_profile_type == ScanProfileType.DECLARED:
if ooi.scan_profile.scan_profile_type == ScanProfileType.DECLARED.value:
return [start]
return octopoes.get_scan_profile_inheritance(reference, valid_time, [start])

Expand All @@ -427,11 +430,11 @@ def list_findings(
) -> Paginated[Finding]:
return octopoes.ooi_repository.list_findings(
severities,
valid_time,
exclude_muted,
only_muted,
offset,
limit,
valid_time,
)


Expand Down
8 changes: 4 additions & 4 deletions octopoes/octopoes/connector/octopoes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from collections.abc import Set
from collections.abc import Sequence, Set
from datetime import datetime
from uuid import UUID

Expand Down Expand Up @@ -91,7 +91,7 @@ def list_objects(
scan_level: set[ScanLevel] = DEFAULT_SCAN_LEVEL_FILTER,
scan_profile_type: set[ScanProfileType] = DEFAULT_SCAN_PROFILE_TYPE_FILTER,
) -> Paginated[OOIType]:
params: dict[str, str | int | list[str] | set[str]] = {
params: dict[str, str | int | list[str] | set[str | int]] = {
"types": [t.__name__ for t in types],
"valid_time": str(valid_time),
"offset": offset,
Expand Down Expand Up @@ -163,7 +163,7 @@ def list_origins(
"source": source,
"result": result,
"task_id": str(task_id) if task_id else None,
"origin_type": origin_type,
"origin_type": str(origin_type) if origin_type else None,
},
)

Expand Down Expand Up @@ -292,7 +292,7 @@ def query_many(
self,
path: str,
valid_time: datetime,
sources: list[OOI | Reference | str],
sources: Sequence[OOI | Reference | str],
) -> list[tuple[str, OOIType]]:
if not sources:
return []
Expand Down
Loading