From 65804a4c3df377f6a4cbcb5138cff6284bea2cfc Mon Sep 17 00:00:00 2001 From: Nguyen The Nam Date: Fri, 13 Dec 2024 11:37:56 +0100 Subject: [PATCH] Precommit reformat --- llama_assistant/agent.py | 37 +++++-- llama_assistant/config.py | 36 ++++--- llama_assistant/llama_assistant_app.py | 31 +++--- llama_assistant/model_handler.py | 2 +- llama_assistant/ocr_engine.py | 45 ++++---- llama_assistant/processing_thread.py | 11 +- llama_assistant/screen_capture_widget.py | 53 ++++----- llama_assistant/setting_dialog.py | 132 ++++++++++++++++------- llama_assistant/setting_validator.py | 23 ++-- llama_assistant/ui_manager.py | 2 +- 10 files changed, 230 insertions(+), 142 deletions(-) diff --git a/llama_assistant/agent.py b/llama_assistant/agent.py index 8e6ec17..9489acb 100644 --- a/llama_assistant/agent.py +++ b/llama_assistant/agent.py @@ -12,6 +12,7 @@ SYSTEM_PROMPT = {"role": "system", "content": "Generate short and simple response."} + def convert_message_list_to_str(messages): chat_history_str = "" for message in messages: @@ -39,9 +40,11 @@ class ChatHistory: def __init__(self, llm, max_history_size: int, max_output_tokens: int): self.llm = llm self.max_output_tokens = max_output_tokens - self.max_history_size = max_history_size # in tokens + self.max_history_size = max_history_size # in tokens self.max_history_size_in_words = max_history_size * 3 / 4 - self.max_history_size_in_words = self.max_history_size_in_words - 128 # to account for some formatting tokens + self.max_history_size_in_words = ( + self.max_history_size_in_words - 128 + ) # to account for some formatting tokens print("Max history size in words:", self.max_history_size_in_words) self.total_size = 0 self.chat_history = [] @@ -57,7 +60,14 @@ def add_message(self, message: dict): if self.total_size + new_msg_size > self.max_history_size_in_words: print("Chat history is too long, summarizing the conversation...") history_summary = self.llm.create_chat_completion( - messages= [SYSTEM_PROMPT] + self.chat_history + [{"role": "user", "content": "Briefly summarize the conversation in a few sentences."}], + messages=[SYSTEM_PROMPT] + + self.chat_history + + [ + { + "role": "user", + "content": "Briefly summarize the conversation in a few sentences.", + } + ], stream=False, max_tokens=256, )["choices"][0]["message"]["content"] @@ -119,8 +129,11 @@ def __init__( self.search_index = None self.retriever = None - self.chat_history = ChatHistory(llm=llm, max_output_tokens=self.max_output_tokens, - max_history_size=self.max_input_tokens) + self.chat_history = ChatHistory( + llm=llm, + max_output_tokens=self.max_output_tokens, + max_history_size=self.max_input_tokens, + ) self.lookup_files = set() self.embed_model = HuggingFaceEmbedding(model_name=rag_setting["embed_model_name"]) @@ -144,9 +157,7 @@ def update_index(self, files: Optional[Set[str]] = set()): show_progress=True, num_workers=1 ) - self.search_index = VectorStoreIndex.from_documents( - documents, embed_model=self.embed_model - ) + self.search_index = VectorStoreIndex.from_documents(documents, embed_model=self.embed_model) self.retriever = self.search_index.as_retriever(similarity_top_k=self.retrieval_top_k) @@ -221,7 +232,15 @@ async def condense_history_to_query(self, ctx: Context, ev: SetupEvent) -> Conde if len(self.chat_history) > 0 and self.retriever is not None: standalone_query = self.llm.create_chat_completion( - messages=[SYSTEM_PROMPT] + self.chat_history.get_chat_history() + [{"role": "user", "content": query_str + "\n Condense this conversation to stand-alone question using only 1 sentence."}], + messages=[SYSTEM_PROMPT] + + self.chat_history.get_chat_history() + + [ + { + "role": "user", + "content": query_str + + "\n Condense this conversation to stand-alone question using only 1 sentence.", + } + ], stream=False, top_k=self.generation_setting["top_k"], top_p=self.generation_setting["top_p"], diff --git a/llama_assistant/config.py b/llama_assistant/config.py index e602a11..97f1a5c 100644 --- a/llama_assistant/config.py +++ b/llama_assistant/config.py @@ -12,11 +12,11 @@ "hey_llama_chat": False, "hey_llama_mic": False, "generation": { - "context_len": 4096, + "context_len": 4096, "max_output_tokens": 1024, "top_k": 40, "top_p": 0.95, - "temperature": 0.2 + "temperature": 0.2, }, "rag": { "embed_model_name": "BAAI/bge-base-en-v1.5", @@ -28,22 +28,26 @@ } VALIDATOR = { - 'generation': { - 'context_len': {'type': 'int', 'min': 2048}, - 'max_output_tokens': {'type': 'int', 'min': 512, 'max': 2048}, - 'top_k': {'type': 'int', 'min': 1, 'max': 100}, - 'top_p': {'type': 'float', 'min': 0, 'max': 1}, - 'temperature': {'type': 'float', 'min': 0, 'max': 1}, - }, - 'rag': { - 'chunk_size': {'type': 'int', 'min': 64, 'max': 512}, - 'chunk_overlap': {'type': 'int', 'min': 64, 'max': 256}, - 'max_retrieval_top_k': {'type': 'int', 'min': 1, 'max': 5}, - 'similarity_threshold': {'type': 'float', 'min': 0, 'max': 1}, - } + "generation": { + "context_len": {"type": "int", "min": 2048}, + "max_output_tokens": {"type": "int", "min": 512, "max": 2048}, + "top_k": {"type": "int", "min": 1, "max": 100}, + "top_p": {"type": "float", "min": 0, "max": 1}, + "temperature": {"type": "float", "min": 0, "max": 1}, + }, + "rag": { + "chunk_size": {"type": "int", "min": 64, "max": 512}, + "chunk_overlap": {"type": "int", "min": 64, "max": 256}, + "max_retrieval_top_k": {"type": "int", "min": 1, "max": 5}, + "similarity_threshold": {"type": "float", "min": 0, "max": 1}, + }, } -DEFAULT_EMBEDING_MODELS = ["BAAI/bge-small-en-v1.5", "BAAI/bge-base-en-v1.5", "BAAI/bge-large-en-v1.5"] +DEFAULT_EMBEDING_MODELS = [ + "BAAI/bge-small-en-v1.5", + "BAAI/bge-base-en-v1.5", + "BAAI/bge-large-en-v1.5", +] DEFAULT_MODELS = [ { diff --git a/llama_assistant/llama_assistant_app.py b/llama_assistant/llama_assistant_app.py index 9caa99d..2559d5a 100644 --- a/llama_assistant/llama_assistant_app.py +++ b/llama_assistant/llama_assistant_app.py @@ -14,15 +14,9 @@ QVBoxLayout, QMessageBox, QSystemTrayIcon, - QRubberBand -) -from PyQt5.QtCore import ( - Qt, - QPoint, - QTimer, - QSize, - QRect + QRubberBand, ) +from PyQt5.QtCore import Qt, QPoint, QTimer, QSize, QRect from PyQt5.QtGui import ( QPixmap, QPainter, @@ -201,18 +195,20 @@ def on_ocr_button_clicked(self): self.last_response = "" self.gen_mark_down = False - - self.ui_manager.chat_box.append(f'
You: OCR this captured region') + + self.ui_manager.chat_box.append( + f'
You: OCR this captured region' + ) self.ui_manager.chat_box.append('AI: ') - + self.start_cursor_pos = self.ui_manager.chat_box.textCursor().position() img_path = config.ocr_tmp_file if not img_path.exists(): print("No image find for OCR") - self.ui_manager.chat_box.append('No image found for OCR') + self.ui_manager.chat_box.append("No image found for OCR") return - + self.processing_thread = OCRThread(img_path, streaming=True) self.processing_thread.preloader_signal.connect(self.indicate_loading) self.processing_thread.update_signal.connect(self.update_chat_box) @@ -224,7 +220,6 @@ def on_ask_with_ocr_context(self): self.screen_capture_widget.hide() self.has_ocr_context = True - def on_submit(self): message = self.ui_manager.input_field.toPlainText() if message == "": @@ -238,7 +233,7 @@ def on_submit(self): for file_path in self.dropped_files: self.remove_file_thumbnail(self.file_containers[file_path], file_path) - + return self.last_response = "" @@ -316,7 +311,7 @@ def process_image_with_prompt(self, image_path, file_paths, prompt): prompt, image=image, lookup_files=file_paths, - ocr_img_path=config.ocr_tmp_file if self.has_ocr_context else None + ocr_img_path=config.ocr_tmp_file if self.has_ocr_context else None, ) self.processing_thread.preloader_signal.connect(self.indicate_loading) self.processing_thread.update_signal.connect(self.update_chat_box) @@ -346,7 +341,7 @@ def indicate_loading(self, message): QApplication.processEvents() # Process events to update the UI time.sleep(0.05) time.sleep(0.5) - + def update_chat_box(self, text): self.last_response += text @@ -360,7 +355,7 @@ def update_chat_box(self, text): formatted_text = markdown_response else: formatted_text = self.last_response.replace("\n", "
") + "
" - + self.clear_text_from_start_pos() cursor = self.ui_manager.chat_box.textCursor() cursor.insertHtml(formatted_text) diff --git a/llama_assistant/model_handler.py b/llama_assistant/model_handler.py index 6b8323e..6c1a112 100644 --- a/llama_assistant/model_handler.py +++ b/llama_assistant/model_handler.py @@ -217,7 +217,7 @@ def update_chat_history(self, message: str, role: str): if self.loaded_agent is None: print("Agent has not been initialized. Cannot update chat history.") return - + agent = self.loaded_agent.get("agent") if agent: agent.chat_history.add_message({"role": role, "content": message}) diff --git a/llama_assistant/ocr_engine.py b/llama_assistant/ocr_engine.py index 847dc3d..9ab2244 100644 --- a/llama_assistant/ocr_engine.py +++ b/llama_assistant/ocr_engine.py @@ -4,27 +4,28 @@ import numpy as np import time + def group_boxes_to_lines(bboxes, vertical_tolerance=5): """ Groups bounding boxes into lines based on vertical alignment. - + Args: bboxes: List of bounding boxes [(xmin, ymin, xmax, ymax)]. vertical_tolerance: Tolerance for vertical proximity to consider boxes in the same line. - + Returns: List of lines, where each line is a list of bounding boxes. """ # Sort bounding boxes by ymin (top edge) bboxes = sorted(bboxes, key=lambda bbox: bbox[1]) - + lines = [] current_line = [] current_ymin = None for bbox in bboxes: xmin, ymin, xmax, ymax = bbox - + # Check if starting a new line or current box is not vertically aligned if current_ymin is None or ymin > current_ymin + vertical_tolerance: # Save the current line if not empty @@ -38,21 +39,22 @@ def group_boxes_to_lines(bboxes, vertical_tolerance=5): else: # Add box to the current line current_line.append(bbox) - + # Add the last line if any if current_line: current_line.sort(key=lambda box: box[0]) lines.append(current_line) - + return lines + def quad_to_rect(quad_boxes): """ Converts a quadrilateral bounding box to a rectangular bounding box. - + Args: quad_boxes: List of 4 points [(x1, y1), (x2, y2), (x3, y3), (x4, y4)] representing the quadrilateral. - + Returns: List of rectangular bounding box (xmin, ymin, xmax, ymax). """ @@ -70,19 +72,20 @@ def quad_to_rect(quad_boxes): ymax = np.max(y_coords) result.append((xmin, ymin, xmax, ymax)) - + return result + class OCREngine: def __init__(self): self.ocr = None - + def load_ocr(self, processign_thread=None): if self.ocr is None: if processign_thread: processign_thread.set_preloading(True, "Initializing OCR ....") - self.ocr = PaddleOCR(use_angle_cls=True, lang='en') + self.ocr = PaddleOCR(use_angle_cls=True, lang="en") time.sleep(1.2) if processign_thread: @@ -90,14 +93,14 @@ def load_ocr(self, processign_thread=None): def perform_ocr(self, img_path, streaming=False, processing_thread=None): self.load_ocr(processing_thread) - img = np.array(Image.open(img_path).convert('RGB')) + img = np.array(Image.open(img_path).convert("RGB")) ori_im = img.copy() # text detection dt_boxes, _ = self.ocr.text_detector(img) if dt_boxes is None: return None - + img_crop_list = [] dt_boxes = quad_to_rect(dt_boxes) @@ -106,6 +109,7 @@ def perform_ocr(self, img_path, streaming=False, processing_thread=None): # return a generator if streaming if streaming: + def generate_result(): for boxes in lines: img_crop_list = [] @@ -116,17 +120,17 @@ def generate_result(): continue img_crop_list.append(img_crop) rec_res, _ = self.ocr.text_recognizer(img_crop_list) - + line_text = "" for rec_result in rec_res: text, score = rec_result if score >= self.ocr.drop_score: line_text += text + " " - yield line_text+'\n' + yield line_text + "\n" return generate_result() - + # non-streaming full_result = "" for boxes in lines: @@ -138,15 +142,16 @@ def generate_result(): continue img_crop_list.append(img_crop) rec_res, _ = self.ocr.text_recognizer(img_crop_list) - + line_text = "" for rec_result in rec_res: text, score = rec_result if score >= self.ocr.drop_score: line_text += text + " " - full_result += line_text + '\n' + full_result += line_text + "\n" return full_result - -ocr_engine = OCREngine() \ No newline at end of file + + +ocr_engine = OCREngine() diff --git a/llama_assistant/processing_thread.py b/llama_assistant/processing_thread.py index 04bc6c2..ce1a492 100644 --- a/llama_assistant/processing_thread.py +++ b/llama_assistant/processing_thread.py @@ -6,6 +6,7 @@ from llama_assistant.model_handler import handler as model_handler from llama_assistant.ocr_engine import ocr_engine + class ProcessingThread(QThread): preloader_signal = pyqtSignal(str) update_signal = pyqtSignal(str) @@ -75,6 +76,7 @@ def set_preloading(self, preloading: bool, message: str): def is_preloading(self): return self.preloading + class OCRThread(QThread): preloader_signal = pyqtSignal(str) update_signal = pyqtSignal(str) @@ -96,16 +98,18 @@ def set_preloading(self, preloading: bool, message: str): def is_preloading(self): return self.preloading - + def run(self): - output = ocr_engine.perform_ocr(self.img_path, streaming=self.streaming, processing_thread=self) + output = ocr_engine.perform_ocr( + self.img_path, streaming=self.streaming, processing_thread=self + ) full_response_str = "Here is the OCR result:\n" self.is_ocr_done = True if not self.streaming and type(output) == str: self.update_signal.emit(full_response_str + output) return - + self.update_signal.emit(full_response_str) for chunk in output: self.update_signal.emit(chunk) @@ -114,4 +118,3 @@ def run(self): model_handler.update_chat_history("OCR this image:", "user") model_handler.update_chat_history(full_response_str, "assistant") self.finished_signal.emit() - \ No newline at end of file diff --git a/llama_assistant/screen_capture_widget.py b/llama_assistant/screen_capture_widget.py index 8523578..ed351cf 100644 --- a/llama_assistant/screen_capture_widget.py +++ b/llama_assistant/screen_capture_widget.py @@ -5,6 +5,7 @@ from llama_assistant import config from llama_assistant.ocr_engine import OCREngine + if TYPE_CHECKING: from llama_assistant.llama_assistant_app import LlamaAssistantApp @@ -13,17 +14,17 @@ class ScreenCaptureWidget(QWidget): def __init__(self, parent: "LlamaAssistantApp"): super().__init__() self.setWindowFlags(Qt.FramelessWindowHint) - + self.parent = parent self.ocr_engine = OCREngine() - + # Get screen size screen = QDesktopWidget().screenGeometry() self.setGeometry(0, 0, screen.width(), screen.height()) - + # Set crosshairs cursor self.setCursor(Qt.CrossCursor) - + # To store the start and end points of the mouse region self.start_point = None self.end_point = None @@ -73,71 +74,73 @@ def hide(self): self.button_widget.hide() super().hide() - def mousePressEvent(self, event): if event.button() == Qt.LeftButton: self.start_point = event.pos() # Capture start position - self.end_point = event.pos() # Initialize end point to start position + self.end_point = event.pos() # Initialize end point to start position print(f"Mouse press at {self.start_point}") - + def mouseReleaseEvent(self, event): if event.button() == Qt.LeftButton: self.end_point = event.pos() # Capture end position - + print(f"Mouse release at {self.end_point}") - + # Capture the region between start and end points if self.start_point and self.end_point: self.capture_region(self.start_point, self.end_point) - + # Trigger repaint to show the red rectangle self.update() self.show_buttons() - def mouseMoveEvent(self, event): if self.start_point: # Update the end_point to the current mouse position as it moves self.end_point = event.pos() - + # Trigger repaint to update the rectangle self.update() - + def capture_region(self, start_point, end_point): # Convert local widget coordinates to global screen coordinates start_global = self.mapToGlobal(start_point) end_global = self.mapToGlobal(end_point) - + # Create a QRect from the global start and end points region_rect = QRect(start_global, end_global) - + # Ensure the rectangle is valid (non-negative width/height) region_rect = region_rect.normalized() - + # Capture the screen region screen = QApplication.primaryScreen() - pixmap = screen.grabWindow(0, region_rect.x(), region_rect.y(), region_rect.width(), region_rect.height()) + pixmap = screen.grabWindow( + 0, region_rect.x(), region_rect.y(), region_rect.width(), region_rect.height() + ) # Save the captured region as an image pixmap.save(str(config.ocr_tmp_file), "PNG") print(f"Captured region saved at '{config.ocr_tmp_file}'.") - + def paintEvent(self, event): # If the start and end points are set, draw the rectangle if self.start_point and self.end_point: # Create a painter object painter = QPainter(self) - + # Set the pen color to red pen = QPen(QColor(255, 0, 0)) # Red color pen.setWidth(3) # Set width of the border painter.setPen(pen) - + # Draw the rectangle from start_point to end_point self.region_rect = QRect(self.start_point, self.end_point) - self.region_rect = self.region_rect.normalized() # Normalize to ensure correct width/height - + self.region_rect = ( + self.region_rect.normalized() + ) # Normalize to ensure correct width/height + painter.drawRect(self.region_rect) # Draw the rectangle super().paintEvent(event) # Call the base class paintEvent @@ -158,9 +161,9 @@ def show_buttons(self): self.ocr_button.setGeometry(0, 0, button_width, button_height) self.ask_button.setGeometry(button_width + spacing, 0, button_width, button_height) - self.button_widget.setGeometry(rect.left(), button_y, 2 * button_width + spacing, button_height) + self.button_widget.setGeometry( + rect.left(), button_y, 2 * button_width + spacing, button_height + ) self.button_widget.setAttribute(Qt.WA_TranslucentBackground) self.button_widget.setWindowFlags(Qt.FramelessWindowHint) self.button_widget.show() - - \ No newline at end of file diff --git a/llama_assistant/setting_dialog.py b/llama_assistant/setting_dialog.py index 9a5ae12..843d748 100644 --- a/llama_assistant/setting_dialog.py +++ b/llama_assistant/setting_dialog.py @@ -24,6 +24,7 @@ from llama_assistant import config from llama_assistant.setting_validator import validate_numeric_field + class SettingsDialog(QDialog): settingsSaved = pyqtSignal() @@ -218,56 +219,76 @@ def create_rag_settings_group(self): self.main_layout.addWidget(group_box) def accept(self): - valid, message = validate_numeric_field("Context Length", self.context_len_input.text(), - constraints=config.VALIDATOR['generation']['context_len']) + valid, message = validate_numeric_field( + "Context Length", + self.context_len_input.text(), + constraints=config.VALIDATOR["generation"]["context_len"], + ) if not valid: QMessageBox.warning(self, "Validation Error", message) return - - valid, message = validate_numeric_field("Temperature", self.temperature_input.text(), - constraints=config.VALIDATOR['generation']['temperature']) + + valid, message = validate_numeric_field( + "Temperature", + self.temperature_input.text(), + constraints=config.VALIDATOR["generation"]["temperature"], + ) if not valid: QMessageBox.warning(self, "Validation Error", message) return - - valid, message = validate_numeric_field("Top p", self.top_p_input.text(), - constraints=config.VALIDATOR['generation']['top_p']) + + valid, message = validate_numeric_field( + "Top p", self.top_p_input.text(), constraints=config.VALIDATOR["generation"]["top_p"] + ) if not valid: QMessageBox.warning(self, "Validation Error", message) return - - valid, message = validate_numeric_field("Top k", self.top_k_input.text(), - constraints=config.VALIDATOR['generation']['top_k']) + + valid, message = validate_numeric_field( + "Top k", self.top_k_input.text(), constraints=config.VALIDATOR["generation"]["top_k"] + ) if not valid: QMessageBox.warning(self, "Validation Error", message) return - - valid, message = validate_numeric_field("Chunk Size", self.chunk_size_input.text(), - constraints=config.VALIDATOR['rag']['chunk_size']) + + valid, message = validate_numeric_field( + "Chunk Size", + self.chunk_size_input.text(), + constraints=config.VALIDATOR["rag"]["chunk_size"], + ) if not valid: QMessageBox.warning(self, "Validation Error", message) return - - valid, message = validate_numeric_field("Chunk Overlap", self.chunk_overlap_input.text(), - constraints=config.VALIDATOR['rag']['chunk_overlap']) + + valid, message = validate_numeric_field( + "Chunk Overlap", + self.chunk_overlap_input.text(), + constraints=config.VALIDATOR["rag"]["chunk_overlap"], + ) if not valid: QMessageBox.warning(self, "Validation Error", message) return - - valid, message = validate_numeric_field("Max Retrieval Top k", self.max_retrieval_top_k_input.text(), - constraints=config.VALIDATOR['rag']['max_retrieval_top_k']) - + + valid, message = validate_numeric_field( + "Max Retrieval Top k", + self.max_retrieval_top_k_input.text(), + constraints=config.VALIDATOR["rag"]["max_retrieval_top_k"], + ) + if not valid: QMessageBox.warning(self, "Validation Error", message) return - - valid, message = validate_numeric_field("Similarity Threshold", self.similarity_threshold_input.text(), - constraints=config.VALIDATOR['rag']['similarity_threshold']) - + + valid, message = validate_numeric_field( + "Similarity Threshold", + self.similarity_threshold_input.text(), + constraints=config.VALIDATOR["rag"]["similarity_threshold"], + ) + if not valid: QMessageBox.warning(self, "Validation Error", message) return - + self.save_settings() self.settingsSaved.emit() super().accept() @@ -317,29 +338,66 @@ def load_settings(self): if "rag" not in settings: settings["rag"] = {} - embed_model = settings["rag"].get("embed_model_name", config.DEFAULT_SETTINGS['rag']['embed_model_name']) + embed_model = settings["rag"].get( + "embed_model_name", config.DEFAULT_SETTINGS["rag"]["embed_model_name"] + ) if embed_model in config.DEFAULT_EMBEDING_MODELS: self.embed_model_combo.setCurrentText(embed_model) - - self.chunk_size_input.setText(str(settings["rag"].get("chunk_size", config.DEFAULT_SETTINGS['rag']['chunk_size']))) + + self.chunk_size_input.setText( + str(settings["rag"].get("chunk_size", config.DEFAULT_SETTINGS["rag"]["chunk_size"])) + ) self.chunk_overlap_input.setText( - str(settings["rag"].get("chunk_overlap", config.DEFAULT_SETTINGS['rag']['chunk_overlap'])) + str( + settings["rag"].get( + "chunk_overlap", config.DEFAULT_SETTINGS["rag"]["chunk_overlap"] + ) + ) ) self.max_retrieval_top_k_input.setText( - str(settings["rag"].get("max_retrieval_top_k", config.DEFAULT_SETTINGS['rag']['max_retrieval_top_k'])) + str( + settings["rag"].get( + "max_retrieval_top_k", config.DEFAULT_SETTINGS["rag"]["max_retrieval_top_k"] + ) + ) ) self.similarity_threshold_input.setText( - str(settings["rag"].get("similarity_threshold", config.DEFAULT_SETTINGS['rag']['similarity_threshold'])) + str( + settings["rag"].get( + "similarity_threshold", + config.DEFAULT_SETTINGS["rag"]["similarity_threshold"], + ) + ) ) self.context_len_input.setText( - str(settings["generation"].get("context_len", config.DEFAULT_SETTINGS['generation']['context_len'])) + str( + settings["generation"].get( + "context_len", config.DEFAULT_SETTINGS["generation"]["context_len"] + ) + ) ) - + self.temperature_input.setText( - str(settings["generation"].get("temperature", config.DEFAULT_SETTINGS['generation']['temperature'])) + str( + settings["generation"].get( + "temperature", config.DEFAULT_SETTINGS["generation"]["temperature"] + ) + ) + ) + self.top_p_input.setText( + str( + settings["generation"].get( + "top_p", config.DEFAULT_SETTINGS["generation"]["top_p"] + ) + ) + ) + self.top_k_input.setText( + str( + settings["generation"].get( + "top_k", config.DEFAULT_SETTINGS["generation"]["top_k"] + ) + ) ) - self.top_p_input.setText(str(settings["generation"].get("top_p", config.DEFAULT_SETTINGS['generation']['top_p']))) - self.top_k_input.setText(str(settings["generation"].get("top_k", config.DEFAULT_SETTINGS['generation']['top_k']))) else: self.color = QColor("#1E1E1E") self.shortcut_recorder.setText("++") diff --git a/llama_assistant/setting_validator.py b/llama_assistant/setting_validator.py index 19cdcb7..bb45a5e 100644 --- a/llama_assistant/setting_validator.py +++ b/llama_assistant/setting_validator.py @@ -1,10 +1,9 @@ - def validate_numeric_field(name, value_str, constraints): - type = constraints['type'] - min = constraints.get('min') - max = constraints.get('max') - - if type == 'float': + type = constraints["type"] + min = constraints.get("min") + max = constraints.get("max") + + if type == "float": if isinstance(value_str, float): value = value_str else: @@ -13,8 +12,8 @@ def validate_numeric_field(name, value_str, constraints): except ValueError: message = f"Invalid value for {name}. Expected a float, got {value_str}" return False, message - - elif type == 'int': + + elif type == "int": if isinstance(value_str, int): value = value_str else: @@ -23,11 +22,13 @@ def validate_numeric_field(name, value_str, constraints): except ValueError: message = f"Invalid value for {name}. Expected an integer, got {value_str}" return False, message - + if min is not None and value < min: message = f"Invalid value for {name}. Expected a value greater than or equal to {min}, got {value}" return False, message if max is not None and value > max: - message = f"Invalid value for {name}. Expected a value less than or equal to {max}, got {value}" + message = ( + f"Invalid value for {name}. Expected a value less than or equal to {max}, got {value}" + ) return False, message - return True, value \ No newline at end of file + return True, value diff --git a/llama_assistant/ui_manager.py b/llama_assistant/ui_manager.py index 817e60e..15d2b59 100644 --- a/llama_assistant/ui_manager.py +++ b/llama_assistant/ui_manager.py @@ -24,7 +24,7 @@ copy_icon_svg, clear_icon_svg, microphone_icon_svg, - crosshair_icon_svg + crosshair_icon_svg, )