diff --git a/README.md b/README.md index 4f88cfe..fc1d98c 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,7 @@ https://github.com/user-attachments/assets/af2c544b-6d46-4c44-87d8-9a051ba213db - [x] ⚡ Streaming support for response. - [x] 🎙️ Add offline STT support: WhisperCPP. - [x] 🧠 Knowledge database: LlamaIndex +- [x] ⌖ Screen Spot: Screen capture and analyze with OCR - [ ] 🔌 Plugin system for extensibility. - [ ] 📰 News and weather updates. - [ ] 📧 Email integration with Gmail and Outlook. diff --git a/llama_assistant/agent.py b/llama_assistant/agent.py index 8e6ec17..9489acb 100644 --- a/llama_assistant/agent.py +++ b/llama_assistant/agent.py @@ -12,6 +12,7 @@ SYSTEM_PROMPT = {"role": "system", "content": "Generate short and simple response."} + def convert_message_list_to_str(messages): chat_history_str = "" for message in messages: @@ -39,9 +40,11 @@ class ChatHistory: def __init__(self, llm, max_history_size: int, max_output_tokens: int): self.llm = llm self.max_output_tokens = max_output_tokens - self.max_history_size = max_history_size # in tokens + self.max_history_size = max_history_size # in tokens self.max_history_size_in_words = max_history_size * 3 / 4 - self.max_history_size_in_words = self.max_history_size_in_words - 128 # to account for some formatting tokens + self.max_history_size_in_words = ( + self.max_history_size_in_words - 128 + ) # to account for some formatting tokens print("Max history size in words:", self.max_history_size_in_words) self.total_size = 0 self.chat_history = [] @@ -57,7 +60,14 @@ def add_message(self, message: dict): if self.total_size + new_msg_size > self.max_history_size_in_words: print("Chat history is too long, summarizing the conversation...") history_summary = self.llm.create_chat_completion( - messages= [SYSTEM_PROMPT] + self.chat_history + [{"role": "user", "content": "Briefly summarize the conversation in a few sentences."}], + messages=[SYSTEM_PROMPT] + + self.chat_history + + [ + { + "role": "user", + "content": "Briefly summarize the conversation in a few sentences.", + } + ], stream=False, max_tokens=256, )["choices"][0]["message"]["content"] @@ -119,8 +129,11 @@ def __init__( self.search_index = None self.retriever = None - self.chat_history = ChatHistory(llm=llm, max_output_tokens=self.max_output_tokens, - max_history_size=self.max_input_tokens) + self.chat_history = ChatHistory( + llm=llm, + max_output_tokens=self.max_output_tokens, + max_history_size=self.max_input_tokens, + ) self.lookup_files = set() self.embed_model = HuggingFaceEmbedding(model_name=rag_setting["embed_model_name"]) @@ -144,9 +157,7 @@ def update_index(self, files: Optional[Set[str]] = set()): show_progress=True, num_workers=1 ) - self.search_index = VectorStoreIndex.from_documents( - documents, embed_model=self.embed_model - ) + self.search_index = VectorStoreIndex.from_documents(documents, embed_model=self.embed_model) self.retriever = self.search_index.as_retriever(similarity_top_k=self.retrieval_top_k) @@ -221,7 +232,15 @@ async def condense_history_to_query(self, ctx: Context, ev: SetupEvent) -> Conde if len(self.chat_history) > 0 and self.retriever is not None: standalone_query = self.llm.create_chat_completion( - messages=[SYSTEM_PROMPT] + self.chat_history.get_chat_history() + [{"role": "user", "content": query_str + "\n Condense this conversation to stand-alone question using only 1 sentence."}], + messages=[SYSTEM_PROMPT] + + self.chat_history.get_chat_history() + + [ + { + "role": "user", + "content": query_str + + "\n Condense this conversation to stand-alone question using only 1 sentence.", + } + ], stream=False, top_k=self.generation_setting["top_k"], top_p=self.generation_setting["top_p"], diff --git a/llama_assistant/config.py b/llama_assistant/config.py index b966ff3..97f1a5c 100644 --- a/llama_assistant/config.py +++ b/llama_assistant/config.py @@ -12,11 +12,11 @@ "hey_llama_chat": False, "hey_llama_mic": False, "generation": { - "context_len": 4096, + "context_len": 4096, "max_output_tokens": 1024, "top_k": 40, "top_p": 0.95, - "temperature": 0.2 + "temperature": 0.2, }, "rag": { "embed_model_name": "BAAI/bge-base-en-v1.5", @@ -28,22 +28,26 @@ } VALIDATOR = { - 'generation': { - 'context_len': {'type': 'int', 'min': 2048}, - 'max_output_tokens': {'type': 'int', 'min': 512, 'max': 2048}, - 'top_k': {'type': 'int', 'min': 1, 'max': 100}, - 'top_p': {'type': 'float', 'min': 0, 'max': 1}, - 'temperature': {'type': 'float', 'min': 0, 'max': 1}, - }, - 'rag': { - 'chunk_size': {'type': 'int', 'min': 64, 'max': 512}, - 'chunk_overlap': {'type': 'int', 'min': 64, 'max': 256}, - 'max_retrieval_top_k': {'type': 'int', 'min': 1, 'max': 5}, - 'similarity_threshold': {'type': 'float', 'min': 0, 'max': 1}, - } + "generation": { + "context_len": {"type": "int", "min": 2048}, + "max_output_tokens": {"type": "int", "min": 512, "max": 2048}, + "top_k": {"type": "int", "min": 1, "max": 100}, + "top_p": {"type": "float", "min": 0, "max": 1}, + "temperature": {"type": "float", "min": 0, "max": 1}, + }, + "rag": { + "chunk_size": {"type": "int", "min": 64, "max": 512}, + "chunk_overlap": {"type": "int", "min": 64, "max": 256}, + "max_retrieval_top_k": {"type": "int", "min": 1, "max": 5}, + "similarity_threshold": {"type": "float", "min": 0, "max": 1}, + }, } -DEFAULT_EMBEDING_MODELS = ["BAAI/bge-small-en-v1.5", "BAAI/bge-base-en-v1.5", "BAAI/bge-large-en-v1.5"] +DEFAULT_EMBEDING_MODELS = [ + "BAAI/bge-small-en-v1.5", + "BAAI/bge-base-en-v1.5", + "BAAI/bge-large-en-v1.5", +] DEFAULT_MODELS = [ { @@ -150,6 +154,7 @@ custom_models_file = llama_assistant_dir / "custom_models.json" settings_file = llama_assistant_dir / "settings.json" document_icon = "llama_assistant/resources/document_icon.png" +ocr_tmp_file = llama_assistant_dir / "ocr_tmp.png" if custom_models_file.exists(): with open(custom_models_file, "r") as f: diff --git a/llama_assistant/icons.py b/llama_assistant/icons.py index c4564d5..4fd5dff 100644 --- a/llama_assistant/icons.py +++ b/llama_assistant/icons.py @@ -29,6 +29,16 @@ """ +crosshair_icon_svg = """ + + + + + + + +""" + def create_icon_from_svg(svg_string): svg_bytes = QByteArray(svg_string.encode("utf-8")) diff --git a/llama_assistant/llama_assistant_app.py b/llama_assistant/llama_assistant_app.py index f6e4d5e..2559d5a 100644 --- a/llama_assistant/llama_assistant_app.py +++ b/llama_assistant/llama_assistant_app.py @@ -14,12 +14,9 @@ QVBoxLayout, QMessageBox, QSystemTrayIcon, + QRubberBand, ) -from PyQt5.QtCore import ( - Qt, - QPoint, - QTimer, -) +from PyQt5.QtCore import Qt, QPoint, QTimer, QSize, QRect from PyQt5.QtGui import ( QPixmap, QPainter, @@ -36,9 +33,10 @@ from llama_assistant.setting_dialog import SettingsDialog from llama_assistant.speech_recognition_thread import SpeechRecognitionThread from llama_assistant.utils import image_to_base64_data_uri -from llama_assistant.processing_thread import ProcessingThread +from llama_assistant.processing_thread import ProcessingThread, OCRThread from llama_assistant.ui_manager import UIManager from llama_assistant.tray_manager import TrayManager +from llama_assistant.screen_capture_widget import ScreenCaptureWidget from llama_assistant.setting_validator import validate_numeric_field from llama_assistant.utils import load_image @@ -50,6 +48,7 @@ def __init__(self): self.load_settings() self.ui_manager = UIManager(self) self.tray_manager = TrayManager(self) + self.screen_capture_widget = ScreenCaptureWidget(self) self.setup_global_shortcut() self.last_response = "" self.dropped_image = None @@ -62,6 +61,12 @@ def __init__(self): self.current_multimodal_model = self.settings.get("multimodal_model") self.processing_thread = None self.markdown_creator = mistune.create_markdown() + self.gen_mark_down = True + self.has_ocr_context = False + + def capture_screenshot(self): + self.hide() + self.screen_capture_widget.show() def tray_icon_activated(self, reason): if reason == QSystemTrayIcon.ActivationReason.Trigger: @@ -183,6 +188,38 @@ def toggle_visibility(self): self.raise_() self.ui_manager.input_field.setFocus() + def on_ocr_button_clicked(self): + self.show() + self.show_chat_box() + self.screen_capture_widget.hide() + + self.last_response = "" + self.gen_mark_down = False + + self.ui_manager.chat_box.append( + f'
You: OCR this captured region' + ) + self.ui_manager.chat_box.append('AI: ') + + self.start_cursor_pos = self.ui_manager.chat_box.textCursor().position() + + img_path = config.ocr_tmp_file + if not img_path.exists(): + print("No image find for OCR") + self.ui_manager.chat_box.append("No image found for OCR") + return + + self.processing_thread = OCRThread(img_path, streaming=True) + self.processing_thread.preloader_signal.connect(self.indicate_loading) + self.processing_thread.update_signal.connect(self.update_chat_box) + self.processing_thread.finished_signal.connect(self.on_processing_finished) + self.processing_thread.start() + + def on_ask_with_ocr_context(self): + self.show() + self.screen_capture_widget.hide() + self.has_ocr_context = True + def on_submit(self): message = self.ui_manager.input_field.toPlainText() if message == "": @@ -200,6 +237,7 @@ def on_submit(self): return self.last_response = "" + self.gen_mark_down = True if self.dropped_image: self.process_image_with_prompt(self.dropped_image, self.dropped_files, message) @@ -244,6 +282,7 @@ def process_text(self, message, file_paths, task="chat"): self.rag_setting, prompt, lookup_files=file_paths, + ocr_img_path=config.ocr_tmp_file if self.has_ocr_context else None, ) self.processing_thread.preloader_signal.connect(self.indicate_loading) @@ -251,6 +290,8 @@ def process_text(self, message, file_paths, task="chat"): self.processing_thread.finished_signal.connect(self.on_processing_finished) self.processing_thread.start() + self.has_ocr_context = False + def process_image_with_prompt(self, image_path, file_paths, prompt): self.show_chat_box() self.ui_manager.chat_box.append( @@ -270,20 +311,27 @@ def process_image_with_prompt(self, image_path, file_paths, prompt): prompt, image=image, lookup_files=file_paths, + ocr_img_path=config.ocr_tmp_file if self.has_ocr_context else None, ) self.processing_thread.preloader_signal.connect(self.indicate_loading) self.processing_thread.update_signal.connect(self.update_chat_box) self.processing_thread.finished_signal.connect(self.on_processing_finished) self.processing_thread.start() + self.has_ocr_context = False + + def clear_text_from_start_pos(self): + cursor = self.ui_manager.chat_box.textCursor() + cursor.setPosition(self.start_cursor_pos) + # Select all text from the start_pos to the end + cursor.movePosition(QTextCursor.End, QTextCursor.KeepAnchor) + # Remove the selected text + cursor.removeSelectedText() + def indicate_loading(self, message): while self.processing_thread.is_preloading(): + self.clear_text_from_start_pos() cursor = self.ui_manager.chat_box.textCursor() - cursor.setPosition(self.start_cursor_pos) - # Select all text from the start_pos to the end - cursor.movePosition(QTextCursor.End, QTextCursor.KeepAnchor) - # Remove the selected text - cursor.removeSelectedText() # create animation where the characters are displayed one by one for c in message: if c == " ": @@ -296,20 +344,21 @@ def indicate_loading(self, message): def update_chat_box(self, text): self.last_response += text - markdown_response = self.markdown_creator(self.last_response) - # Since cannot change the font size of the h1, h2 tag, we will replace it with h3 - markdown_response = markdown_response.replace("

", "

").replace("

", "") - markdown_response = markdown_response.replace("

", "

").replace("

", "") - markdown_response += "
" + + formatted_text = "" + if self.gen_mark_down: + markdown_response = self.markdown_creator(self.last_response) + # Since cannot change the font size of the h1, h2 tag, we will replace it with h3 + markdown_response = markdown_response.replace("

", "

").replace("

", "") + markdown_response = markdown_response.replace("

", "

").replace("

", "") + markdown_response += "
" + formatted_text = markdown_response + else: + formatted_text = self.last_response.replace("\n", "
") + "
" + + self.clear_text_from_start_pos() cursor = self.ui_manager.chat_box.textCursor() - cursor.setPosition( - self.start_cursor_pos - ) # regenerate the updated text from the start position - # Select all text from the start_pos to the end - cursor.movePosition(QTextCursor.End, QTextCursor.KeepAnchor) - # Remove the selected text - cursor.removeSelectedText() - cursor.insertHtml(markdown_response) + cursor.insertHtml(formatted_text) self.ui_manager.chat_box.verticalScrollBar().setValue( self.ui_manager.chat_box.verticalScrollBar().maximum() ) @@ -514,14 +563,6 @@ def remove_image_thumbnail(self): self.ui_manager.input_field.setPlaceholderText("Ask me anything...") self.setFixedHeight(self.height() - 110) # Decrease height after removing image - def mousePressEvent(self, event): - self.oldPos = event.globalPos() - - def mouseMoveEvent(self, event): - delta = QPoint(event.globalPos() - self.oldPos) - self.move(self.x() + delta.x(), self.y() + delta.y()) - self.oldPos = event.globalPos() - def on_wake_word_detected(self, model_name): self.show() self.activateWindow() diff --git a/llama_assistant/model_handler.py b/llama_assistant/model_handler.py index c149b18..6c1a112 100644 --- a/llama_assistant/model_handler.py +++ b/llama_assistant/model_handler.py @@ -214,6 +214,10 @@ def chat_completion( return response def update_chat_history(self, message: str, role: str): + if self.loaded_agent is None: + print("Agent has not been initialized. Cannot update chat history.") + return + agent = self.loaded_agent.get("agent") if agent: agent.chat_history.add_message({"role": role, "content": message}) diff --git a/llama_assistant/ocr_engine.py b/llama_assistant/ocr_engine.py new file mode 100644 index 0000000..9ab2244 --- /dev/null +++ b/llama_assistant/ocr_engine.py @@ -0,0 +1,157 @@ +from paddleocr import PaddleOCR +from PIL import Image +import copy +import numpy as np +import time + + +def group_boxes_to_lines(bboxes, vertical_tolerance=5): + """ + Groups bounding boxes into lines based on vertical alignment. + + Args: + bboxes: List of bounding boxes [(xmin, ymin, xmax, ymax)]. + vertical_tolerance: Tolerance for vertical proximity to consider boxes in the same line. + + Returns: + List of lines, where each line is a list of bounding boxes. + """ + # Sort bounding boxes by ymin (top edge) + bboxes = sorted(bboxes, key=lambda bbox: bbox[1]) + + lines = [] + current_line = [] + current_ymin = None + + for bbox in bboxes: + xmin, ymin, xmax, ymax = bbox + + # Check if starting a new line or current box is not vertically aligned + if current_ymin is None or ymin > current_ymin + vertical_tolerance: + # Save the current line if not empty + if current_line: + # Sort current line by xmin + current_line.sort(key=lambda box: box[0]) + lines.append(current_line) + # Start a new line + current_line = [bbox] + current_ymin = ymin + else: + # Add box to the current line + current_line.append(bbox) + + # Add the last line if any + if current_line: + current_line.sort(key=lambda box: box[0]) + lines.append(current_line) + + return lines + + +def quad_to_rect(quad_boxes): + """ + Converts a quadrilateral bounding box to a rectangular bounding box. + + Args: + quad_boxes: List of 4 points [(x1, y1), (x2, y2), (x3, y3), (x4, y4)] representing the quadrilateral. + + Returns: + List of rectangular bounding box (xmin, ymin, xmax, ymax). + """ + result = [] + for quad_box in quad_boxes: + quad_box = quad_box.astype(np.int32) + # Extract all x and y coordinates + x_coords = quad_box[:, 0] + y_coords = quad_box[:, 1] + + # Find the enclosing rectangle + xmin = np.min(x_coords) + ymin = np.min(y_coords) + xmax = np.max(x_coords) + ymax = np.max(y_coords) + + result.append((xmin, ymin, xmax, ymax)) + + return result + + +class OCREngine: + def __init__(self): + self.ocr = None + + def load_ocr(self, processign_thread=None): + if self.ocr is None: + if processign_thread: + processign_thread.set_preloading(True, "Initializing OCR ....") + + self.ocr = PaddleOCR(use_angle_cls=True, lang="en") + time.sleep(1.2) + + if processign_thread: + processign_thread.set_preloading(False, "...") + + def perform_ocr(self, img_path, streaming=False, processing_thread=None): + self.load_ocr(processing_thread) + img = np.array(Image.open(img_path).convert("RGB")) + ori_im = img.copy() + + # text detection + dt_boxes, _ = self.ocr.text_detector(img) + if dt_boxes is None: + return None + + img_crop_list = [] + dt_boxes = quad_to_rect(dt_boxes) + + # group boxes into lines + lines = group_boxes_to_lines(dt_boxes) + + # return a generator if streaming + if streaming: + + def generate_result(): + for boxes in lines: + img_crop_list = [] + for bno in range(len(boxes)): + xmin, ymin, xmax, ymax = copy.deepcopy(boxes[bno]) + img_crop = ori_im[ymin:ymax, xmin:xmax] + if any([dim <= 0 for dim in img_crop.shape[:2]]): + continue + img_crop_list.append(img_crop) + rec_res, _ = self.ocr.text_recognizer(img_crop_list) + + line_text = "" + for rec_result in rec_res: + text, score = rec_result + if score >= self.ocr.drop_score: + line_text += text + " " + + yield line_text + "\n" + + return generate_result() + + # non-streaming + full_result = "" + for boxes in lines: + img_crop_list = [] + for bno in range(len(boxes)): + xmin, ymin, xmax, ymax = copy.deepcopy(boxes[bno]) + img_crop = ori_im[ymin:ymax, xmin:xmax] + if any([dim <= 0 for dim in img_crop.shape[:2]]): + continue + img_crop_list.append(img_crop) + rec_res, _ = self.ocr.text_recognizer(img_crop_list) + + line_text = "" + for rec_result in rec_res: + text, score = rec_result + if score >= self.ocr.drop_score: + line_text += text + " " + + full_result += line_text + "\n" + + return full_result + + +ocr_engine = OCREngine() diff --git a/llama_assistant/processing_thread.py b/llama_assistant/processing_thread.py index 3cc7a8b..ce1a492 100644 --- a/llama_assistant/processing_thread.py +++ b/llama_assistant/processing_thread.py @@ -4,6 +4,7 @@ pyqtSignal, ) from llama_assistant.model_handler import handler as model_handler +from llama_assistant.ocr_engine import ocr_engine class ProcessingThread(QThread): @@ -19,6 +20,7 @@ def __init__( prompt: str, lookup_files: Optional[Set[str]] = None, image: str = None, + ocr_img_path: str = None, ): super().__init__() self.model = model @@ -27,9 +29,17 @@ def __init__( self.prompt = prompt self.image = image self.lookup_files = lookup_files - self.preloadong = False + self.preloading = False + self.ocr_img_path = ocr_img_path def run(self): + if self.ocr_img_path: + self.set_preloading(True, "Thinking ....") + ocr_output = ocr_engine.perform_ocr(self.ocr_img_path, streaming=False) + ocr_output = f"Here is the OCR result:\n{ocr_output}\n" + self.prompt = ocr_output + self.prompt + print("Prompt with OCR context:", self.prompt) + output = model_handler.chat_completion( self.model, self.generation_setting, @@ -65,3 +75,46 @@ def set_preloading(self, preloading: bool, message: str): def is_preloading(self): return self.preloading + + +class OCRThread(QThread): + preloader_signal = pyqtSignal(str) + update_signal = pyqtSignal(str) + finished_signal = pyqtSignal() + + def __init__(self, img_path: str, streaming: bool = False): + super().__init__() + self.img_path = img_path + self.preloading = False + self.streaming = streaming + self.is_ocr_done = False + + def emit_preloading_message(self, message: str): + self.preloader_signal.emit(message) + + def set_preloading(self, preloading: bool, message: str): + self.preloading = preloading + self.emit_preloading_message(message) + + def is_preloading(self): + return self.preloading + + def run(self): + output = ocr_engine.perform_ocr( + self.img_path, streaming=self.streaming, processing_thread=self + ) + full_response_str = "Here is the OCR result:\n" + self.is_ocr_done = True + + if not self.streaming and type(output) == str: + self.update_signal.emit(full_response_str + output) + return + + self.update_signal.emit(full_response_str) + for chunk in output: + self.update_signal.emit(chunk) + full_response_str += chunk + + model_handler.update_chat_history("OCR this image:", "user") + model_handler.update_chat_history(full_response_str, "assistant") + self.finished_signal.emit() diff --git a/llama_assistant/screen_capture_widget.py b/llama_assistant/screen_capture_widget.py new file mode 100644 index 0000000..ed351cf --- /dev/null +++ b/llama_assistant/screen_capture_widget.py @@ -0,0 +1,169 @@ +from typing import TYPE_CHECKING +from PyQt5.QtWidgets import QApplication, QWidget, QDesktopWidget, QPushButton +from PyQt5.QtCore import Qt, QRect +from PyQt5.QtGui import QPainter, QColor, QPen + +from llama_assistant import config +from llama_assistant.ocr_engine import OCREngine + +if TYPE_CHECKING: + from llama_assistant.llama_assistant_app import LlamaAssistantApp + + +class ScreenCaptureWidget(QWidget): + def __init__(self, parent: "LlamaAssistantApp"): + super().__init__() + self.setWindowFlags(Qt.FramelessWindowHint) + + self.parent = parent + self.ocr_engine = OCREngine() + + # Get screen size + screen = QDesktopWidget().screenGeometry() + self.setGeometry(0, 0, screen.width(), screen.height()) + + # Set crosshairs cursor + self.setCursor(Qt.CrossCursor) + + # To store the start and end points of the mouse region + self.start_point = None + self.end_point = None + + # Buttons to appear after selection + self.button_widget = QWidget() + self.ocr_button = QPushButton("OCR", self.button_widget) + self.ask_button = QPushButton("Ask", self.button_widget) + self.ocr_button.setCursor(Qt.PointingHandCursor) + self.ask_button.setCursor(Qt.PointingHandCursor) + opacity = self.parent.settings.get("transparency", 90) / 100 + base_style = f""" + border: none; + border-radius: 20px; + color: white; + padding: 10px 15px; + font-size: 16px; + """ + button_style = f""" + QPushButton {{ + {base_style} + padding: 2.5px 5px; + border-radius: 5px; + background-color: rgba{QColor(self.parent.settings["color"]).lighter(120).getRgb()[:3] + (opacity,)}; + }} + """ + self.ocr_button.setStyleSheet(button_style) + self.ask_button.setStyleSheet(button_style) + self.button_widget.hide() + + # Connect button signals + self.ocr_button.clicked.connect(self.parent.on_ocr_button_clicked) + self.ask_button.clicked.connect(self.parent.on_ask_with_ocr_context) + + def show(self): + # remove painting if any + self.start_point = None + self.end_point = None + self.update() + + # Set window opacity to 50% + self.setWindowOpacity(0.5) + # self.setAttribute(Qt.WA_TranslucentBackground, True) + super().show() + + def hide(self): + self.button_widget.hide() + super().hide() + + def mousePressEvent(self, event): + if event.button() == Qt.LeftButton: + self.start_point = event.pos() # Capture start position + self.end_point = event.pos() # Initialize end point to start position + print(f"Mouse press at {self.start_point}") + + def mouseReleaseEvent(self, event): + if event.button() == Qt.LeftButton: + self.end_point = event.pos() # Capture end position + + print(f"Mouse release at {self.end_point}") + + # Capture the region between start and end points + if self.start_point and self.end_point: + self.capture_region(self.start_point, self.end_point) + + # Trigger repaint to show the red rectangle + self.update() + + self.show_buttons() + + def mouseMoveEvent(self, event): + if self.start_point: + # Update the end_point to the current mouse position as it moves + self.end_point = event.pos() + + # Trigger repaint to update the rectangle + self.update() + + def capture_region(self, start_point, end_point): + # Convert local widget coordinates to global screen coordinates + start_global = self.mapToGlobal(start_point) + end_global = self.mapToGlobal(end_point) + + # Create a QRect from the global start and end points + region_rect = QRect(start_global, end_global) + + # Ensure the rectangle is valid (non-negative width/height) + region_rect = region_rect.normalized() + + # Capture the screen region + screen = QApplication.primaryScreen() + pixmap = screen.grabWindow( + 0, region_rect.x(), region_rect.y(), region_rect.width(), region_rect.height() + ) + + # Save the captured region as an image + pixmap.save(str(config.ocr_tmp_file), "PNG") + print(f"Captured region saved at '{config.ocr_tmp_file}'.") + + def paintEvent(self, event): + # If the start and end points are set, draw the rectangle + if self.start_point and self.end_point: + # Create a painter object + painter = QPainter(self) + + # Set the pen color to red + pen = QPen(QColor(255, 0, 0)) # Red color + pen.setWidth(3) # Set width of the border + painter.setPen(pen) + + # Draw the rectangle from start_point to end_point + self.region_rect = QRect(self.start_point, self.end_point) + self.region_rect = ( + self.region_rect.normalized() + ) # Normalize to ensure correct width/height + + painter.drawRect(self.region_rect) # Draw the rectangle + + super().paintEvent(event) # Call the base class paintEvent + + def show_buttons(self): + if self.start_point and self.end_point: + # Get normalized rectangle + rect = QRect(self.start_point, self.end_point).normalized() + + # Calculate button positions + button_y = rect.bottom() + 10 # Place buttons below the rectangle + button_width = 80 + button_height = 30 + spacing = 10 + + print("Showing buttons") + + self.ocr_button.setGeometry(0, 0, button_width, button_height) + self.ask_button.setGeometry(button_width + spacing, 0, button_width, button_height) + + self.button_widget.setGeometry( + rect.left(), button_y, 2 * button_width + spacing, button_height + ) + self.button_widget.setAttribute(Qt.WA_TranslucentBackground) + self.button_widget.setWindowFlags(Qt.FramelessWindowHint) + self.button_widget.show() diff --git a/llama_assistant/setting_dialog.py b/llama_assistant/setting_dialog.py index 9a5ae12..843d748 100644 --- a/llama_assistant/setting_dialog.py +++ b/llama_assistant/setting_dialog.py @@ -24,6 +24,7 @@ from llama_assistant import config from llama_assistant.setting_validator import validate_numeric_field + class SettingsDialog(QDialog): settingsSaved = pyqtSignal() @@ -218,56 +219,76 @@ def create_rag_settings_group(self): self.main_layout.addWidget(group_box) def accept(self): - valid, message = validate_numeric_field("Context Length", self.context_len_input.text(), - constraints=config.VALIDATOR['generation']['context_len']) + valid, message = validate_numeric_field( + "Context Length", + self.context_len_input.text(), + constraints=config.VALIDATOR["generation"]["context_len"], + ) if not valid: QMessageBox.warning(self, "Validation Error", message) return - - valid, message = validate_numeric_field("Temperature", self.temperature_input.text(), - constraints=config.VALIDATOR['generation']['temperature']) + + valid, message = validate_numeric_field( + "Temperature", + self.temperature_input.text(), + constraints=config.VALIDATOR["generation"]["temperature"], + ) if not valid: QMessageBox.warning(self, "Validation Error", message) return - - valid, message = validate_numeric_field("Top p", self.top_p_input.text(), - constraints=config.VALIDATOR['generation']['top_p']) + + valid, message = validate_numeric_field( + "Top p", self.top_p_input.text(), constraints=config.VALIDATOR["generation"]["top_p"] + ) if not valid: QMessageBox.warning(self, "Validation Error", message) return - - valid, message = validate_numeric_field("Top k", self.top_k_input.text(), - constraints=config.VALIDATOR['generation']['top_k']) + + valid, message = validate_numeric_field( + "Top k", self.top_k_input.text(), constraints=config.VALIDATOR["generation"]["top_k"] + ) if not valid: QMessageBox.warning(self, "Validation Error", message) return - - valid, message = validate_numeric_field("Chunk Size", self.chunk_size_input.text(), - constraints=config.VALIDATOR['rag']['chunk_size']) + + valid, message = validate_numeric_field( + "Chunk Size", + self.chunk_size_input.text(), + constraints=config.VALIDATOR["rag"]["chunk_size"], + ) if not valid: QMessageBox.warning(self, "Validation Error", message) return - - valid, message = validate_numeric_field("Chunk Overlap", self.chunk_overlap_input.text(), - constraints=config.VALIDATOR['rag']['chunk_overlap']) + + valid, message = validate_numeric_field( + "Chunk Overlap", + self.chunk_overlap_input.text(), + constraints=config.VALIDATOR["rag"]["chunk_overlap"], + ) if not valid: QMessageBox.warning(self, "Validation Error", message) return - - valid, message = validate_numeric_field("Max Retrieval Top k", self.max_retrieval_top_k_input.text(), - constraints=config.VALIDATOR['rag']['max_retrieval_top_k']) - + + valid, message = validate_numeric_field( + "Max Retrieval Top k", + self.max_retrieval_top_k_input.text(), + constraints=config.VALIDATOR["rag"]["max_retrieval_top_k"], + ) + if not valid: QMessageBox.warning(self, "Validation Error", message) return - - valid, message = validate_numeric_field("Similarity Threshold", self.similarity_threshold_input.text(), - constraints=config.VALIDATOR['rag']['similarity_threshold']) - + + valid, message = validate_numeric_field( + "Similarity Threshold", + self.similarity_threshold_input.text(), + constraints=config.VALIDATOR["rag"]["similarity_threshold"], + ) + if not valid: QMessageBox.warning(self, "Validation Error", message) return - + self.save_settings() self.settingsSaved.emit() super().accept() @@ -317,29 +338,66 @@ def load_settings(self): if "rag" not in settings: settings["rag"] = {} - embed_model = settings["rag"].get("embed_model_name", config.DEFAULT_SETTINGS['rag']['embed_model_name']) + embed_model = settings["rag"].get( + "embed_model_name", config.DEFAULT_SETTINGS["rag"]["embed_model_name"] + ) if embed_model in config.DEFAULT_EMBEDING_MODELS: self.embed_model_combo.setCurrentText(embed_model) - - self.chunk_size_input.setText(str(settings["rag"].get("chunk_size", config.DEFAULT_SETTINGS['rag']['chunk_size']))) + + self.chunk_size_input.setText( + str(settings["rag"].get("chunk_size", config.DEFAULT_SETTINGS["rag"]["chunk_size"])) + ) self.chunk_overlap_input.setText( - str(settings["rag"].get("chunk_overlap", config.DEFAULT_SETTINGS['rag']['chunk_overlap'])) + str( + settings["rag"].get( + "chunk_overlap", config.DEFAULT_SETTINGS["rag"]["chunk_overlap"] + ) + ) ) self.max_retrieval_top_k_input.setText( - str(settings["rag"].get("max_retrieval_top_k", config.DEFAULT_SETTINGS['rag']['max_retrieval_top_k'])) + str( + settings["rag"].get( + "max_retrieval_top_k", config.DEFAULT_SETTINGS["rag"]["max_retrieval_top_k"] + ) + ) ) self.similarity_threshold_input.setText( - str(settings["rag"].get("similarity_threshold", config.DEFAULT_SETTINGS['rag']['similarity_threshold'])) + str( + settings["rag"].get( + "similarity_threshold", + config.DEFAULT_SETTINGS["rag"]["similarity_threshold"], + ) + ) ) self.context_len_input.setText( - str(settings["generation"].get("context_len", config.DEFAULT_SETTINGS['generation']['context_len'])) + str( + settings["generation"].get( + "context_len", config.DEFAULT_SETTINGS["generation"]["context_len"] + ) + ) ) - + self.temperature_input.setText( - str(settings["generation"].get("temperature", config.DEFAULT_SETTINGS['generation']['temperature'])) + str( + settings["generation"].get( + "temperature", config.DEFAULT_SETTINGS["generation"]["temperature"] + ) + ) + ) + self.top_p_input.setText( + str( + settings["generation"].get( + "top_p", config.DEFAULT_SETTINGS["generation"]["top_p"] + ) + ) + ) + self.top_k_input.setText( + str( + settings["generation"].get( + "top_k", config.DEFAULT_SETTINGS["generation"]["top_k"] + ) + ) ) - self.top_p_input.setText(str(settings["generation"].get("top_p", config.DEFAULT_SETTINGS['generation']['top_p']))) - self.top_k_input.setText(str(settings["generation"].get("top_k", config.DEFAULT_SETTINGS['generation']['top_k']))) else: self.color = QColor("#1E1E1E") self.shortcut_recorder.setText("++") diff --git a/llama_assistant/setting_validator.py b/llama_assistant/setting_validator.py index 19cdcb7..bb45a5e 100644 --- a/llama_assistant/setting_validator.py +++ b/llama_assistant/setting_validator.py @@ -1,10 +1,9 @@ - def validate_numeric_field(name, value_str, constraints): - type = constraints['type'] - min = constraints.get('min') - max = constraints.get('max') - - if type == 'float': + type = constraints["type"] + min = constraints.get("min") + max = constraints.get("max") + + if type == "float": if isinstance(value_str, float): value = value_str else: @@ -13,8 +12,8 @@ def validate_numeric_field(name, value_str, constraints): except ValueError: message = f"Invalid value for {name}. Expected a float, got {value_str}" return False, message - - elif type == 'int': + + elif type == "int": if isinstance(value_str, int): value = value_str else: @@ -23,11 +22,13 @@ def validate_numeric_field(name, value_str, constraints): except ValueError: message = f"Invalid value for {name}. Expected an integer, got {value_str}" return False, message - + if min is not None and value < min: message = f"Invalid value for {name}. Expected a value greater than or equal to {min}, got {value}" return False, message if max is not None and value > max: - message = f"Invalid value for {name}. Expected a value less than or equal to {max}, got {value}" + message = ( + f"Invalid value for {name}. Expected a value less than or equal to {max}, got {value}" + ) return False, message - return True, value \ No newline at end of file + return True, value diff --git a/llama_assistant/ui_manager.py b/llama_assistant/ui_manager.py index 651f155..15d2b59 100644 --- a/llama_assistant/ui_manager.py +++ b/llama_assistant/ui_manager.py @@ -24,6 +24,7 @@ copy_icon_svg, clear_icon_svg, microphone_icon_svg, + crosshair_icon_svg, ) @@ -87,6 +88,7 @@ def init_ui(self): self.input_field.dropEvent = self.parent.dropEvent input_layout.addWidget(self.input_field) + button_layout = QVBoxLayout() self.mic_button = QPushButton(self.parent) self.mic_button.setIcon(create_icon_from_svg(microphone_icon_svg)) self.mic_button.setIconSize(QSize(24, 24)) @@ -104,7 +106,29 @@ def init_ui(self): } """ ) - input_layout.addWidget(self.mic_button) + button_layout.addWidget(self.mic_button) + + self.screenshot_button = QPushButton(self.parent) + self.screenshot_button.setIcon(create_icon_from_svg(crosshair_icon_svg)) + self.screenshot_button.setIconSize(QSize(24, 24)) + self.screenshot_button.setFixedSize(40, 40) + self.screenshot_button.clicked.connect(self.parent.capture_screenshot) + self.screenshot_button.setToolTip("Screen Spot") + self.screenshot_button.setStyleSheet( + """ + QPushButton { + background-color: rgba(100, 100, 100, 200); + border: none; + border-radius: 20px; + } + QPushButton:hover { + background-color: rgba(100, 100, 100, 230); + } + """ + ) + button_layout.addWidget(self.screenshot_button) + + input_layout.addLayout(button_layout) close_button = QPushButton("×", self.parent) close_button.clicked.connect(self.parent.hide) diff --git a/pyproject.toml b/pyproject.toml index e0a3512..a792bc6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,8 @@ dependencies = [ "llama-index-embeddings-huggingface==0.4.0", "docx2txt==0.8", "mistune==3.0.2", + "paddlepaddle==2.6.2", + "paddleocr==2.9.1", "whispercpp @ git+https://github.com/stlukey/whispercpp.py" ] dynamic = [] diff --git a/requirements.txt b/requirements.txt index 172f12a..f2f417b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,6 @@ llama-index-readers-file==0.4.0 llama-index-embeddings-huggingface==0.4.0 docx2txt==0.8 mistune==3.0.2 +paddlepaddle==2.6.2 +paddleocr==2.9.1 git+https://github.com/stlukey/whispercpp.py