Skip to content

Commit

Permalink
Fix octopoes typing
Browse files Browse the repository at this point in the history
  • Loading branch information
dekkers committed Feb 23, 2024
1 parent 9e6fcc3 commit 925fee0
Show file tree
Hide file tree
Showing 55 changed files with 272 additions and 206 deletions.
6 changes: 0 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,6 @@ repos:
^boefjes/tools |
^keiko/templates |
^mula/whitelist\.py$ |
^octopoes/bits |
^octopoes/octopoes/core |
^octopoes/octopoes/events |
^octopoes/octopoes/models |
^octopoes/octopoes/repositories |
^octopoes/octopoes/xtdb |
^octopoes/tools |
^rocky/whitelist\.py$ |
/tests/ |
Expand Down
2 changes: 1 addition & 1 deletion octopoes/.ci/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ services:
args:
ENVIRONMENT: dev
context: .
command: pytest -x tests/integration --timeout=300
command: pytest tests/integration --timeout=300
depends_on:
- xtdb
- ci_octopoes
Expand Down
10 changes: 5 additions & 5 deletions octopoes/bits/check_csp_header/check_csp_header.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def run(input_ooi: HTTPHeader, additional_oois: list, config: dict[str, str]) ->
if header.key.lower() != "content-security-policy":
return

findings: [str] = []
findings: list[str] = []

if "http://" in header.value:
findings.append("Http should not be used in the CSP settings of an HTTP Header.")
Expand Down Expand Up @@ -100,10 +100,10 @@ def run(input_ooi: HTTPHeader, additional_oois: list, config: dict[str, str]) ->

def _ip_valid(source: str) -> bool:
"Check if there are IP's in this source, return False if the address found was to be non global. Ignores non ips"
ip = NON_DECIMAL_FILTER.sub("", source)
if ip:
ip_str = NON_DECIMAL_FILTER.sub("", source)
if ip_str:
try:
ip = ipaddress.ip_address(ip)
ip = ipaddress.ip_address(ip_str)
if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_multicast or ip.is_reserved:
return False
except ValueError:
Expand All @@ -121,7 +121,7 @@ def _create_kat_finding(header: Reference, kat_id: str, description: str) -> Ite
)


def _source_valid(policy: [str]) -> bool:
def _source_valid(policy: list[str]) -> bool:
for value in policy:
if not (
re.search(r"\S+\.\S{2,3}([\s]+|$|;|:[0-9]+)", value)
Expand Down
2 changes: 1 addition & 1 deletion octopoes/bits/check_hsts_header/check_hsts_header.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def run(input_ooi: HTTPHeader, additional_oois: list, config: dict[str, str]) ->
one_year = datetime.timedelta(days=365).total_seconds()

max_age = int(config.get("max-age", one_year)) if config else one_year
findings: [str] = []
findings: list[str] = []

headervalue = header.value.lower()
if "includesubdomains" not in headervalue:
Expand Down
5 changes: 2 additions & 3 deletions octopoes/bits/runner.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections.abc import Iterator
from importlib import import_module
from inspect import isfunction, signature
from typing import Any, Protocol, cast
from typing import Any, Protocol

from bits.definitions import BitDefinition
from octopoes.models import OOI
Expand All @@ -22,7 +22,6 @@ def __init__(self, bit_definition: BitDefinition):

def run(self, *args, **kwargs) -> list[OOI]:
module = import_module(self.module)
module = cast(Runnable, module)

if not hasattr(module, "run") or not isfunction(module.run):
raise ModuleException(f"Module {module} does not define a run function")
Expand All @@ -38,7 +37,7 @@ def __str__(self):


def _bit_run_signature(input_ooi: OOI, additional_oois: list[OOI], config: dict[str, str]) -> Iterator[OOI]:
...
yield input_ooi


BIT_SIGNATURE = signature(_bit_run_signature)
14 changes: 5 additions & 9 deletions octopoes/bits/spf_discovery/internetnl_spf_parser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright: 2022, ECP, NLnet Labs and the Internet.nl contributors
# SPDX-License-Identifier: Apache-2.0
import contextlib
import ipaddress

from pyparsing import (
Expand Down Expand Up @@ -40,16 +39,13 @@ def _parse_ipv6(tokens):
"""
match = str(tokens[0])
ipv6 = None
try:
ipv6 = ipaddress.IPv6Address(match)
return str(ipaddress.IPv6Address(match))
except ipaddress.AddressValueError:
with contextlib.suppress(ipaddress.AddressValueError, ipaddress.NetmaskValueError):
ipv6 = ipaddress.IPv6Network(match, strict=False)

if not ipv6:
raise ParseException("Non valid IPv6 address/network.")
return str(ipv6)
try:
return str(ipaddress.IPv6Network(match, strict=False))
except (ipaddress.AddressValueError, ipaddress.NetmaskValueError) as e:
raise ParseException("Non valid IPv6 address/network.") from e


SP = White(ws=" ", exact=1).suppress()
Expand Down
10 changes: 5 additions & 5 deletions octopoes/bits/spf_discovery/spf_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def run(input_ooi: DNSTXTRecord, additional_oois, config: dict[str, str]) -> Ite
yield Finding(finding_type=ft.reference, ooi=input_ooi.reference, description="This SPF record is invalid")


def parse_ip_qualifiers(mechanism: str, input_ooi: DNSTXTRecord, spf_record: DNSSPFRecord) -> Iterator[str]:
def parse_ip_qualifiers(mechanism: str, input_ooi: DNSTXTRecord, spf_record: DNSSPFRecord) -> Iterator[OOI]:
# split mechanism into qualifier and ip
qualifier, ip = mechanism.split(":", 1)
ip = mechanism[4:]
Expand All @@ -72,7 +72,7 @@ def parse_ip_qualifiers(mechanism: str, input_ooi: DNSTXTRecord, spf_record: DNS
)


def parse_a_mx_qualifiers(mechanism: str, input_ooi: DNSTXTRecord, spf_record: DNSSPFRecord) -> Iterator[str]:
def parse_a_mx_qualifiers(mechanism: str, input_ooi: DNSTXTRecord, spf_record: DNSSPFRecord) -> Iterator[OOI]:
if mechanism == "a" or mechanism == "mx":
yield DNSSPFMechanismHostname(spf_record=spf_record.reference, hostname=input_ooi.hostname, mechanism=mechanism)
else:
Expand All @@ -86,7 +86,7 @@ def parse_a_mx_qualifiers(mechanism: str, input_ooi: DNSTXTRecord, spf_record: D
spf_record=spf_record.reference, hostname=hostname.reference, mechanism=mechanism_type
)
if mechanism.startswith("a/") or mechanism.startswith("mx/"):
mechanism_type, domain = mechanism.split("/", 1)[1]
mechanism_type, domain = mechanism.split("/", 1)
# TODO: fix prefix lengths
domain = domain.split("/")[0]
hostname = Hostname(name=domain, network=Network(name=input_ooi.hostname.tokenized.network.name).reference)
Expand All @@ -98,7 +98,7 @@ def parse_a_mx_qualifiers(mechanism: str, input_ooi: DNSTXTRecord, spf_record: D

def parse_ptr_exists_include_mechanism(
mechanism: str, input_ooi: DNSTXTRecord, spf_record: DNSSPFRecord
) -> Iterator[str]:
) -> Iterator[OOI]:
if mechanism == "ptr":
yield DNSSPFMechanismHostname(spf_record=spf_record.reference, hostname=input_ooi.hostname, mechanism="ptr")
else:
Expand All @@ -113,7 +113,7 @@ def parse_ptr_exists_include_mechanism(
)


def parse_redirect_mechanism(mechanism: str, input_ooi: DNSTXTRecord, spf_record: DNSSPFRecord) -> Iterator[str]:
def parse_redirect_mechanism(mechanism: str, input_ooi: DNSTXTRecord, spf_record: DNSSPFRecord) -> Iterator[OOI]:
mechanism_type, domain = mechanism.split("=", 1)
# currently, the model only supports hostnames and not domains
if domain.startswith("_"):
Expand Down
13 changes: 8 additions & 5 deletions octopoes/octopoes/api/models.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import uuid
from datetime import datetime
from typing import Any
from typing import Annotated, Any

from pydantic import AwareDatetime, BaseModel, Field

from octopoes.models import Reference
from octopoes.models.types import OOIType
from octopoes.models.types import ConcreteOOIType, OOIType


class ServiceHealth(BaseModel):
Expand Down Expand Up @@ -59,18 +59,21 @@ class ScanProfileDeclaration(BaseModel):
valid_time: datetime


ValidatedOOIType = Annotated[ConcreteOOIType, Field(discriminator="object_type")]


# API models (timezone validation and pydantic parsing)
class ValidatedObservation(_BaseObservation):
"""Used by Octopoes API to validate and parse correctly"""

result: list[OOIType]
result: list[ValidatedOOIType]
valid_time: AwareDatetime


class ValidatedDeclaration(BaseModel):
"""Used by Octopoes API to validate and parse correctly"""

ooi: OOIType
ooi: ValidatedOOIType
valid_time: AwareDatetime
method: str | None = "manual"
task_id: uuid.UUID | None = Field(default_factory=uuid.uuid4)
Expand All @@ -79,7 +82,7 @@ class ValidatedDeclaration(BaseModel):
class ValidatedAffirmation(BaseModel):
"""Used by Octopoes API to validate and parse correctly"""

ooi: OOIType
ooi: ValidatedOOIType
valid_time: AwareDatetime
method: str | None = "hydration"
task_id: uuid.UUID | None = Field(default_factory=uuid.uuid4)
5 changes: 4 additions & 1 deletion octopoes/octopoes/api/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,10 +368,13 @@ def get_scan_profile_inheritance(
reference: Reference = Depends(extract_reference),
) -> list[InheritanceSection]:
ooi = octopoes.get_ooi(reference, valid_time)
if not ooi.scan_profile:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="OOI does not have a scanprofile")

start = InheritanceSection(
reference=ooi.reference, level=ooi.scan_profile.level, scan_profile_type=ooi.scan_profile.scan_profile_type
)
if ooi.scan_profile.scan_profile_type == ScanProfileType.DECLARED:
if ooi.scan_profile.scan_profile_type == ScanProfileType.DECLARED.value:
return [start]
return octopoes.get_scan_profile_inheritance(reference, valid_time, [start])

Expand Down
4 changes: 2 additions & 2 deletions octopoes/octopoes/connector/octopoes.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def list_objects(
scan_level: set[ScanLevel] = DEFAULT_SCAN_LEVEL_FILTER,
scan_profile_type: set[ScanProfileType] = DEFAULT_SCAN_PROFILE_TYPE_FILTER,
) -> Paginated[OOIType]:
params: dict[str, str | int | list[str] | set[str]] = {
params: dict[str, str | int | list[str] | set[str | int]] = {
"types": [t.__name__ for t in types],
"valid_time": str(valid_time),
"offset": offset,
Expand Down Expand Up @@ -169,7 +169,7 @@ def list_origins(
"source": source,
"result": result,
"task_id": str(task_id) if task_id else None,
"origin_type": origin_type,
"origin_type": str(origin_type) if origin_type else None,
},
)

Expand Down
Loading

0 comments on commit 925fee0

Please sign in to comment.