Skip to content

Commit

Permalink
Merge branch 'main' into fix-typing-octopoes
Browse files Browse the repository at this point in the history
  • Loading branch information
dekkers committed Feb 29, 2024
2 parents 40b2b8c + aa231a3 commit 604f06e
Show file tree
Hide file tree
Showing 24 changed files with 499 additions and 111 deletions.
4 changes: 2 additions & 2 deletions bytes/bytes/api/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@ def validation_exception_handler(_: Request, exc: RequestValidationError | Valid


@router.get("/", include_in_schema=False)
def health() -> RedirectResponse:
def root() -> RedirectResponse:
return RedirectResponse(url="/health")


@router.get("/health", response_model=ServiceHealth)
def root() -> ServiceHealth:
def health() -> ServiceHealth:
bytes_health = ServiceHealth(service="bytes", healthy=True, version=__version__)
return bytes_health

Expand Down
48 changes: 47 additions & 1 deletion octopoes/octopoes/api/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from octopoes.version import __version__
from octopoes.xtdb.client import XTDBSession
from octopoes.xtdb.exceptions import XTDBException
from octopoes.xtdb.query import Aliased
from octopoes.xtdb.query import Query as XTDBQuery

logger = getLogger(__name__)
Expand Down Expand Up @@ -140,6 +141,51 @@ def query(
return octopoes.ooi_repository.query(xtdb_query, valid_time)


@router.get("/query-many", tags=["Objects"])
def query_many(
path: str,
sources: list[str] = Query(),
octopoes: OctopoesService = Depends(octopoes_service),
valid_time: datetime = Depends(extract_valid_time),
):
"""
How does this work and why do we do this?
We want to fetch all results but be able to tie these back to the source that was used for a result.
If we query "Network.hostname" for a list of Networks ids, how do we know which hostname lives on which network?
The answer is to add the network id to the "select" statement, so the result is of the form
[(network_id_1, hostname1), (network_id_2, hostname3), ...]
Because you can only select variables in Datalog, "network_id_1" needs to be an Alias. Hence `source_alias`.
We need to tie that to the Network primary_key and add a where-in clause. The example projected on the code:
q = XTDBQuery.from_path(object_path) # Adds "where ?Hostname.network = ?Network
q.find(source_alias).pull(query.result_type) # "select ?network_id, ?Hostname
.where(object_path.segments[0].source_type, primary_key=source_alias) # where ?Network.primary_key = ?network_id
.where_in(object_path.segments[0].source_type, primary_key=sources) # and ?Network.primary_key in ["1", ...]"
"""

if not sources:
return []

object_path = ObjectPath.parse(path)
if not object_path.segments:
raise HTTPException(status_code=400, detail="No path components provided.")

q = XTDBQuery.from_path(object_path)
source_alias = Aliased(object_path.segments[0].source_type, field="primary_key")

return octopoes.ooi_repository.query(
q.find(source_alias)
.pull(q.result_type)
.where(object_path.segments[0].source_type, primary_key=source_alias)
.where_in(object_path.segments[0].source_type, primary_key=sources),
valid_time,
)


@router.post("/objects/load_bulk", tags=["Objects"])
def load_objects_bulk(
octopoes: OctopoesService = Depends(octopoes_service),
Expand Down Expand Up @@ -384,11 +430,11 @@ def list_findings(
) -> Paginated[Finding]:
return octopoes.ooi_repository.list_findings(
severities,
valid_time,
exclude_muted,
only_muted,
offset,
limit,
valid_time,
)


Expand Down
25 changes: 22 additions & 3 deletions octopoes/octopoes/connector/octopoes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from collections.abc import Set
from collections.abc import Sequence, Set
from datetime import datetime
from uuid import UUID

Expand Down Expand Up @@ -272,13 +272,13 @@ def query(
self,
path: str,
valid_time: datetime,
source: Reference | str | None = None,
source: OOI | Reference | str | None = None,
offset: int = DEFAULT_OFFSET,
limit: int = DEFAULT_LIMIT,
) -> list[OOI]:
params = {
"path": path,
"source": source,
"source": source.reference if isinstance(source, OOI) else source,
"valid_time": str(valid_time),
"offset": offset,
"limit": limit,
Expand All @@ -287,3 +287,22 @@ def query(
TypeAdapter(OOIType).validate_python(ooi)
for ooi in self.session.get(f"/{self.client}/query", params=params).json()
]

def query_many(
self,
path: str,
valid_time: datetime,
sources: Sequence[OOI | Reference | str],
) -> list[tuple[str, OOIType]]:
if not sources:
return []

params = {
"path": path,
"sources": [str(ooi) for ooi in sources],
"valid_time": str(valid_time),
}

result = self.session.get(f"/{self.client}/query-many", params=params).json()

return TypeAdapter(list[tuple[str, OOIType]]).validate_python(result)
35 changes: 26 additions & 9 deletions octopoes/octopoes/repositories/ooi_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,11 @@ def count_findings_by_severity(self, valid_time: datetime) -> Counter:
def list_findings(
self,
severities,
valid_time,
exclude_muted,
only_muted,
offset,
limit,
valid_time,
) -> Paginated[Finding]:
raise NotImplementedError

Expand All @@ -141,7 +141,7 @@ def get_bit_configs(self, source: OOI, bit_definition: BitDefinition, valid_time
def list_related(self, ooi: OOI, path: Path, valid_time: datetime) -> list[OOI]:
raise NotImplementedError

def query(self, query: Query, valid_time: datetime) -> list[OOI]:
def query(self, query: Query, valid_time: datetime) -> list[OOI | tuple]:
raise NotImplementedError


Expand Down Expand Up @@ -665,16 +665,17 @@ def list_related(self, ooi: OOI, path: Path, valid_time: datetime) -> list[OOI]:
path_start_alias = path.segments[0].source_type
query = Query.from_path(path).where(path_start_alias, primary_key=ooi.primary_key)

return self.query(query, valid_time)
# query() can return different types depending on the query
return self.query(query, valid_time) # type: ignore[return-value]

def list_findings(
self,
severities: set[RiskLevelSeverity],
valid_time: datetime,
exclude_muted=False,
only_muted=False,
offset=DEFAULT_OFFSET,
limit=DEFAULT_LIMIT,
valid_time: datetime | None = None,
) -> Paginated[Finding]:
# clause to find risk_severity
concrete_finding_types = to_concrete({FindingType})
Expand Down Expand Up @@ -738,12 +739,28 @@ def list_findings(
}}
"""

res = self.session.client.query(finding_query, valid_time)
findings = [self.deserialize(x[0]) for x in res]
return Paginated(
count=count,
items=findings,
items=[x[0] for x in self.query(finding_query, valid_time)],
)

def query(self, query: Query, valid_time: datetime) -> list[OOI]:
return [self.deserialize(row[0]) for row in self.session.client.query(query, valid_time=valid_time)]
def query(self, query: str | Query, valid_time: datetime) -> list[OOI | tuple]:
results = self.session.client.query(query, valid_time=valid_time)

parsed_results: list[OOI | tuple] = []
for result in results:
parsed_result = []

for item in result:
try:
parsed_result.append(self.deserialize(item))
except (ValueError, TypeError):
parsed_result.append(item)

if len(parsed_result) == 1:
parsed_results.append(parsed_result[0])
continue

parsed_results.append(tuple(parsed_result))

return parsed_results
Loading

0 comments on commit 604f06e

Please sign in to comment.