Skip to content

Commit

Permalink
refactor and add more options
Browse files Browse the repository at this point in the history
  • Loading branch information
ej52 committed Nov 9, 2023
1 parent 08f8a89 commit 2ba7c58
Show file tree
Hide file tree
Showing 10 changed files with 329 additions and 179 deletions.
158 changes: 96 additions & 62 deletions custom_components/ollama_conversation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,44 +5,52 @@
"""
from __future__ import annotations

import json
from typing import Literal

from homeassistant.components import conversation
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import MATCH_ALL
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryNotReady, ConfigEntryAuthFailed, TemplateError
from homeassistant.exceptions import ConfigEntryNotReady, HomeAssistantError, TemplateError
from homeassistant.helpers import intent, template
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.util import ulid

from .api import (
OllamaApiClient,
OllamaApiClientAuthenticationError,
OllamaApiClientError,
)
from .api import OllamaApiClient
from .const import (
DOMAIN, LOGGER,

CONF_BASE_URL,
CONF_CHAT_MODEL,
CONF_PROMPT,
CONF_TOP_K,
CONF_TOP_P,
CONF_MODEL,
CONF_CTX_SIZE,
CONF_MAX_TOKENS,
CONF_MIROSTAT_MODE,
CONF_MIROSTAT_ETA,
CONF_MIROSTAT_TAU,
CONF_TEMPERATURE,
CONF_TOP_K,
CONF_TOP_P,
CONF_PROMPT_SYSTEM,

DEFAULT_CHAT_MODEL,
DEFAULT_PROMPT,
DEFAULT_MODEL,
DEFAULT_CTX_SIZE,
DEFAULT_MAX_TOKENS,
DEFAULT_MIROSTAT_MODE,
DEFAULT_MIROSTAT_ETA,
DEFAULT_MIROSTAT_TAU,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_K,
DEFAULT_TOP_P
DEFAULT_TOP_P,
DEFAULT_PROMPT_SYSTEM
)
from .coordinator import OllamaDataUpdateCoordinator
from .exceptions import (
ApiClientError,
ApiCommError,
ApiJsonError,
ApiTimeoutError
)
from .helpers import get_exposed_entities

# https://developers.home-assistant.io/docs/config_entries_index/#setting-up-an-entry
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
Expand All @@ -63,10 +71,8 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
try:
response = await client.async_get_heartbeat()
if not response:
raise OllamaApiClientError("Invalid Ollama server")
except OllamaApiClientAuthenticationError as exception:
raise ConfigEntryAuthFailed(exception) from exception
except OllamaApiClientError as err:
raise ApiClientError("Invalid Ollama server")
except ApiClientError as err:
raise ConfigEntryNotReady(err) from err

entry.async_on_unload(entry.add_update_listener(async_reload_entry))
Expand Down Expand Up @@ -95,12 +101,7 @@ def __init__(self, hass: HomeAssistant, entry: ConfigEntry, client: OllamaApiCli
self.hass = hass
self.entry = entry
self.client = client
self.history: dict[str, list[dict]] = {}

@property
def attribution(self):
"""Return the attribution."""
return {"name": "Powered by Ollama", "url": "https://github.com/ej52/hass-ollama-conversation"}
self.history: dict[str, dict] = {}

@property
def supported_languages(self) -> list[str] | Literal["*"]:
Expand All @@ -111,73 +112,106 @@ async def async_process(
self, user_input: conversation.ConversationInput
) -> conversation.ConversationResult:
"""Process a sentence."""

intent_response = intent.IntentResponse(language=user_input.language)

model = self.entry.options.get(CONF_CHAT_MODEL, DEFAULT_CHAT_MODEL)
prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
raw_system_prompt = self.entry.options.get(CONF_PROMPT_SYSTEM, DEFAULT_PROMPT_SYSTEM)
exposed_entities = get_exposed_entities(self.hass)

if user_input.conversation_id in self.history:
conversation_id = user_input.conversation_id
context = self.history[conversation_id]
messages = self.history[conversation_id]
else:
conversation_id = ulid.ulid()
context = None
try:
system_prompt = self._async_generate_prompt(raw_system_prompt, exposed_entities)
except TemplateError as err:
LOGGER.error("Error rendering system prompt: %s", err)
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"I had a problem with my system prompt, please check the logs for more information.",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)
messages = {
"system": system_prompt,
"context": None,
}

messages["prompt"] = user_input.text

try:
system_prompt = self._async_generate_prompt(prompt)
except TemplateError as err:
LOGGER.error("Error rendering prompt: %s", err)
response = await self.query(messages)
except (
ApiCommError,
ApiJsonError,
ApiTimeoutError
) as err:
LOGGER.error("Error generating prompt: %s", err)
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem with my template: {err}",
f"Something went wrong, {err}",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)

payload = {
"model": model,
"context": context,
"system": system_prompt,
"prompt": user_input.text,
"stream": False,
"options": {
"top_k": self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K),
"top_p": self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P),
"num_ctx": self.entry.options.get(CONF_CTX_SIZE, DEFAULT_CTX_SIZE),
"num_predict": self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS),
"temperature": self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
}
}

LOGGER.debug("Prompt for %s: %s", model, json.dumps(payload))

try:
result = await self.client.async_generate(payload)
except OllamaApiClientError as err:
except HomeAssistantError as err:
LOGGER.error("Something went wrong: %s", err)
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Sorry, I had a problem talking to the Ollama server: {err}",
f"Something went wrong, please check the logs for more information.",
)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)

LOGGER.debug("Response %s", json.dumps(result))

self.history[conversation_id] = result["context"]
intent_response.async_set_speech(result["response"])
messages["context"] = response["context"]
self.history[conversation_id] = messages

intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_speech(response["response"])
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)

def _async_generate_prompt(self, raw_prompt: str) -> str:
def _async_generate_prompt(self, raw_prompt: str, exposed_entities) -> str:
"""Generate a prompt for the user."""
return template.Template(raw_prompt, self.hass).async_render(
{
"ha_name": self.hass.config.location_name,
"exposed_entities": exposed_entities,
},
parse_result=False,
)

async def query(
self,
messages
):
"""Process a sentence."""
model = self.entry.options.get(CONF_MODEL, DEFAULT_MODEL)

LOGGER.debug("Prompt for %s: %s", model, messages["prompt"])

result = await self.client.async_generate({
"model": model,
"context": messages["context"],
"system": messages["system"],
"prompt": messages["prompt"],
"stream": False,
"options": {
"mirostat": int(self.entry.options.get(CONF_MIROSTAT_MODE, DEFAULT_MIROSTAT_MODE)),
"mirostat_eta": self.entry.options.get(CONF_MIROSTAT_ETA, DEFAULT_MIROSTAT_ETA),
"mirostat_tau": self.entry.options.get(CONF_MIROSTAT_TAU, DEFAULT_MIROSTAT_TAU),
"top_k": self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K),
"top_p": self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P),
"num_ctx": self.entry.options.get(CONF_CTX_SIZE, DEFAULT_CTX_SIZE),
"num_predict": self.entry.options.get(CONF_MAX_TOKENS, DEFAULT_MAX_TOKENS),
"temperature": self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
}
})

response: str = result["response"]
LOGGER.debug("Response %s", response)
return result
59 changes: 22 additions & 37 deletions custom_components/ollama_conversation/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,16 @@
import async_timeout

from .const import TIMEOUT


class OllamaApiClientError(Exception):
"""Exception to indicate a general API error."""


class OllamaApiClientCommunicationError(
OllamaApiClientError
):
"""Exception to indicate a communication error."""


class OllamaApiClientAuthenticationError(
OllamaApiClientError
):
"""Exception to indicate an authentication error."""
from .exceptions import (
ApiClientError,
ApiCommError,
ApiJsonError,
ApiTimeoutError
)


class OllamaApiClient:
"""Sample API Client."""
"""Ollama API Client."""

def __init__(
self,
Expand All @@ -40,10 +30,10 @@ def __init__(

async def async_get_heartbeat(self) -> bool:
"""Get heartbeat from the API."""
response = await self._api_wrapper(
response: str = await self._api_wrapper(
method="get", url=self._base_url, decode_json=False
)
return response == "Ollama is running"
return response.strip() == "Ollama is running"

async def async_get_models(self) -> any:
"""Get models from the API."""
Expand Down Expand Up @@ -78,28 +68,23 @@ async def _api_wrapper(
method=method,
url=url,
headers=headers,
raise_for_status=True,
json=data,
)

if response.status in (401, 403):
raise OllamaApiClientAuthenticationError(
"Invalid credentials",
)
if response.status == 404 and decode_json:
json = await response.json()
raise ApiJsonError(json["error"])

response.raise_for_status()

if decode_json:
return await response.json()
return await response.text()

except asyncio.TimeoutError as exception:
raise OllamaApiClientCommunicationError(
"Timeout error fetching information",
) from exception
except (aiohttp.ClientError, socket.gaierror) as exception:
raise OllamaApiClientCommunicationError(
"Error fetching information",
) from exception
except Exception as exception: # pylint: disable=broad-except
raise OllamaApiClientError(
"Something really wrong happened!"
) from exception
except ApiJsonError as e:
raise e
except asyncio.TimeoutError as e:
raise ApiTimeoutError("timeout while talking to the server") from e
except (aiohttp.ClientError, socket.gaierror) as e:
raise ApiCommError("unknown error while talking to the server") from e
except Exception as e: # pylint: disable=broad-except
raise ApiClientError("something really went wrong!") from e
Loading

0 comments on commit 2ba7c58

Please sign in to comment.