diff --git a/.github/workflows/publish-to-pypi.yml b/.github/workflows/publish-to-pypi.yml
index 34558a8..2999842 100644
--- a/.github/workflows/publish-to-pypi.yml
+++ b/.github/workflows/publish-to-pypi.yml
@@ -52,4 +52,4 @@ jobs:
path: dist/
- name: Publish release distributions to PyPI
- uses: pypa/gh-action-pypi-publish@release/v1
\ No newline at end of file
+ uses: pypa/gh-action-pypi-publish@release/v1
diff --git a/llama_assistant/agent.py b/llama_assistant/agent.py
index ebb65cf..21ec19f 100644
--- a/llama_assistant/agent.py
+++ b/llama_assistant/agent.py
@@ -1,5 +1,4 @@
-from typing import List, Set, Optional
-from collections import defaultdict
+from typing import List, Set, Optional, Dict
from llama_cpp import Llama
from llama_index.core import VectorStoreIndex
@@ -11,6 +10,7 @@
from llama_index.core.workflow import Event, StartEvent, StopEvent, Workflow, step
+
def convert_message_list_to_str(messages):
chat_history_str = ""
for message in messages:
@@ -18,23 +18,63 @@ def convert_message_list_to_str(messages):
chat_history_str += message["role"] + ": " + message["content"] + "\n"
else:
chat_history_str += message["role"] + ": " + message["content"]["text"] + "\n"
-
+
return chat_history_str
+
class SetupEvent(Event):
pass
+
class CondenseQueryEvent(Event):
condensed_query_str: str
+
class RetrievalEvent(Event):
nodes: List[NodeWithScore]
+
+class ChatHistory:
+ def __init__(self, max_history_size: int):
+ self.max_history_size = max_history_size
+ self.total_size = 0
+ self.chat_history = []
+
+ def add_message(self, message: dict):
+ if "content" in message and type(message["content"]) is list:
+ # multimodal model's message format
+ new_msg_size = len(message["content"][0]["text"].split())
+ else:
+ # text-only model's message format
+ new_msg_size = len(message["content"].split())
+
+ self.total_size += new_msg_size
+ self.chat_history.append(message)
+ while self.total_size > self.max_history_size:
+ oldest_msg = self.chat_history.pop(0)
+ if "content" in oldest_msg and type(oldest_msg["content"]) is list:
+ # multimodal model's message format
+ len_oldest_msg = len(oldest_msg["content"][0]["text"].split())
+ else:
+ # text-only model's message format
+ len_oldest_msg = len(oldest_msg["content"].split())
+ self.total_size -= len_oldest_msg
+
+ def get_chat_history(self):
+ return self.chat_history
+
+ def clear(self):
+ self.chat_history = []
+ self.total_size = 0
+
+ def __len__(self):
+ return len(self.chat_history)
+
+
class RAGAgent(Workflow):
SUMMARY_TEMPLATE = (
- "Given the chat history:\n"
+ "Given the converstation:\n"
"'''{chat_history_str}'''\n\n"
- "And the user asked the following question:{query_str}\n"
"Rewrite to a standalone question:\n"
)
@@ -45,49 +85,111 @@ class RAGAgent(Workflow):
"-----\n"
"Please write a response to the following question, using the above information if relevant:\n"
"{query_str}\n"
- )
- def __init__(self, embed_model_name: str, llm: Llama, timeout: int = 60, verbose: bool = False):
+ )
+ SYSTEM_PROMPT = {"role": "system", "content": "Generate short and simple response:"}
+
+ def __init__(
+ self,
+ generation_setting: Dict,
+ rag_setting: Dict,
+ llm: Llama,
+ timeout: int = 60,
+ verbose: bool = False,
+ ):
super().__init__(timeout=timeout, verbose=verbose)
- self.k = 3
+ self.generation_setting = generation_setting
+ self.context_len = generation_setting["context_len"]
+ # we want the retrieved context = our set value but not more than the context_len
+ self.retrieval_top_k = (self.context_len - rag_setting["chunk_size"]) // rag_setting[
+ "chunk_overlap"
+ ]
+ self.retrieval_top_k = min(max(1, self.retrieval_top_k), rag_setting["max_retrieval_top_k"])
self.search_index = None
self.retriever = None
- self.chat_history = []
+ # 1 token ~ 3/4 words, so we multiply the context_len by 0.7 to get the max number of words
+ # why not 0.75? because we want to be a bit more conservative
+ self.chat_history = ChatHistory(max_history_size=self.context_len * 0.7)
self.lookup_files = set()
- self.embed_model = HuggingFaceEmbedding(model_name=embed_model_name)
+ self.embed_model = HuggingFaceEmbedding(model_name=rag_setting["embed_model_name"])
Settings.embed_model = self.embed_model
- self.node_processor = SimilarityPostprocessor(similarity_cutoff=0.3)
+ Settings.chunk_size = rag_setting["chunk_size"]
+ Settings.chunk_overlap = rag_setting["chunk_overlap"]
+ self.node_processor = SimilarityPostprocessor(
+ similarity_cutoff=rag_setting["similarity_threshold"]
+ )
self.llm = llm
- def udpate_index(self, files: Optional[Set[str] ] = set()):
+ def update_index(self, files: Optional[Set[str]] = set()):
if not files:
print("No lookup files provided, clearing index...")
self.retriever = None
self.search_index = None
return
-
+
print("Indexing documents...")
- documents = SimpleDirectoryReader(input_files=files, recursive=True).load_data(show_progress=True, num_workers=1)
- page_num_tracker = defaultdict(int)
- for doc in documents:
- key = doc.metadata['file_path']
- doc.metadata['page_index'] = page_num_tracker[key]
- page_num_tracker[key] += 1
+ documents = SimpleDirectoryReader(input_files=files, recursive=True).load_data(
+ show_progress=True, num_workers=1
+ )
if self.search_index is None:
- self.search_index = VectorStoreIndex.from_documents(documents, embed_model=self.embed_model)
+ self.search_index = VectorStoreIndex.from_documents(
+ documents, embed_model=self.embed_model
+ )
else:
for doc in documents:
- self.search_index.insert(doc) # Add the new document to the index
+ self.search_index.insert(doc) # Add the new document to the index
+
+ self.retriever = self.search_index.as_retriever(similarity_top_k=self.retrieval_top_k)
+
+ def update_rag_setting(self, rag_setting: Dict):
+ if self.node_processor.similarity_cutoff != rag_setting["similarity_threshold"]:
+ self.node_processor = SimilarityPostprocessor(
+ similarity_cutoff=rag_setting["similarity_threshold"]
+ )
+
+ new_top_k = (self.context_len - rag_setting["chunk_size"]) // rag_setting["chunk_overlap"]
+ new_top_k = min(max(1, new_top_k), rag_setting["max_retrieval_top_k"])
+ if self.retrieval_top_k != new_top_k:
+ self.retrieval_top_k = new_top_k
+ if self.retriever:
+ self.retriever = self.search_index.as_retriever(
+ similarity_top_k=self.retrieval_top_k
+ )
+
+ if (
+ self.embed_model.model_name != rag_setting["embed_model_name"]
+ or Settings.chunk_size != rag_setting["chunk_size"]
+ or Settings.chunk_overlap != rag_setting["chunk_overlap"]
+ ):
+ self.embed_model = HuggingFaceEmbedding(model_name=rag_setting["embed_model_name"])
+ Settings.embed_model = self.embed_model
+ Settings.chunk_size = rag_setting["chunk_size"]
+ Settings.chunk_overlap = rag_setting["chunk_overlap"]
+
+ # reindex since those are the settings that affect the index
+ if self.lookup_files:
+ print("Re-indexing documents since the rag settings have changed...")
+ documents = SimpleDirectoryReader(
+ input_files=self.lookup_files, recursive=True
+ ).load_data(show_progress=True, num_workers=1)
+ self.search_index = VectorStoreIndex.from_documents(
+ documents, embed_model=self.embed_model
+ )
+ self.retriever = self.search_index.as_retriever(
+ similarity_top_k=self.retrieval_top_k
+ )
+
+ def update_generation_setting(self, generation_setting):
+ self.generation_setting = generation_setting
+ self.context_len = generation_setting["context_len"]
- self.retriever = self.search_index.as_retriever(similarity_top_k=self.k)
-
@step
async def setup(self, ctx: Context, ev: StartEvent) -> SetupEvent:
- # set frequetly used variables to context
+ # set frequently used variables to context
query_str = ev.query_str
image = ev.image
- lookup_files = ev.lookup_files
+ lookup_files = ev.lookup_files if ev.lookup_files else set()
streaming = ev.streaming
await ctx.set("query_str", query_str)
await ctx.set("image", image)
@@ -96,7 +198,7 @@ async def setup(self, ctx: Context, ev: StartEvent) -> SetupEvent:
# update index if needed
if lookup_files != self.lookup_files:
print("Different lookup files, updating index...")
- self.udpate_index(lookup_files)
+ self.update_index(lookup_files)
self.lookup_files = lookup_files.copy()
@@ -105,36 +207,48 @@ async def setup(self, ctx: Context, ev: StartEvent) -> SetupEvent:
@step
async def condense_history_to_query(self, ctx: Context, ev: SetupEvent) -> CondenseQueryEvent:
"""
- Condense the chat history and the query into a single query. Only used for retrieval.
+ Condense the chat history and the query into a single query. Only used for retrieval.
"""
query_str = await ctx.get("query_str")
formated_query = ""
-
+
if len(self.chat_history) > 0 or self.retriever is not None:
- chat_history_str = convert_message_list_to_str(self.chat_history)
- formated_query = self.SUMMARY_TEMPLATE.format(chat_history_str=chat_history_str, query_str=query_str)
+ self.chat_history.add_message({"role": "user", "content": query_str})
+ chat_history_str = convert_message_list_to_str(self.chat_history.get_chat_history())
+ # use llm to summarize the chat history into a single query,
+ # which is used to retrieve context from the documents
+ formated_query = self.SUMMARY_TEMPLATE.format(chat_history_str=chat_history_str)
+
history_summary = self.llm.create_chat_completion(
- messages=[{"role": "user", "content": formated_query}], stream=False
+ messages=[{"role": "user", "content": formated_query}],
+ stream=False,
+ top_k=self.generation_setting["top_k"],
+ top_p=self.generation_setting["top_p"],
+ temperature=self.generation_setting["temperature"],
)["choices"][0]["message"]["content"]
+
condensed_query = "Context:\n" + history_summary + "\nQuestion: " + query_str
+ # remove the last user message from the chat history
+ # later we will add the query with retrieved context (RAG)
+ self.chat_history.chat_history.pop()
else:
# if there is no history or no need for retrieval, return the query as is
condensed_query = query_str
return CondenseQueryEvent(condensed_query_str=condensed_query)
-
+
@step
async def retrieve(self, ctx: Context, ev: CondenseQueryEvent) -> RetrievalEvent:
- # retrieve from context
if not self.retriever:
return RetrievalEvent(nodes=[])
+ # retrieve from dropped documents
condensed_query_str = ev.condensed_query_str
nodes = await self.retriever.aretrieve(condensed_query_str)
nodes = self.node_processor.postprocess_nodes(nodes)
return RetrievalEvent(nodes=nodes)
-
+
def _prepare_query_with_context(
self,
query_str: str,
@@ -144,7 +258,7 @@ def _prepare_query_with_context(
if len(nodes) == 0:
return query_str
-
+
for idx, node in enumerate(nodes):
node_text = node.get_content(metadata_mode="llm")
node_context += f"\n{node_text}\n\n"
@@ -152,11 +266,11 @@ def _prepare_query_with_context(
formatted_query = self.CONTEXT_PROMPT_TEMPLATE.format(
node_context=node_context, query_str=query_str
)
-
+
return formatted_query
@step
- async def llm_response(self, ctx: Context, retrieval_ev: RetrievalEvent) -> StopEvent:
+ async def llm_response(self, ctx: Context, retrieval_ev: RetrievalEvent) -> StopEvent:
nodes = retrieval_ev.nodes
query_str = await ctx.get("query_str")
image = await ctx.get("image")
@@ -174,12 +288,18 @@ async def llm_response(self, ctx: Context, retrieval_ev: RetrievalEvent) -> Sto
else:
formated_message = {"role": "user", "content": query_with_ctx}
+ self.chat_history.add_message(formated_message)
+
response = self.llm.create_chat_completion(
- messages=self.chat_history+[formated_message], stream=streaming
+ messages=[self.SYSTEM_PROMPT] + self.chat_history.get_chat_history(),
+ stream=streaming,
+ top_k=self.generation_setting["top_k"],
+ top_p=self.generation_setting["top_p"],
+ temperature=self.generation_setting["temperature"],
)
- self.chat_history.append({"role": "user", "content": query_str})
+ # remove the query with context from the chat history,
+ self.chat_history.chat_history.pop()
+ # add the short query (without context) instead -> not to clutter the chat history
+ self.chat_history.add_message({"role": "user", "content": query_str})
return StopEvent(result=response)
-
-
-
\ No newline at end of file
diff --git a/llama_assistant/config.py b/llama_assistant/config.py
index 6ad8f41..9d6278b 100644
--- a/llama_assistant/config.py
+++ b/llama_assistant/config.py
@@ -111,6 +111,11 @@
},
]
+# generation setting
+context_len = 2048
+top_k = 40
+top_p = 0.95
+temperature = 0.2
home_dir = Path.home()
llama_assistant_dir = home_dir / "llama_assistant"
@@ -119,8 +124,12 @@
settings_file = llama_assistant_dir / "settings.json"
document_icon = "llama_assistant/resources/document_icon.png"
-# for RAG pipeline
+# RAG setting
embed_model_name = "BAAI/bge-base-en-v1.5"
+chunk_size = 256
+chunk_overlap = 128
+max_retrieval_top_k = 3
+similarity_threshold = 0.6
if custom_models_file.exists():
with open(custom_models_file, "r") as f:
diff --git a/llama_assistant/llama_assistant_app.py b/llama_assistant/llama_assistant_app.py
index 956b47a..912bf7a 100644
--- a/llama_assistant/llama_assistant_app.py
+++ b/llama_assistant/llama_assistant_app.py
@@ -97,6 +97,8 @@ def load_settings(self):
self.deinit_wake_word_detector()
self.current_text_model = self.settings.get("text_model")
self.current_multimodal_model = self.settings.get("multimodal_model")
+ self.generation_setting = self.settings.get("generation")
+ self.rag_setting = self.settings.get("rag")
def setup_global_shortcut(self):
try:
@@ -177,7 +179,7 @@ def on_submit(self):
for file_path in self.dropped_files:
self.remove_file_thumbnail(self.file_containers[file_path], file_path)
-
+
return
self.last_response = ""
@@ -189,7 +191,6 @@ def on_submit(self):
else:
QTimer.singleShot(100, lambda: self.process_text(message, self.dropped_files, "chat"))
-
def on_task_button_clicked(self):
button = self.sender()
task = button.text()
@@ -203,7 +204,7 @@ def process_text(self, message, file_paths, task="chat"):
self.clear_chat()
self.show_chat_box()
if task == "chat":
- prompt = message + " \n" + "Generate a short and simple response."
+ prompt = message
elif task == "Summarize":
prompt = f"Summarize the following text: {message}"
elif task == "Rephrase":
@@ -220,7 +221,13 @@ def process_text(self, message, file_paths, task="chat"):
self.start_cursor_pos = self.ui_manager.chat_box.textCursor().position()
- self.processing_thread = ProcessingThread(self.current_text_model, prompt, lookup_files=file_paths)
+ self.processing_thread = ProcessingThread(
+ self.current_text_model,
+ self.generation_setting,
+ self.rag_setting,
+ prompt,
+ lookup_files=file_paths,
+ )
self.processing_thread.update_signal.connect(self.update_chat_box)
self.processing_thread.finished_signal.connect(self.on_processing_finished)
self.processing_thread.start()
@@ -238,7 +245,12 @@ def process_image_with_prompt(self, image_path, file_paths, prompt):
image = image_to_base64_data_uri(image_path)
self.processing_thread = ProcessingThread(
- self.current_multimodal_model, prompt, image=image, lookup_files=file_paths
+ self.current_multimodal_model,
+ self.generation_setting,
+ self.rag_setting,
+ prompt,
+ image=image,
+ lookup_files=file_paths,
)
self.processing_thread.update_signal.connect(self.update_chat_box)
self.processing_thread.finished_signal.connect(self.on_processing_finished)
@@ -252,7 +264,9 @@ def update_chat_box(self, text):
markdown_response = markdown_response.replace("
", "").replace("
", "")
markdown_response += ""
cursor = self.ui_manager.chat_box.textCursor()
- cursor.setPosition(self.start_cursor_pos) # regenerate the updated text from the start position
+ cursor.setPosition(
+ self.start_cursor_pos
+ ) # regenerate the updated text from the start position
# Select all text from the start_pos to the end
cursor.movePosition(QTextCursor.End, QTextCursor.KeepAnchor)
# Remove the selected text
@@ -310,11 +324,12 @@ def dropEvent(self, event: QDropEvent):
elif file_path.lower().endswith((".pdf", "doc", ".docx", ".txt")):
if file_path not in self.dropped_files:
self.dropped_files.add(file_path)
- self.ui_manager.input_field.setPlaceholderText("Enter a prompt for the document...")
+ self.ui_manager.input_field.setPlaceholderText(
+ "Enter a prompt for the document..."
+ )
self.show_file_thumbnail(file_path)
else:
print(f"File {file_path} already added")
-
def remove_file_thumbnail(self, file_label, file_path):
file_label.setParent(None)
@@ -402,7 +417,7 @@ def show_file_thumbnail(self, file_path):
# Create a QLabel for the text
text_label = QLabel(file_path.split("/")[-1], container)
- # set text background color to white, text size to 5px
+ # set text background color to white, text size to 5px
# and rounded corners, vertical alignment to top
text_label.setStyleSheet(
"""
@@ -444,6 +459,7 @@ def show_file_thumbnail(self, file_path):
# Load and set the pixmap
import os
+
print("Icon path:", str(config.document_icon), os.path.exists(str(config.document_icon)))
pixmap = QPixmap(str(config.document_icon))
scaled_pixmap = pixmap.scaled(
@@ -460,7 +476,7 @@ def show_file_thumbnail(self, file_path):
self.setFixedHeight(self.height() + 110) # Increase height to accommodate larger file
self.file_containers[file_path] = container
-
+
def remove_image_thumbnail(self):
if self.image_label:
self.image_label.setParent(None)
diff --git a/llama_assistant/model_handler.py b/llama_assistant/model_handler.py
index 66c3ae6..b7b13b8 100644
--- a/llama_assistant/model_handler.py
+++ b/llama_assistant/model_handler.py
@@ -1,3 +1,4 @@
+import asyncio
from typing import List, Dict, Set, Optional
import time
from threading import Timer
@@ -12,6 +13,7 @@
from llama_assistant import config
from llama_assistant.agent import RAGAgent
+
class Model:
def __init__(
self,
@@ -54,19 +56,21 @@ def remove_supported_model(self, model_id: str):
if self.current_model_id == model_id:
self.unload_agent()
- def load_agent(self, model_id: str) -> Optional[Dict]:
+ def load_agent(self, model_id: str, generation_setting, rag_setting: Dict) -> Optional[Dict]:
self.refresh_supported_models()
if self.current_model_id == model_id and self.loaded_agent:
- return self.loaded_agent
+ if generation_setting["context_len"] == self.loaded_agent["model"].context_params.n_ctx:
+ self.loaded_agent["agent"].update_rag_setting(rag_setting)
+ self.loaded_agent["agent"].update_generation_setting(generation_setting)
+ return self.loaded_agent
+ # if no model is loaded or different model is loaded, or context_len is different, reinitialize the agent
self.unload_agent() # Unload the current model if any
model = next((m for m in self.supported_models if m.model_id == model_id), None)
if not model:
print(f"Model with ID {model_id} not found.")
return None
-
- print('Load agent =========')
if model.is_online():
if model.model_type == "text":
@@ -74,7 +78,7 @@ def load_agent(self, model_id: str) -> Optional[Dict]:
loaded_model = Llama.from_pretrained(
repo_id=model.repo_id,
filename=model.filename,
- n_ctx=2048,
+ n_ctx=generation_setting["context_len"],
)
elif model.model_type == "image":
if "moondream2" in model.model_id:
@@ -86,7 +90,7 @@ def load_agent(self, model_id: str) -> Optional[Dict]:
repo_id=model.repo_id,
filename=model.filename,
chat_handler=chat_handler,
- n_ctx=2048,
+ n_ctx=generation_setting["context_len"],
)
elif "MiniCPM" in model.model_id:
chat_handler = MiniCPMv26ChatHandler.from_pretrained(
@@ -97,7 +101,7 @@ def load_agent(self, model_id: str) -> Optional[Dict]:
repo_id=model.repo_id,
filename=model.filename,
chat_handler=chat_handler,
- n_ctx=2048,
+ n_ctx=generation_setting["context_len"],
)
elif "llava-v1.5" in model.model_id:
chat_handler = Llava15ChatHandler.from_pretrained(
@@ -108,7 +112,7 @@ def load_agent(self, model_id: str) -> Optional[Dict]:
repo_id=model.repo_id,
filename=model.filename,
chat_handler=chat_handler,
- n_ctx=2048,
+ n_ctx=generation_setting["context_len"],
)
elif "llava-v1.6" in model.model_id:
chat_handler = Llava16ChatHandler.from_pretrained(
@@ -119,7 +123,7 @@ def load_agent(self, model_id: str) -> Optional[Dict]:
repo_id=model.repo_id,
filename=model.filename,
chat_handler=chat_handler,
- n_ctx=2048,
+ n_ctx=generation_setting["context_len"],
)
else:
print(f"Unsupported model type: {model.model_type}")
@@ -129,19 +133,24 @@ def load_agent(self, model_id: str) -> Optional[Dict]:
print("load local model")
loaded_model = Llama(model_path=model.model_path)
+ print("Intializing agent ...")
+
agent = RAGAgent(
- config.embed_model_name,
+ generation_setting,
+ rag_setting,
llm=loaded_model,
)
self.loaded_agent = {
"model": loaded_model,
"agent": agent,
- "last_used": time.time()
+ "generation_setting": generation_setting,
+ "rag_setting": rag_setting,
+ "last_used": time.time(),
}
self.current_model_id = model_id
self._schedule_unload()
-
+
return self.loaded_agent
def unload_agent(self):
@@ -153,20 +162,26 @@ def unload_agent(self):
self.unload_timer.cancel()
self.unload_timer = None
- async def run_agent(self, agent: RAGAgent, message: str, lookup_files: Set, image: str, stream: bool):
- response = await agent.run(query_str=message, lookup_files=lookup_files, image=image, streaming=stream)
+ async def run_agent(
+ self, agent: RAGAgent, message: str, lookup_files: Set, image: str, stream: bool
+ ):
+ response = await agent.run(
+ query_str=message, lookup_files=lookup_files, image=image, streaming=stream
+ )
return response
-
+
def chat_completion(
self,
model_id: str,
+ generation_setting: Dict,
+ rag_setting: Dict,
message: str,
image: Optional[str] = None,
- lookup_files: Optional[Set[str] ] = set(),
+ lookup_files: Optional[Set[str]] = None,
stream: bool = False,
) -> str:
print("In chat_completion")
- agent_data = self.load_agent(model_id)
+ agent_data = self.load_agent(model_id, generation_setting, rag_setting)
agent = agent_data.get("agent")
if not agent_data:
return "Failed to load model"
@@ -174,24 +189,25 @@ def chat_completion(
agent_data["last_used"] = time.time()
self._schedule_unload()
- import asyncio
try:
loop = asyncio.get_running_loop()
- response = loop.run_until_complete(self.run_agent(agent, message, lookup_files, image, stream))
+ response = loop.run_until_complete(
+ self.run_agent(agent, message, lookup_files, image, stream)
+ )
except RuntimeError: # no running event loop
response = asyncio.run(self.run_agent(agent, message, lookup_files, image, stream))
-
+
return response
-
+
def update_chat_history(self, message: str, role: str):
agent = self.loaded_agent.get("agent")
if agent:
- agent.chat_history.append({"role": role, "content": message})
+ agent.chat_history.add_message({"role": role, "content": message})
def clear_chat_history(self):
agent = self.loaded_agent.get("agent")
if agent:
- agent.chat_history = []
+ agent.chat_history.clear()
def _schedule_unload(self):
if self.unload_timer:
diff --git a/llama_assistant/processing_thread.py b/llama_assistant/processing_thread.py
index 7cc529a..80bce14 100644
--- a/llama_assistant/processing_thread.py
+++ b/llama_assistant/processing_thread.py
@@ -1,4 +1,4 @@
-from typing import Set
+from typing import Set, Optional, Dict
from PyQt5.QtCore import (
QThread,
pyqtSignal,
@@ -10,17 +10,32 @@ class ProcessingThread(QThread):
update_signal = pyqtSignal(str)
finished_signal = pyqtSignal()
- def __init__(self, model, prompt, lookup_files=set(), image=None):
+ def __init__(
+ self,
+ model: str,
+ generation_setting: Dict,
+ rag_setting: Dict,
+ prompt: str,
+ lookup_files: Optional[Set[str]] = None,
+ image: str = None,
+ ):
super().__init__()
self.model = model
+ self.generation_setting = generation_setting
+ self.rag_setting = rag_setting
self.prompt = prompt
self.image = image
self.lookup_files = lookup_files
def run(self):
output = model_handler.chat_completion(
- self.model, self.prompt, image=self.image,
- lookup_files=self.lookup_files, stream=True
+ self.model,
+ self.generation_setting,
+ self.rag_setting,
+ self.prompt,
+ image=self.image,
+ lookup_files=self.lookup_files,
+ stream=True,
)
full_response_str = ""
for chunk in output:
@@ -34,7 +49,6 @@ def run(self):
model_handler.update_chat_history(full_response_str, "assistant")
self.finished_signal.emit()
-
def clear_chat_history(self):
model_handler.clear_chat_history()
- self.finished_signal.emit()
\ No newline at end of file
+ self.finished_signal.emit()
diff --git a/llama_assistant/setting_dialog.py b/llama_assistant/setting_dialog.py
index bab4d41..e19c747 100644
--- a/llama_assistant/setting_dialog.py
+++ b/llama_assistant/setting_dialog.py
@@ -44,6 +44,8 @@ def __init__(self, parent=None):
# Voice Activation Settings Group
self.create_voice_activation_settings_group()
+ self.create_rag_settings_group()
+
# Create a horizontal layout for the save button
button_layout = QHBoxLayout()
self.save_button = QPushButton("Save")
@@ -122,6 +124,34 @@ def create_model_settings_group(self):
multimodal_model_layout.addStretch()
layout.addLayout(multimodal_model_layout)
+ context_len_layout = QHBoxLayout()
+ context_len_label = QLabel("Context Length:")
+ self.context_len_input = QLineEdit()
+ context_len_layout.addWidget(context_len_label)
+ context_len_layout.addWidget(self.context_len_input)
+ layout.addLayout(context_len_layout)
+
+ temperature_layout = QHBoxLayout()
+ temperature_label = QLabel("Temperature:")
+ self.temperature_input = QLineEdit()
+ temperature_layout.addWidget(temperature_label)
+ temperature_layout.addWidget(self.temperature_input)
+ layout.addLayout(temperature_layout)
+
+ top_p_layout = QHBoxLayout()
+ top_p_label = QLabel("Top p:")
+ self.top_p_input = QLineEdit()
+ top_p_layout.addWidget(top_p_label)
+ top_p_layout.addWidget(self.top_p_input)
+ layout.addLayout(top_p_layout)
+
+ top_k_layout = QHBoxLayout()
+ top_k_label = QLabel("Top k:")
+ self.top_k_input = QLineEdit()
+ top_k_layout.addWidget(top_k_label)
+ top_k_layout.addWidget(self.top_k_input)
+ layout.addLayout(top_k_layout)
+
self.manage_custom_models_button = QPushButton("Manage Custom Models")
self.manage_custom_models_button.clicked.connect(self.open_custom_models_dialog)
layout.addWidget(self.manage_custom_models_button)
@@ -143,6 +173,48 @@ def create_voice_activation_settings_group(self):
group_box.setLayout(layout)
self.main_layout.addWidget(group_box)
+ def create_rag_settings_group(self):
+ group_box = QGroupBox("RAG Settings")
+ layout = QVBoxLayout()
+
+ embed_model_layout = QHBoxLayout()
+ embed_model_label = QLabel("Embed Model Name:")
+ self.embed_model_input = QLineEdit()
+ embed_model_layout.addWidget(embed_model_label)
+ embed_model_layout.addWidget(self.embed_model_input)
+ layout.addLayout(embed_model_layout)
+
+ chunk_size_layout = QHBoxLayout()
+ chunk_size_label = QLabel("Chunk Size:")
+ self.chunk_size_input = QLineEdit()
+ chunk_size_layout.addWidget(chunk_size_label)
+ chunk_size_layout.addWidget(self.chunk_size_input)
+ layout.addLayout(chunk_size_layout)
+
+ chunk_overlap_layout = QHBoxLayout()
+ chunk_overlap_label = QLabel("Chunk Overlap:")
+ self.chunk_overlap_input = QLineEdit()
+ chunk_overlap_layout.addWidget(chunk_overlap_label)
+ chunk_overlap_layout.addWidget(self.chunk_overlap_input)
+ layout.addLayout(chunk_overlap_layout)
+
+ max_retrieval_top_k_layout = QHBoxLayout()
+ max_retrieval_top_k_label = QLabel("Max Retrieval Top k:")
+ self.max_retrieval_top_k_input = QLineEdit()
+ max_retrieval_top_k_layout.addWidget(max_retrieval_top_k_label)
+ max_retrieval_top_k_layout.addWidget(self.max_retrieval_top_k_input)
+ layout.addLayout(max_retrieval_top_k_layout)
+
+ similarity_threshold_layout = QHBoxLayout()
+ similarity_threshold_label = QLabel("Similarity Threshold:")
+ self.similarity_threshold_input = QLineEdit()
+ similarity_threshold_layout.addWidget(similarity_threshold_label)
+ similarity_threshold_layout.addWidget(self.similarity_threshold_input)
+ layout.addLayout(similarity_threshold_layout)
+
+ group_box.setLayout(layout)
+ self.main_layout.addWidget(group_box)
+
def accept(self):
self.save_settings()
self.settingsSaved.emit()
@@ -186,6 +258,33 @@ def load_settings(self):
self.hey_llama_chat_checkbox.setChecked(settings.get("hey_llama_chat", False))
self.hey_llama_mic_checkbox.setChecked(settings.get("hey_llama_mic", False))
self.update_hey_llama_mic_state(settings.get("hey_llama_chat", False))
+
+ # Load new settings
+ if "generation" not in settings:
+ settings["generation"] = {}
+ if "rag" not in settings:
+ settings["rag"] = {}
+ self.embed_model_input.setText(
+ settings["rag"].get("embed_model_name", config.embed_model_name)
+ )
+ self.chunk_size_input.setText(str(settings["rag"].get("chunk_size", config.chunk_size)))
+ self.chunk_overlap_input.setText(
+ str(settings["rag"].get("chunk_overlap", config.chunk_overlap))
+ )
+ self.max_retrieval_top_k_input.setText(
+ str(settings["rag"].get("max_retrieval_top_k", config.max_retrieval_top_k))
+ )
+ self.similarity_threshold_input.setText(
+ str(settings["rag"].get("similarity_threshold", config.similarity_threshold))
+ )
+ self.context_len_input.setText(
+ str(settings["generation"].get("context_len", config.context_len))
+ )
+ self.temperature_input.setText(
+ str(settings["generation"].get("temperature", config.temperature))
+ )
+ self.top_p_input.setText(str(settings["generation"].get("top_p", config.top_p)))
+ self.top_k_input.setText(str(settings["generation"].get("top_k", config.top_k)))
else:
self.color = QColor("#1E1E1E")
self.shortcut_recorder.setText("++")
@@ -199,6 +298,19 @@ def get_settings(self):
"multimodal_model": self.multimodal_model_combo.currentText(),
"hey_llama_chat": self.hey_llama_chat_checkbox.isChecked(),
"hey_llama_mic": self.hey_llama_mic_checkbox.isChecked(),
+ "generation": {
+ "context_len": int(self.context_len_input.text()),
+ "temperature": float(self.temperature_input.text()),
+ "top_p": float(self.top_p_input.text()),
+ "top_k": int(self.top_k_input.text()),
+ },
+ "rag": {
+ "embed_model_name": self.embed_model_input.text(),
+ "chunk_size": int(self.chunk_size_input.text()),
+ "chunk_overlap": int(self.chunk_overlap_input.text()),
+ "max_retrieval_top_k": int(self.max_retrieval_top_k_input.text()),
+ "similarity_threshold": float(self.similarity_threshold_input.text()),
+ },
}
def save_settings(self, settings=None):
diff --git a/llama_assistant/ui_manager.py b/llama_assistant/ui_manager.py
index d14cdd4..651f155 100644
--- a/llama_assistant/ui_manager.py
+++ b/llama_assistant/ui_manager.py
@@ -26,16 +26,18 @@
microphone_icon_svg,
)
+
class CustomQTextBrowser(QTextBrowser):
def __init__(self, parent):
super().__init__(parent)
-
+
# Apply stylesheet specific to generated text content
- self.document().setDefaultStyleSheet("""
+ self.document().setDefaultStyleSheet(
+ """
p {
color: #FFFFFF;
font-size: 16px;
- line-height: 1.3;
+ line-height: 1.3;
}
li {
line-height: 1.3;
@@ -49,7 +51,9 @@ def __init__(self, parent):
font-family: Consolas, "Courier New", monospace;
overflow: hidden;
}
- """)
+ """
+ )
+
class UIManager:
def __init__(self, parent):
@@ -177,7 +181,7 @@ def init_ui(self):
self.scroll_area = QScrollArea(self.parent)
self.scroll_area.setWidgetResizable(True) # Allow the widget inside to resize
- self.scroll_area.setMinimumHeight(400)
+ self.scroll_area.setMinimumHeight(400)
self.scroll_area.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff)
self.scroll_area.setStyleSheet(
"""
@@ -206,7 +210,7 @@ def init_ui(self):
self.scroll_area.setWidget(self.chat_box)
self.scroll_area.hide() # Hide the scroll area initially
-
+
# Ensure the scroll area can expand fully in the layout
self.scroll_area.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
main_layout.addWidget(self.scroll_area)
diff --git a/pyproject.toml b/pyproject.toml
index 7dc78eb..8c63013 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -25,6 +25,7 @@ classifiers = [
"Programming Language :: Python :: 3.12",
]
dependencies = [
+ "numpy",
"ffmpeg-python",
"PyQt5",
"markdown",
@@ -34,6 +35,11 @@ dependencies = [
"huggingface_hub",
"openwakeword",
"whispercpp",
+ "llama-index-core",
+ "llama-index-readers-file",
+ "llama-index-embeddings-huggingface",
+ "docx2txt",
+ "mistune"
]
dynamic = []
diff --git a/requirements.txt b/requirements.txt
index db7c2cd..3768d2e 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -9,7 +9,8 @@ openwakeword==0.6.0
pyinstaller==6.10.0
ffmpeg-python==0.2.0
llama-index-core==0.12.0
+llama-index-readers-file==1.2.2
llama-index-embeddings-huggingface==0.4.0
docx2txt==0.8
mistune==3.0.2
-git+https://github.com/stlukey/whispercpp.py
\ No newline at end of file
+git+https://github.com/stlukey/whispercpp.py