Skip to content

Commit

Permalink
add xai llm provider
Browse files Browse the repository at this point in the history
  • Loading branch information
djcopley committed Nov 29, 2024
1 parent cf620d5 commit 2593884
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/shelloracle/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,9 @@ def _providers() -> dict[str, type[Provider]]:
from shelloracle.providers.localai import LocalAI
from shelloracle.providers.ollama import Ollama
from shelloracle.providers.openai import OpenAI
from shelloracle.providers.xai import XAI

return {Ollama.name: Ollama, OpenAI.name: OpenAI, LocalAI.name: LocalAI}
return {Ollama.name: Ollama, OpenAI.name: OpenAI, LocalAI.name: LocalAI, XAI.name: XAI}


def get_provider(name: str) -> type[Provider]:
Expand Down
38 changes: 38 additions & 0 deletions src/shelloracle/providers/xai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from collections.abc import AsyncIterator

from openai import APIError, AsyncOpenAI

from shelloracle.providers import Provider, ProviderError, Setting, system_prompt


class XAI(Provider):
name = "XAI"

api_key = Setting(default="")
model = Setting(default="grok-beta")

def __init__(self):
if not self.api_key:
msg = "No API key provided"
raise ProviderError(msg)
self.client = AsyncOpenAI(
api_key=self.api_key,
base_url="https://api.x.ai/v1",
)

async def generate(self, prompt: str) -> AsyncIterator[str]:
try:
stream = await self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
],
stream=True,
)
async for chunk in stream:
if chunk.choices[0].delta.content is not None:
yield chunk.choices[0].delta.content
except APIError as e:
msg = f"Something went wrong while querying XAI: {e}"
raise ProviderError(msg) from e

0 comments on commit 2593884

Please sign in to comment.