Skip to content

Commit

Permalink
Litellm code qa common config (#7116)
Browse files Browse the repository at this point in the history
* feat(base_llm): initial commit for common base config class

Addresses code qa critique andrewyng/aisuite#113 (comment)

* feat(base_llm/): add transform request/response abstract methods to base config class

---------

Co-authored-by: Krrish Dholakia <[email protected]>
  • Loading branch information
ishaan-jaff and krrishdholakia authored Dec 9, 2024
1 parent 3161619 commit 0fa05b1
Show file tree
Hide file tree
Showing 19 changed files with 851 additions and 1,131 deletions.
5 changes: 3 additions & 2 deletions litellm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
DEFAULT_FLUSH_INTERVAL_SECONDS,
ROUTER_MAX_FALLBACKS,
DEFAULT_MAX_RETRIES,
LITELLM_CHAT_PROVIDERS,
)
from litellm.types.guardrails import GuardrailItem
from litellm.proxy._types import (
Expand Down Expand Up @@ -1128,10 +1129,10 @@ class LlmProviders(str, Enum):
from .llms.bedrock.embed.cohere_transformation import BedrockCohereEmbeddingConfig
from .llms.OpenAI.openai import (
OpenAIConfig,
OpenAITextCompletionConfig,
MistralEmbeddingConfig,
DeepInfraConfig,
)
from litellm.llms.OpenAI.completion.transformation import OpenAITextCompletionConfig
from .llms.groq.chat.transformation import GroqChatConfig
from .llms.azure_ai.chat.transformation import AzureAIStudioConfig
from .llms.mistral.mistral_chat_transformation import MistralConfig
Expand Down Expand Up @@ -1165,7 +1166,7 @@ class LlmProviders(str, Enum):
FireworksAIEmbeddingConfig,
)
from .llms.jina_ai.embedding.transformation import JinaAIEmbeddingConfig
from .llms.xai.chat.xai_transformation import XAIChatConfig
from .llms.xai.chat.transformation import XAIChatConfig
from .llms.volcengine import VolcEngineConfig
from .llms.text_completion_codestral import MistralTextCompletionConfig
from .llms.azure.azure import (
Expand Down
64 changes: 64 additions & 0 deletions litellm/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,67 @@
DEFAULT_BATCH_SIZE = 512
DEFAULT_FLUSH_INTERVAL_SECONDS = 5
DEFAULT_MAX_RETRIES = 2
LITELLM_CHAT_PROVIDERS = [
"openai",
"openai_like",
"xai",
"custom_openai",
"text-completion-openai",
"cohere",
"cohere_chat",
"clarifai",
"anthropic",
"replicate",
"huggingface",
"together_ai",
"openrouter",
"vertex_ai",
"vertex_ai_beta",
"palm",
"gemini",
"ai21",
"baseten",
"azure",
"azure_text",
"azure_ai",
"sagemaker",
"sagemaker_chat",
"bedrock",
"vllm",
"nlp_cloud",
"petals",
"oobabooga",
"ollama",
"ollama_chat",
"deepinfra",
"perplexity",
"anyscale",
"mistral",
"groq",
"nvidia_nim",
"cerebras",
"ai21_chat",
"volcengine",
"codestral",
"text-completion-codestral",
"deepseek",
"sambanova",
"maritalk",
"voyage",
"cloudflare",
"xinference",
"fireworks_ai",
"friendliai",
"watsonx",
"watsonx_text",
"triton",
"predibase",
"databricks",
"empower",
"github",
"custom",
"litellm_proxy",
"hosted_vllm",
"lm_studio",
"galadriel",
]
26 changes: 3 additions & 23 deletions litellm/litellm_core_utils/get_supported_openai_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,27 +255,7 @@ def get_supported_openai_params( # noqa: PLR0915
elif custom_llm_provider == "watsonx":
return litellm.IBMWatsonXChatConfig().get_supported_openai_params(model=model)
elif custom_llm_provider == "custom_openai" or "text-completion-openai":
return [
"functions",
"function_call",
"temperature",
"top_p",
"n",
"stream",
"stream_options",
"stop",
"max_tokens",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
"response_format",
"seed",
"tools",
"tool_choice",
"max_retries",
"logprobs",
"top_logprobs",
"extra_headers",
]
return litellm.OpenAITextCompletionConfig().get_supported_openai_params(
model=model
)
return None
71 changes: 69 additions & 2 deletions litellm/llms/OpenAI/chat/gpt_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,26 @@
"""

import types
from typing import List, Optional, Union
from typing import TYPE_CHECKING, Any, List, Optional, Union, cast

import httpx

import litellm
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage
from litellm.types.utils import ModelResponse

from ..common_utils import OpenAIError

if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj

LoggingClass = LiteLLMLoggingObj
else:
LoggingClass = Any


class OpenAIGPTConfig:
class OpenAIGPTConfig(BaseConfig):
"""
Reference: https://platform.openai.com/docs/api-reference/chat/create
Expand Down Expand Up @@ -168,3 +181,57 @@ def _transform_messages(
self, messages: List[AllMessageValues]
) -> List[AllMessageValues]:
return messages

def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
"""
Transform the overall request to be sent to the API.
Returns:
dict: The transformed request. Sent as the body of the API call.
"""
raise NotImplementedError

def transform_response(
self,
model: str,
raw_response: dict,
model_response: ModelResponse,
logging_obj: LoggingClass,
api_key: str,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
encoding: str,
) -> ModelResponse:
"""
Transform the response from the API.
Returns:
dict: The transformed response.
"""
raise NotImplementedError

def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return OpenAIError(
status_code=status_code,
message=error_message,
headers=cast(httpx.Headers, headers),
)

def validate_environment(
self,
api_key: str,
headers: dict,
model: str,
messages: List[AllMessageValues],
) -> dict:
raise NotImplementedError
36 changes: 35 additions & 1 deletion litellm/llms/OpenAI/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,44 @@
"""

import json
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional

import httpx
import openai

from litellm.llms.base_llm.transformation import BaseLLMException


class OpenAIError(BaseLLMException):
def __init__(
self,
status_code: int,
message: str,
request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None,
headers: Optional[httpx.Headers] = None,
):
self.status_code = status_code
self.message = message
self.headers = headers
if request:
self.request = request
else:
self.request = httpx.Request(method="POST", url="https://api.openai.com/v1")
if response:
self.response = response
else:
self.response = httpx.Response(
status_code=status_code, request=self.request
)
super().__init__(
status_code=status_code,
message=self.message,
headers=self.headers,
request=self.request,
response=self.response,
)


####### Error Handling Utils for OpenAI API #######################
###################################################################
Expand Down
Loading

0 comments on commit 0fa05b1

Please sign in to comment.