diff --git a/README.md b/README.md index 4f88cfe..fc1d98c 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,7 @@ https://github.com/user-attachments/assets/af2c544b-6d46-4c44-87d8-9a051ba213db - [x] ⚡ Streaming support for response. - [x] 🎙️ Add offline STT support: WhisperCPP. - [x] 🧠 Knowledge database: LlamaIndex +- [x] ⌖ Screen Spot: Screen capture and analyze with OCR - [ ] 🔌 Plugin system for extensibility. - [ ] 📰 News and weather updates. - [ ] 📧 Email integration with Gmail and Outlook. diff --git a/llama_assistant/config.py b/llama_assistant/config.py index b966ff3..e602a11 100644 --- a/llama_assistant/config.py +++ b/llama_assistant/config.py @@ -150,6 +150,7 @@ custom_models_file = llama_assistant_dir / "custom_models.json" settings_file = llama_assistant_dir / "settings.json" document_icon = "llama_assistant/resources/document_icon.png" +ocr_tmp_file = llama_assistant_dir / "ocr_tmp.png" if custom_models_file.exists(): with open(custom_models_file, "r") as f: diff --git a/llama_assistant/icons.py b/llama_assistant/icons.py index c4564d5..4fd5dff 100644 --- a/llama_assistant/icons.py +++ b/llama_assistant/icons.py @@ -29,6 +29,16 @@ """ +crosshair_icon_svg = """ + + + + + + + +""" + def create_icon_from_svg(svg_string): svg_bytes = QByteArray(svg_string.encode("utf-8")) diff --git a/llama_assistant/llama_assistant_app.py b/llama_assistant/llama_assistant_app.py index f6e4d5e..9caa99d 100644 --- a/llama_assistant/llama_assistant_app.py +++ b/llama_assistant/llama_assistant_app.py @@ -14,11 +14,14 @@ QVBoxLayout, QMessageBox, QSystemTrayIcon, + QRubberBand ) from PyQt5.QtCore import ( Qt, QPoint, QTimer, + QSize, + QRect ) from PyQt5.QtGui import ( QPixmap, @@ -36,9 +39,10 @@ from llama_assistant.setting_dialog import SettingsDialog from llama_assistant.speech_recognition_thread import SpeechRecognitionThread from llama_assistant.utils import image_to_base64_data_uri -from llama_assistant.processing_thread import ProcessingThread +from llama_assistant.processing_thread import ProcessingThread, OCRThread from llama_assistant.ui_manager import UIManager from llama_assistant.tray_manager import TrayManager +from llama_assistant.screen_capture_widget import ScreenCaptureWidget from llama_assistant.setting_validator import validate_numeric_field from llama_assistant.utils import load_image @@ -50,6 +54,7 @@ def __init__(self): self.load_settings() self.ui_manager = UIManager(self) self.tray_manager = TrayManager(self) + self.screen_capture_widget = ScreenCaptureWidget(self) self.setup_global_shortcut() self.last_response = "" self.dropped_image = None @@ -62,6 +67,12 @@ def __init__(self): self.current_multimodal_model = self.settings.get("multimodal_model") self.processing_thread = None self.markdown_creator = mistune.create_markdown() + self.gen_mark_down = True + self.has_ocr_context = False + + def capture_screenshot(self): + self.hide() + self.screen_capture_widget.show() def tray_icon_activated(self, reason): if reason == QSystemTrayIcon.ActivationReason.Trigger: @@ -183,6 +194,37 @@ def toggle_visibility(self): self.raise_() self.ui_manager.input_field.setFocus() + def on_ocr_button_clicked(self): + self.show() + self.show_chat_box() + self.screen_capture_widget.hide() + + self.last_response = "" + self.gen_mark_down = False + + self.ui_manager.chat_box.append(f'
You: OCR this captured region') + self.ui_manager.chat_box.append('AI: ') + + self.start_cursor_pos = self.ui_manager.chat_box.textCursor().position() + + img_path = config.ocr_tmp_file + if not img_path.exists(): + print("No image find for OCR") + self.ui_manager.chat_box.append('No image found for OCR') + return + + self.processing_thread = OCRThread(img_path, streaming=True) + self.processing_thread.preloader_signal.connect(self.indicate_loading) + self.processing_thread.update_signal.connect(self.update_chat_box) + self.processing_thread.finished_signal.connect(self.on_processing_finished) + self.processing_thread.start() + + def on_ask_with_ocr_context(self): + self.show() + self.screen_capture_widget.hide() + self.has_ocr_context = True + + def on_submit(self): message = self.ui_manager.input_field.toPlainText() if message == "": @@ -196,10 +238,11 @@ def on_submit(self): for file_path in self.dropped_files: self.remove_file_thumbnail(self.file_containers[file_path], file_path) - + return self.last_response = "" + self.gen_mark_down = True if self.dropped_image: self.process_image_with_prompt(self.dropped_image, self.dropped_files, message) @@ -244,6 +287,7 @@ def process_text(self, message, file_paths, task="chat"): self.rag_setting, prompt, lookup_files=file_paths, + ocr_img_path=config.ocr_tmp_file if self.has_ocr_context else None, ) self.processing_thread.preloader_signal.connect(self.indicate_loading) @@ -251,6 +295,8 @@ def process_text(self, message, file_paths, task="chat"): self.processing_thread.finished_signal.connect(self.on_processing_finished) self.processing_thread.start() + self.has_ocr_context = False + def process_image_with_prompt(self, image_path, file_paths, prompt): self.show_chat_box() self.ui_manager.chat_box.append( @@ -270,20 +316,27 @@ def process_image_with_prompt(self, image_path, file_paths, prompt): prompt, image=image, lookup_files=file_paths, + ocr_img_path=config.ocr_tmp_file if self.has_ocr_context else None ) self.processing_thread.preloader_signal.connect(self.indicate_loading) self.processing_thread.update_signal.connect(self.update_chat_box) self.processing_thread.finished_signal.connect(self.on_processing_finished) self.processing_thread.start() + self.has_ocr_context = False + + def clear_text_from_start_pos(self): + cursor = self.ui_manager.chat_box.textCursor() + cursor.setPosition(self.start_cursor_pos) + # Select all text from the start_pos to the end + cursor.movePosition(QTextCursor.End, QTextCursor.KeepAnchor) + # Remove the selected text + cursor.removeSelectedText() + def indicate_loading(self, message): while self.processing_thread.is_preloading(): + self.clear_text_from_start_pos() cursor = self.ui_manager.chat_box.textCursor() - cursor.setPosition(self.start_cursor_pos) - # Select all text from the start_pos to the end - cursor.movePosition(QTextCursor.End, QTextCursor.KeepAnchor) - # Remove the selected text - cursor.removeSelectedText() # create animation where the characters are displayed one by one for c in message: if c == " ": @@ -293,23 +346,24 @@ def indicate_loading(self, message): QApplication.processEvents() # Process events to update the UI time.sleep(0.05) time.sleep(0.5) - + def update_chat_box(self, text): self.last_response += text - markdown_response = self.markdown_creator(self.last_response) - # Since cannot change the font size of the h1, h2 tag, we will replace it with h3 - markdown_response = markdown_response.replace("

", "

").replace("

", "") - markdown_response = markdown_response.replace("

", "

").replace("

", "") - markdown_response += "
" + + formatted_text = "" + if self.gen_mark_down: + markdown_response = self.markdown_creator(self.last_response) + # Since cannot change the font size of the h1, h2 tag, we will replace it with h3 + markdown_response = markdown_response.replace("

", "

").replace("

", "") + markdown_response = markdown_response.replace("

", "

").replace("

", "") + markdown_response += "
" + formatted_text = markdown_response + else: + formatted_text = self.last_response.replace("\n", "
") + "
" + + self.clear_text_from_start_pos() cursor = self.ui_manager.chat_box.textCursor() - cursor.setPosition( - self.start_cursor_pos - ) # regenerate the updated text from the start position - # Select all text from the start_pos to the end - cursor.movePosition(QTextCursor.End, QTextCursor.KeepAnchor) - # Remove the selected text - cursor.removeSelectedText() - cursor.insertHtml(markdown_response) + cursor.insertHtml(formatted_text) self.ui_manager.chat_box.verticalScrollBar().setValue( self.ui_manager.chat_box.verticalScrollBar().maximum() ) @@ -514,14 +568,6 @@ def remove_image_thumbnail(self): self.ui_manager.input_field.setPlaceholderText("Ask me anything...") self.setFixedHeight(self.height() - 110) # Decrease height after removing image - def mousePressEvent(self, event): - self.oldPos = event.globalPos() - - def mouseMoveEvent(self, event): - delta = QPoint(event.globalPos() - self.oldPos) - self.move(self.x() + delta.x(), self.y() + delta.y()) - self.oldPos = event.globalPos() - def on_wake_word_detected(self, model_name): self.show() self.activateWindow() diff --git a/llama_assistant/model_handler.py b/llama_assistant/model_handler.py index c149b18..6b8323e 100644 --- a/llama_assistant/model_handler.py +++ b/llama_assistant/model_handler.py @@ -214,6 +214,10 @@ def chat_completion( return response def update_chat_history(self, message: str, role: str): + if self.loaded_agent is None: + print("Agent has not been initialized. Cannot update chat history.") + return + agent = self.loaded_agent.get("agent") if agent: agent.chat_history.add_message({"role": role, "content": message}) diff --git a/llama_assistant/ocr_engine.py b/llama_assistant/ocr_engine.py new file mode 100644 index 0000000..847dc3d --- /dev/null +++ b/llama_assistant/ocr_engine.py @@ -0,0 +1,152 @@ +from paddleocr import PaddleOCR +from PIL import Image +import copy +import numpy as np +import time + +def group_boxes_to_lines(bboxes, vertical_tolerance=5): + """ + Groups bounding boxes into lines based on vertical alignment. + + Args: + bboxes: List of bounding boxes [(xmin, ymin, xmax, ymax)]. + vertical_tolerance: Tolerance for vertical proximity to consider boxes in the same line. + + Returns: + List of lines, where each line is a list of bounding boxes. + """ + # Sort bounding boxes by ymin (top edge) + bboxes = sorted(bboxes, key=lambda bbox: bbox[1]) + + lines = [] + current_line = [] + current_ymin = None + + for bbox in bboxes: + xmin, ymin, xmax, ymax = bbox + + # Check if starting a new line or current box is not vertically aligned + if current_ymin is None or ymin > current_ymin + vertical_tolerance: + # Save the current line if not empty + if current_line: + # Sort current line by xmin + current_line.sort(key=lambda box: box[0]) + lines.append(current_line) + # Start a new line + current_line = [bbox] + current_ymin = ymin + else: + # Add box to the current line + current_line.append(bbox) + + # Add the last line if any + if current_line: + current_line.sort(key=lambda box: box[0]) + lines.append(current_line) + + return lines + +def quad_to_rect(quad_boxes): + """ + Converts a quadrilateral bounding box to a rectangular bounding box. + + Args: + quad_boxes: List of 4 points [(x1, y1), (x2, y2), (x3, y3), (x4, y4)] representing the quadrilateral. + + Returns: + List of rectangular bounding box (xmin, ymin, xmax, ymax). + """ + result = [] + for quad_box in quad_boxes: + quad_box = quad_box.astype(np.int32) + # Extract all x and y coordinates + x_coords = quad_box[:, 0] + y_coords = quad_box[:, 1] + + # Find the enclosing rectangle + xmin = np.min(x_coords) + ymin = np.min(y_coords) + xmax = np.max(x_coords) + ymax = np.max(y_coords) + + result.append((xmin, ymin, xmax, ymax)) + + return result + +class OCREngine: + def __init__(self): + self.ocr = None + + def load_ocr(self, processign_thread=None): + if self.ocr is None: + if processign_thread: + processign_thread.set_preloading(True, "Initializing OCR ....") + + self.ocr = PaddleOCR(use_angle_cls=True, lang='en') + time.sleep(1.2) + + if processign_thread: + processign_thread.set_preloading(False, "...") + + def perform_ocr(self, img_path, streaming=False, processing_thread=None): + self.load_ocr(processing_thread) + img = np.array(Image.open(img_path).convert('RGB')) + ori_im = img.copy() + + # text detection + dt_boxes, _ = self.ocr.text_detector(img) + if dt_boxes is None: + return None + + img_crop_list = [] + dt_boxes = quad_to_rect(dt_boxes) + + # group boxes into lines + lines = group_boxes_to_lines(dt_boxes) + + # return a generator if streaming + if streaming: + def generate_result(): + for boxes in lines: + img_crop_list = [] + for bno in range(len(boxes)): + xmin, ymin, xmax, ymax = copy.deepcopy(boxes[bno]) + img_crop = ori_im[ymin:ymax, xmin:xmax] + if any([dim <= 0 for dim in img_crop.shape[:2]]): + continue + img_crop_list.append(img_crop) + rec_res, _ = self.ocr.text_recognizer(img_crop_list) + + line_text = "" + for rec_result in rec_res: + text, score = rec_result + if score >= self.ocr.drop_score: + line_text += text + " " + + yield line_text+'\n' + + return generate_result() + + # non-streaming + full_result = "" + for boxes in lines: + img_crop_list = [] + for bno in range(len(boxes)): + xmin, ymin, xmax, ymax = copy.deepcopy(boxes[bno]) + img_crop = ori_im[ymin:ymax, xmin:xmax] + if any([dim <= 0 for dim in img_crop.shape[:2]]): + continue + img_crop_list.append(img_crop) + rec_res, _ = self.ocr.text_recognizer(img_crop_list) + + line_text = "" + for rec_result in rec_res: + text, score = rec_result + if score >= self.ocr.drop_score: + line_text += text + " " + + full_result += line_text + '\n' + + return full_result + +ocr_engine = OCREngine() \ No newline at end of file diff --git a/llama_assistant/processing_thread.py b/llama_assistant/processing_thread.py index 3cc7a8b..04bc6c2 100644 --- a/llama_assistant/processing_thread.py +++ b/llama_assistant/processing_thread.py @@ -4,7 +4,7 @@ pyqtSignal, ) from llama_assistant.model_handler import handler as model_handler - +from llama_assistant.ocr_engine import ocr_engine class ProcessingThread(QThread): preloader_signal = pyqtSignal(str) @@ -19,6 +19,7 @@ def __init__( prompt: str, lookup_files: Optional[Set[str]] = None, image: str = None, + ocr_img_path: str = None, ): super().__init__() self.model = model @@ -27,9 +28,17 @@ def __init__( self.prompt = prompt self.image = image self.lookup_files = lookup_files - self.preloadong = False + self.preloading = False + self.ocr_img_path = ocr_img_path def run(self): + if self.ocr_img_path: + self.set_preloading(True, "Thinking ....") + ocr_output = ocr_engine.perform_ocr(self.ocr_img_path, streaming=False) + ocr_output = f"Here is the OCR result:\n{ocr_output}\n" + self.prompt = ocr_output + self.prompt + print("Prompt with OCR context:", self.prompt) + output = model_handler.chat_completion( self.model, self.generation_setting, @@ -65,3 +74,44 @@ def set_preloading(self, preloading: bool, message: str): def is_preloading(self): return self.preloading + +class OCRThread(QThread): + preloader_signal = pyqtSignal(str) + update_signal = pyqtSignal(str) + finished_signal = pyqtSignal() + + def __init__(self, img_path: str, streaming: bool = False): + super().__init__() + self.img_path = img_path + self.preloading = False + self.streaming = streaming + self.is_ocr_done = False + + def emit_preloading_message(self, message: str): + self.preloader_signal.emit(message) + + def set_preloading(self, preloading: bool, message: str): + self.preloading = preloading + self.emit_preloading_message(message) + + def is_preloading(self): + return self.preloading + + def run(self): + output = ocr_engine.perform_ocr(self.img_path, streaming=self.streaming, processing_thread=self) + full_response_str = "Here is the OCR result:\n" + self.is_ocr_done = True + + if not self.streaming and type(output) == str: + self.update_signal.emit(full_response_str + output) + return + + self.update_signal.emit(full_response_str) + for chunk in output: + self.update_signal.emit(chunk) + full_response_str += chunk + + model_handler.update_chat_history("OCR this image:", "user") + model_handler.update_chat_history(full_response_str, "assistant") + self.finished_signal.emit() + \ No newline at end of file diff --git a/llama_assistant/screen_capture_widget.py b/llama_assistant/screen_capture_widget.py new file mode 100644 index 0000000..8523578 --- /dev/null +++ b/llama_assistant/screen_capture_widget.py @@ -0,0 +1,166 @@ +from typing import TYPE_CHECKING +from PyQt5.QtWidgets import QApplication, QWidget, QDesktopWidget, QPushButton +from PyQt5.QtCore import Qt, QRect +from PyQt5.QtGui import QPainter, QColor, QPen + +from llama_assistant import config +from llama_assistant.ocr_engine import OCREngine +if TYPE_CHECKING: + from llama_assistant.llama_assistant_app import LlamaAssistantApp + + +class ScreenCaptureWidget(QWidget): + def __init__(self, parent: "LlamaAssistantApp"): + super().__init__() + self.setWindowFlags(Qt.FramelessWindowHint) + + self.parent = parent + self.ocr_engine = OCREngine() + + # Get screen size + screen = QDesktopWidget().screenGeometry() + self.setGeometry(0, 0, screen.width(), screen.height()) + + # Set crosshairs cursor + self.setCursor(Qt.CrossCursor) + + # To store the start and end points of the mouse region + self.start_point = None + self.end_point = None + + # Buttons to appear after selection + self.button_widget = QWidget() + self.ocr_button = QPushButton("OCR", self.button_widget) + self.ask_button = QPushButton("Ask", self.button_widget) + self.ocr_button.setCursor(Qt.PointingHandCursor) + self.ask_button.setCursor(Qt.PointingHandCursor) + opacity = self.parent.settings.get("transparency", 90) / 100 + base_style = f""" + border: none; + border-radius: 20px; + color: white; + padding: 10px 15px; + font-size: 16px; + """ + button_style = f""" + QPushButton {{ + {base_style} + padding: 2.5px 5px; + border-radius: 5px; + background-color: rgba{QColor(self.parent.settings["color"]).lighter(120).getRgb()[:3] + (opacity,)}; + }} + """ + self.ocr_button.setStyleSheet(button_style) + self.ask_button.setStyleSheet(button_style) + self.button_widget.hide() + + # Connect button signals + self.ocr_button.clicked.connect(self.parent.on_ocr_button_clicked) + self.ask_button.clicked.connect(self.parent.on_ask_with_ocr_context) + + def show(self): + # remove painting if any + self.start_point = None + self.end_point = None + self.update() + + # Set window opacity to 50% + self.setWindowOpacity(0.5) + # self.setAttribute(Qt.WA_TranslucentBackground, True) + super().show() + + def hide(self): + self.button_widget.hide() + super().hide() + + + def mousePressEvent(self, event): + if event.button() == Qt.LeftButton: + self.start_point = event.pos() # Capture start position + self.end_point = event.pos() # Initialize end point to start position + print(f"Mouse press at {self.start_point}") + + def mouseReleaseEvent(self, event): + if event.button() == Qt.LeftButton: + self.end_point = event.pos() # Capture end position + + print(f"Mouse release at {self.end_point}") + + # Capture the region between start and end points + if self.start_point and self.end_point: + self.capture_region(self.start_point, self.end_point) + + # Trigger repaint to show the red rectangle + self.update() + + self.show_buttons() + + + def mouseMoveEvent(self, event): + if self.start_point: + # Update the end_point to the current mouse position as it moves + self.end_point = event.pos() + + # Trigger repaint to update the rectangle + self.update() + + def capture_region(self, start_point, end_point): + # Convert local widget coordinates to global screen coordinates + start_global = self.mapToGlobal(start_point) + end_global = self.mapToGlobal(end_point) + + # Create a QRect from the global start and end points + region_rect = QRect(start_global, end_global) + + # Ensure the rectangle is valid (non-negative width/height) + region_rect = region_rect.normalized() + + # Capture the screen region + screen = QApplication.primaryScreen() + pixmap = screen.grabWindow(0, region_rect.x(), region_rect.y(), region_rect.width(), region_rect.height()) + + # Save the captured region as an image + pixmap.save(str(config.ocr_tmp_file), "PNG") + print(f"Captured region saved at '{config.ocr_tmp_file}'.") + + def paintEvent(self, event): + # If the start and end points are set, draw the rectangle + if self.start_point and self.end_point: + # Create a painter object + painter = QPainter(self) + + # Set the pen color to red + pen = QPen(QColor(255, 0, 0)) # Red color + pen.setWidth(3) # Set width of the border + painter.setPen(pen) + + # Draw the rectangle from start_point to end_point + self.region_rect = QRect(self.start_point, self.end_point) + self.region_rect = self.region_rect.normalized() # Normalize to ensure correct width/height + + painter.drawRect(self.region_rect) # Draw the rectangle + + super().paintEvent(event) # Call the base class paintEvent + + def show_buttons(self): + if self.start_point and self.end_point: + # Get normalized rectangle + rect = QRect(self.start_point, self.end_point).normalized() + + # Calculate button positions + button_y = rect.bottom() + 10 # Place buttons below the rectangle + button_width = 80 + button_height = 30 + spacing = 10 + + print("Showing buttons") + + self.ocr_button.setGeometry(0, 0, button_width, button_height) + self.ask_button.setGeometry(button_width + spacing, 0, button_width, button_height) + + self.button_widget.setGeometry(rect.left(), button_y, 2 * button_width + spacing, button_height) + self.button_widget.setAttribute(Qt.WA_TranslucentBackground) + self.button_widget.setWindowFlags(Qt.FramelessWindowHint) + self.button_widget.show() + + \ No newline at end of file diff --git a/llama_assistant/ui_manager.py b/llama_assistant/ui_manager.py index 651f155..817e60e 100644 --- a/llama_assistant/ui_manager.py +++ b/llama_assistant/ui_manager.py @@ -24,6 +24,7 @@ copy_icon_svg, clear_icon_svg, microphone_icon_svg, + crosshair_icon_svg ) @@ -87,6 +88,7 @@ def init_ui(self): self.input_field.dropEvent = self.parent.dropEvent input_layout.addWidget(self.input_field) + button_layout = QVBoxLayout() self.mic_button = QPushButton(self.parent) self.mic_button.setIcon(create_icon_from_svg(microphone_icon_svg)) self.mic_button.setIconSize(QSize(24, 24)) @@ -104,7 +106,29 @@ def init_ui(self): } """ ) - input_layout.addWidget(self.mic_button) + button_layout.addWidget(self.mic_button) + + self.screenshot_button = QPushButton(self.parent) + self.screenshot_button.setIcon(create_icon_from_svg(crosshair_icon_svg)) + self.screenshot_button.setIconSize(QSize(24, 24)) + self.screenshot_button.setFixedSize(40, 40) + self.screenshot_button.clicked.connect(self.parent.capture_screenshot) + self.screenshot_button.setToolTip("Screen Spot") + self.screenshot_button.setStyleSheet( + """ + QPushButton { + background-color: rgba(100, 100, 100, 200); + border: none; + border-radius: 20px; + } + QPushButton:hover { + background-color: rgba(100, 100, 100, 230); + } + """ + ) + button_layout.addWidget(self.screenshot_button) + + input_layout.addLayout(button_layout) close_button = QPushButton("×", self.parent) close_button.clicked.connect(self.parent.hide) diff --git a/pyproject.toml b/pyproject.toml index e0a3512..a792bc6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,8 @@ dependencies = [ "llama-index-embeddings-huggingface==0.4.0", "docx2txt==0.8", "mistune==3.0.2", + "paddlepaddle==2.6.2", + "paddleocr==2.9.1", "whispercpp @ git+https://github.com/stlukey/whispercpp.py" ] dynamic = [] diff --git a/requirements.txt b/requirements.txt index 172f12a..f2f417b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,6 @@ llama-index-readers-file==0.4.0 llama-index-embeddings-huggingface==0.4.0 docx2txt==0.8 mistune==3.0.2 +paddlepaddle==2.6.2 +paddleocr==2.9.1 git+https://github.com/stlukey/whispercpp.py