Skip to content

Commit

Permalink
feat: mask api_key for konko (#14010)
Browse files Browse the repository at this point in the history
for #12165
  • Loading branch information
chyroc authored Jan 1, 2024
1 parent 62d32bd commit a4ae4bc
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 13 deletions.
28 changes: 16 additions & 12 deletions libs/community/langchain_community/chat_models/konko.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
)
from langchain_core.messages import AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.utils import get_from_dict_or_env
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env

from langchain_community.adapters.openai import (
convert_dict_to_message,
Expand Down Expand Up @@ -72,8 +72,8 @@ def is_lc_serializable(cls) -> bool:
"""What sampling temperature to use."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
openai_api_key: Optional[str] = None
konko_api_key: Optional[str] = None
openai_api_key: Optional[SecretStr] = None
konko_api_key: Optional[SecretStr] = None
request_timeout: Optional[Union[float, Tuple[float, float]]] = None
"""Timeout for requests to Konko completion API."""
max_retries: int = 6
Expand All @@ -88,8 +88,8 @@ def is_lc_serializable(cls) -> bool:
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["konko_api_key"] = get_from_dict_or_env(
values, "konko_api_key", "KONKO_API_KEY"
values["konko_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "konko_api_key", "KONKO_API_KEY")
)
try:
import konko
Expand Down Expand Up @@ -128,37 +128,41 @@ def _default_params(self) -> Dict[str, Any]:

@staticmethod
def get_available_models(
konko_api_key: Optional[str] = None,
openai_api_key: Optional[str] = None,
konko_api_key: Union[str, SecretStr, None] = None,
openai_api_key: Union[str, SecretStr, None] = None,
konko_api_base: str = DEFAULT_API_BASE,
) -> Set[str]:
"""Get available models from Konko API."""

# Try to retrieve the OpenAI API key if it's not passed as an argument
if not openai_api_key:
try:
openai_api_key = os.environ["OPENAI_API_KEY"]
openai_api_key = convert_to_secret_str(os.environ["OPENAI_API_KEY"])
except KeyError:
pass # It's okay if it's not set, we just won't use it
elif isinstance(openai_api_key, str):
openai_api_key = convert_to_secret_str(openai_api_key)

# Try to retrieve the Konko API key if it's not passed as an argument
if not konko_api_key:
try:
konko_api_key = os.environ["KONKO_API_KEY"]
konko_api_key = convert_to_secret_str(os.environ["KONKO_API_KEY"])
except KeyError:
raise ValueError(
"Konko API key must be passed as keyword argument or "
"set in environment variable KONKO_API_KEY."
)
elif isinstance(konko_api_key, str):
konko_api_key = convert_to_secret_str(konko_api_key)

models_url = f"{konko_api_base}/models"

headers = {
"Authorization": f"Bearer {konko_api_key}",
"Authorization": f"Bearer {konko_api_key.get_secret_value()}",
}

if openai_api_key:
headers["X-OpenAI-Api-Key"] = openai_api_key
headers["X-OpenAI-Api-Key"] = openai_api_key.get_secret_value()

models_response = requests.get(models_url, headers=headers)

Expand Down
44 changes: 43 additions & 1 deletion libs/community/tests/integration_tests/chat_models/test_konko.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,57 @@
"""Evaluate ChatKonko Interface."""
from typing import Any
from typing import Any, cast

import pytest
from langchain_core.callbacks import CallbackManager
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult
from langchain_core.pydantic_v1 import SecretStr
from pytest import CaptureFixture, MonkeyPatch

from langchain_community.chat_models.konko import ChatKonko
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler


def test_konko_key_masked_when_passed_from_env(
monkeypatch: MonkeyPatch, capsys: CaptureFixture
) -> None:
"""Test initialization with an API key provided via an env variable"""
monkeypatch.setenv("OPENAI_API_KEY", "test-openai-key")
monkeypatch.setenv("KONKO_API_KEY", "test-konko-key")

chat = ChatKonko()

print(chat.openai_api_key, end="")
captured = capsys.readouterr()
assert captured.out == "**********"

print(chat.konko_api_key, end="")
captured = capsys.readouterr()
assert captured.out == "**********"


def test_konko_key_masked_when_passed_via_constructor(
capsys: CaptureFixture,
) -> None:
"""Test initialization with an API key provided via the initializer"""
chat = ChatKonko(openai_api_key="test-openai-key", konko_api_key="test-konko-key")

print(chat.konko_api_key, end="")
captured = capsys.readouterr()
assert captured.out == "**********"

print(chat.konko_secret_key, end="")
captured = capsys.readouterr()
assert captured.out == "**********"


def test_uses_actual_secret_value_from_secret_str() -> None:
"""Test that actual secret is retrieved using `.get_secret_value()`."""
chat = ChatKonko(openai_api_key="test-openai-key", konko_api_key="test-konko-key")
assert cast(SecretStr, chat.konko_api_key).get_secret_value() == "test-openai-key"
assert cast(SecretStr, chat.konko_secret_key).get_secret_value() == "test-konko-key"


def test_konko_chat_test() -> None:
"""Evaluate basic ChatKonko functionality."""
chat_instance = ChatKonko(max_tokens=10)
Expand Down

0 comments on commit a4ae4bc

Please sign in to comment.