diff --git a/libs/langchain/langchain/llms/cerebriumai.py b/libs/langchain/langchain/llms/cerebriumai.py index 0a162f5dfeaa4..75c7c7b5fa701 100644 --- a/libs/langchain/langchain/llms/cerebriumai.py +++ b/libs/langchain/langchain/llms/cerebriumai.py @@ -1,13 +1,13 @@ import logging -from typing import Any, Dict, List, Mapping, Optional +from typing import Any, Dict, List, Mapping, Optional, cast import requests -from langchain_core.pydantic_v1 import Extra, Field, root_validator +from langchain_core.pydantic_v1 import Extra, Field, SecretStr, root_validator from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens -from langchain.utils import get_from_dict_or_env +from langchain.utils import convert_to_secret_str, get_from_dict_or_env logger = logging.getLogger(__name__) @@ -15,8 +15,9 @@ class CerebriumAI(LLM): """CerebriumAI large language models. - To use, you should have the ``cerebrium`` python package installed, and the - environment variable ``CEREBRIUMAI_API_KEY`` set with your API key. + To use, you should have the ``cerebrium`` python package installed. + You should also have the environment variable ``CEREBRIUMAI_API_KEY`` + set with your API key or pass it as a named argument in the constructor. Any parameters that are valid to be passed to the call can be passed in, even if not explicitly saved on this class. @@ -25,7 +26,7 @@ class CerebriumAI(LLM): .. code-block:: python from langchain.llms import CerebriumAI - cerebrium = CerebriumAI(endpoint_url="") + cerebrium = CerebriumAI(endpoint_url="", cerebriumai_api_key="my-api-key") """ @@ -36,7 +37,7 @@ class CerebriumAI(LLM): """Holds any model parameters valid for `create` call not explicitly specified.""" - cerebriumai_api_key: Optional[str] = None + cerebriumai_api_key: Optional[SecretStr] = None class Config: """Configuration for this pydantic config.""" @@ -64,8 +65,8 @@ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - cerebriumai_api_key = get_from_dict_or_env( - values, "cerebriumai_api_key", "CEREBRIUMAI_API_KEY" + cerebriumai_api_key = convert_to_secret_str( + get_from_dict_or_env(values, "cerebriumai_api_key", "CEREBRIUMAI_API_KEY") ) values["cerebriumai_api_key"] = cerebriumai_api_key return values @@ -91,7 +92,9 @@ def _call( **kwargs: Any, ) -> str: headers: Dict = { - "Authorization": self.cerebriumai_api_key, + "Authorization": cast( + SecretStr, self.cerebriumai_api_key + ).get_secret_value(), "Content-Type": "application/json", } params = self.model_kwargs or {} diff --git a/libs/langchain/tests/integration_tests/llms/test_cerebrium.py b/libs/langchain/tests/integration_tests/llms/test_cerebriumai.py similarity index 100% rename from libs/langchain/tests/integration_tests/llms/test_cerebrium.py rename to libs/langchain/tests/integration_tests/llms/test_cerebriumai.py diff --git a/libs/langchain/tests/unit_tests/llms/test_cerebriumai.py b/libs/langchain/tests/unit_tests/llms/test_cerebriumai.py new file mode 100644 index 0000000000000..b7d343081dd98 --- /dev/null +++ b/libs/langchain/tests/unit_tests/llms/test_cerebriumai.py @@ -0,0 +1,33 @@ +"""Test CerebriumAI llm""" + + +from langchain_core.pydantic_v1 import SecretStr +from pytest import CaptureFixture, MonkeyPatch + +from langchain.llms.cerebriumai import CerebriumAI + + +def test_api_key_is_secret_string() -> None: + llm = CerebriumAI(cerebriumai_api_key="test-cerebriumai-api-key") + assert isinstance(llm.cerebriumai_api_key, SecretStr) + + +def test_api_key_masked_when_passed_via_constructor(capsys: CaptureFixture) -> None: + llm = CerebriumAI(cerebriumai_api_key="secret-api-key") + print(llm.cerebriumai_api_key, end="") + captured = capsys.readouterr() + + assert captured.out == "**********" + assert repr(llm.cerebriumai_api_key) == "SecretStr('**********')" + + +def test_api_key_masked_when_passed_from_env( + monkeypatch: MonkeyPatch, capsys: CaptureFixture +) -> None: + monkeypatch.setenv("CEREBRIUMAI_API_KEY", "secret-api-key") + llm = CerebriumAI() + print(llm.cerebriumai_api_key, end="") + captured = capsys.readouterr() + + assert captured.out == "**********" + assert repr(llm.cerebriumai_api_key) == "SecretStr('**********')"