Skip to content

Commit

Permalink
Fix azure llm service
Browse files Browse the repository at this point in the history
  • Loading branch information
liukidar committed Dec 5, 2024
1 parent 1577a44 commit 97328c4
Showing 1 changed file with 18 additions and 4 deletions.
22 changes: 18 additions & 4 deletions fast_graphrag/_llm/_llm_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,21 @@ class OpenAILLMService(BaseLLMService):
model: Optional[str] = field(default="gpt-4o-mini")
mode: instructor.Mode = field(default=instructor.Mode.JSON)
client: Literal["openai", "azure"] = field(default="openai")
api_version: Optional[str] = field(default=None)

def __post_init__(self):
if self.client == "azure":
assert self.base_url is not None, "Azure OpenAI requires a base url."
assert (
self.base_url is not None and self.api_version is not None
), "Azure OpenAI requires a base url and an api version."
self.llm_async_client = instructor.from_openai(
AsyncAzureOpenAI(base_url=self.base_url, api_key=self.api_key, timeout=TIMEOUT_SECONDS), mode=self.mode
AsyncAzureOpenAI(
azure_endpoint=self.base_url,
api_key=self.api_key,
api_version=self.api_version,
timeout=TIMEOUT_SECONDS,
),
mode=self.mode,
)
elif self.client == "openai":
self.llm_async_client = instructor.from_openai(
Expand Down Expand Up @@ -123,11 +132,16 @@ class OpenAIEmbeddingService(BaseEmbeddingService):
max_elements_per_request: int = field(default=32)
model: Optional[str] = field(default="text-embedding-3-small")
client: Literal["openai", "azure"] = field(default="openai")
api_version: Optional[str] = field(default=None)

def __post_init__(self):
if self.client == "azure":
assert self.base_url is not None, "Azure OpenAI requires a base url."
self.embedding_async_client = AsyncAzureOpenAI(base_url=self.base_url, api_key=self.api_key)
assert (
self.base_url is not None and self.api_version is not None
), "Azure OpenAI requires a base url and an api version."
self.embedding_async_client = AsyncAzureOpenAI(
azure_endpoint=self.base_url, api_key=self.api_key, api_version=self.api_version
)
elif self.client == "openai":
self.embedding_async_client = AsyncOpenAI(base_url=self.base_url, api_key=self.api_key)
else:
Expand Down

0 comments on commit 97328c4

Please sign in to comment.