diff --git a/fast_graphrag/_llm/_llm_openai.py b/fast_graphrag/_llm/_llm_openai.py index 84cbe1c..caa6928 100644 --- a/fast_graphrag/_llm/_llm_openai.py +++ b/fast_graphrag/_llm/_llm_openai.py @@ -33,12 +33,21 @@ class OpenAILLMService(BaseLLMService): model: Optional[str] = field(default="gpt-4o-mini") mode: instructor.Mode = field(default=instructor.Mode.JSON) client: Literal["openai", "azure"] = field(default="openai") + api_version: Optional[str] = field(default=None) def __post_init__(self): if self.client == "azure": - assert self.base_url is not None, "Azure OpenAI requires a base url." + assert ( + self.base_url is not None and self.api_version is not None + ), "Azure OpenAI requires a base url and an api version." self.llm_async_client = instructor.from_openai( - AsyncAzureOpenAI(base_url=self.base_url, api_key=self.api_key, timeout=TIMEOUT_SECONDS), mode=self.mode + AsyncAzureOpenAI( + azure_endpoint=self.base_url, + api_key=self.api_key, + api_version=self.api_version, + timeout=TIMEOUT_SECONDS, + ), + mode=self.mode, ) elif self.client == "openai": self.llm_async_client = instructor.from_openai( @@ -123,11 +132,16 @@ class OpenAIEmbeddingService(BaseEmbeddingService): max_elements_per_request: int = field(default=32) model: Optional[str] = field(default="text-embedding-3-small") client: Literal["openai", "azure"] = field(default="openai") + api_version: Optional[str] = field(default=None) def __post_init__(self): if self.client == "azure": - assert self.base_url is not None, "Azure OpenAI requires a base url." - self.embedding_async_client = AsyncAzureOpenAI(base_url=self.base_url, api_key=self.api_key) + assert ( + self.base_url is not None and self.api_version is not None + ), "Azure OpenAI requires a base url and an api version." + self.embedding_async_client = AsyncAzureOpenAI( + azure_endpoint=self.base_url, api_key=self.api_key, api_version=self.api_version + ) elif self.client == "openai": self.embedding_async_client = AsyncOpenAI(base_url=self.base_url, api_key=self.api_key) else: