Skip to content

Commit

Permalink
CRF: unpack embedding result
Browse files Browse the repository at this point in the history
  • Loading branch information
fantix committed May 1, 2024
1 parent 60b6c85 commit e60188b
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions edgedb/ai/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def query_rag(

def stream_rag(
self, message: str, context: typing.Optional[types.QueryContext] = None
):
) -> typing.Iterator[str]:
with httpx_sse.connect_sse(
self.client,
"post",
Expand All @@ -138,12 +138,14 @@ def stream_rag(
for sse in event_source.iter_sse():
yield sse.data

def generate_embeddings(self, *inputs: str, model: str):
def generate_embeddings(
self, *inputs: str, model: str
) -> list[list[float]]:
resp = self.client.post(
"/embeddings", json={"input": inputs, "model": model}
)
resp.raise_for_status()
return resp.json()
return [data["embedding"] for data in resp.json()["data"]]


class AsyncEdgeDBAI(BaseEdgeDBAI):
Expand All @@ -167,7 +169,7 @@ async def query_rag(

async def stream_rag(
self, message: str, context: typing.Optional[types.QueryContext] = None
):
) -> typing.Iterator[str]:
async with httpx_sse.aconnect_sse(
self.client,
"post",
Expand All @@ -181,9 +183,11 @@ async def stream_rag(
async for sse in event_source.aiter_sse():
yield sse.data

async def generate_embeddings(self, *inputs: str, model: str):
async def generate_embeddings(
self, *inputs: str, model: str
) -> list[list[float]]:
resp = await self.client.post(
"/embeddings", json={"input": inputs, "model": model}
)
resp.raise_for_status()
return resp.json()
return [data["embedding"] for data in resp.json()["data"]]

0 comments on commit e60188b

Please sign in to comment.