From 8b85ddaed282751dc5eb7bf46678cba72c29c59f Mon Sep 17 00:00:00 2001 From: Nguyen The Nam Date: Wed, 27 Nov 2024 21:43:52 +0100 Subject: [PATCH] Manage context length; Allow configuration for RAG and generation --- .github/workflows/publish-to-pypi.yml | 2 +- llama_assistant/agent.py | 204 ++++++++++++++++++++----- llama_assistant/config.py | 11 +- llama_assistant/llama_assistant_app.py | 36 +++-- llama_assistant/model_handler.py | 62 +++++--- llama_assistant/processing_thread.py | 26 +++- llama_assistant/setting_dialog.py | 112 ++++++++++++++ llama_assistant/ui_manager.py | 16 +- pyproject.toml | 6 + requirements.txt | 3 +- 10 files changed, 388 insertions(+), 90 deletions(-) diff --git a/.github/workflows/publish-to-pypi.yml b/.github/workflows/publish-to-pypi.yml index 34558a8..2999842 100644 --- a/.github/workflows/publish-to-pypi.yml +++ b/.github/workflows/publish-to-pypi.yml @@ -52,4 +52,4 @@ jobs: path: dist/ - name: Publish release distributions to PyPI - uses: pypa/gh-action-pypi-publish@release/v1 \ No newline at end of file + uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/llama_assistant/agent.py b/llama_assistant/agent.py index ebb65cf..21ec19f 100644 --- a/llama_assistant/agent.py +++ b/llama_assistant/agent.py @@ -1,5 +1,4 @@ -from typing import List, Set, Optional -from collections import defaultdict +from typing import List, Set, Optional, Dict from llama_cpp import Llama from llama_index.core import VectorStoreIndex @@ -11,6 +10,7 @@ from llama_index.core.workflow import Event, StartEvent, StopEvent, Workflow, step + def convert_message_list_to_str(messages): chat_history_str = "" for message in messages: @@ -18,23 +18,63 @@ def convert_message_list_to_str(messages): chat_history_str += message["role"] + ": " + message["content"] + "\n" else: chat_history_str += message["role"] + ": " + message["content"]["text"] + "\n" - + return chat_history_str + class SetupEvent(Event): pass + class CondenseQueryEvent(Event): condensed_query_str: str + class RetrievalEvent(Event): nodes: List[NodeWithScore] + +class ChatHistory: + def __init__(self, max_history_size: int): + self.max_history_size = max_history_size + self.total_size = 0 + self.chat_history = [] + + def add_message(self, message: dict): + if "content" in message and type(message["content"]) is list: + # multimodal model's message format + new_msg_size = len(message["content"][0]["text"].split()) + else: + # text-only model's message format + new_msg_size = len(message["content"].split()) + + self.total_size += new_msg_size + self.chat_history.append(message) + while self.total_size > self.max_history_size: + oldest_msg = self.chat_history.pop(0) + if "content" in oldest_msg and type(oldest_msg["content"]) is list: + # multimodal model's message format + len_oldest_msg = len(oldest_msg["content"][0]["text"].split()) + else: + # text-only model's message format + len_oldest_msg = len(oldest_msg["content"].split()) + self.total_size -= len_oldest_msg + + def get_chat_history(self): + return self.chat_history + + def clear(self): + self.chat_history = [] + self.total_size = 0 + + def __len__(self): + return len(self.chat_history) + + class RAGAgent(Workflow): SUMMARY_TEMPLATE = ( - "Given the chat history:\n" + "Given the converstation:\n" "'''{chat_history_str}'''\n\n" - "And the user asked the following question:{query_str}\n" "Rewrite to a standalone question:\n" ) @@ -45,49 +85,111 @@ class RAGAgent(Workflow): "-----\n" "Please write a response to the following question, using the above information if relevant:\n" "{query_str}\n" - ) - def __init__(self, embed_model_name: str, llm: Llama, timeout: int = 60, verbose: bool = False): + ) + SYSTEM_PROMPT = {"role": "system", "content": "Generate short and simple response:"} + + def __init__( + self, + generation_setting: Dict, + rag_setting: Dict, + llm: Llama, + timeout: int = 60, + verbose: bool = False, + ): super().__init__(timeout=timeout, verbose=verbose) - self.k = 3 + self.generation_setting = generation_setting + self.context_len = generation_setting["context_len"] + # we want the retrieved context = our set value but not more than the context_len + self.retrieval_top_k = (self.context_len - rag_setting["chunk_size"]) // rag_setting[ + "chunk_overlap" + ] + self.retrieval_top_k = min(max(1, self.retrieval_top_k), rag_setting["max_retrieval_top_k"]) self.search_index = None self.retriever = None - self.chat_history = [] + # 1 token ~ 3/4 words, so we multiply the context_len by 0.7 to get the max number of words + # why not 0.75? because we want to be a bit more conservative + self.chat_history = ChatHistory(max_history_size=self.context_len * 0.7) self.lookup_files = set() - self.embed_model = HuggingFaceEmbedding(model_name=embed_model_name) + self.embed_model = HuggingFaceEmbedding(model_name=rag_setting["embed_model_name"]) Settings.embed_model = self.embed_model - self.node_processor = SimilarityPostprocessor(similarity_cutoff=0.3) + Settings.chunk_size = rag_setting["chunk_size"] + Settings.chunk_overlap = rag_setting["chunk_overlap"] + self.node_processor = SimilarityPostprocessor( + similarity_cutoff=rag_setting["similarity_threshold"] + ) self.llm = llm - def udpate_index(self, files: Optional[Set[str] ] = set()): + def update_index(self, files: Optional[Set[str]] = set()): if not files: print("No lookup files provided, clearing index...") self.retriever = None self.search_index = None return - + print("Indexing documents...") - documents = SimpleDirectoryReader(input_files=files, recursive=True).load_data(show_progress=True, num_workers=1) - page_num_tracker = defaultdict(int) - for doc in documents: - key = doc.metadata['file_path'] - doc.metadata['page_index'] = page_num_tracker[key] - page_num_tracker[key] += 1 + documents = SimpleDirectoryReader(input_files=files, recursive=True).load_data( + show_progress=True, num_workers=1 + ) if self.search_index is None: - self.search_index = VectorStoreIndex.from_documents(documents, embed_model=self.embed_model) + self.search_index = VectorStoreIndex.from_documents( + documents, embed_model=self.embed_model + ) else: for doc in documents: - self.search_index.insert(doc) # Add the new document to the index + self.search_index.insert(doc) # Add the new document to the index + + self.retriever = self.search_index.as_retriever(similarity_top_k=self.retrieval_top_k) + + def update_rag_setting(self, rag_setting: Dict): + if self.node_processor.similarity_cutoff != rag_setting["similarity_threshold"]: + self.node_processor = SimilarityPostprocessor( + similarity_cutoff=rag_setting["similarity_threshold"] + ) + + new_top_k = (self.context_len - rag_setting["chunk_size"]) // rag_setting["chunk_overlap"] + new_top_k = min(max(1, new_top_k), rag_setting["max_retrieval_top_k"]) + if self.retrieval_top_k != new_top_k: + self.retrieval_top_k = new_top_k + if self.retriever: + self.retriever = self.search_index.as_retriever( + similarity_top_k=self.retrieval_top_k + ) + + if ( + self.embed_model.model_name != rag_setting["embed_model_name"] + or Settings.chunk_size != rag_setting["chunk_size"] + or Settings.chunk_overlap != rag_setting["chunk_overlap"] + ): + self.embed_model = HuggingFaceEmbedding(model_name=rag_setting["embed_model_name"]) + Settings.embed_model = self.embed_model + Settings.chunk_size = rag_setting["chunk_size"] + Settings.chunk_overlap = rag_setting["chunk_overlap"] + + # reindex since those are the settings that affect the index + if self.lookup_files: + print("Re-indexing documents since the rag settings have changed...") + documents = SimpleDirectoryReader( + input_files=self.lookup_files, recursive=True + ).load_data(show_progress=True, num_workers=1) + self.search_index = VectorStoreIndex.from_documents( + documents, embed_model=self.embed_model + ) + self.retriever = self.search_index.as_retriever( + similarity_top_k=self.retrieval_top_k + ) + + def update_generation_setting(self, generation_setting): + self.generation_setting = generation_setting + self.context_len = generation_setting["context_len"] - self.retriever = self.search_index.as_retriever(similarity_top_k=self.k) - @step async def setup(self, ctx: Context, ev: StartEvent) -> SetupEvent: - # set frequetly used variables to context + # set frequently used variables to context query_str = ev.query_str image = ev.image - lookup_files = ev.lookup_files + lookup_files = ev.lookup_files if ev.lookup_files else set() streaming = ev.streaming await ctx.set("query_str", query_str) await ctx.set("image", image) @@ -96,7 +198,7 @@ async def setup(self, ctx: Context, ev: StartEvent) -> SetupEvent: # update index if needed if lookup_files != self.lookup_files: print("Different lookup files, updating index...") - self.udpate_index(lookup_files) + self.update_index(lookup_files) self.lookup_files = lookup_files.copy() @@ -105,36 +207,48 @@ async def setup(self, ctx: Context, ev: StartEvent) -> SetupEvent: @step async def condense_history_to_query(self, ctx: Context, ev: SetupEvent) -> CondenseQueryEvent: """ - Condense the chat history and the query into a single query. Only used for retrieval. + Condense the chat history and the query into a single query. Only used for retrieval. """ query_str = await ctx.get("query_str") formated_query = "" - + if len(self.chat_history) > 0 or self.retriever is not None: - chat_history_str = convert_message_list_to_str(self.chat_history) - formated_query = self.SUMMARY_TEMPLATE.format(chat_history_str=chat_history_str, query_str=query_str) + self.chat_history.add_message({"role": "user", "content": query_str}) + chat_history_str = convert_message_list_to_str(self.chat_history.get_chat_history()) + # use llm to summarize the chat history into a single query, + # which is used to retrieve context from the documents + formated_query = self.SUMMARY_TEMPLATE.format(chat_history_str=chat_history_str) + history_summary = self.llm.create_chat_completion( - messages=[{"role": "user", "content": formated_query}], stream=False + messages=[{"role": "user", "content": formated_query}], + stream=False, + top_k=self.generation_setting["top_k"], + top_p=self.generation_setting["top_p"], + temperature=self.generation_setting["temperature"], )["choices"][0]["message"]["content"] + condensed_query = "Context:\n" + history_summary + "\nQuestion: " + query_str + # remove the last user message from the chat history + # later we will add the query with retrieved context (RAG) + self.chat_history.chat_history.pop() else: # if there is no history or no need for retrieval, return the query as is condensed_query = query_str return CondenseQueryEvent(condensed_query_str=condensed_query) - + @step async def retrieve(self, ctx: Context, ev: CondenseQueryEvent) -> RetrievalEvent: - # retrieve from context if not self.retriever: return RetrievalEvent(nodes=[]) + # retrieve from dropped documents condensed_query_str = ev.condensed_query_str nodes = await self.retriever.aretrieve(condensed_query_str) nodes = self.node_processor.postprocess_nodes(nodes) return RetrievalEvent(nodes=nodes) - + def _prepare_query_with_context( self, query_str: str, @@ -144,7 +258,7 @@ def _prepare_query_with_context( if len(nodes) == 0: return query_str - + for idx, node in enumerate(nodes): node_text = node.get_content(metadata_mode="llm") node_context += f"\n{node_text}\n\n" @@ -152,11 +266,11 @@ def _prepare_query_with_context( formatted_query = self.CONTEXT_PROMPT_TEMPLATE.format( node_context=node_context, query_str=query_str ) - + return formatted_query @step - async def llm_response(self, ctx: Context, retrieval_ev: RetrievalEvent) -> StopEvent: + async def llm_response(self, ctx: Context, retrieval_ev: RetrievalEvent) -> StopEvent: nodes = retrieval_ev.nodes query_str = await ctx.get("query_str") image = await ctx.get("image") @@ -174,12 +288,18 @@ async def llm_response(self, ctx: Context, retrieval_ev: RetrievalEvent) -> Sto else: formated_message = {"role": "user", "content": query_with_ctx} + self.chat_history.add_message(formated_message) + response = self.llm.create_chat_completion( - messages=self.chat_history+[formated_message], stream=streaming + messages=[self.SYSTEM_PROMPT] + self.chat_history.get_chat_history(), + stream=streaming, + top_k=self.generation_setting["top_k"], + top_p=self.generation_setting["top_p"], + temperature=self.generation_setting["temperature"], ) - self.chat_history.append({"role": "user", "content": query_str}) + # remove the query with context from the chat history, + self.chat_history.chat_history.pop() + # add the short query (without context) instead -> not to clutter the chat history + self.chat_history.add_message({"role": "user", "content": query_str}) return StopEvent(result=response) - - - \ No newline at end of file diff --git a/llama_assistant/config.py b/llama_assistant/config.py index 6ad8f41..9d6278b 100644 --- a/llama_assistant/config.py +++ b/llama_assistant/config.py @@ -111,6 +111,11 @@ }, ] +# generation setting +context_len = 2048 +top_k = 40 +top_p = 0.95 +temperature = 0.2 home_dir = Path.home() llama_assistant_dir = home_dir / "llama_assistant" @@ -119,8 +124,12 @@ settings_file = llama_assistant_dir / "settings.json" document_icon = "llama_assistant/resources/document_icon.png" -# for RAG pipeline +# RAG setting embed_model_name = "BAAI/bge-base-en-v1.5" +chunk_size = 256 +chunk_overlap = 128 +max_retrieval_top_k = 3 +similarity_threshold = 0.6 if custom_models_file.exists(): with open(custom_models_file, "r") as f: diff --git a/llama_assistant/llama_assistant_app.py b/llama_assistant/llama_assistant_app.py index 956b47a..912bf7a 100644 --- a/llama_assistant/llama_assistant_app.py +++ b/llama_assistant/llama_assistant_app.py @@ -97,6 +97,8 @@ def load_settings(self): self.deinit_wake_word_detector() self.current_text_model = self.settings.get("text_model") self.current_multimodal_model = self.settings.get("multimodal_model") + self.generation_setting = self.settings.get("generation") + self.rag_setting = self.settings.get("rag") def setup_global_shortcut(self): try: @@ -177,7 +179,7 @@ def on_submit(self): for file_path in self.dropped_files: self.remove_file_thumbnail(self.file_containers[file_path], file_path) - + return self.last_response = "" @@ -189,7 +191,6 @@ def on_submit(self): else: QTimer.singleShot(100, lambda: self.process_text(message, self.dropped_files, "chat")) - def on_task_button_clicked(self): button = self.sender() task = button.text() @@ -203,7 +204,7 @@ def process_text(self, message, file_paths, task="chat"): self.clear_chat() self.show_chat_box() if task == "chat": - prompt = message + " \n" + "Generate a short and simple response." + prompt = message elif task == "Summarize": prompt = f"Summarize the following text: {message}" elif task == "Rephrase": @@ -220,7 +221,13 @@ def process_text(self, message, file_paths, task="chat"): self.start_cursor_pos = self.ui_manager.chat_box.textCursor().position() - self.processing_thread = ProcessingThread(self.current_text_model, prompt, lookup_files=file_paths) + self.processing_thread = ProcessingThread( + self.current_text_model, + self.generation_setting, + self.rag_setting, + prompt, + lookup_files=file_paths, + ) self.processing_thread.update_signal.connect(self.update_chat_box) self.processing_thread.finished_signal.connect(self.on_processing_finished) self.processing_thread.start() @@ -238,7 +245,12 @@ def process_image_with_prompt(self, image_path, file_paths, prompt): image = image_to_base64_data_uri(image_path) self.processing_thread = ProcessingThread( - self.current_multimodal_model, prompt, image=image, lookup_files=file_paths + self.current_multimodal_model, + self.generation_setting, + self.rag_setting, + prompt, + image=image, + lookup_files=file_paths, ) self.processing_thread.update_signal.connect(self.update_chat_box) self.processing_thread.finished_signal.connect(self.on_processing_finished) @@ -252,7 +264,9 @@ def update_chat_box(self, text): markdown_response = markdown_response.replace("

", "

").replace("

", "") markdown_response += "
" cursor = self.ui_manager.chat_box.textCursor() - cursor.setPosition(self.start_cursor_pos) # regenerate the updated text from the start position + cursor.setPosition( + self.start_cursor_pos + ) # regenerate the updated text from the start position # Select all text from the start_pos to the end cursor.movePosition(QTextCursor.End, QTextCursor.KeepAnchor) # Remove the selected text @@ -310,11 +324,12 @@ def dropEvent(self, event: QDropEvent): elif file_path.lower().endswith((".pdf", "doc", ".docx", ".txt")): if file_path not in self.dropped_files: self.dropped_files.add(file_path) - self.ui_manager.input_field.setPlaceholderText("Enter a prompt for the document...") + self.ui_manager.input_field.setPlaceholderText( + "Enter a prompt for the document..." + ) self.show_file_thumbnail(file_path) else: print(f"File {file_path} already added") - def remove_file_thumbnail(self, file_label, file_path): file_label.setParent(None) @@ -402,7 +417,7 @@ def show_file_thumbnail(self, file_path): # Create a QLabel for the text text_label = QLabel(file_path.split("/")[-1], container) - # set text background color to white, text size to 5px + # set text background color to white, text size to 5px # and rounded corners, vertical alignment to top text_label.setStyleSheet( """ @@ -444,6 +459,7 @@ def show_file_thumbnail(self, file_path): # Load and set the pixmap import os + print("Icon path:", str(config.document_icon), os.path.exists(str(config.document_icon))) pixmap = QPixmap(str(config.document_icon)) scaled_pixmap = pixmap.scaled( @@ -460,7 +476,7 @@ def show_file_thumbnail(self, file_path): self.setFixedHeight(self.height() + 110) # Increase height to accommodate larger file self.file_containers[file_path] = container - + def remove_image_thumbnail(self): if self.image_label: self.image_label.setParent(None) diff --git a/llama_assistant/model_handler.py b/llama_assistant/model_handler.py index 66c3ae6..b7b13b8 100644 --- a/llama_assistant/model_handler.py +++ b/llama_assistant/model_handler.py @@ -1,3 +1,4 @@ +import asyncio from typing import List, Dict, Set, Optional import time from threading import Timer @@ -12,6 +13,7 @@ from llama_assistant import config from llama_assistant.agent import RAGAgent + class Model: def __init__( self, @@ -54,19 +56,21 @@ def remove_supported_model(self, model_id: str): if self.current_model_id == model_id: self.unload_agent() - def load_agent(self, model_id: str) -> Optional[Dict]: + def load_agent(self, model_id: str, generation_setting, rag_setting: Dict) -> Optional[Dict]: self.refresh_supported_models() if self.current_model_id == model_id and self.loaded_agent: - return self.loaded_agent + if generation_setting["context_len"] == self.loaded_agent["model"].context_params.n_ctx: + self.loaded_agent["agent"].update_rag_setting(rag_setting) + self.loaded_agent["agent"].update_generation_setting(generation_setting) + return self.loaded_agent + # if no model is loaded or different model is loaded, or context_len is different, reinitialize the agent self.unload_agent() # Unload the current model if any model = next((m for m in self.supported_models if m.model_id == model_id), None) if not model: print(f"Model with ID {model_id} not found.") return None - - print('Load agent =========') if model.is_online(): if model.model_type == "text": @@ -74,7 +78,7 @@ def load_agent(self, model_id: str) -> Optional[Dict]: loaded_model = Llama.from_pretrained( repo_id=model.repo_id, filename=model.filename, - n_ctx=2048, + n_ctx=generation_setting["context_len"], ) elif model.model_type == "image": if "moondream2" in model.model_id: @@ -86,7 +90,7 @@ def load_agent(self, model_id: str) -> Optional[Dict]: repo_id=model.repo_id, filename=model.filename, chat_handler=chat_handler, - n_ctx=2048, + n_ctx=generation_setting["context_len"], ) elif "MiniCPM" in model.model_id: chat_handler = MiniCPMv26ChatHandler.from_pretrained( @@ -97,7 +101,7 @@ def load_agent(self, model_id: str) -> Optional[Dict]: repo_id=model.repo_id, filename=model.filename, chat_handler=chat_handler, - n_ctx=2048, + n_ctx=generation_setting["context_len"], ) elif "llava-v1.5" in model.model_id: chat_handler = Llava15ChatHandler.from_pretrained( @@ -108,7 +112,7 @@ def load_agent(self, model_id: str) -> Optional[Dict]: repo_id=model.repo_id, filename=model.filename, chat_handler=chat_handler, - n_ctx=2048, + n_ctx=generation_setting["context_len"], ) elif "llava-v1.6" in model.model_id: chat_handler = Llava16ChatHandler.from_pretrained( @@ -119,7 +123,7 @@ def load_agent(self, model_id: str) -> Optional[Dict]: repo_id=model.repo_id, filename=model.filename, chat_handler=chat_handler, - n_ctx=2048, + n_ctx=generation_setting["context_len"], ) else: print(f"Unsupported model type: {model.model_type}") @@ -129,19 +133,24 @@ def load_agent(self, model_id: str) -> Optional[Dict]: print("load local model") loaded_model = Llama(model_path=model.model_path) + print("Intializing agent ...") + agent = RAGAgent( - config.embed_model_name, + generation_setting, + rag_setting, llm=loaded_model, ) self.loaded_agent = { "model": loaded_model, "agent": agent, - "last_used": time.time() + "generation_setting": generation_setting, + "rag_setting": rag_setting, + "last_used": time.time(), } self.current_model_id = model_id self._schedule_unload() - + return self.loaded_agent def unload_agent(self): @@ -153,20 +162,26 @@ def unload_agent(self): self.unload_timer.cancel() self.unload_timer = None - async def run_agent(self, agent: RAGAgent, message: str, lookup_files: Set, image: str, stream: bool): - response = await agent.run(query_str=message, lookup_files=lookup_files, image=image, streaming=stream) + async def run_agent( + self, agent: RAGAgent, message: str, lookup_files: Set, image: str, stream: bool + ): + response = await agent.run( + query_str=message, lookup_files=lookup_files, image=image, streaming=stream + ) return response - + def chat_completion( self, model_id: str, + generation_setting: Dict, + rag_setting: Dict, message: str, image: Optional[str] = None, - lookup_files: Optional[Set[str] ] = set(), + lookup_files: Optional[Set[str]] = None, stream: bool = False, ) -> str: print("In chat_completion") - agent_data = self.load_agent(model_id) + agent_data = self.load_agent(model_id, generation_setting, rag_setting) agent = agent_data.get("agent") if not agent_data: return "Failed to load model" @@ -174,24 +189,25 @@ def chat_completion( agent_data["last_used"] = time.time() self._schedule_unload() - import asyncio try: loop = asyncio.get_running_loop() - response = loop.run_until_complete(self.run_agent(agent, message, lookup_files, image, stream)) + response = loop.run_until_complete( + self.run_agent(agent, message, lookup_files, image, stream) + ) except RuntimeError: # no running event loop response = asyncio.run(self.run_agent(agent, message, lookup_files, image, stream)) - + return response - + def update_chat_history(self, message: str, role: str): agent = self.loaded_agent.get("agent") if agent: - agent.chat_history.append({"role": role, "content": message}) + agent.chat_history.add_message({"role": role, "content": message}) def clear_chat_history(self): agent = self.loaded_agent.get("agent") if agent: - agent.chat_history = [] + agent.chat_history.clear() def _schedule_unload(self): if self.unload_timer: diff --git a/llama_assistant/processing_thread.py b/llama_assistant/processing_thread.py index 7cc529a..80bce14 100644 --- a/llama_assistant/processing_thread.py +++ b/llama_assistant/processing_thread.py @@ -1,4 +1,4 @@ -from typing import Set +from typing import Set, Optional, Dict from PyQt5.QtCore import ( QThread, pyqtSignal, @@ -10,17 +10,32 @@ class ProcessingThread(QThread): update_signal = pyqtSignal(str) finished_signal = pyqtSignal() - def __init__(self, model, prompt, lookup_files=set(), image=None): + def __init__( + self, + model: str, + generation_setting: Dict, + rag_setting: Dict, + prompt: str, + lookup_files: Optional[Set[str]] = None, + image: str = None, + ): super().__init__() self.model = model + self.generation_setting = generation_setting + self.rag_setting = rag_setting self.prompt = prompt self.image = image self.lookup_files = lookup_files def run(self): output = model_handler.chat_completion( - self.model, self.prompt, image=self.image, - lookup_files=self.lookup_files, stream=True + self.model, + self.generation_setting, + self.rag_setting, + self.prompt, + image=self.image, + lookup_files=self.lookup_files, + stream=True, ) full_response_str = "" for chunk in output: @@ -34,7 +49,6 @@ def run(self): model_handler.update_chat_history(full_response_str, "assistant") self.finished_signal.emit() - def clear_chat_history(self): model_handler.clear_chat_history() - self.finished_signal.emit() \ No newline at end of file + self.finished_signal.emit() diff --git a/llama_assistant/setting_dialog.py b/llama_assistant/setting_dialog.py index bab4d41..e19c747 100644 --- a/llama_assistant/setting_dialog.py +++ b/llama_assistant/setting_dialog.py @@ -44,6 +44,8 @@ def __init__(self, parent=None): # Voice Activation Settings Group self.create_voice_activation_settings_group() + self.create_rag_settings_group() + # Create a horizontal layout for the save button button_layout = QHBoxLayout() self.save_button = QPushButton("Save") @@ -122,6 +124,34 @@ def create_model_settings_group(self): multimodal_model_layout.addStretch() layout.addLayout(multimodal_model_layout) + context_len_layout = QHBoxLayout() + context_len_label = QLabel("Context Length:") + self.context_len_input = QLineEdit() + context_len_layout.addWidget(context_len_label) + context_len_layout.addWidget(self.context_len_input) + layout.addLayout(context_len_layout) + + temperature_layout = QHBoxLayout() + temperature_label = QLabel("Temperature:") + self.temperature_input = QLineEdit() + temperature_layout.addWidget(temperature_label) + temperature_layout.addWidget(self.temperature_input) + layout.addLayout(temperature_layout) + + top_p_layout = QHBoxLayout() + top_p_label = QLabel("Top p:") + self.top_p_input = QLineEdit() + top_p_layout.addWidget(top_p_label) + top_p_layout.addWidget(self.top_p_input) + layout.addLayout(top_p_layout) + + top_k_layout = QHBoxLayout() + top_k_label = QLabel("Top k:") + self.top_k_input = QLineEdit() + top_k_layout.addWidget(top_k_label) + top_k_layout.addWidget(self.top_k_input) + layout.addLayout(top_k_layout) + self.manage_custom_models_button = QPushButton("Manage Custom Models") self.manage_custom_models_button.clicked.connect(self.open_custom_models_dialog) layout.addWidget(self.manage_custom_models_button) @@ -143,6 +173,48 @@ def create_voice_activation_settings_group(self): group_box.setLayout(layout) self.main_layout.addWidget(group_box) + def create_rag_settings_group(self): + group_box = QGroupBox("RAG Settings") + layout = QVBoxLayout() + + embed_model_layout = QHBoxLayout() + embed_model_label = QLabel("Embed Model Name:") + self.embed_model_input = QLineEdit() + embed_model_layout.addWidget(embed_model_label) + embed_model_layout.addWidget(self.embed_model_input) + layout.addLayout(embed_model_layout) + + chunk_size_layout = QHBoxLayout() + chunk_size_label = QLabel("Chunk Size:") + self.chunk_size_input = QLineEdit() + chunk_size_layout.addWidget(chunk_size_label) + chunk_size_layout.addWidget(self.chunk_size_input) + layout.addLayout(chunk_size_layout) + + chunk_overlap_layout = QHBoxLayout() + chunk_overlap_label = QLabel("Chunk Overlap:") + self.chunk_overlap_input = QLineEdit() + chunk_overlap_layout.addWidget(chunk_overlap_label) + chunk_overlap_layout.addWidget(self.chunk_overlap_input) + layout.addLayout(chunk_overlap_layout) + + max_retrieval_top_k_layout = QHBoxLayout() + max_retrieval_top_k_label = QLabel("Max Retrieval Top k:") + self.max_retrieval_top_k_input = QLineEdit() + max_retrieval_top_k_layout.addWidget(max_retrieval_top_k_label) + max_retrieval_top_k_layout.addWidget(self.max_retrieval_top_k_input) + layout.addLayout(max_retrieval_top_k_layout) + + similarity_threshold_layout = QHBoxLayout() + similarity_threshold_label = QLabel("Similarity Threshold:") + self.similarity_threshold_input = QLineEdit() + similarity_threshold_layout.addWidget(similarity_threshold_label) + similarity_threshold_layout.addWidget(self.similarity_threshold_input) + layout.addLayout(similarity_threshold_layout) + + group_box.setLayout(layout) + self.main_layout.addWidget(group_box) + def accept(self): self.save_settings() self.settingsSaved.emit() @@ -186,6 +258,33 @@ def load_settings(self): self.hey_llama_chat_checkbox.setChecked(settings.get("hey_llama_chat", False)) self.hey_llama_mic_checkbox.setChecked(settings.get("hey_llama_mic", False)) self.update_hey_llama_mic_state(settings.get("hey_llama_chat", False)) + + # Load new settings + if "generation" not in settings: + settings["generation"] = {} + if "rag" not in settings: + settings["rag"] = {} + self.embed_model_input.setText( + settings["rag"].get("embed_model_name", config.embed_model_name) + ) + self.chunk_size_input.setText(str(settings["rag"].get("chunk_size", config.chunk_size))) + self.chunk_overlap_input.setText( + str(settings["rag"].get("chunk_overlap", config.chunk_overlap)) + ) + self.max_retrieval_top_k_input.setText( + str(settings["rag"].get("max_retrieval_top_k", config.max_retrieval_top_k)) + ) + self.similarity_threshold_input.setText( + str(settings["rag"].get("similarity_threshold", config.similarity_threshold)) + ) + self.context_len_input.setText( + str(settings["generation"].get("context_len", config.context_len)) + ) + self.temperature_input.setText( + str(settings["generation"].get("temperature", config.temperature)) + ) + self.top_p_input.setText(str(settings["generation"].get("top_p", config.top_p))) + self.top_k_input.setText(str(settings["generation"].get("top_k", config.top_k))) else: self.color = QColor("#1E1E1E") self.shortcut_recorder.setText("++") @@ -199,6 +298,19 @@ def get_settings(self): "multimodal_model": self.multimodal_model_combo.currentText(), "hey_llama_chat": self.hey_llama_chat_checkbox.isChecked(), "hey_llama_mic": self.hey_llama_mic_checkbox.isChecked(), + "generation": { + "context_len": int(self.context_len_input.text()), + "temperature": float(self.temperature_input.text()), + "top_p": float(self.top_p_input.text()), + "top_k": int(self.top_k_input.text()), + }, + "rag": { + "embed_model_name": self.embed_model_input.text(), + "chunk_size": int(self.chunk_size_input.text()), + "chunk_overlap": int(self.chunk_overlap_input.text()), + "max_retrieval_top_k": int(self.max_retrieval_top_k_input.text()), + "similarity_threshold": float(self.similarity_threshold_input.text()), + }, } def save_settings(self, settings=None): diff --git a/llama_assistant/ui_manager.py b/llama_assistant/ui_manager.py index d14cdd4..651f155 100644 --- a/llama_assistant/ui_manager.py +++ b/llama_assistant/ui_manager.py @@ -26,16 +26,18 @@ microphone_icon_svg, ) + class CustomQTextBrowser(QTextBrowser): def __init__(self, parent): super().__init__(parent) - + # Apply stylesheet specific to generated text content - self.document().setDefaultStyleSheet(""" + self.document().setDefaultStyleSheet( + """ p { color: #FFFFFF; font-size: 16px; - line-height: 1.3; + line-height: 1.3; } li { line-height: 1.3; @@ -49,7 +51,9 @@ def __init__(self, parent): font-family: Consolas, "Courier New", monospace; overflow: hidden; } - """) + """ + ) + class UIManager: def __init__(self, parent): @@ -177,7 +181,7 @@ def init_ui(self): self.scroll_area = QScrollArea(self.parent) self.scroll_area.setWidgetResizable(True) # Allow the widget inside to resize - self.scroll_area.setMinimumHeight(400) + self.scroll_area.setMinimumHeight(400) self.scroll_area.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff) self.scroll_area.setStyleSheet( """ @@ -206,7 +210,7 @@ def init_ui(self): self.scroll_area.setWidget(self.chat_box) self.scroll_area.hide() # Hide the scroll area initially - + # Ensure the scroll area can expand fully in the layout self.scroll_area.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) main_layout.addWidget(self.scroll_area) diff --git a/pyproject.toml b/pyproject.toml index 7dc78eb..8c63013 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ classifiers = [ "Programming Language :: Python :: 3.12", ] dependencies = [ + "numpy", "ffmpeg-python", "PyQt5", "markdown", @@ -34,6 +35,11 @@ dependencies = [ "huggingface_hub", "openwakeword", "whispercpp", + "llama-index-core", + "llama-index-readers-file", + "llama-index-embeddings-huggingface", + "docx2txt", + "mistune" ] dynamic = [] diff --git a/requirements.txt b/requirements.txt index db7c2cd..3768d2e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,7 +9,8 @@ openwakeword==0.6.0 pyinstaller==6.10.0 ffmpeg-python==0.2.0 llama-index-core==0.12.0 +llama-index-readers-file==1.2.2 llama-index-embeddings-huggingface==0.4.0 docx2txt==0.8 mistune==3.0.2 -git+https://github.com/stlukey/whispercpp.py \ No newline at end of file +git+https://github.com/stlukey/whispercpp.py