diff --git a/README.md b/README.md
index 4f88cfe..fc1d98c 100644
--- a/README.md
+++ b/README.md
@@ -54,6 +54,7 @@ https://github.com/user-attachments/assets/af2c544b-6d46-4c44-87d8-9a051ba213db
- [x] ⚡ Streaming support for response.
- [x] 🎙️ Add offline STT support: WhisperCPP.
- [x] 🧠 Knowledge database: LlamaIndex
+- [x] ⌖ Screen Spot: Screen capture and analyze with OCR
- [ ] 🔌 Plugin system for extensibility.
- [ ] 📰 News and weather updates.
- [ ] 📧 Email integration with Gmail and Outlook.
diff --git a/llama_assistant/agent.py b/llama_assistant/agent.py
index 8e6ec17..9489acb 100644
--- a/llama_assistant/agent.py
+++ b/llama_assistant/agent.py
@@ -12,6 +12,7 @@
SYSTEM_PROMPT = {"role": "system", "content": "Generate short and simple response."}
+
def convert_message_list_to_str(messages):
chat_history_str = ""
for message in messages:
@@ -39,9 +40,11 @@ class ChatHistory:
def __init__(self, llm, max_history_size: int, max_output_tokens: int):
self.llm = llm
self.max_output_tokens = max_output_tokens
- self.max_history_size = max_history_size # in tokens
+ self.max_history_size = max_history_size # in tokens
self.max_history_size_in_words = max_history_size * 3 / 4
- self.max_history_size_in_words = self.max_history_size_in_words - 128 # to account for some formatting tokens
+ self.max_history_size_in_words = (
+ self.max_history_size_in_words - 128
+ ) # to account for some formatting tokens
print("Max history size in words:", self.max_history_size_in_words)
self.total_size = 0
self.chat_history = []
@@ -57,7 +60,14 @@ def add_message(self, message: dict):
if self.total_size + new_msg_size > self.max_history_size_in_words:
print("Chat history is too long, summarizing the conversation...")
history_summary = self.llm.create_chat_completion(
- messages= [SYSTEM_PROMPT] + self.chat_history + [{"role": "user", "content": "Briefly summarize the conversation in a few sentences."}],
+ messages=[SYSTEM_PROMPT]
+ + self.chat_history
+ + [
+ {
+ "role": "user",
+ "content": "Briefly summarize the conversation in a few sentences.",
+ }
+ ],
stream=False,
max_tokens=256,
)["choices"][0]["message"]["content"]
@@ -119,8 +129,11 @@ def __init__(
self.search_index = None
self.retriever = None
- self.chat_history = ChatHistory(llm=llm, max_output_tokens=self.max_output_tokens,
- max_history_size=self.max_input_tokens)
+ self.chat_history = ChatHistory(
+ llm=llm,
+ max_output_tokens=self.max_output_tokens,
+ max_history_size=self.max_input_tokens,
+ )
self.lookup_files = set()
self.embed_model = HuggingFaceEmbedding(model_name=rag_setting["embed_model_name"])
@@ -144,9 +157,7 @@ def update_index(self, files: Optional[Set[str]] = set()):
show_progress=True, num_workers=1
)
- self.search_index = VectorStoreIndex.from_documents(
- documents, embed_model=self.embed_model
- )
+ self.search_index = VectorStoreIndex.from_documents(documents, embed_model=self.embed_model)
self.retriever = self.search_index.as_retriever(similarity_top_k=self.retrieval_top_k)
@@ -221,7 +232,15 @@ async def condense_history_to_query(self, ctx: Context, ev: SetupEvent) -> Conde
if len(self.chat_history) > 0 and self.retriever is not None:
standalone_query = self.llm.create_chat_completion(
- messages=[SYSTEM_PROMPT] + self.chat_history.get_chat_history() + [{"role": "user", "content": query_str + "\n Condense this conversation to stand-alone question using only 1 sentence."}],
+ messages=[SYSTEM_PROMPT]
+ + self.chat_history.get_chat_history()
+ + [
+ {
+ "role": "user",
+ "content": query_str
+ + "\n Condense this conversation to stand-alone question using only 1 sentence.",
+ }
+ ],
stream=False,
top_k=self.generation_setting["top_k"],
top_p=self.generation_setting["top_p"],
diff --git a/llama_assistant/config.py b/llama_assistant/config.py
index b966ff3..97f1a5c 100644
--- a/llama_assistant/config.py
+++ b/llama_assistant/config.py
@@ -12,11 +12,11 @@
"hey_llama_chat": False,
"hey_llama_mic": False,
"generation": {
- "context_len": 4096,
+ "context_len": 4096,
"max_output_tokens": 1024,
"top_k": 40,
"top_p": 0.95,
- "temperature": 0.2
+ "temperature": 0.2,
},
"rag": {
"embed_model_name": "BAAI/bge-base-en-v1.5",
@@ -28,22 +28,26 @@
}
VALIDATOR = {
- 'generation': {
- 'context_len': {'type': 'int', 'min': 2048},
- 'max_output_tokens': {'type': 'int', 'min': 512, 'max': 2048},
- 'top_k': {'type': 'int', 'min': 1, 'max': 100},
- 'top_p': {'type': 'float', 'min': 0, 'max': 1},
- 'temperature': {'type': 'float', 'min': 0, 'max': 1},
- },
- 'rag': {
- 'chunk_size': {'type': 'int', 'min': 64, 'max': 512},
- 'chunk_overlap': {'type': 'int', 'min': 64, 'max': 256},
- 'max_retrieval_top_k': {'type': 'int', 'min': 1, 'max': 5},
- 'similarity_threshold': {'type': 'float', 'min': 0, 'max': 1},
- }
+ "generation": {
+ "context_len": {"type": "int", "min": 2048},
+ "max_output_tokens": {"type": "int", "min": 512, "max": 2048},
+ "top_k": {"type": "int", "min": 1, "max": 100},
+ "top_p": {"type": "float", "min": 0, "max": 1},
+ "temperature": {"type": "float", "min": 0, "max": 1},
+ },
+ "rag": {
+ "chunk_size": {"type": "int", "min": 64, "max": 512},
+ "chunk_overlap": {"type": "int", "min": 64, "max": 256},
+ "max_retrieval_top_k": {"type": "int", "min": 1, "max": 5},
+ "similarity_threshold": {"type": "float", "min": 0, "max": 1},
+ },
}
-DEFAULT_EMBEDING_MODELS = ["BAAI/bge-small-en-v1.5", "BAAI/bge-base-en-v1.5", "BAAI/bge-large-en-v1.5"]
+DEFAULT_EMBEDING_MODELS = [
+ "BAAI/bge-small-en-v1.5",
+ "BAAI/bge-base-en-v1.5",
+ "BAAI/bge-large-en-v1.5",
+]
DEFAULT_MODELS = [
{
@@ -150,6 +154,7 @@
custom_models_file = llama_assistant_dir / "custom_models.json"
settings_file = llama_assistant_dir / "settings.json"
document_icon = "llama_assistant/resources/document_icon.png"
+ocr_tmp_file = llama_assistant_dir / "ocr_tmp.png"
if custom_models_file.exists():
with open(custom_models_file, "r") as f:
diff --git a/llama_assistant/icons.py b/llama_assistant/icons.py
index c4564d5..4fd5dff 100644
--- a/llama_assistant/icons.py
+++ b/llama_assistant/icons.py
@@ -29,6 +29,16 @@
"""
+crosshair_icon_svg = """
+
+
+
+
+
+
+
+"""
+
def create_icon_from_svg(svg_string):
svg_bytes = QByteArray(svg_string.encode("utf-8"))
diff --git a/llama_assistant/llama_assistant_app.py b/llama_assistant/llama_assistant_app.py
index f6e4d5e..2559d5a 100644
--- a/llama_assistant/llama_assistant_app.py
+++ b/llama_assistant/llama_assistant_app.py
@@ -14,12 +14,9 @@
QVBoxLayout,
QMessageBox,
QSystemTrayIcon,
+ QRubberBand,
)
-from PyQt5.QtCore import (
- Qt,
- QPoint,
- QTimer,
-)
+from PyQt5.QtCore import Qt, QPoint, QTimer, QSize, QRect
from PyQt5.QtGui import (
QPixmap,
QPainter,
@@ -36,9 +33,10 @@
from llama_assistant.setting_dialog import SettingsDialog
from llama_assistant.speech_recognition_thread import SpeechRecognitionThread
from llama_assistant.utils import image_to_base64_data_uri
-from llama_assistant.processing_thread import ProcessingThread
+from llama_assistant.processing_thread import ProcessingThread, OCRThread
from llama_assistant.ui_manager import UIManager
from llama_assistant.tray_manager import TrayManager
+from llama_assistant.screen_capture_widget import ScreenCaptureWidget
from llama_assistant.setting_validator import validate_numeric_field
from llama_assistant.utils import load_image
@@ -50,6 +48,7 @@ def __init__(self):
self.load_settings()
self.ui_manager = UIManager(self)
self.tray_manager = TrayManager(self)
+ self.screen_capture_widget = ScreenCaptureWidget(self)
self.setup_global_shortcut()
self.last_response = ""
self.dropped_image = None
@@ -62,6 +61,12 @@ def __init__(self):
self.current_multimodal_model = self.settings.get("multimodal_model")
self.processing_thread = None
self.markdown_creator = mistune.create_markdown()
+ self.gen_mark_down = True
+ self.has_ocr_context = False
+
+ def capture_screenshot(self):
+ self.hide()
+ self.screen_capture_widget.show()
def tray_icon_activated(self, reason):
if reason == QSystemTrayIcon.ActivationReason.Trigger:
@@ -183,6 +188,38 @@ def toggle_visibility(self):
self.raise_()
self.ui_manager.input_field.setFocus()
+ def on_ocr_button_clicked(self):
+ self.show()
+ self.show_chat_box()
+ self.screen_capture_widget.hide()
+
+ self.last_response = ""
+ self.gen_mark_down = False
+
+ self.ui_manager.chat_box.append(
+ f'
You: OCR this captured region'
+ )
+ self.ui_manager.chat_box.append('AI: ')
+
+ self.start_cursor_pos = self.ui_manager.chat_box.textCursor().position()
+
+ img_path = config.ocr_tmp_file
+ if not img_path.exists():
+ print("No image find for OCR")
+ self.ui_manager.chat_box.append("No image found for OCR")
+ return
+
+ self.processing_thread = OCRThread(img_path, streaming=True)
+ self.processing_thread.preloader_signal.connect(self.indicate_loading)
+ self.processing_thread.update_signal.connect(self.update_chat_box)
+ self.processing_thread.finished_signal.connect(self.on_processing_finished)
+ self.processing_thread.start()
+
+ def on_ask_with_ocr_context(self):
+ self.show()
+ self.screen_capture_widget.hide()
+ self.has_ocr_context = True
+
def on_submit(self):
message = self.ui_manager.input_field.toPlainText()
if message == "":
@@ -200,6 +237,7 @@ def on_submit(self):
return
self.last_response = ""
+ self.gen_mark_down = True
if self.dropped_image:
self.process_image_with_prompt(self.dropped_image, self.dropped_files, message)
@@ -244,6 +282,7 @@ def process_text(self, message, file_paths, task="chat"):
self.rag_setting,
prompt,
lookup_files=file_paths,
+ ocr_img_path=config.ocr_tmp_file if self.has_ocr_context else None,
)
self.processing_thread.preloader_signal.connect(self.indicate_loading)
@@ -251,6 +290,8 @@ def process_text(self, message, file_paths, task="chat"):
self.processing_thread.finished_signal.connect(self.on_processing_finished)
self.processing_thread.start()
+ self.has_ocr_context = False
+
def process_image_with_prompt(self, image_path, file_paths, prompt):
self.show_chat_box()
self.ui_manager.chat_box.append(
@@ -270,20 +311,27 @@ def process_image_with_prompt(self, image_path, file_paths, prompt):
prompt,
image=image,
lookup_files=file_paths,
+ ocr_img_path=config.ocr_tmp_file if self.has_ocr_context else None,
)
self.processing_thread.preloader_signal.connect(self.indicate_loading)
self.processing_thread.update_signal.connect(self.update_chat_box)
self.processing_thread.finished_signal.connect(self.on_processing_finished)
self.processing_thread.start()
+ self.has_ocr_context = False
+
+ def clear_text_from_start_pos(self):
+ cursor = self.ui_manager.chat_box.textCursor()
+ cursor.setPosition(self.start_cursor_pos)
+ # Select all text from the start_pos to the end
+ cursor.movePosition(QTextCursor.End, QTextCursor.KeepAnchor)
+ # Remove the selected text
+ cursor.removeSelectedText()
+
def indicate_loading(self, message):
while self.processing_thread.is_preloading():
+ self.clear_text_from_start_pos()
cursor = self.ui_manager.chat_box.textCursor()
- cursor.setPosition(self.start_cursor_pos)
- # Select all text from the start_pos to the end
- cursor.movePosition(QTextCursor.End, QTextCursor.KeepAnchor)
- # Remove the selected text
- cursor.removeSelectedText()
# create animation where the characters are displayed one by one
for c in message:
if c == " ":
@@ -296,20 +344,21 @@ def indicate_loading(self, message):
def update_chat_box(self, text):
self.last_response += text
- markdown_response = self.markdown_creator(self.last_response)
- # Since cannot change the font size of the h1, h2 tag, we will replace it with h3
- markdown_response = markdown_response.replace("", "").replace(" ", "")
- markdown_response = markdown_response.replace("", "").replace(" ", "")
- markdown_response += "
"
+
+ formatted_text = ""
+ if self.gen_mark_down:
+ markdown_response = self.markdown_creator(self.last_response)
+ # Since cannot change the font size of the h1, h2 tag, we will replace it with h3
+ markdown_response = markdown_response.replace("", "").replace(" ", "")
+ markdown_response = markdown_response.replace("", "").replace(" ", "")
+ markdown_response += "
"
+ formatted_text = markdown_response
+ else:
+ formatted_text = self.last_response.replace("\n", " ") + "
"
+
+ self.clear_text_from_start_pos()
cursor = self.ui_manager.chat_box.textCursor()
- cursor.setPosition(
- self.start_cursor_pos
- ) # regenerate the updated text from the start position
- # Select all text from the start_pos to the end
- cursor.movePosition(QTextCursor.End, QTextCursor.KeepAnchor)
- # Remove the selected text
- cursor.removeSelectedText()
- cursor.insertHtml(markdown_response)
+ cursor.insertHtml(formatted_text)
self.ui_manager.chat_box.verticalScrollBar().setValue(
self.ui_manager.chat_box.verticalScrollBar().maximum()
)
@@ -514,14 +563,6 @@ def remove_image_thumbnail(self):
self.ui_manager.input_field.setPlaceholderText("Ask me anything...")
self.setFixedHeight(self.height() - 110) # Decrease height after removing image
- def mousePressEvent(self, event):
- self.oldPos = event.globalPos()
-
- def mouseMoveEvent(self, event):
- delta = QPoint(event.globalPos() - self.oldPos)
- self.move(self.x() + delta.x(), self.y() + delta.y())
- self.oldPos = event.globalPos()
-
def on_wake_word_detected(self, model_name):
self.show()
self.activateWindow()
diff --git a/llama_assistant/model_handler.py b/llama_assistant/model_handler.py
index c149b18..6c1a112 100644
--- a/llama_assistant/model_handler.py
+++ b/llama_assistant/model_handler.py
@@ -214,6 +214,10 @@ def chat_completion(
return response
def update_chat_history(self, message: str, role: str):
+ if self.loaded_agent is None:
+ print("Agent has not been initialized. Cannot update chat history.")
+ return
+
agent = self.loaded_agent.get("agent")
if agent:
agent.chat_history.add_message({"role": role, "content": message})
diff --git a/llama_assistant/ocr_engine.py b/llama_assistant/ocr_engine.py
new file mode 100644
index 0000000..9ab2244
--- /dev/null
+++ b/llama_assistant/ocr_engine.py
@@ -0,0 +1,157 @@
+from paddleocr import PaddleOCR
+from PIL import Image
+import copy
+import numpy as np
+import time
+
+
+def group_boxes_to_lines(bboxes, vertical_tolerance=5):
+ """
+ Groups bounding boxes into lines based on vertical alignment.
+
+ Args:
+ bboxes: List of bounding boxes [(xmin, ymin, xmax, ymax)].
+ vertical_tolerance: Tolerance for vertical proximity to consider boxes in the same line.
+
+ Returns:
+ List of lines, where each line is a list of bounding boxes.
+ """
+ # Sort bounding boxes by ymin (top edge)
+ bboxes = sorted(bboxes, key=lambda bbox: bbox[1])
+
+ lines = []
+ current_line = []
+ current_ymin = None
+
+ for bbox in bboxes:
+ xmin, ymin, xmax, ymax = bbox
+
+ # Check if starting a new line or current box is not vertically aligned
+ if current_ymin is None or ymin > current_ymin + vertical_tolerance:
+ # Save the current line if not empty
+ if current_line:
+ # Sort current line by xmin
+ current_line.sort(key=lambda box: box[0])
+ lines.append(current_line)
+ # Start a new line
+ current_line = [bbox]
+ current_ymin = ymin
+ else:
+ # Add box to the current line
+ current_line.append(bbox)
+
+ # Add the last line if any
+ if current_line:
+ current_line.sort(key=lambda box: box[0])
+ lines.append(current_line)
+
+ return lines
+
+
+def quad_to_rect(quad_boxes):
+ """
+ Converts a quadrilateral bounding box to a rectangular bounding box.
+
+ Args:
+ quad_boxes: List of 4 points [(x1, y1), (x2, y2), (x3, y3), (x4, y4)] representing the quadrilateral.
+
+ Returns:
+ List of rectangular bounding box (xmin, ymin, xmax, ymax).
+ """
+ result = []
+ for quad_box in quad_boxes:
+ quad_box = quad_box.astype(np.int32)
+ # Extract all x and y coordinates
+ x_coords = quad_box[:, 0]
+ y_coords = quad_box[:, 1]
+
+ # Find the enclosing rectangle
+ xmin = np.min(x_coords)
+ ymin = np.min(y_coords)
+ xmax = np.max(x_coords)
+ ymax = np.max(y_coords)
+
+ result.append((xmin, ymin, xmax, ymax))
+
+ return result
+
+
+class OCREngine:
+ def __init__(self):
+ self.ocr = None
+
+ def load_ocr(self, processign_thread=None):
+ if self.ocr is None:
+ if processign_thread:
+ processign_thread.set_preloading(True, "Initializing OCR ....")
+
+ self.ocr = PaddleOCR(use_angle_cls=True, lang="en")
+ time.sleep(1.2)
+
+ if processign_thread:
+ processign_thread.set_preloading(False, "...")
+
+ def perform_ocr(self, img_path, streaming=False, processing_thread=None):
+ self.load_ocr(processing_thread)
+ img = np.array(Image.open(img_path).convert("RGB"))
+ ori_im = img.copy()
+
+ # text detection
+ dt_boxes, _ = self.ocr.text_detector(img)
+ if dt_boxes is None:
+ return None
+
+ img_crop_list = []
+ dt_boxes = quad_to_rect(dt_boxes)
+
+ # group boxes into lines
+ lines = group_boxes_to_lines(dt_boxes)
+
+ # return a generator if streaming
+ if streaming:
+
+ def generate_result():
+ for boxes in lines:
+ img_crop_list = []
+ for bno in range(len(boxes)):
+ xmin, ymin, xmax, ymax = copy.deepcopy(boxes[bno])
+ img_crop = ori_im[ymin:ymax, xmin:xmax]
+ if any([dim <= 0 for dim in img_crop.shape[:2]]):
+ continue
+ img_crop_list.append(img_crop)
+ rec_res, _ = self.ocr.text_recognizer(img_crop_list)
+
+ line_text = ""
+ for rec_result in rec_res:
+ text, score = rec_result
+ if score >= self.ocr.drop_score:
+ line_text += text + " "
+
+ yield line_text + "\n"
+
+ return generate_result()
+
+ # non-streaming
+ full_result = ""
+ for boxes in lines:
+ img_crop_list = []
+ for bno in range(len(boxes)):
+ xmin, ymin, xmax, ymax = copy.deepcopy(boxes[bno])
+ img_crop = ori_im[ymin:ymax, xmin:xmax]
+ if any([dim <= 0 for dim in img_crop.shape[:2]]):
+ continue
+ img_crop_list.append(img_crop)
+ rec_res, _ = self.ocr.text_recognizer(img_crop_list)
+
+ line_text = ""
+ for rec_result in rec_res:
+ text, score = rec_result
+ if score >= self.ocr.drop_score:
+ line_text += text + " "
+
+ full_result += line_text + "\n"
+
+ return full_result
+
+
+ocr_engine = OCREngine()
diff --git a/llama_assistant/processing_thread.py b/llama_assistant/processing_thread.py
index 3cc7a8b..ce1a492 100644
--- a/llama_assistant/processing_thread.py
+++ b/llama_assistant/processing_thread.py
@@ -4,6 +4,7 @@
pyqtSignal,
)
from llama_assistant.model_handler import handler as model_handler
+from llama_assistant.ocr_engine import ocr_engine
class ProcessingThread(QThread):
@@ -19,6 +20,7 @@ def __init__(
prompt: str,
lookup_files: Optional[Set[str]] = None,
image: str = None,
+ ocr_img_path: str = None,
):
super().__init__()
self.model = model
@@ -27,9 +29,17 @@ def __init__(
self.prompt = prompt
self.image = image
self.lookup_files = lookup_files
- self.preloadong = False
+ self.preloading = False
+ self.ocr_img_path = ocr_img_path
def run(self):
+ if self.ocr_img_path:
+ self.set_preloading(True, "Thinking ....")
+ ocr_output = ocr_engine.perform_ocr(self.ocr_img_path, streaming=False)
+ ocr_output = f"Here is the OCR result:\n{ocr_output}\n"
+ self.prompt = ocr_output + self.prompt
+ print("Prompt with OCR context:", self.prompt)
+
output = model_handler.chat_completion(
self.model,
self.generation_setting,
@@ -65,3 +75,46 @@ def set_preloading(self, preloading: bool, message: str):
def is_preloading(self):
return self.preloading
+
+
+class OCRThread(QThread):
+ preloader_signal = pyqtSignal(str)
+ update_signal = pyqtSignal(str)
+ finished_signal = pyqtSignal()
+
+ def __init__(self, img_path: str, streaming: bool = False):
+ super().__init__()
+ self.img_path = img_path
+ self.preloading = False
+ self.streaming = streaming
+ self.is_ocr_done = False
+
+ def emit_preloading_message(self, message: str):
+ self.preloader_signal.emit(message)
+
+ def set_preloading(self, preloading: bool, message: str):
+ self.preloading = preloading
+ self.emit_preloading_message(message)
+
+ def is_preloading(self):
+ return self.preloading
+
+ def run(self):
+ output = ocr_engine.perform_ocr(
+ self.img_path, streaming=self.streaming, processing_thread=self
+ )
+ full_response_str = "Here is the OCR result:\n"
+ self.is_ocr_done = True
+
+ if not self.streaming and type(output) == str:
+ self.update_signal.emit(full_response_str + output)
+ return
+
+ self.update_signal.emit(full_response_str)
+ for chunk in output:
+ self.update_signal.emit(chunk)
+ full_response_str += chunk
+
+ model_handler.update_chat_history("OCR this image:", "user")
+ model_handler.update_chat_history(full_response_str, "assistant")
+ self.finished_signal.emit()
diff --git a/llama_assistant/screen_capture_widget.py b/llama_assistant/screen_capture_widget.py
new file mode 100644
index 0000000..ed351cf
--- /dev/null
+++ b/llama_assistant/screen_capture_widget.py
@@ -0,0 +1,169 @@
+from typing import TYPE_CHECKING
+from PyQt5.QtWidgets import QApplication, QWidget, QDesktopWidget, QPushButton
+from PyQt5.QtCore import Qt, QRect
+from PyQt5.QtGui import QPainter, QColor, QPen
+
+from llama_assistant import config
+from llama_assistant.ocr_engine import OCREngine
+
+if TYPE_CHECKING:
+ from llama_assistant.llama_assistant_app import LlamaAssistantApp
+
+
+class ScreenCaptureWidget(QWidget):
+ def __init__(self, parent: "LlamaAssistantApp"):
+ super().__init__()
+ self.setWindowFlags(Qt.FramelessWindowHint)
+
+ self.parent = parent
+ self.ocr_engine = OCREngine()
+
+ # Get screen size
+ screen = QDesktopWidget().screenGeometry()
+ self.setGeometry(0, 0, screen.width(), screen.height())
+
+ # Set crosshairs cursor
+ self.setCursor(Qt.CrossCursor)
+
+ # To store the start and end points of the mouse region
+ self.start_point = None
+ self.end_point = None
+
+ # Buttons to appear after selection
+ self.button_widget = QWidget()
+ self.ocr_button = QPushButton("OCR", self.button_widget)
+ self.ask_button = QPushButton("Ask", self.button_widget)
+ self.ocr_button.setCursor(Qt.PointingHandCursor)
+ self.ask_button.setCursor(Qt.PointingHandCursor)
+ opacity = self.parent.settings.get("transparency", 90) / 100
+ base_style = f"""
+ border: none;
+ border-radius: 20px;
+ color: white;
+ padding: 10px 15px;
+ font-size: 16px;
+ """
+ button_style = f"""
+ QPushButton {{
+ {base_style}
+ padding: 2.5px 5px;
+ border-radius: 5px;
+ background-color: rgba{QColor(self.parent.settings["color"]).lighter(120).getRgb()[:3] + (opacity,)};
+ }}
+ """
+ self.ocr_button.setStyleSheet(button_style)
+ self.ask_button.setStyleSheet(button_style)
+ self.button_widget.hide()
+
+ # Connect button signals
+ self.ocr_button.clicked.connect(self.parent.on_ocr_button_clicked)
+ self.ask_button.clicked.connect(self.parent.on_ask_with_ocr_context)
+
+ def show(self):
+ # remove painting if any
+ self.start_point = None
+ self.end_point = None
+ self.update()
+
+ # Set window opacity to 50%
+ self.setWindowOpacity(0.5)
+ # self.setAttribute(Qt.WA_TranslucentBackground, True)
+ super().show()
+
+ def hide(self):
+ self.button_widget.hide()
+ super().hide()
+
+ def mousePressEvent(self, event):
+ if event.button() == Qt.LeftButton:
+ self.start_point = event.pos() # Capture start position
+ self.end_point = event.pos() # Initialize end point to start position
+ print(f"Mouse press at {self.start_point}")
+
+ def mouseReleaseEvent(self, event):
+ if event.button() == Qt.LeftButton:
+ self.end_point = event.pos() # Capture end position
+
+ print(f"Mouse release at {self.end_point}")
+
+ # Capture the region between start and end points
+ if self.start_point and self.end_point:
+ self.capture_region(self.start_point, self.end_point)
+
+ # Trigger repaint to show the red rectangle
+ self.update()
+
+ self.show_buttons()
+
+ def mouseMoveEvent(self, event):
+ if self.start_point:
+ # Update the end_point to the current mouse position as it moves
+ self.end_point = event.pos()
+
+ # Trigger repaint to update the rectangle
+ self.update()
+
+ def capture_region(self, start_point, end_point):
+ # Convert local widget coordinates to global screen coordinates
+ start_global = self.mapToGlobal(start_point)
+ end_global = self.mapToGlobal(end_point)
+
+ # Create a QRect from the global start and end points
+ region_rect = QRect(start_global, end_global)
+
+ # Ensure the rectangle is valid (non-negative width/height)
+ region_rect = region_rect.normalized()
+
+ # Capture the screen region
+ screen = QApplication.primaryScreen()
+ pixmap = screen.grabWindow(
+ 0, region_rect.x(), region_rect.y(), region_rect.width(), region_rect.height()
+ )
+
+ # Save the captured region as an image
+ pixmap.save(str(config.ocr_tmp_file), "PNG")
+ print(f"Captured region saved at '{config.ocr_tmp_file}'.")
+
+ def paintEvent(self, event):
+ # If the start and end points are set, draw the rectangle
+ if self.start_point and self.end_point:
+ # Create a painter object
+ painter = QPainter(self)
+
+ # Set the pen color to red
+ pen = QPen(QColor(255, 0, 0)) # Red color
+ pen.setWidth(3) # Set width of the border
+ painter.setPen(pen)
+
+ # Draw the rectangle from start_point to end_point
+ self.region_rect = QRect(self.start_point, self.end_point)
+ self.region_rect = (
+ self.region_rect.normalized()
+ ) # Normalize to ensure correct width/height
+
+ painter.drawRect(self.region_rect) # Draw the rectangle
+
+ super().paintEvent(event) # Call the base class paintEvent
+
+ def show_buttons(self):
+ if self.start_point and self.end_point:
+ # Get normalized rectangle
+ rect = QRect(self.start_point, self.end_point).normalized()
+
+ # Calculate button positions
+ button_y = rect.bottom() + 10 # Place buttons below the rectangle
+ button_width = 80
+ button_height = 30
+ spacing = 10
+
+ print("Showing buttons")
+
+ self.ocr_button.setGeometry(0, 0, button_width, button_height)
+ self.ask_button.setGeometry(button_width + spacing, 0, button_width, button_height)
+
+ self.button_widget.setGeometry(
+ rect.left(), button_y, 2 * button_width + spacing, button_height
+ )
+ self.button_widget.setAttribute(Qt.WA_TranslucentBackground)
+ self.button_widget.setWindowFlags(Qt.FramelessWindowHint)
+ self.button_widget.show()
diff --git a/llama_assistant/setting_dialog.py b/llama_assistant/setting_dialog.py
index 9a5ae12..843d748 100644
--- a/llama_assistant/setting_dialog.py
+++ b/llama_assistant/setting_dialog.py
@@ -24,6 +24,7 @@
from llama_assistant import config
from llama_assistant.setting_validator import validate_numeric_field
+
class SettingsDialog(QDialog):
settingsSaved = pyqtSignal()
@@ -218,56 +219,76 @@ def create_rag_settings_group(self):
self.main_layout.addWidget(group_box)
def accept(self):
- valid, message = validate_numeric_field("Context Length", self.context_len_input.text(),
- constraints=config.VALIDATOR['generation']['context_len'])
+ valid, message = validate_numeric_field(
+ "Context Length",
+ self.context_len_input.text(),
+ constraints=config.VALIDATOR["generation"]["context_len"],
+ )
if not valid:
QMessageBox.warning(self, "Validation Error", message)
return
-
- valid, message = validate_numeric_field("Temperature", self.temperature_input.text(),
- constraints=config.VALIDATOR['generation']['temperature'])
+
+ valid, message = validate_numeric_field(
+ "Temperature",
+ self.temperature_input.text(),
+ constraints=config.VALIDATOR["generation"]["temperature"],
+ )
if not valid:
QMessageBox.warning(self, "Validation Error", message)
return
-
- valid, message = validate_numeric_field("Top p", self.top_p_input.text(),
- constraints=config.VALIDATOR['generation']['top_p'])
+
+ valid, message = validate_numeric_field(
+ "Top p", self.top_p_input.text(), constraints=config.VALIDATOR["generation"]["top_p"]
+ )
if not valid:
QMessageBox.warning(self, "Validation Error", message)
return
-
- valid, message = validate_numeric_field("Top k", self.top_k_input.text(),
- constraints=config.VALIDATOR['generation']['top_k'])
+
+ valid, message = validate_numeric_field(
+ "Top k", self.top_k_input.text(), constraints=config.VALIDATOR["generation"]["top_k"]
+ )
if not valid:
QMessageBox.warning(self, "Validation Error", message)
return
-
- valid, message = validate_numeric_field("Chunk Size", self.chunk_size_input.text(),
- constraints=config.VALIDATOR['rag']['chunk_size'])
+
+ valid, message = validate_numeric_field(
+ "Chunk Size",
+ self.chunk_size_input.text(),
+ constraints=config.VALIDATOR["rag"]["chunk_size"],
+ )
if not valid:
QMessageBox.warning(self, "Validation Error", message)
return
-
- valid, message = validate_numeric_field("Chunk Overlap", self.chunk_overlap_input.text(),
- constraints=config.VALIDATOR['rag']['chunk_overlap'])
+
+ valid, message = validate_numeric_field(
+ "Chunk Overlap",
+ self.chunk_overlap_input.text(),
+ constraints=config.VALIDATOR["rag"]["chunk_overlap"],
+ )
if not valid:
QMessageBox.warning(self, "Validation Error", message)
return
-
- valid, message = validate_numeric_field("Max Retrieval Top k", self.max_retrieval_top_k_input.text(),
- constraints=config.VALIDATOR['rag']['max_retrieval_top_k'])
-
+
+ valid, message = validate_numeric_field(
+ "Max Retrieval Top k",
+ self.max_retrieval_top_k_input.text(),
+ constraints=config.VALIDATOR["rag"]["max_retrieval_top_k"],
+ )
+
if not valid:
QMessageBox.warning(self, "Validation Error", message)
return
-
- valid, message = validate_numeric_field("Similarity Threshold", self.similarity_threshold_input.text(),
- constraints=config.VALIDATOR['rag']['similarity_threshold'])
-
+
+ valid, message = validate_numeric_field(
+ "Similarity Threshold",
+ self.similarity_threshold_input.text(),
+ constraints=config.VALIDATOR["rag"]["similarity_threshold"],
+ )
+
if not valid:
QMessageBox.warning(self, "Validation Error", message)
return
-
+
self.save_settings()
self.settingsSaved.emit()
super().accept()
@@ -317,29 +338,66 @@ def load_settings(self):
if "rag" not in settings:
settings["rag"] = {}
- embed_model = settings["rag"].get("embed_model_name", config.DEFAULT_SETTINGS['rag']['embed_model_name'])
+ embed_model = settings["rag"].get(
+ "embed_model_name", config.DEFAULT_SETTINGS["rag"]["embed_model_name"]
+ )
if embed_model in config.DEFAULT_EMBEDING_MODELS:
self.embed_model_combo.setCurrentText(embed_model)
-
- self.chunk_size_input.setText(str(settings["rag"].get("chunk_size", config.DEFAULT_SETTINGS['rag']['chunk_size'])))
+
+ self.chunk_size_input.setText(
+ str(settings["rag"].get("chunk_size", config.DEFAULT_SETTINGS["rag"]["chunk_size"]))
+ )
self.chunk_overlap_input.setText(
- str(settings["rag"].get("chunk_overlap", config.DEFAULT_SETTINGS['rag']['chunk_overlap']))
+ str(
+ settings["rag"].get(
+ "chunk_overlap", config.DEFAULT_SETTINGS["rag"]["chunk_overlap"]
+ )
+ )
)
self.max_retrieval_top_k_input.setText(
- str(settings["rag"].get("max_retrieval_top_k", config.DEFAULT_SETTINGS['rag']['max_retrieval_top_k']))
+ str(
+ settings["rag"].get(
+ "max_retrieval_top_k", config.DEFAULT_SETTINGS["rag"]["max_retrieval_top_k"]
+ )
+ )
)
self.similarity_threshold_input.setText(
- str(settings["rag"].get("similarity_threshold", config.DEFAULT_SETTINGS['rag']['similarity_threshold']))
+ str(
+ settings["rag"].get(
+ "similarity_threshold",
+ config.DEFAULT_SETTINGS["rag"]["similarity_threshold"],
+ )
+ )
)
self.context_len_input.setText(
- str(settings["generation"].get("context_len", config.DEFAULT_SETTINGS['generation']['context_len']))
+ str(
+ settings["generation"].get(
+ "context_len", config.DEFAULT_SETTINGS["generation"]["context_len"]
+ )
+ )
)
-
+
self.temperature_input.setText(
- str(settings["generation"].get("temperature", config.DEFAULT_SETTINGS['generation']['temperature']))
+ str(
+ settings["generation"].get(
+ "temperature", config.DEFAULT_SETTINGS["generation"]["temperature"]
+ )
+ )
+ )
+ self.top_p_input.setText(
+ str(
+ settings["generation"].get(
+ "top_p", config.DEFAULT_SETTINGS["generation"]["top_p"]
+ )
+ )
+ )
+ self.top_k_input.setText(
+ str(
+ settings["generation"].get(
+ "top_k", config.DEFAULT_SETTINGS["generation"]["top_k"]
+ )
+ )
)
- self.top_p_input.setText(str(settings["generation"].get("top_p", config.DEFAULT_SETTINGS['generation']['top_p'])))
- self.top_k_input.setText(str(settings["generation"].get("top_k", config.DEFAULT_SETTINGS['generation']['top_k'])))
else:
self.color = QColor("#1E1E1E")
self.shortcut_recorder.setText("++")
diff --git a/llama_assistant/setting_validator.py b/llama_assistant/setting_validator.py
index 19cdcb7..bb45a5e 100644
--- a/llama_assistant/setting_validator.py
+++ b/llama_assistant/setting_validator.py
@@ -1,10 +1,9 @@
-
def validate_numeric_field(name, value_str, constraints):
- type = constraints['type']
- min = constraints.get('min')
- max = constraints.get('max')
-
- if type == 'float':
+ type = constraints["type"]
+ min = constraints.get("min")
+ max = constraints.get("max")
+
+ if type == "float":
if isinstance(value_str, float):
value = value_str
else:
@@ -13,8 +12,8 @@ def validate_numeric_field(name, value_str, constraints):
except ValueError:
message = f"Invalid value for {name}. Expected a float, got {value_str}"
return False, message
-
- elif type == 'int':
+
+ elif type == "int":
if isinstance(value_str, int):
value = value_str
else:
@@ -23,11 +22,13 @@ def validate_numeric_field(name, value_str, constraints):
except ValueError:
message = f"Invalid value for {name}. Expected an integer, got {value_str}"
return False, message
-
+
if min is not None and value < min:
message = f"Invalid value for {name}. Expected a value greater than or equal to {min}, got {value}"
return False, message
if max is not None and value > max:
- message = f"Invalid value for {name}. Expected a value less than or equal to {max}, got {value}"
+ message = (
+ f"Invalid value for {name}. Expected a value less than or equal to {max}, got {value}"
+ )
return False, message
- return True, value
\ No newline at end of file
+ return True, value
diff --git a/llama_assistant/ui_manager.py b/llama_assistant/ui_manager.py
index 651f155..15d2b59 100644
--- a/llama_assistant/ui_manager.py
+++ b/llama_assistant/ui_manager.py
@@ -24,6 +24,7 @@
copy_icon_svg,
clear_icon_svg,
microphone_icon_svg,
+ crosshair_icon_svg,
)
@@ -87,6 +88,7 @@ def init_ui(self):
self.input_field.dropEvent = self.parent.dropEvent
input_layout.addWidget(self.input_field)
+ button_layout = QVBoxLayout()
self.mic_button = QPushButton(self.parent)
self.mic_button.setIcon(create_icon_from_svg(microphone_icon_svg))
self.mic_button.setIconSize(QSize(24, 24))
@@ -104,7 +106,29 @@ def init_ui(self):
}
"""
)
- input_layout.addWidget(self.mic_button)
+ button_layout.addWidget(self.mic_button)
+
+ self.screenshot_button = QPushButton(self.parent)
+ self.screenshot_button.setIcon(create_icon_from_svg(crosshair_icon_svg))
+ self.screenshot_button.setIconSize(QSize(24, 24))
+ self.screenshot_button.setFixedSize(40, 40)
+ self.screenshot_button.clicked.connect(self.parent.capture_screenshot)
+ self.screenshot_button.setToolTip("Screen Spot")
+ self.screenshot_button.setStyleSheet(
+ """
+ QPushButton {
+ background-color: rgba(100, 100, 100, 200);
+ border: none;
+ border-radius: 20px;
+ }
+ QPushButton:hover {
+ background-color: rgba(100, 100, 100, 230);
+ }
+ """
+ )
+ button_layout.addWidget(self.screenshot_button)
+
+ input_layout.addLayout(button_layout)
close_button = QPushButton("×", self.parent)
close_button.clicked.connect(self.parent.hide)
diff --git a/pyproject.toml b/pyproject.toml
index e0a3512..a792bc6 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -40,6 +40,8 @@ dependencies = [
"llama-index-embeddings-huggingface==0.4.0",
"docx2txt==0.8",
"mistune==3.0.2",
+ "paddlepaddle==2.6.2",
+ "paddleocr==2.9.1",
"whispercpp @ git+https://github.com/stlukey/whispercpp.py"
]
dynamic = []
diff --git a/requirements.txt b/requirements.txt
index 172f12a..f2f417b 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -13,4 +13,6 @@ llama-index-readers-file==0.4.0
llama-index-embeddings-huggingface==0.4.0
docx2txt==0.8
mistune==3.0.2
+paddlepaddle==2.6.2
+paddleocr==2.9.1
git+https://github.com/stlukey/whispercpp.py