diff --git a/llama_assistant/agent.py b/llama_assistant/agent.py
index 8e6ec17..9489acb 100644
--- a/llama_assistant/agent.py
+++ b/llama_assistant/agent.py
@@ -12,6 +12,7 @@
SYSTEM_PROMPT = {"role": "system", "content": "Generate short and simple response."}
+
def convert_message_list_to_str(messages):
chat_history_str = ""
for message in messages:
@@ -39,9 +40,11 @@ class ChatHistory:
def __init__(self, llm, max_history_size: int, max_output_tokens: int):
self.llm = llm
self.max_output_tokens = max_output_tokens
- self.max_history_size = max_history_size # in tokens
+ self.max_history_size = max_history_size # in tokens
self.max_history_size_in_words = max_history_size * 3 / 4
- self.max_history_size_in_words = self.max_history_size_in_words - 128 # to account for some formatting tokens
+ self.max_history_size_in_words = (
+ self.max_history_size_in_words - 128
+ ) # to account for some formatting tokens
print("Max history size in words:", self.max_history_size_in_words)
self.total_size = 0
self.chat_history = []
@@ -57,7 +60,14 @@ def add_message(self, message: dict):
if self.total_size + new_msg_size > self.max_history_size_in_words:
print("Chat history is too long, summarizing the conversation...")
history_summary = self.llm.create_chat_completion(
- messages= [SYSTEM_PROMPT] + self.chat_history + [{"role": "user", "content": "Briefly summarize the conversation in a few sentences."}],
+ messages=[SYSTEM_PROMPT]
+ + self.chat_history
+ + [
+ {
+ "role": "user",
+ "content": "Briefly summarize the conversation in a few sentences.",
+ }
+ ],
stream=False,
max_tokens=256,
)["choices"][0]["message"]["content"]
@@ -119,8 +129,11 @@ def __init__(
self.search_index = None
self.retriever = None
- self.chat_history = ChatHistory(llm=llm, max_output_tokens=self.max_output_tokens,
- max_history_size=self.max_input_tokens)
+ self.chat_history = ChatHistory(
+ llm=llm,
+ max_output_tokens=self.max_output_tokens,
+ max_history_size=self.max_input_tokens,
+ )
self.lookup_files = set()
self.embed_model = HuggingFaceEmbedding(model_name=rag_setting["embed_model_name"])
@@ -144,9 +157,7 @@ def update_index(self, files: Optional[Set[str]] = set()):
show_progress=True, num_workers=1
)
- self.search_index = VectorStoreIndex.from_documents(
- documents, embed_model=self.embed_model
- )
+ self.search_index = VectorStoreIndex.from_documents(documents, embed_model=self.embed_model)
self.retriever = self.search_index.as_retriever(similarity_top_k=self.retrieval_top_k)
@@ -221,7 +232,15 @@ async def condense_history_to_query(self, ctx: Context, ev: SetupEvent) -> Conde
if len(self.chat_history) > 0 and self.retriever is not None:
standalone_query = self.llm.create_chat_completion(
- messages=[SYSTEM_PROMPT] + self.chat_history.get_chat_history() + [{"role": "user", "content": query_str + "\n Condense this conversation to stand-alone question using only 1 sentence."}],
+ messages=[SYSTEM_PROMPT]
+ + self.chat_history.get_chat_history()
+ + [
+ {
+ "role": "user",
+ "content": query_str
+ + "\n Condense this conversation to stand-alone question using only 1 sentence.",
+ }
+ ],
stream=False,
top_k=self.generation_setting["top_k"],
top_p=self.generation_setting["top_p"],
diff --git a/llama_assistant/config.py b/llama_assistant/config.py
index e602a11..97f1a5c 100644
--- a/llama_assistant/config.py
+++ b/llama_assistant/config.py
@@ -12,11 +12,11 @@
"hey_llama_chat": False,
"hey_llama_mic": False,
"generation": {
- "context_len": 4096,
+ "context_len": 4096,
"max_output_tokens": 1024,
"top_k": 40,
"top_p": 0.95,
- "temperature": 0.2
+ "temperature": 0.2,
},
"rag": {
"embed_model_name": "BAAI/bge-base-en-v1.5",
@@ -28,22 +28,26 @@
}
VALIDATOR = {
- 'generation': {
- 'context_len': {'type': 'int', 'min': 2048},
- 'max_output_tokens': {'type': 'int', 'min': 512, 'max': 2048},
- 'top_k': {'type': 'int', 'min': 1, 'max': 100},
- 'top_p': {'type': 'float', 'min': 0, 'max': 1},
- 'temperature': {'type': 'float', 'min': 0, 'max': 1},
- },
- 'rag': {
- 'chunk_size': {'type': 'int', 'min': 64, 'max': 512},
- 'chunk_overlap': {'type': 'int', 'min': 64, 'max': 256},
- 'max_retrieval_top_k': {'type': 'int', 'min': 1, 'max': 5},
- 'similarity_threshold': {'type': 'float', 'min': 0, 'max': 1},
- }
+ "generation": {
+ "context_len": {"type": "int", "min": 2048},
+ "max_output_tokens": {"type": "int", "min": 512, "max": 2048},
+ "top_k": {"type": "int", "min": 1, "max": 100},
+ "top_p": {"type": "float", "min": 0, "max": 1},
+ "temperature": {"type": "float", "min": 0, "max": 1},
+ },
+ "rag": {
+ "chunk_size": {"type": "int", "min": 64, "max": 512},
+ "chunk_overlap": {"type": "int", "min": 64, "max": 256},
+ "max_retrieval_top_k": {"type": "int", "min": 1, "max": 5},
+ "similarity_threshold": {"type": "float", "min": 0, "max": 1},
+ },
}
-DEFAULT_EMBEDING_MODELS = ["BAAI/bge-small-en-v1.5", "BAAI/bge-base-en-v1.5", "BAAI/bge-large-en-v1.5"]
+DEFAULT_EMBEDING_MODELS = [
+ "BAAI/bge-small-en-v1.5",
+ "BAAI/bge-base-en-v1.5",
+ "BAAI/bge-large-en-v1.5",
+]
DEFAULT_MODELS = [
{
diff --git a/llama_assistant/llama_assistant_app.py b/llama_assistant/llama_assistant_app.py
index 9caa99d..2559d5a 100644
--- a/llama_assistant/llama_assistant_app.py
+++ b/llama_assistant/llama_assistant_app.py
@@ -14,15 +14,9 @@
QVBoxLayout,
QMessageBox,
QSystemTrayIcon,
- QRubberBand
-)
-from PyQt5.QtCore import (
- Qt,
- QPoint,
- QTimer,
- QSize,
- QRect
+ QRubberBand,
)
+from PyQt5.QtCore import Qt, QPoint, QTimer, QSize, QRect
from PyQt5.QtGui import (
QPixmap,
QPainter,
@@ -201,18 +195,20 @@ def on_ocr_button_clicked(self):
self.last_response = ""
self.gen_mark_down = False
-
- self.ui_manager.chat_box.append(f'
You: OCR this captured region')
+
+ self.ui_manager.chat_box.append(
+ f'You: OCR this captured region'
+ )
self.ui_manager.chat_box.append('AI: ')
-
+
self.start_cursor_pos = self.ui_manager.chat_box.textCursor().position()
img_path = config.ocr_tmp_file
if not img_path.exists():
print("No image find for OCR")
- self.ui_manager.chat_box.append('No image found for OCR')
+ self.ui_manager.chat_box.append("No image found for OCR")
return
-
+
self.processing_thread = OCRThread(img_path, streaming=True)
self.processing_thread.preloader_signal.connect(self.indicate_loading)
self.processing_thread.update_signal.connect(self.update_chat_box)
@@ -224,7 +220,6 @@ def on_ask_with_ocr_context(self):
self.screen_capture_widget.hide()
self.has_ocr_context = True
-
def on_submit(self):
message = self.ui_manager.input_field.toPlainText()
if message == "":
@@ -238,7 +233,7 @@ def on_submit(self):
for file_path in self.dropped_files:
self.remove_file_thumbnail(self.file_containers[file_path], file_path)
-
+
return
self.last_response = ""
@@ -316,7 +311,7 @@ def process_image_with_prompt(self, image_path, file_paths, prompt):
prompt,
image=image,
lookup_files=file_paths,
- ocr_img_path=config.ocr_tmp_file if self.has_ocr_context else None
+ ocr_img_path=config.ocr_tmp_file if self.has_ocr_context else None,
)
self.processing_thread.preloader_signal.connect(self.indicate_loading)
self.processing_thread.update_signal.connect(self.update_chat_box)
@@ -346,7 +341,7 @@ def indicate_loading(self, message):
QApplication.processEvents() # Process events to update the UI
time.sleep(0.05)
time.sleep(0.5)
-
+
def update_chat_box(self, text):
self.last_response += text
@@ -360,7 +355,7 @@ def update_chat_box(self, text):
formatted_text = markdown_response
else:
formatted_text = self.last_response.replace("\n", "
") + ""
-
+
self.clear_text_from_start_pos()
cursor = self.ui_manager.chat_box.textCursor()
cursor.insertHtml(formatted_text)
diff --git a/llama_assistant/model_handler.py b/llama_assistant/model_handler.py
index 6b8323e..6c1a112 100644
--- a/llama_assistant/model_handler.py
+++ b/llama_assistant/model_handler.py
@@ -217,7 +217,7 @@ def update_chat_history(self, message: str, role: str):
if self.loaded_agent is None:
print("Agent has not been initialized. Cannot update chat history.")
return
-
+
agent = self.loaded_agent.get("agent")
if agent:
agent.chat_history.add_message({"role": role, "content": message})
diff --git a/llama_assistant/ocr_engine.py b/llama_assistant/ocr_engine.py
index 847dc3d..9ab2244 100644
--- a/llama_assistant/ocr_engine.py
+++ b/llama_assistant/ocr_engine.py
@@ -4,27 +4,28 @@
import numpy as np
import time
+
def group_boxes_to_lines(bboxes, vertical_tolerance=5):
"""
Groups bounding boxes into lines based on vertical alignment.
-
+
Args:
bboxes: List of bounding boxes [(xmin, ymin, xmax, ymax)].
vertical_tolerance: Tolerance for vertical proximity to consider boxes in the same line.
-
+
Returns:
List of lines, where each line is a list of bounding boxes.
"""
# Sort bounding boxes by ymin (top edge)
bboxes = sorted(bboxes, key=lambda bbox: bbox[1])
-
+
lines = []
current_line = []
current_ymin = None
for bbox in bboxes:
xmin, ymin, xmax, ymax = bbox
-
+
# Check if starting a new line or current box is not vertically aligned
if current_ymin is None or ymin > current_ymin + vertical_tolerance:
# Save the current line if not empty
@@ -38,21 +39,22 @@ def group_boxes_to_lines(bboxes, vertical_tolerance=5):
else:
# Add box to the current line
current_line.append(bbox)
-
+
# Add the last line if any
if current_line:
current_line.sort(key=lambda box: box[0])
lines.append(current_line)
-
+
return lines
+
def quad_to_rect(quad_boxes):
"""
Converts a quadrilateral bounding box to a rectangular bounding box.
-
+
Args:
quad_boxes: List of 4 points [(x1, y1), (x2, y2), (x3, y3), (x4, y4)] representing the quadrilateral.
-
+
Returns:
List of rectangular bounding box (xmin, ymin, xmax, ymax).
"""
@@ -70,19 +72,20 @@ def quad_to_rect(quad_boxes):
ymax = np.max(y_coords)
result.append((xmin, ymin, xmax, ymax))
-
+
return result
+
class OCREngine:
def __init__(self):
self.ocr = None
-
+
def load_ocr(self, processign_thread=None):
if self.ocr is None:
if processign_thread:
processign_thread.set_preloading(True, "Initializing OCR ....")
- self.ocr = PaddleOCR(use_angle_cls=True, lang='en')
+ self.ocr = PaddleOCR(use_angle_cls=True, lang="en")
time.sleep(1.2)
if processign_thread:
@@ -90,14 +93,14 @@ def load_ocr(self, processign_thread=None):
def perform_ocr(self, img_path, streaming=False, processing_thread=None):
self.load_ocr(processing_thread)
- img = np.array(Image.open(img_path).convert('RGB'))
+ img = np.array(Image.open(img_path).convert("RGB"))
ori_im = img.copy()
# text detection
dt_boxes, _ = self.ocr.text_detector(img)
if dt_boxes is None:
return None
-
+
img_crop_list = []
dt_boxes = quad_to_rect(dt_boxes)
@@ -106,6 +109,7 @@ def perform_ocr(self, img_path, streaming=False, processing_thread=None):
# return a generator if streaming
if streaming:
+
def generate_result():
for boxes in lines:
img_crop_list = []
@@ -116,17 +120,17 @@ def generate_result():
continue
img_crop_list.append(img_crop)
rec_res, _ = self.ocr.text_recognizer(img_crop_list)
-
+
line_text = ""
for rec_result in rec_res:
text, score = rec_result
if score >= self.ocr.drop_score:
line_text += text + " "
- yield line_text+'\n'
+ yield line_text + "\n"
return generate_result()
-
+
# non-streaming
full_result = ""
for boxes in lines:
@@ -138,15 +142,16 @@ def generate_result():
continue
img_crop_list.append(img_crop)
rec_res, _ = self.ocr.text_recognizer(img_crop_list)
-
+
line_text = ""
for rec_result in rec_res:
text, score = rec_result
if score >= self.ocr.drop_score:
line_text += text + " "
- full_result += line_text + '\n'
+ full_result += line_text + "\n"
return full_result
-
-ocr_engine = OCREngine()
\ No newline at end of file
+
+
+ocr_engine = OCREngine()
diff --git a/llama_assistant/processing_thread.py b/llama_assistant/processing_thread.py
index 04bc6c2..ce1a492 100644
--- a/llama_assistant/processing_thread.py
+++ b/llama_assistant/processing_thread.py
@@ -6,6 +6,7 @@
from llama_assistant.model_handler import handler as model_handler
from llama_assistant.ocr_engine import ocr_engine
+
class ProcessingThread(QThread):
preloader_signal = pyqtSignal(str)
update_signal = pyqtSignal(str)
@@ -75,6 +76,7 @@ def set_preloading(self, preloading: bool, message: str):
def is_preloading(self):
return self.preloading
+
class OCRThread(QThread):
preloader_signal = pyqtSignal(str)
update_signal = pyqtSignal(str)
@@ -96,16 +98,18 @@ def set_preloading(self, preloading: bool, message: str):
def is_preloading(self):
return self.preloading
-
+
def run(self):
- output = ocr_engine.perform_ocr(self.img_path, streaming=self.streaming, processing_thread=self)
+ output = ocr_engine.perform_ocr(
+ self.img_path, streaming=self.streaming, processing_thread=self
+ )
full_response_str = "Here is the OCR result:\n"
self.is_ocr_done = True
if not self.streaming and type(output) == str:
self.update_signal.emit(full_response_str + output)
return
-
+
self.update_signal.emit(full_response_str)
for chunk in output:
self.update_signal.emit(chunk)
@@ -114,4 +118,3 @@ def run(self):
model_handler.update_chat_history("OCR this image:", "user")
model_handler.update_chat_history(full_response_str, "assistant")
self.finished_signal.emit()
-
\ No newline at end of file
diff --git a/llama_assistant/screen_capture_widget.py b/llama_assistant/screen_capture_widget.py
index 8523578..ed351cf 100644
--- a/llama_assistant/screen_capture_widget.py
+++ b/llama_assistant/screen_capture_widget.py
@@ -5,6 +5,7 @@
from llama_assistant import config
from llama_assistant.ocr_engine import OCREngine
+
if TYPE_CHECKING:
from llama_assistant.llama_assistant_app import LlamaAssistantApp
@@ -13,17 +14,17 @@ class ScreenCaptureWidget(QWidget):
def __init__(self, parent: "LlamaAssistantApp"):
super().__init__()
self.setWindowFlags(Qt.FramelessWindowHint)
-
+
self.parent = parent
self.ocr_engine = OCREngine()
-
+
# Get screen size
screen = QDesktopWidget().screenGeometry()
self.setGeometry(0, 0, screen.width(), screen.height())
-
+
# Set crosshairs cursor
self.setCursor(Qt.CrossCursor)
-
+
# To store the start and end points of the mouse region
self.start_point = None
self.end_point = None
@@ -73,71 +74,73 @@ def hide(self):
self.button_widget.hide()
super().hide()
-
def mousePressEvent(self, event):
if event.button() == Qt.LeftButton:
self.start_point = event.pos() # Capture start position
- self.end_point = event.pos() # Initialize end point to start position
+ self.end_point = event.pos() # Initialize end point to start position
print(f"Mouse press at {self.start_point}")
-
+
def mouseReleaseEvent(self, event):
if event.button() == Qt.LeftButton:
self.end_point = event.pos() # Capture end position
-
+
print(f"Mouse release at {self.end_point}")
-
+
# Capture the region between start and end points
if self.start_point and self.end_point:
self.capture_region(self.start_point, self.end_point)
-
+
# Trigger repaint to show the red rectangle
self.update()
self.show_buttons()
-
def mouseMoveEvent(self, event):
if self.start_point:
# Update the end_point to the current mouse position as it moves
self.end_point = event.pos()
-
+
# Trigger repaint to update the rectangle
self.update()
-
+
def capture_region(self, start_point, end_point):
# Convert local widget coordinates to global screen coordinates
start_global = self.mapToGlobal(start_point)
end_global = self.mapToGlobal(end_point)
-
+
# Create a QRect from the global start and end points
region_rect = QRect(start_global, end_global)
-
+
# Ensure the rectangle is valid (non-negative width/height)
region_rect = region_rect.normalized()
-
+
# Capture the screen region
screen = QApplication.primaryScreen()
- pixmap = screen.grabWindow(0, region_rect.x(), region_rect.y(), region_rect.width(), region_rect.height())
+ pixmap = screen.grabWindow(
+ 0, region_rect.x(), region_rect.y(), region_rect.width(), region_rect.height()
+ )
# Save the captured region as an image
pixmap.save(str(config.ocr_tmp_file), "PNG")
print(f"Captured region saved at '{config.ocr_tmp_file}'.")
-
+
def paintEvent(self, event):
# If the start and end points are set, draw the rectangle
if self.start_point and self.end_point:
# Create a painter object
painter = QPainter(self)
-
+
# Set the pen color to red
pen = QPen(QColor(255, 0, 0)) # Red color
pen.setWidth(3) # Set width of the border
painter.setPen(pen)
-
+
# Draw the rectangle from start_point to end_point
self.region_rect = QRect(self.start_point, self.end_point)
- self.region_rect = self.region_rect.normalized() # Normalize to ensure correct width/height
-
+ self.region_rect = (
+ self.region_rect.normalized()
+ ) # Normalize to ensure correct width/height
+
painter.drawRect(self.region_rect) # Draw the rectangle
super().paintEvent(event) # Call the base class paintEvent
@@ -158,9 +161,9 @@ def show_buttons(self):
self.ocr_button.setGeometry(0, 0, button_width, button_height)
self.ask_button.setGeometry(button_width + spacing, 0, button_width, button_height)
- self.button_widget.setGeometry(rect.left(), button_y, 2 * button_width + spacing, button_height)
+ self.button_widget.setGeometry(
+ rect.left(), button_y, 2 * button_width + spacing, button_height
+ )
self.button_widget.setAttribute(Qt.WA_TranslucentBackground)
self.button_widget.setWindowFlags(Qt.FramelessWindowHint)
self.button_widget.show()
-
-
\ No newline at end of file
diff --git a/llama_assistant/setting_dialog.py b/llama_assistant/setting_dialog.py
index 9a5ae12..843d748 100644
--- a/llama_assistant/setting_dialog.py
+++ b/llama_assistant/setting_dialog.py
@@ -24,6 +24,7 @@
from llama_assistant import config
from llama_assistant.setting_validator import validate_numeric_field
+
class SettingsDialog(QDialog):
settingsSaved = pyqtSignal()
@@ -218,56 +219,76 @@ def create_rag_settings_group(self):
self.main_layout.addWidget(group_box)
def accept(self):
- valid, message = validate_numeric_field("Context Length", self.context_len_input.text(),
- constraints=config.VALIDATOR['generation']['context_len'])
+ valid, message = validate_numeric_field(
+ "Context Length",
+ self.context_len_input.text(),
+ constraints=config.VALIDATOR["generation"]["context_len"],
+ )
if not valid:
QMessageBox.warning(self, "Validation Error", message)
return
-
- valid, message = validate_numeric_field("Temperature", self.temperature_input.text(),
- constraints=config.VALIDATOR['generation']['temperature'])
+
+ valid, message = validate_numeric_field(
+ "Temperature",
+ self.temperature_input.text(),
+ constraints=config.VALIDATOR["generation"]["temperature"],
+ )
if not valid:
QMessageBox.warning(self, "Validation Error", message)
return
-
- valid, message = validate_numeric_field("Top p", self.top_p_input.text(),
- constraints=config.VALIDATOR['generation']['top_p'])
+
+ valid, message = validate_numeric_field(
+ "Top p", self.top_p_input.text(), constraints=config.VALIDATOR["generation"]["top_p"]
+ )
if not valid:
QMessageBox.warning(self, "Validation Error", message)
return
-
- valid, message = validate_numeric_field("Top k", self.top_k_input.text(),
- constraints=config.VALIDATOR['generation']['top_k'])
+
+ valid, message = validate_numeric_field(
+ "Top k", self.top_k_input.text(), constraints=config.VALIDATOR["generation"]["top_k"]
+ )
if not valid:
QMessageBox.warning(self, "Validation Error", message)
return
-
- valid, message = validate_numeric_field("Chunk Size", self.chunk_size_input.text(),
- constraints=config.VALIDATOR['rag']['chunk_size'])
+
+ valid, message = validate_numeric_field(
+ "Chunk Size",
+ self.chunk_size_input.text(),
+ constraints=config.VALIDATOR["rag"]["chunk_size"],
+ )
if not valid:
QMessageBox.warning(self, "Validation Error", message)
return
-
- valid, message = validate_numeric_field("Chunk Overlap", self.chunk_overlap_input.text(),
- constraints=config.VALIDATOR['rag']['chunk_overlap'])
+
+ valid, message = validate_numeric_field(
+ "Chunk Overlap",
+ self.chunk_overlap_input.text(),
+ constraints=config.VALIDATOR["rag"]["chunk_overlap"],
+ )
if not valid:
QMessageBox.warning(self, "Validation Error", message)
return
-
- valid, message = validate_numeric_field("Max Retrieval Top k", self.max_retrieval_top_k_input.text(),
- constraints=config.VALIDATOR['rag']['max_retrieval_top_k'])
-
+
+ valid, message = validate_numeric_field(
+ "Max Retrieval Top k",
+ self.max_retrieval_top_k_input.text(),
+ constraints=config.VALIDATOR["rag"]["max_retrieval_top_k"],
+ )
+
if not valid:
QMessageBox.warning(self, "Validation Error", message)
return
-
- valid, message = validate_numeric_field("Similarity Threshold", self.similarity_threshold_input.text(),
- constraints=config.VALIDATOR['rag']['similarity_threshold'])
-
+
+ valid, message = validate_numeric_field(
+ "Similarity Threshold",
+ self.similarity_threshold_input.text(),
+ constraints=config.VALIDATOR["rag"]["similarity_threshold"],
+ )
+
if not valid:
QMessageBox.warning(self, "Validation Error", message)
return
-
+
self.save_settings()
self.settingsSaved.emit()
super().accept()
@@ -317,29 +338,66 @@ def load_settings(self):
if "rag" not in settings:
settings["rag"] = {}
- embed_model = settings["rag"].get("embed_model_name", config.DEFAULT_SETTINGS['rag']['embed_model_name'])
+ embed_model = settings["rag"].get(
+ "embed_model_name", config.DEFAULT_SETTINGS["rag"]["embed_model_name"]
+ )
if embed_model in config.DEFAULT_EMBEDING_MODELS:
self.embed_model_combo.setCurrentText(embed_model)
-
- self.chunk_size_input.setText(str(settings["rag"].get("chunk_size", config.DEFAULT_SETTINGS['rag']['chunk_size'])))
+
+ self.chunk_size_input.setText(
+ str(settings["rag"].get("chunk_size", config.DEFAULT_SETTINGS["rag"]["chunk_size"]))
+ )
self.chunk_overlap_input.setText(
- str(settings["rag"].get("chunk_overlap", config.DEFAULT_SETTINGS['rag']['chunk_overlap']))
+ str(
+ settings["rag"].get(
+ "chunk_overlap", config.DEFAULT_SETTINGS["rag"]["chunk_overlap"]
+ )
+ )
)
self.max_retrieval_top_k_input.setText(
- str(settings["rag"].get("max_retrieval_top_k", config.DEFAULT_SETTINGS['rag']['max_retrieval_top_k']))
+ str(
+ settings["rag"].get(
+ "max_retrieval_top_k", config.DEFAULT_SETTINGS["rag"]["max_retrieval_top_k"]
+ )
+ )
)
self.similarity_threshold_input.setText(
- str(settings["rag"].get("similarity_threshold", config.DEFAULT_SETTINGS['rag']['similarity_threshold']))
+ str(
+ settings["rag"].get(
+ "similarity_threshold",
+ config.DEFAULT_SETTINGS["rag"]["similarity_threshold"],
+ )
+ )
)
self.context_len_input.setText(
- str(settings["generation"].get("context_len", config.DEFAULT_SETTINGS['generation']['context_len']))
+ str(
+ settings["generation"].get(
+ "context_len", config.DEFAULT_SETTINGS["generation"]["context_len"]
+ )
+ )
)
-
+
self.temperature_input.setText(
- str(settings["generation"].get("temperature", config.DEFAULT_SETTINGS['generation']['temperature']))
+ str(
+ settings["generation"].get(
+ "temperature", config.DEFAULT_SETTINGS["generation"]["temperature"]
+ )
+ )
+ )
+ self.top_p_input.setText(
+ str(
+ settings["generation"].get(
+ "top_p", config.DEFAULT_SETTINGS["generation"]["top_p"]
+ )
+ )
+ )
+ self.top_k_input.setText(
+ str(
+ settings["generation"].get(
+ "top_k", config.DEFAULT_SETTINGS["generation"]["top_k"]
+ )
+ )
)
- self.top_p_input.setText(str(settings["generation"].get("top_p", config.DEFAULT_SETTINGS['generation']['top_p'])))
- self.top_k_input.setText(str(settings["generation"].get("top_k", config.DEFAULT_SETTINGS['generation']['top_k'])))
else:
self.color = QColor("#1E1E1E")
self.shortcut_recorder.setText("++")
diff --git a/llama_assistant/setting_validator.py b/llama_assistant/setting_validator.py
index 19cdcb7..bb45a5e 100644
--- a/llama_assistant/setting_validator.py
+++ b/llama_assistant/setting_validator.py
@@ -1,10 +1,9 @@
-
def validate_numeric_field(name, value_str, constraints):
- type = constraints['type']
- min = constraints.get('min')
- max = constraints.get('max')
-
- if type == 'float':
+ type = constraints["type"]
+ min = constraints.get("min")
+ max = constraints.get("max")
+
+ if type == "float":
if isinstance(value_str, float):
value = value_str
else:
@@ -13,8 +12,8 @@ def validate_numeric_field(name, value_str, constraints):
except ValueError:
message = f"Invalid value for {name}. Expected a float, got {value_str}"
return False, message
-
- elif type == 'int':
+
+ elif type == "int":
if isinstance(value_str, int):
value = value_str
else:
@@ -23,11 +22,13 @@ def validate_numeric_field(name, value_str, constraints):
except ValueError:
message = f"Invalid value for {name}. Expected an integer, got {value_str}"
return False, message
-
+
if min is not None and value < min:
message = f"Invalid value for {name}. Expected a value greater than or equal to {min}, got {value}"
return False, message
if max is not None and value > max:
- message = f"Invalid value for {name}. Expected a value less than or equal to {max}, got {value}"
+ message = (
+ f"Invalid value for {name}. Expected a value less than or equal to {max}, got {value}"
+ )
return False, message
- return True, value
\ No newline at end of file
+ return True, value
diff --git a/llama_assistant/ui_manager.py b/llama_assistant/ui_manager.py
index 817e60e..15d2b59 100644
--- a/llama_assistant/ui_manager.py
+++ b/llama_assistant/ui_manager.py
@@ -24,7 +24,7 @@
copy_icon_svg,
clear_icon_svg,
microphone_icon_svg,
- crosshair_icon_svg
+ crosshair_icon_svg,
)