From 0fa05b1d918ca62191457bda5853bb66bc759bfe Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 9 Dec 2024 10:27:16 -0800 Subject: [PATCH] Litellm code qa common config (#7116) * feat(base_llm): initial commit for common base config class Addresses code qa critique https://github.com/andrewyng/aisuite/issues/113#issuecomment-2512369132 * feat(base_llm/): add transform request/response abstract methods to base config class --------- Co-authored-by: Krrish Dholakia --- litellm/__init__.py | 5 +- litellm/constants.py | 64 ++ .../get_supported_openai_params.py | 26 +- .../llms/OpenAI/chat/gpt_transformation.py | 71 +- litellm/llms/OpenAI/common_utils.py | 36 +- litellm/llms/OpenAI/completion/handler.py | 314 +++++++++ .../llms/OpenAI/completion/transformation.py | 192 ++++++ litellm/llms/OpenAI/openai.py | 482 +------------- litellm/llms/azure_text.py | 3 +- litellm/llms/base_llm/transformation.py | 126 ++++ litellm/llms/databricks/chat/old_handler.py | 611 ------------------ litellm/llms/openai_like/chat/handler.py | 2 +- .../llms/together_ai/completion/handler.py | 2 +- .../together_ai/completion/transformation.py | 2 +- ...ai_transformation.py => transformation.py} | 2 - litellm/main.py | 3 +- litellm/utils.py | 12 +- tests/llm_translation/test_xai.py | 2 +- tests/local_testing/test_config.py | 27 + 19 files changed, 851 insertions(+), 1131 deletions(-) create mode 100644 litellm/llms/OpenAI/completion/handler.py create mode 100644 litellm/llms/OpenAI/completion/transformation.py create mode 100644 litellm/llms/base_llm/transformation.py delete mode 100644 litellm/llms/databricks/chat/old_handler.py rename litellm/llms/xai/chat/{xai_transformation.py => transformation.py} (97%) diff --git a/litellm/__init__.py b/litellm/__init__.py index 4b872b01430b..9f56897d7255 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -22,6 +22,7 @@ DEFAULT_FLUSH_INTERVAL_SECONDS, ROUTER_MAX_FALLBACKS, DEFAULT_MAX_RETRIES, + LITELLM_CHAT_PROVIDERS, ) from litellm.types.guardrails import GuardrailItem from litellm.proxy._types import ( @@ -1128,10 +1129,10 @@ class LlmProviders(str, Enum): from .llms.bedrock.embed.cohere_transformation import BedrockCohereEmbeddingConfig from .llms.OpenAI.openai import ( OpenAIConfig, - OpenAITextCompletionConfig, MistralEmbeddingConfig, DeepInfraConfig, ) +from litellm.llms.OpenAI.completion.transformation import OpenAITextCompletionConfig from .llms.groq.chat.transformation import GroqChatConfig from .llms.azure_ai.chat.transformation import AzureAIStudioConfig from .llms.mistral.mistral_chat_transformation import MistralConfig @@ -1165,7 +1166,7 @@ class LlmProviders(str, Enum): FireworksAIEmbeddingConfig, ) from .llms.jina_ai.embedding.transformation import JinaAIEmbeddingConfig -from .llms.xai.chat.xai_transformation import XAIChatConfig +from .llms.xai.chat.transformation import XAIChatConfig from .llms.volcengine import VolcEngineConfig from .llms.text_completion_codestral import MistralTextCompletionConfig from .llms.azure.azure import ( diff --git a/litellm/constants.py b/litellm/constants.py index 97dc6c7348bd..c0aa2a36907e 100644 --- a/litellm/constants.py +++ b/litellm/constants.py @@ -2,3 +2,67 @@ DEFAULT_BATCH_SIZE = 512 DEFAULT_FLUSH_INTERVAL_SECONDS = 5 DEFAULT_MAX_RETRIES = 2 +LITELLM_CHAT_PROVIDERS = [ + "openai", + "openai_like", + "xai", + "custom_openai", + "text-completion-openai", + "cohere", + "cohere_chat", + "clarifai", + "anthropic", + "replicate", + "huggingface", + "together_ai", + "openrouter", + "vertex_ai", + "vertex_ai_beta", + "palm", + "gemini", + "ai21", + "baseten", + "azure", + "azure_text", + "azure_ai", + "sagemaker", + "sagemaker_chat", + "bedrock", + "vllm", + "nlp_cloud", + "petals", + "oobabooga", + "ollama", + "ollama_chat", + "deepinfra", + "perplexity", + "anyscale", + "mistral", + "groq", + "nvidia_nim", + "cerebras", + "ai21_chat", + "volcengine", + "codestral", + "text-completion-codestral", + "deepseek", + "sambanova", + "maritalk", + "voyage", + "cloudflare", + "xinference", + "fireworks_ai", + "friendliai", + "watsonx", + "watsonx_text", + "triton", + "predibase", + "databricks", + "empower", + "github", + "custom", + "litellm_proxy", + "hosted_vllm", + "lm_studio", + "galadriel", +] diff --git a/litellm/litellm_core_utils/get_supported_openai_params.py b/litellm/litellm_core_utils/get_supported_openai_params.py index 554df8092b34..8ac4834e79a5 100644 --- a/litellm/litellm_core_utils/get_supported_openai_params.py +++ b/litellm/litellm_core_utils/get_supported_openai_params.py @@ -255,27 +255,7 @@ def get_supported_openai_params( # noqa: PLR0915 elif custom_llm_provider == "watsonx": return litellm.IBMWatsonXChatConfig().get_supported_openai_params(model=model) elif custom_llm_provider == "custom_openai" or "text-completion-openai": - return [ - "functions", - "function_call", - "temperature", - "top_p", - "n", - "stream", - "stream_options", - "stop", - "max_tokens", - "presence_penalty", - "frequency_penalty", - "logit_bias", - "user", - "response_format", - "seed", - "tools", - "tool_choice", - "max_retries", - "logprobs", - "top_logprobs", - "extra_headers", - ] + return litellm.OpenAITextCompletionConfig().get_supported_openai_params( + model=model + ) return None diff --git a/litellm/llms/OpenAI/chat/gpt_transformation.py b/litellm/llms/OpenAI/chat/gpt_transformation.py index c0c7e14dd8f0..272b1e9f9e8d 100644 --- a/litellm/llms/OpenAI/chat/gpt_transformation.py +++ b/litellm/llms/OpenAI/chat/gpt_transformation.py @@ -3,13 +3,26 @@ """ import types -from typing import List, Optional, Union +from typing import TYPE_CHECKING, Any, List, Optional, Union, cast + +import httpx import litellm +from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage +from litellm.types.utils import ModelResponse + +from ..common_utils import OpenAIError + +if TYPE_CHECKING: + from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj + + LoggingClass = LiteLLMLoggingObj +else: + LoggingClass = Any -class OpenAIGPTConfig: +class OpenAIGPTConfig(BaseConfig): """ Reference: https://platform.openai.com/docs/api-reference/chat/create @@ -168,3 +181,57 @@ def _transform_messages( self, messages: List[AllMessageValues] ) -> List[AllMessageValues]: return messages + + def transform_request( + self, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + headers: dict, + ) -> dict: + """ + Transform the overall request to be sent to the API. + + Returns: + dict: The transformed request. Sent as the body of the API call. + """ + raise NotImplementedError + + def transform_response( + self, + model: str, + raw_response: dict, + model_response: ModelResponse, + logging_obj: LoggingClass, + api_key: str, + request_data: dict, + messages: List[AllMessageValues], + optional_params: dict, + encoding: str, + ) -> ModelResponse: + """ + Transform the response from the API. + + Returns: + dict: The transformed response. + """ + raise NotImplementedError + + def get_error_class( + self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] + ) -> BaseLLMException: + return OpenAIError( + status_code=status_code, + message=error_message, + headers=cast(httpx.Headers, headers), + ) + + def validate_environment( + self, + api_key: str, + headers: dict, + model: str, + messages: List[AllMessageValues], + ) -> dict: + raise NotImplementedError diff --git a/litellm/llms/OpenAI/common_utils.py b/litellm/llms/OpenAI/common_utils.py index 01c3ae9435e9..5da8c4925f62 100644 --- a/litellm/llms/OpenAI/common_utils.py +++ b/litellm/llms/OpenAI/common_utils.py @@ -3,10 +3,44 @@ """ import json -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional +import httpx import openai +from litellm.llms.base_llm.transformation import BaseLLMException + + +class OpenAIError(BaseLLMException): + def __init__( + self, + status_code: int, + message: str, + request: Optional[httpx.Request] = None, + response: Optional[httpx.Response] = None, + headers: Optional[httpx.Headers] = None, + ): + self.status_code = status_code + self.message = message + self.headers = headers + if request: + self.request = request + else: + self.request = httpx.Request(method="POST", url="https://api.openai.com/v1") + if response: + self.response = response + else: + self.response = httpx.Response( + status_code=status_code, request=self.request + ) + super().__init__( + status_code=status_code, + message=self.message, + headers=self.headers, + request=self.request, + response=self.response, + ) + ####### Error Handling Utils for OpenAI API ####################### ################################################################### diff --git a/litellm/llms/OpenAI/completion/handler.py b/litellm/llms/OpenAI/completion/handler.py new file mode 100644 index 000000000000..43f1ed923299 --- /dev/null +++ b/litellm/llms/OpenAI/completion/handler.py @@ -0,0 +1,314 @@ +import json +from typing import Callable, List, Optional, Union + +from openai import AsyncOpenAI, OpenAI + +import litellm +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper +from litellm.llms.base import BaseLLM +from litellm.types.llms.openai import AllMessageValues, OpenAITextCompletionUserMessage +from litellm.types.utils import ModelResponse, TextCompletionResponse + +from ..common_utils import OpenAIError +from .transformation import OpenAITextCompletionConfig + + +class OpenAITextCompletion(BaseLLM): + openai_text_completion_global_config = OpenAITextCompletionConfig() + + def __init__(self) -> None: + super().__init__() + + def validate_environment(self, api_key): + headers = { + "content-type": "application/json", + } + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + return headers + + def completion( + self, + model_response: ModelResponse, + api_key: str, + model: str, + messages: Union[List[AllMessageValues], List[OpenAITextCompletionUserMessage]], + timeout: float, + logging_obj: LiteLLMLoggingObj, + optional_params: dict, + print_verbose: Optional[Callable] = None, + api_base: Optional[str] = None, + acompletion: bool = False, + litellm_params=None, + logger_fn=None, + client=None, + organization: Optional[str] = None, + headers: Optional[dict] = None, + ): + try: + if headers is None: + headers = self.validate_environment(api_key=api_key) + if model is None or messages is None: + raise OpenAIError(status_code=422, message="Missing model or messages") + + # don't send max retries to the api, if set + + prompt = self.openai_text_completion_global_config._transform_prompt( + messages + ) + + data = {"model": model, "prompt": prompt, **optional_params} + max_retries = data.pop("max_retries", 2) + ## LOGGING + logging_obj.pre_call( + input=messages, + api_key=api_key, + additional_args={ + "headers": headers, + "api_base": api_base, + "complete_input_dict": data, + }, + ) + if acompletion is True: + if optional_params.get("stream", False): + return self.async_streaming( + logging_obj=logging_obj, + api_base=api_base, + api_key=api_key, + data=data, + headers=headers, + model_response=model_response, + model=model, + timeout=timeout, + max_retries=max_retries, + client=client, + organization=organization, + ) + else: + return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model, timeout=timeout, max_retries=max_retries, organization=organization, client=client) # type: ignore + elif optional_params.get("stream", False): + return self.streaming( + logging_obj=logging_obj, + api_base=api_base, + api_key=api_key, + data=data, + headers=headers, + model_response=model_response, + model=model, + timeout=timeout, + max_retries=max_retries, # type: ignore + client=client, + organization=organization, + ) + else: + if client is None: + openai_client = OpenAI( + api_key=api_key, + base_url=api_base, + http_client=litellm.client_session, + timeout=timeout, + max_retries=max_retries, # type: ignore + organization=organization, + ) + else: + openai_client = client + + raw_response = openai_client.completions.with_raw_response.create(**data) # type: ignore + response = raw_response.parse() + response_json = response.model_dump() + + ## LOGGING + logging_obj.post_call( + input=prompt, + api_key=api_key, + original_response=response_json, + additional_args={ + "headers": headers, + "api_base": api_base, + }, + ) + + ## RESPONSE OBJECT + return TextCompletionResponse(**response_json) + except Exception as e: + status_code = getattr(e, "status_code", 500) + error_headers = getattr(e, "headers", None) + error_text = getattr(e, "text", str(e)) + error_response = getattr(e, "response", None) + if error_headers is None and error_response: + error_headers = getattr(error_response, "headers", None) + raise OpenAIError( + status_code=status_code, message=error_text, headers=error_headers + ) + + async def acompletion( + self, + logging_obj, + api_base: str, + data: dict, + headers: dict, + model_response: ModelResponse, + prompt: str, + api_key: str, + model: str, + timeout: float, + max_retries: int, + organization: Optional[str] = None, + client=None, + ): + try: + if client is None: + openai_aclient = AsyncOpenAI( + api_key=api_key, + base_url=api_base, + http_client=litellm.aclient_session, + timeout=timeout, + max_retries=max_retries, + organization=organization, + ) + else: + openai_aclient = client + + raw_response = await openai_aclient.completions.with_raw_response.create( + **data + ) + response = raw_response.parse() + response_json = response.model_dump() + + ## LOGGING + logging_obj.post_call( + input=prompt, + api_key=api_key, + original_response=response, + additional_args={ + "headers": headers, + "api_base": api_base, + }, + ) + ## RESPONSE OBJECT + response_obj = TextCompletionResponse(**response_json) + response_obj._hidden_params.original_response = json.dumps(response_json) + return response_obj + except Exception as e: + status_code = getattr(e, "status_code", 500) + error_headers = getattr(e, "headers", None) + error_text = getattr(e, "text", str(e)) + error_response = getattr(e, "response", None) + if error_headers is None and error_response: + error_headers = getattr(error_response, "headers", None) + raise OpenAIError( + status_code=status_code, message=error_text, headers=error_headers + ) + + def streaming( + self, + logging_obj, + api_key: str, + data: dict, + headers: dict, + model_response: ModelResponse, + model: str, + timeout: float, + api_base: Optional[str] = None, + max_retries=None, + client=None, + organization=None, + ): + + if client is None: + openai_client = OpenAI( + api_key=api_key, + base_url=api_base, + http_client=litellm.client_session, + timeout=timeout, + max_retries=max_retries, # type: ignore + organization=organization, + ) + else: + openai_client = client + + try: + raw_response = openai_client.completions.with_raw_response.create(**data) + response = raw_response.parse() + except Exception as e: + status_code = getattr(e, "status_code", 500) + error_headers = getattr(e, "headers", None) + error_text = getattr(e, "text", str(e)) + error_response = getattr(e, "response", None) + if error_headers is None and error_response: + error_headers = getattr(error_response, "headers", None) + raise OpenAIError( + status_code=status_code, message=error_text, headers=error_headers + ) + streamwrapper = CustomStreamWrapper( + completion_stream=response, + model=model, + custom_llm_provider="text-completion-openai", + logging_obj=logging_obj, + stream_options=data.get("stream_options", None), + ) + + try: + for chunk in streamwrapper: + yield chunk + except Exception as e: + status_code = getattr(e, "status_code", 500) + error_headers = getattr(e, "headers", None) + error_text = getattr(e, "text", str(e)) + error_response = getattr(e, "response", None) + if error_headers is None and error_response: + error_headers = getattr(error_response, "headers", None) + raise OpenAIError( + status_code=status_code, message=error_text, headers=error_headers + ) + + async def async_streaming( + self, + logging_obj, + api_key: str, + data: dict, + headers: dict, + model_response: ModelResponse, + model: str, + timeout: float, + max_retries: int, + api_base: Optional[str] = None, + client=None, + organization=None, + ): + if client is None: + openai_client = AsyncOpenAI( + api_key=api_key, + base_url=api_base, + http_client=litellm.aclient_session, + timeout=timeout, + max_retries=max_retries, + organization=organization, + ) + else: + openai_client = client + + raw_response = await openai_client.completions.with_raw_response.create(**data) + response = raw_response.parse() + streamwrapper = CustomStreamWrapper( + completion_stream=response, + model=model, + custom_llm_provider="text-completion-openai", + logging_obj=logging_obj, + stream_options=data.get("stream_options", None), + ) + + try: + async for transformed_chunk in streamwrapper: + yield transformed_chunk + except Exception as e: + status_code = getattr(e, "status_code", 500) + error_headers = getattr(e, "headers", None) + error_text = getattr(e, "text", str(e)) + error_response = getattr(e, "response", None) + if error_headers is None and error_response: + error_headers = getattr(error_response, "headers", None) + raise OpenAIError( + status_code=status_code, message=error_text, headers=error_headers + ) diff --git a/litellm/llms/OpenAI/completion/transformation.py b/litellm/llms/OpenAI/completion/transformation.py new file mode 100644 index 000000000000..a701627a3f09 --- /dev/null +++ b/litellm/llms/OpenAI/completion/transformation.py @@ -0,0 +1,192 @@ +""" +Support for gpt model family +""" + +import types +from typing import List, Optional, Union, cast + +import litellm +from litellm.llms.base_llm.transformation import BaseConfig +from litellm.types.llms.openai import ( + AllMessageValues, + AllPromptValues, + OpenAITextCompletionUserMessage, +) +from litellm.types.utils import Choices, Message, ModelResponse, TextCompletionResponse + +from ...prompt_templates.common_utils import convert_content_list_to_str +from ..chat.gpt_transformation import OpenAIGPTConfig +from ..common_utils import OpenAIError +from .utils import is_tokens_or_list_of_tokens + + +class OpenAITextCompletionConfig(OpenAIGPTConfig): + """ + Reference: https://platform.openai.com/docs/api-reference/completions/create + + The class `OpenAITextCompletionConfig` provides configuration for the OpenAI's text completion API interface. Below are the parameters: + + - `best_of` (integer or null): This optional parameter generates server-side completions and returns the one with the highest log probability per token. + + - `echo` (boolean or null): This optional parameter will echo back the prompt in addition to the completion. + + - `frequency_penalty` (number or null): Defaults to 0. It is a numbers from -2.0 to 2.0, where positive values decrease the model's likelihood to repeat the same line. + + - `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion. + + - `logprobs` (integer or null): This optional parameter includes the log probabilities on the most likely tokens as well as the chosen tokens. + + - `max_tokens` (integer or null): This optional parameter sets the maximum number of tokens to generate in the completion. + + - `n` (integer or null): This optional parameter sets how many completions to generate for each prompt. + + - `presence_penalty` (number or null): Defaults to 0 and can be between -2.0 and 2.0. Positive values increase the model's likelihood to talk about new topics. + + - `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens. + + - `suffix` (string or null): Defines the suffix that comes after a completion of inserted text. + + - `temperature` (number or null): This optional parameter defines the sampling temperature to use. + + - `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling. + """ + + best_of: Optional[int] = None + echo: Optional[bool] = None + frequency_penalty: Optional[int] = None + logit_bias: Optional[dict] = None + logprobs: Optional[int] = None + max_tokens: Optional[int] = None + n: Optional[int] = None + presence_penalty: Optional[int] = None + stop: Optional[Union[str, list]] = None + suffix: Optional[str] = None + + def __init__( + self, + best_of: Optional[int] = None, + echo: Optional[bool] = None, + frequency_penalty: Optional[int] = None, + logit_bias: Optional[dict] = None, + logprobs: Optional[int] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[int] = None, + stop: Optional[Union[str, list]] = None, + suffix: Optional[str] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + ) -> None: + locals_ = locals().copy() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + def _transform_prompt( + self, + messages: Union[List[AllMessageValues], List[OpenAITextCompletionUserMessage]], + ) -> AllPromptValues: + if len(messages) == 1: # base case + message_content = messages[0].get("content") + if ( + message_content + and isinstance(message_content, list) + and is_tokens_or_list_of_tokens(message_content) + ): + openai_prompt: AllPromptValues = cast(AllPromptValues, message_content) + else: + openai_prompt = "" + content = convert_content_list_to_str( + cast(AllMessageValues, messages[0]) + ) + openai_prompt += content + else: + prompt_str_list: List[str] = [] + for m in messages: + try: # expect list of int/list of list of int to be a 1 message array only. + content = convert_content_list_to_str(cast(AllMessageValues, m)) + prompt_str_list.append(content) + except Exception as e: + raise e + openai_prompt = prompt_str_list + return openai_prompt + + def convert_to_chat_model_response_object( + self, + response_object: Optional[TextCompletionResponse] = None, + model_response_object: Optional[ModelResponse] = None, + ): + try: + ## RESPONSE OBJECT + if response_object is None or model_response_object is None: + raise ValueError("Error in response object format") + choice_list = [] + for idx, choice in enumerate(response_object["choices"]): + message = Message( + content=choice["text"], + role="assistant", + ) + choice = Choices( + finish_reason=choice["finish_reason"], index=idx, message=message + ) + choice_list.append(choice) + model_response_object.choices = choice_list + + if "usage" in response_object: + setattr(model_response_object, "usage", response_object["usage"]) + + if "id" in response_object: + model_response_object.id = response_object["id"] + + if "model" in response_object: + model_response_object.model = response_object["model"] + + model_response_object._hidden_params["original_response"] = ( + response_object # track original response, if users make a litellm.text_completion() request, we can return the original response + ) + return model_response_object + except Exception as e: + raise e + + def get_supported_openai_params(self, model: str) -> List: + return [ + "functions", + "function_call", + "temperature", + "top_p", + "n", + "stream", + "stream_options", + "stop", + "max_tokens", + "presence_penalty", + "frequency_penalty", + "logit_bias", + "user", + "response_format", + "seed", + "tools", + "tool_choice", + "max_retries", + "logprobs", + "top_logprobs", + "extra_headers", + ] diff --git a/litellm/llms/OpenAI/openai.py b/litellm/llms/OpenAI/openai.py index 66ce75701886..108a31d19a91 100644 --- a/litellm/llms/OpenAI/openai.py +++ b/litellm/llms/OpenAI/openai.py @@ -34,37 +34,8 @@ from ...types.llms.openai import * from ..base import BaseLLM -from ..prompt_templates.common_utils import convert_content_list_to_str from ..prompt_templates.factory import custom_prompt, prompt_factory -from .common_utils import drop_params_from_unprocessable_entity_error -from .completion.utils import is_tokens_or_list_of_tokens - - -class OpenAIError(Exception): - def __init__( - self, - status_code, - message, - request: Optional[httpx.Request] = None, - response: Optional[httpx.Response] = None, - headers: Optional[httpx.Headers] = None, - ): - self.status_code = status_code - self.message = message - self.headers = headers - if request: - self.request = request - else: - self.request = httpx.Request(method="POST", url="https://api.openai.com/v1") - if response: - self.response = response - else: - self.response = httpx.Response( - status_code=status_code, request=self.request - ) - super().__init__( - self.message - ) # Call the base class constructor with the parameters it needs +from .common_utils import OpenAIError, drop_params_from_unprocessable_entity_error class MistralEmbeddingConfig: @@ -379,155 +350,6 @@ def map_openai_params( ) -class OpenAITextCompletionConfig: - """ - Reference: https://platform.openai.com/docs/api-reference/completions/create - - The class `OpenAITextCompletionConfig` provides configuration for the OpenAI's text completion API interface. Below are the parameters: - - - `best_of` (integer or null): This optional parameter generates server-side completions and returns the one with the highest log probability per token. - - - `echo` (boolean or null): This optional parameter will echo back the prompt in addition to the completion. - - - `frequency_penalty` (number or null): Defaults to 0. It is a numbers from -2.0 to 2.0, where positive values decrease the model's likelihood to repeat the same line. - - - `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion. - - - `logprobs` (integer or null): This optional parameter includes the log probabilities on the most likely tokens as well as the chosen tokens. - - - `max_tokens` (integer or null): This optional parameter sets the maximum number of tokens to generate in the completion. - - - `n` (integer or null): This optional parameter sets how many completions to generate for each prompt. - - - `presence_penalty` (number or null): Defaults to 0 and can be between -2.0 and 2.0. Positive values increase the model's likelihood to talk about new topics. - - - `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens. - - - `suffix` (string or null): Defines the suffix that comes after a completion of inserted text. - - - `temperature` (number or null): This optional parameter defines the sampling temperature to use. - - - `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling. - """ - - best_of: Optional[int] = None - echo: Optional[bool] = None - frequency_penalty: Optional[int] = None - logit_bias: Optional[dict] = None - logprobs: Optional[int] = None - max_tokens: Optional[int] = None - n: Optional[int] = None - presence_penalty: Optional[int] = None - stop: Optional[Union[str, list]] = None - suffix: Optional[str] = None - temperature: Optional[float] = None - top_p: Optional[float] = None - - def __init__( - self, - best_of: Optional[int] = None, - echo: Optional[bool] = None, - frequency_penalty: Optional[int] = None, - logit_bias: Optional[dict] = None, - logprobs: Optional[int] = None, - max_tokens: Optional[int] = None, - n: Optional[int] = None, - presence_penalty: Optional[int] = None, - stop: Optional[Union[str, list]] = None, - suffix: Optional[str] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - ) -> None: - locals_ = locals().copy() - for key, value in locals_.items(): - if key != "self" and value is not None: - setattr(self.__class__, key, value) - - @classmethod - def get_config(cls): - return { - k: v - for k, v in cls.__dict__.items() - if not k.startswith("__") - and not isinstance( - v, - ( - types.FunctionType, - types.BuiltinFunctionType, - classmethod, - staticmethod, - ), - ) - and v is not None - } - - def _transform_prompt( - self, - messages: Union[List[AllMessageValues], List[OpenAITextCompletionUserMessage]], - ) -> AllPromptValues: - if len(messages) == 1: # base case - message_content = messages[0].get("content") - if ( - message_content - and isinstance(message_content, list) - and is_tokens_or_list_of_tokens(message_content) - ): - openai_prompt: AllPromptValues = cast(AllPromptValues, message_content) - else: - openai_prompt = "" - content = convert_content_list_to_str( - cast(AllMessageValues, messages[0]) - ) - openai_prompt += content - else: - prompt_str_list: List[str] = [] - for m in messages: - try: # expect list of int/list of list of int to be a 1 message array only. - content = convert_content_list_to_str(cast(AllMessageValues, m)) - prompt_str_list.append(content) - except Exception as e: - raise e - openai_prompt = prompt_str_list - return openai_prompt - - def convert_to_chat_model_response_object( - self, - response_object: Optional[TextCompletionResponse] = None, - model_response_object: Optional[ModelResponse] = None, - ): - try: - ## RESPONSE OBJECT - if response_object is None or model_response_object is None: - raise ValueError("Error in response object format") - choice_list = [] - for idx, choice in enumerate(response_object["choices"]): - message = Message( - content=choice["text"], - role="assistant", - ) - choice = Choices( - finish_reason=choice["finish_reason"], index=idx, message=message - ) - choice_list.append(choice) - model_response_object.choices = choice_list - - if "usage" in response_object: - setattr(model_response_object, "usage", response_object["usage"]) - - if "id" in response_object: - model_response_object.id = response_object["id"] - - if "model" in response_object: - model_response_object.model = response_object["model"] - - model_response_object._hidden_params["original_response"] = ( - response_object # track original response, if users make a litellm.text_completion() request, we can return the original response - ) - return model_response_object - except Exception as e: - raise e - - class OpenAIChatCompletion(BaseLLM): def __init__(self) -> None: @@ -710,7 +532,7 @@ def completion( # type: ignore # noqa: PLR0915 custom_llm_provider=custom_llm_provider, ) if messages is not None and custom_llm_provider is not None: - provider_config = ProviderConfigManager.get_provider_config( + provider_config = ProviderConfigManager.get_provider_chat_config( model=model, provider=LlmProviders(custom_llm_provider) ) messages = provider_config._transform_messages(messages) @@ -1584,306 +1406,6 @@ async def ahealth_check( return response -class OpenAITextCompletion(BaseLLM): - openai_text_completion_global_config = OpenAITextCompletionConfig() - - def __init__(self) -> None: - super().__init__() - - def validate_environment(self, api_key): - headers = { - "content-type": "application/json", - } - if api_key: - headers["Authorization"] = f"Bearer {api_key}" - return headers - - def completion( - self, - model_response: ModelResponse, - api_key: str, - model: str, - messages: Union[List[AllMessageValues], List[OpenAITextCompletionUserMessage]], - timeout: float, - logging_obj: LiteLLMLoggingObj, - optional_params: dict, - print_verbose: Optional[Callable] = None, - api_base: Optional[str] = None, - acompletion: bool = False, - litellm_params=None, - logger_fn=None, - client=None, - organization: Optional[str] = None, - headers: Optional[dict] = None, - ): - try: - if headers is None: - headers = self.validate_environment(api_key=api_key) - if model is None or messages is None: - raise OpenAIError(status_code=422, message="Missing model or messages") - - # don't send max retries to the api, if set - - prompt = self.openai_text_completion_global_config._transform_prompt( - messages - ) - - data = {"model": model, "prompt": prompt, **optional_params} - max_retries = data.pop("max_retries", 2) - ## LOGGING - logging_obj.pre_call( - input=messages, - api_key=api_key, - additional_args={ - "headers": headers, - "api_base": api_base, - "complete_input_dict": data, - }, - ) - if acompletion is True: - if optional_params.get("stream", False): - return self.async_streaming( - logging_obj=logging_obj, - api_base=api_base, - api_key=api_key, - data=data, - headers=headers, - model_response=model_response, - model=model, - timeout=timeout, - max_retries=max_retries, - client=client, - organization=organization, - ) - else: - return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model, timeout=timeout, max_retries=max_retries, organization=organization, client=client) # type: ignore - elif optional_params.get("stream", False): - return self.streaming( - logging_obj=logging_obj, - api_base=api_base, - api_key=api_key, - data=data, - headers=headers, - model_response=model_response, - model=model, - timeout=timeout, - max_retries=max_retries, # type: ignore - client=client, - organization=organization, - ) - else: - if client is None: - openai_client = OpenAI( - api_key=api_key, - base_url=api_base, - http_client=litellm.client_session, - timeout=timeout, - max_retries=max_retries, # type: ignore - organization=organization, - ) - else: - openai_client = client - - raw_response = openai_client.completions.with_raw_response.create(**data) # type: ignore - response = raw_response.parse() - response_json = response.model_dump() - - ## LOGGING - logging_obj.post_call( - input=prompt, - api_key=api_key, - original_response=response_json, - additional_args={ - "headers": headers, - "api_base": api_base, - }, - ) - - ## RESPONSE OBJECT - return TextCompletionResponse(**response_json) - except Exception as e: - status_code = getattr(e, "status_code", 500) - error_headers = getattr(e, "headers", None) - error_text = getattr(e, "text", str(e)) - error_response = getattr(e, "response", None) - if error_headers is None and error_response: - error_headers = getattr(error_response, "headers", None) - raise OpenAIError( - status_code=status_code, message=error_text, headers=error_headers - ) - - async def acompletion( - self, - logging_obj, - api_base: str, - data: dict, - headers: dict, - model_response: ModelResponse, - prompt: str, - api_key: str, - model: str, - timeout: float, - max_retries: int, - organization: Optional[str] = None, - client=None, - ): - try: - if client is None: - openai_aclient = AsyncOpenAI( - api_key=api_key, - base_url=api_base, - http_client=litellm.aclient_session, - timeout=timeout, - max_retries=max_retries, - organization=organization, - ) - else: - openai_aclient = client - - raw_response = await openai_aclient.completions.with_raw_response.create( - **data - ) - response = raw_response.parse() - response_json = response.model_dump() - - ## LOGGING - logging_obj.post_call( - input=prompt, - api_key=api_key, - original_response=response, - additional_args={ - "headers": headers, - "api_base": api_base, - }, - ) - ## RESPONSE OBJECT - response_obj = TextCompletionResponse(**response_json) - response_obj._hidden_params.original_response = json.dumps(response_json) - return response_obj - except Exception as e: - status_code = getattr(e, "status_code", 500) - error_headers = getattr(e, "headers", None) - error_text = getattr(e, "text", str(e)) - error_response = getattr(e, "response", None) - if error_headers is None and error_response: - error_headers = getattr(error_response, "headers", None) - raise OpenAIError( - status_code=status_code, message=error_text, headers=error_headers - ) - - def streaming( - self, - logging_obj, - api_key: str, - data: dict, - headers: dict, - model_response: ModelResponse, - model: str, - timeout: float, - api_base: Optional[str] = None, - max_retries=None, - client=None, - organization=None, - ): - - if client is None: - openai_client = OpenAI( - api_key=api_key, - base_url=api_base, - http_client=litellm.client_session, - timeout=timeout, - max_retries=max_retries, # type: ignore - organization=organization, - ) - else: - openai_client = client - - try: - raw_response = openai_client.completions.with_raw_response.create(**data) - response = raw_response.parse() - except Exception as e: - status_code = getattr(e, "status_code", 500) - error_headers = getattr(e, "headers", None) - error_text = getattr(e, "text", str(e)) - error_response = getattr(e, "response", None) - if error_headers is None and error_response: - error_headers = getattr(error_response, "headers", None) - raise OpenAIError( - status_code=status_code, message=error_text, headers=error_headers - ) - streamwrapper = CustomStreamWrapper( - completion_stream=response, - model=model, - custom_llm_provider="text-completion-openai", - logging_obj=logging_obj, - stream_options=data.get("stream_options", None), - ) - - try: - for chunk in streamwrapper: - yield chunk - except Exception as e: - status_code = getattr(e, "status_code", 500) - error_headers = getattr(e, "headers", None) - error_text = getattr(e, "text", str(e)) - error_response = getattr(e, "response", None) - if error_headers is None and error_response: - error_headers = getattr(error_response, "headers", None) - raise OpenAIError( - status_code=status_code, message=error_text, headers=error_headers - ) - - async def async_streaming( - self, - logging_obj, - api_key: str, - data: dict, - headers: dict, - model_response: ModelResponse, - model: str, - timeout: float, - max_retries: int, - api_base: Optional[str] = None, - client=None, - organization=None, - ): - if client is None: - openai_client = AsyncOpenAI( - api_key=api_key, - base_url=api_base, - http_client=litellm.aclient_session, - timeout=timeout, - max_retries=max_retries, - organization=organization, - ) - else: - openai_client = client - - raw_response = await openai_client.completions.with_raw_response.create(**data) - response = raw_response.parse() - streamwrapper = CustomStreamWrapper( - completion_stream=response, - model=model, - custom_llm_provider="text-completion-openai", - logging_obj=logging_obj, - stream_options=data.get("stream_options", None), - ) - - try: - async for transformed_chunk in streamwrapper: - yield transformed_chunk - except Exception as e: - status_code = getattr(e, "status_code", 500) - error_headers = getattr(e, "headers", None) - error_text = getattr(e, "text", str(e)) - error_response = getattr(e, "response", None) - if error_headers is None and error_response: - error_headers = getattr(error_response, "headers", None) - raise OpenAIError( - status_code=status_code, message=error_text, headers=error_headers - ) - - class OpenAIFilesAPI(BaseLLM): """ OpenAI methods to support for batches diff --git a/litellm/llms/azure_text.py b/litellm/llms/azure_text.py index c75965a8f500..9f52f214d70e 100644 --- a/litellm/llms/azure_text.py +++ b/litellm/llms/azure_text.py @@ -20,7 +20,8 @@ ) from .base import BaseLLM -from .OpenAI.openai import OpenAITextCompletion, OpenAITextCompletionConfig +from .OpenAI.completion.handler import OpenAITextCompletion +from .OpenAI.completion.transformation import OpenAITextCompletionConfig from .prompt_templates.factory import custom_prompt, prompt_factory openai_text_completion_config = OpenAITextCompletionConfig() diff --git a/litellm/llms/base_llm/transformation.py b/litellm/llms/base_llm/transformation.py new file mode 100644 index 000000000000..580ba94f3c2b --- /dev/null +++ b/litellm/llms/base_llm/transformation.py @@ -0,0 +1,126 @@ +""" +Common base config for all LLM providers +""" + +import types +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, List, Optional + +import httpx + +from litellm.types.llms.openai import AllMessageValues +from litellm.types.utils import ModelResponse + +if TYPE_CHECKING: + from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj + + LoggingClass = LiteLLMLoggingObj +else: + LoggingClass = Any + + +class BaseLLMException(Exception): + def __init__( + self, + status_code: int, + message: str, + headers: Optional[httpx.Headers] = None, + request: Optional[httpx.Request] = None, + response: Optional[httpx.Response] = None, + ): + self.status_code = status_code + self.message: str = message + self.headers = headers + self.request = httpx.Request(method="POST", url="https://docs.litellm.ai/docs") + self.response = httpx.Response(status_code=status_code, request=self.request) + super().__init__( + self.message + ) # Call the base class constructor with the parameters it needs + + +class BaseConfig(ABC): + def __init__(self): + pass + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + @abstractmethod + def get_supported_openai_params(self, model: str) -> list: + pass + + @abstractmethod + def map_openai_params( + self, + non_default_params: dict, + optional_params: dict, + model: str, + drop_params: bool, + ) -> dict: + pass + + @abstractmethod + def validate_environment( + self, + api_key: str, + headers: dict, + model: str, + messages: List[AllMessageValues], + ) -> dict: + pass + + @abstractmethod + def transform_request( + self, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + headers: dict, + ) -> dict: + pass + + @abstractmethod + def _transform_messages( + self, messages: List[AllMessageValues] + ) -> List[AllMessageValues]: + pass + + @abstractmethod + def transform_response( + self, + model: str, + raw_response: dict, + model_response: ModelResponse, + logging_obj: LoggingClass, + api_key: str, + request_data: dict, + messages: List[AllMessageValues], + optional_params: dict, + encoding: str, + ) -> ModelResponse: + pass + + @abstractmethod + def get_error_class( + self, + error_message: str, + status_code: int, + headers: dict, + ) -> BaseLLMException: + pass diff --git a/litellm/llms/databricks/chat/old_handler.py b/litellm/llms/databricks/chat/old_handler.py deleted file mode 100644 index 95cc1cfc6d0b..000000000000 --- a/litellm/llms/databricks/chat/old_handler.py +++ /dev/null @@ -1,611 +0,0 @@ -# What is this? -## Handler file for databricks API https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request -import copy -import json -import os -import time -import types -from enum import Enum -from functools import partial -from typing import Any, Callable, List, Literal, Optional, Tuple, Union - -import httpx # type: ignore -import requests # type: ignore - -import litellm -from litellm import LlmProviders -from litellm.litellm_core_utils.core_helpers import map_finish_reason -from litellm.llms.custom_httpx.http_handler import ( - AsyncHTTPHandler, - HTTPHandler, - get_async_httpx_client, -) -from litellm.llms.databricks.exceptions import DatabricksError -from litellm.llms.databricks.streaming_utils import ModelResponseIterator -from litellm.types.llms.openai import ( - ChatCompletionDeltaChunk, - ChatCompletionResponseMessage, - ChatCompletionToolCallChunk, - ChatCompletionToolCallFunctionChunk, - ChatCompletionUsageBlock, -) -from litellm.types.utils import ( - CustomStreamingDecoder, - GenericStreamingChunk, - ProviderField, -) -from litellm.utils import ( - CustomStreamWrapper, - EmbeddingResponse, - ModelResponse, - ProviderConfigManager, - Usage, -) - -from ...base import BaseLLM -from ...prompt_templates.factory import custom_prompt, prompt_factory -from .transformation import DatabricksConfig - - -async def make_call( - client: Optional[AsyncHTTPHandler], - api_base: str, - headers: dict, - data: str, - model: str, - messages: list, - logging_obj, - streaming_decoder: Optional[CustomStreamingDecoder] = None, -): - if client is None: - client = get_async_httpx_client( - llm_provider=litellm.LlmProviders.DATABRICKS - ) # Create a new client if none provided - response = await client.post(api_base, headers=headers, data=data, stream=True) - - if response.status_code != 200: - raise DatabricksError(status_code=response.status_code, message=response.text) - - if streaming_decoder is not None: - completion_stream: Any = streaming_decoder.aiter_bytes( - response.aiter_bytes(chunk_size=1024) - ) - else: - completion_stream = ModelResponseIterator( - streaming_response=response.aiter_lines(), sync_stream=False - ) - # LOGGING - logging_obj.post_call( - input=messages, - api_key="", - original_response=completion_stream, # Pass the completion stream for logging - additional_args={"complete_input_dict": data}, - ) - - return completion_stream - - -def make_sync_call( - client: Optional[HTTPHandler], - api_base: str, - headers: dict, - data: str, - model: str, - messages: list, - logging_obj, - streaming_decoder: Optional[CustomStreamingDecoder] = None, -): - if client is None: - client = litellm.module_level_client # Create a new client if none provided - - response = client.post(api_base, headers=headers, data=data, stream=True) - - if response.status_code != 200: - raise DatabricksError(status_code=response.status_code, message=response.read()) - - if streaming_decoder is not None: - completion_stream = streaming_decoder.iter_bytes( - response.iter_bytes(chunk_size=1024) - ) - else: - completion_stream = ModelResponseIterator( - streaming_response=response.iter_lines(), sync_stream=True - ) - - # LOGGING - logging_obj.post_call( - input=messages, - api_key="", - original_response="first stream response received", - additional_args={"complete_input_dict": data}, - ) - - return completion_stream - - -class DatabricksChatCompletion(BaseLLM): - def __init__(self) -> None: - super().__init__() - - # makes headers for API call - def _get_databricks_credentials( - self, api_key: Optional[str], api_base: Optional[str], headers: Optional[dict] - ) -> Tuple[str, dict]: - headers = headers or {"Content-Type": "application/json"} - try: - from databricks.sdk import WorkspaceClient - - databricks_client = WorkspaceClient() - api_base = api_base or f"{databricks_client.config.host}/serving-endpoints" - - if api_key is None: - databricks_auth_headers: dict[str, str] = ( - databricks_client.config.authenticate() - ) - headers = {**databricks_auth_headers, **headers} - - return api_base, headers - except ImportError: - raise DatabricksError( - status_code=400, - message=( - "If the Databricks base URL and API key are not set, the databricks-sdk " - "Python library must be installed. Please install the databricks-sdk, set " - "{LLM_PROVIDER}_API_BASE and {LLM_PROVIDER}_API_KEY environment variables, " - "or provide the base URL and API key as arguments." - ), - ) - - def _validate_environment( - self, - api_key: Optional[str], - api_base: Optional[str], - endpoint_type: Literal["chat_completions", "embeddings"], - custom_endpoint: Optional[bool], - headers: Optional[dict], - ) -> Tuple[str, dict]: - if api_key is None and headers is None: - if custom_endpoint: - raise DatabricksError( - status_code=400, - message="Missing API Key - A call is being made to LLM Provider but no key is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params", - ) - else: - api_base, headers = self._get_databricks_credentials( - api_base=api_base, api_key=api_key, headers=headers - ) - - if api_base is None: - if custom_endpoint: - raise DatabricksError( - status_code=400, - message="Missing API Base - A call is being made to LLM Provider but no api base is set either in the environment variables ({LLM_PROVIDER}_API_KEY) or via params", - ) - else: - api_base, headers = self._get_databricks_credentials( - api_base=api_base, api_key=api_key, headers=headers - ) - - if headers is None: - headers = { - "Authorization": "Bearer {}".format(api_key), - "Content-Type": "application/json", - } - else: - if api_key is not None: - headers.update({"Authorization": "Bearer {}".format(api_key)}) - - if api_key is not None: - headers["Authorization"] = f"Bearer {api_key}" - - if endpoint_type == "chat_completions" and custom_endpoint is not True: - api_base = "{}/chat/completions".format(api_base) - elif endpoint_type == "embeddings" and custom_endpoint is not True: - api_base = "{}/embeddings".format(api_base) - return api_base, headers - - async def acompletion_stream_function( - self, - model: str, - messages: list, - custom_llm_provider: str, - api_base: str, - custom_prompt_dict: dict, - model_response: ModelResponse, - print_verbose: Callable, - encoding, - api_key, - logging_obj, - stream, - data: dict, - optional_params=None, - litellm_params=None, - logger_fn=None, - headers={}, - client: Optional[AsyncHTTPHandler] = None, - streaming_decoder: Optional[CustomStreamingDecoder] = None, - ) -> CustomStreamWrapper: - - data["stream"] = True - completion_stream = await make_call( - client=client, - api_base=api_base, - headers=headers, - data=json.dumps(data), - model=model, - messages=messages, - logging_obj=logging_obj, - streaming_decoder=streaming_decoder, - ) - streamwrapper = CustomStreamWrapper( - completion_stream=completion_stream, - model=model, - custom_llm_provider=custom_llm_provider, - logging_obj=logging_obj, - ) - return streamwrapper - - async def acompletion_function( - self, - model: str, - messages: list, - api_base: str, - custom_prompt_dict: dict, - model_response: ModelResponse, - custom_llm_provider: str, - print_verbose: Callable, - encoding, - api_key, - logging_obj, - stream, - data: dict, - base_model: Optional[str], - optional_params: dict, - litellm_params=None, - logger_fn=None, - headers={}, - timeout: Optional[Union[float, httpx.Timeout]] = None, - ) -> ModelResponse: - if timeout is None: - timeout = httpx.Timeout(timeout=600.0, connect=5.0) - - self.async_handler = get_async_httpx_client( - llm_provider=litellm.LlmProviders.DATABRICKS, - params={"timeout": timeout}, - ) - - try: - response = await self.async_handler.post( - api_base, headers=headers, data=json.dumps(data) - ) - response.raise_for_status() - - response_json = response.json() - except httpx.HTTPStatusError as e: - raise DatabricksError( - status_code=e.response.status_code, - message=e.response.text, - ) - except httpx.TimeoutException: - raise DatabricksError(status_code=408, message="Timeout error occurred.") - except Exception as e: - raise DatabricksError(status_code=500, message=str(e)) - - logging_obj.post_call( - input=messages, - api_key="", - original_response=response_json, - additional_args={"complete_input_dict": data}, - ) - response = ModelResponse(**response_json) - - response.model = custom_llm_provider + "/" + (response.model or "") - - if base_model is not None: - response._hidden_params["model"] = base_model - return response - - def completion( - self, - model: str, - messages: list, - api_base: str, - custom_llm_provider: str, - custom_prompt_dict: dict, - model_response: ModelResponse, - print_verbose: Callable, - encoding, - api_key: Optional[str], - logging_obj, - optional_params: dict, - acompletion=None, - litellm_params=None, - logger_fn=None, - headers: Optional[dict] = None, - timeout: Optional[Union[float, httpx.Timeout]] = None, - client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, - custom_endpoint: Optional[bool] = None, - streaming_decoder: Optional[ - CustomStreamingDecoder - ] = None, # if openai-compatible api needs custom stream decoder - e.g. sagemaker - ): - custom_endpoint = custom_endpoint or optional_params.pop( - "custom_endpoint", None - ) - base_model: Optional[str] = optional_params.pop("base_model", None) - api_base, headers = self._validate_environment( - api_base=api_base, - api_key=api_key, - endpoint_type="chat_completions", - custom_endpoint=custom_endpoint, - headers=headers, - ) - ## Load Config - config = litellm.DatabricksConfig().get_config() - for k, v in config.items(): - if ( - k not in optional_params - ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in - optional_params[k] = v - - stream: bool = optional_params.get("stream", None) or False - optional_params.pop( - "max_retries", None - ) # [TODO] add max retry support at llm api call level - optional_params["stream"] = stream - - if messages is not None and custom_llm_provider is not None: - provider_config = ProviderConfigManager.get_provider_config( - model=model, provider=LlmProviders(custom_llm_provider) - ) - messages = provider_config._transform_messages(messages) - - data = { - "model": model, - "messages": messages, - **optional_params, - } - - ## LOGGING - logging_obj.pre_call( - input=messages, - api_key=api_key, - additional_args={ - "complete_input_dict": data, - "api_base": api_base, - "headers": headers, - }, - ) - if acompletion is True: - if client is not None and isinstance(client, HTTPHandler): - client = None - if ( - stream is not None and stream is True - ): # if function call - fake the streaming (need complete blocks for output parsing in openai format) - print_verbose("makes async anthropic streaming POST request") - data["stream"] = stream - return self.acompletion_stream_function( - model=model, - messages=messages, - data=data, - api_base=api_base, - custom_prompt_dict=custom_prompt_dict, - model_response=model_response, - print_verbose=print_verbose, - encoding=encoding, - api_key=api_key, - logging_obj=logging_obj, - optional_params=optional_params, - stream=stream, - litellm_params=litellm_params, - logger_fn=logger_fn, - headers=headers, - client=client, - custom_llm_provider=custom_llm_provider, - streaming_decoder=streaming_decoder, - ) - else: - return self.acompletion_function( - model=model, - messages=messages, - data=data, - api_base=api_base, - custom_prompt_dict=custom_prompt_dict, - custom_llm_provider=custom_llm_provider, - model_response=model_response, - print_verbose=print_verbose, - encoding=encoding, - api_key=api_key, - logging_obj=logging_obj, - optional_params=optional_params, - stream=stream, - litellm_params=litellm_params, - logger_fn=logger_fn, - headers=headers, - timeout=timeout, - base_model=base_model, - ) - else: - ## COMPLETION CALL - if stream is True: - completion_stream = make_sync_call( - client=( - client - if client is not None and isinstance(client, HTTPHandler) - else None - ), - api_base=api_base, - headers=headers, - data=json.dumps(data), - model=model, - messages=messages, - logging_obj=logging_obj, - streaming_decoder=streaming_decoder, - ) - # completion_stream.__iter__() - return CustomStreamWrapper( - completion_stream=completion_stream, - model=model, - custom_llm_provider=custom_llm_provider, - logging_obj=logging_obj, - ) - else: - if client is None or not isinstance(client, HTTPHandler): - client = HTTPHandler(timeout=timeout) # type: ignore - try: - response = client.post( - api_base, headers=headers, data=json.dumps(data) - ) - response.raise_for_status() - - response_json = response.json() - except httpx.HTTPStatusError as e: - raise DatabricksError( - status_code=e.response.status_code, - message=e.response.text, - ) - except httpx.TimeoutException: - raise DatabricksError( - status_code=408, message="Timeout error occurred." - ) - except Exception as e: - raise DatabricksError(status_code=500, message=str(e)) - - response = ModelResponse(**response_json) - - response.model = custom_llm_provider + "/" + (response.model or "") - - if base_model is not None: - response._hidden_params["model"] = base_model - - return response - - async def aembedding( - self, - input: list, - data: dict, - model_response: ModelResponse, - timeout: float, - api_key: str, - api_base: str, - logging_obj, - headers: dict, - client=None, - ) -> EmbeddingResponse: - response = None - try: - if client is None or isinstance(client, AsyncHTTPHandler): - self.async_client = get_async_httpx_client( - llm_provider=litellm.LlmProviders.DATABRICKS, - params={"timeout": timeout}, - ) - else: - self.async_client = client - - try: - response = await self.async_client.post( - api_base, - headers=headers, - data=json.dumps(data), - ) # type: ignore - - response.raise_for_status() - - response_json = response.json() - except httpx.HTTPStatusError as e: - raise DatabricksError( - status_code=e.response.status_code, - message=response.text if response else str(e), - ) - except httpx.TimeoutException: - raise DatabricksError( - status_code=408, message="Timeout error occurred." - ) - except Exception as e: - raise DatabricksError(status_code=500, message=str(e)) - - ## LOGGING - logging_obj.post_call( - input=input, - api_key=api_key, - additional_args={"complete_input_dict": data}, - original_response=response_json, - ) - return EmbeddingResponse(**response_json) - except Exception as e: - ## LOGGING - logging_obj.post_call( - input=input, - api_key=api_key, - original_response=str(e), - ) - raise e - - def embedding( - self, - model: str, - input: list, - timeout: float, - logging_obj, - api_key: Optional[str], - api_base: Optional[str], - optional_params: dict, - model_response: Optional[litellm.utils.EmbeddingResponse] = None, - client=None, - aembedding=None, - headers: Optional[dict] = None, - ) -> EmbeddingResponse: - api_base, headers = self._validate_environment( - api_base=api_base, - api_key=api_key, - endpoint_type="embeddings", - custom_endpoint=False, - headers=headers, - ) - model = model - data = {"model": model, "input": input, **optional_params} - - ## LOGGING - logging_obj.pre_call( - input=input, - api_key=api_key, - additional_args={"complete_input_dict": data, "api_base": api_base}, - ) - - if aembedding is True: - return self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, headers=headers) # type: ignore - if client is None or isinstance(client, AsyncHTTPHandler): - self.client = HTTPHandler(timeout=timeout) # type: ignore - else: - self.client = client - - ## EMBEDDING CALL - try: - response = self.client.post( - api_base, - headers=headers, - data=json.dumps(data), - ) # type: ignore - - response.raise_for_status() # type: ignore - - response_json = response.json() # type: ignore - except httpx.HTTPStatusError as e: - raise DatabricksError( - status_code=e.response.status_code, - message=e.response.text, - ) - except httpx.TimeoutException: - raise DatabricksError(status_code=408, message="Timeout error occurred.") - except Exception as e: - raise DatabricksError(status_code=500, message=str(e)) - - ## LOGGING - logging_obj.post_call( - input=input, - api_key=api_key, - additional_args={"complete_input_dict": data}, - original_response=response_json, - ) - - return litellm.EmbeddingResponse(**response_json) diff --git a/litellm/llms/openai_like/chat/handler.py b/litellm/llms/openai_like/chat/handler.py index baa9703049a5..831051a2c285 100644 --- a/litellm/llms/openai_like/chat/handler.py +++ b/litellm/llms/openai_like/chat/handler.py @@ -277,7 +277,7 @@ def completion( optional_params["stream"] = stream if messages is not None and custom_llm_provider is not None: - provider_config = ProviderConfigManager.get_provider_config( + provider_config = ProviderConfigManager.get_provider_chat_config( model=model, provider=LlmProviders(custom_llm_provider) ) messages = provider_config._transform_messages(messages) diff --git a/litellm/llms/together_ai/completion/handler.py b/litellm/llms/together_ai/completion/handler.py index fab2a39c571f..fac879447334 100644 --- a/litellm/llms/together_ai/completion/handler.py +++ b/litellm/llms/together_ai/completion/handler.py @@ -12,7 +12,7 @@ from litellm.types.llms.openai import AllMessageValues, OpenAITextCompletionUserMessage from litellm.utils import ModelResponse -from ...OpenAI.openai import OpenAITextCompletion +from ...OpenAI.completion.handler import OpenAITextCompletion from .transformation import TogetherAITextCompletionConfig together_ai_text_completion_global_config = TogetherAITextCompletionConfig() diff --git a/litellm/llms/together_ai/completion/transformation.py b/litellm/llms/together_ai/completion/transformation.py index 65b9ad69bfac..6ec855de8bae 100644 --- a/litellm/llms/together_ai/completion/transformation.py +++ b/litellm/llms/together_ai/completion/transformation.py @@ -15,7 +15,7 @@ OpenAITextCompletionUserMessage, ) -from ...OpenAI.openai import OpenAITextCompletionConfig +from ...OpenAI.completion.transformation import OpenAITextCompletionConfig class TogetherAITextCompletionConfig(OpenAITextCompletionConfig): diff --git a/litellm/llms/xai/chat/xai_transformation.py b/litellm/llms/xai/chat/transformation.py similarity index 97% rename from litellm/llms/xai/chat/xai_transformation.py rename to litellm/llms/xai/chat/transformation.py index 3bd41ed90730..ac3c4236511a 100644 --- a/litellm/llms/xai/chat/xai_transformation.py +++ b/litellm/llms/xai/chat/transformation.py @@ -22,8 +22,6 @@ def get_supported_openai_params(self, model: str) -> list: "logit_bias", "logprobs", "max_tokens", - "messages", - "model", "n", "presence_penalty", "response_format", diff --git a/litellm/main.py b/litellm/main.py index f574b9339c9b..b640707f8445 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -121,7 +121,8 @@ from .llms.huggingface_restapi import Huggingface from .llms.OpenAI.audio_transcriptions import OpenAIAudioTranscription from .llms.OpenAI.chat.o1_handler import OpenAIO1ChatCompletion -from .llms.OpenAI.openai import OpenAIChatCompletion, OpenAITextCompletion +from .llms.OpenAI.completion.handler import OpenAITextCompletion +from .llms.OpenAI.openai import OpenAIChatCompletion from .llms.openai_like.embedding.handler import OpenAILikeEmbeddingHandler from .llms.predibase import PredibaseChatCompletion from .llms.prompt_templates.common_utils import get_completion_messages diff --git a/litellm/utils.py b/litellm/utils.py index bd36e211cff7..a400c3038a65 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -6220,14 +6220,14 @@ def validate_chat_completion_user_messages(messages: List[AllMessageValues]): return messages -from litellm.llms.OpenAI.chat.gpt_transformation import OpenAIGPTConfig +from litellm.llms.base_llm.transformation import BaseConfig class ProviderConfigManager: @staticmethod - def get_provider_config( + def get_provider_chat_config( model: str, provider: litellm.LlmProviders - ) -> OpenAIGPTConfig: + ) -> BaseConfig: """ Returns the provider config for a given provider. """ @@ -6239,8 +6239,12 @@ def get_provider_config( return litellm.GroqChatConfig() elif litellm.LlmProviders.DATABRICKS == provider: return litellm.DatabricksConfig() + elif litellm.LlmProviders.XAI == provider: + return litellm.XAIChatConfig() + elif litellm.LlmProviders.TEXT_COMPLETION_OPENAI == provider: + return litellm.OpenAITextCompletionConfig() - return OpenAIGPTConfig() + return litellm.OpenAIGPTConfig() def get_end_user_id_for_cost_tracking( diff --git a/tests/llm_translation/test_xai.py b/tests/llm_translation/test_xai.py index 3701d39ce9a6..cbe0beab0866 100644 --- a/tests/llm_translation/test_xai.py +++ b/tests/llm_translation/test_xai.py @@ -17,7 +17,7 @@ from litellm import Choices, Message, ModelResponse, EmbeddingResponse, Usage from litellm import completion from unittest.mock import patch -from litellm.llms.xai.chat.xai_transformation import XAIChatConfig, XAI_API_BASE +from litellm.llms.xai.chat.transformation import XAIChatConfig, XAI_API_BASE def test_xai_chat_config_get_openai_compatible_provider_info(): diff --git a/tests/local_testing/test_config.py b/tests/local_testing/test_config.py index 28d144e4dc70..065ba59d64a1 100644 --- a/tests/local_testing/test_config.py +++ b/tests/local_testing/test_config.py @@ -288,3 +288,30 @@ async def _monkey_patch_get_config(*args, **kwargs): assert len(llm_router.model_list) == len(model_list) else: assert len(llm_router.model_list) == len(model_list) + prev_llm_router_val + + +# def test_provider_config_manager(): +# from litellm import LITELLM_CHAT_PROVIDERS, LlmProviders +# from litellm.utils import ProviderConfigManager +# from litellm.llms.base_llm.transformation import BaseConfig +# from litellm.llms.OpenAI.chat.gpt_transformation import OpenAIGPTConfig + +# for provider in LITELLM_CHAT_PROVIDERS: +# assert isinstance( +# ProviderConfigManager.get_provider_chat_config( +# model="gpt-3.5-turbo", provider=LlmProviders(provider) +# ), +# BaseConfig, +# ), f"Provider {provider} is not a subclass of BaseConfig" + +# if ( +# provider != litellm.LlmProviders.OPENAI +# and provider != litellm.LlmProviders.OPENAI_LIKE +# and provider != litellm.LlmProviders.CUSTOM_OPENAI +# ): +# config = ProviderConfigManager.get_provider_chat_config( +# model="gpt-3.5-turbo", provider=LlmProviders(provider) +# ) +# assert ( +# config.__class__.__name__ != "OpenAIGPTConfig" +# ), f"Provider {provider} is an instance of OpenAIGPTConfig"