From be5a8168f96a8e38715d1f72125987960c9b2eaf Mon Sep 17 00:00:00 2001 From: Omer Date: Fri, 27 Sep 2024 14:57:27 -0400 Subject: [PATCH] server: refactor server for modularity and readability --- backend/server.py | 227 ---------------------- backend/server/__init__.py | 0 backend/server/server.py | 129 ++++++++++++ backend/server/server_utils.py | 120 ++++++++++++ backend/{ => server}/websocket_manager.py | 0 main.py | 2 +- multi_agents/main.py | 2 +- 7 files changed, 251 insertions(+), 229 deletions(-) delete mode 100644 backend/server.py create mode 100644 backend/server/__init__.py create mode 100644 backend/server/server.py create mode 100644 backend/server/server_utils.py rename backend/{ => server}/websocket_manager.py (100%) diff --git a/backend/server.py b/backend/server.py deleted file mode 100644 index d94849478..000000000 --- a/backend/server.py +++ /dev/null @@ -1,227 +0,0 @@ -import json -import os -import re -import time - -from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect, File, UploadFile, Header -from fastapi.responses import JSONResponse -from fastapi.middleware.cors import CORSMiddleware -from fastapi.staticfiles import StaticFiles -from fastapi.templating import Jinja2Templates -from pydantic import BaseModel - -from backend.utils import write_md_to_pdf, write_md_to_word, write_text_to_md -from backend.websocket_manager import WebSocketManager - -import shutil -from multi_agents.main import run_research_task -from gpt_researcher.document.document import DocumentLoader -from gpt_researcher.master.actions import stream_output - - -class ResearchRequest(BaseModel): - task: str - report_type: str - agent: str - -class ConfigRequest(BaseModel): - ANTHROPIC_API_KEY: str - TAVILY_API_KEY: str - LANGCHAIN_TRACING_V2: str - LANGCHAIN_API_KEY: str - OPENAI_API_KEY: str - DOC_PATH: str - RETRIEVER: str - GOOGLE_API_KEY: str = '' - GOOGLE_CX_KEY: str = '' - BING_API_KEY: str = '' - SEARCHAPI_API_KEY: str = '' - SERPAPI_API_KEY: str = '' - SERPER_API_KEY: str = '' - SEARX_URL: str = '' - -app = FastAPI() - -app.mount("/site", StaticFiles(directory="./frontend"), name="site") -app.mount("/static", StaticFiles(directory="./frontend/static"), name="static") - -templates = Jinja2Templates(directory="./frontend") - -manager = WebSocketManager() - -# Dynamic directory for outputs once first research is run -@app.on_event("startup") -def startup_event(): - os.makedirs("outputs", exist_ok=True) - app.mount("/outputs", StaticFiles(directory="outputs"), name="outputs") - - -@app.get("/") -async def read_root(request: Request): - return templates.TemplateResponse( - "index.html", {"request": request, "report": None} - ) - - -# Add the sanitize_filename function here -def sanitize_filename(filename): - return re.sub(r"[^\w\s-]", "", filename).strip() - -@app.websocket("/ws") -async def websocket_endpoint(websocket: WebSocket): - await manager.connect(websocket) - try: - while True: - data = await websocket.receive_text() - if data.startswith("start"): - json_data = json.loads(data[6:]) - task = json_data.get("task") - report_type = json_data.get("report_type") - source_urls = json_data.get("source_urls") - tone = json_data.get("tone") - headers = json_data.get("headers", {}) - filename = f"task_{int(time.time())}_{task}" - sanitized_filename = sanitize_filename( - filename - ) # Sanitize the filename - report_source = json_data.get("report_source") - if task and report_type: - report = await manager.start_streaming( - task, report_type, report_source, source_urls, tone, websocket, headers - ) - # Ensure report is a string - if not isinstance(report, str): - report = str(report) - - # Saving report as pdf - pdf_path = await write_md_to_pdf(report, sanitized_filename) - # Saving report as docx - docx_path = await write_md_to_word(report, sanitized_filename) - # Returning the path of saved report files - md_path = await write_text_to_md(report, sanitized_filename) - await websocket.send_json( - { - "type": "path", - "output": { - "pdf": pdf_path, - "docx": docx_path, - "md": md_path, - }, - } - ) - elif data.startswith("human_feedback"): - # Handle human feedback - feedback_data = json.loads(data[14:]) # Remove "human_feedback" prefix - # Process the feedback data as needed - # You might want to send this feedback to the appropriate agent or update the research state - print(f"Received human feedback: {feedback_data}") - # You can add logic here to forward the feedback to the appropriate agent or update the research state - else: - print("Error: not enough parameters provided.") - except WebSocketDisconnect: - await manager.disconnect(websocket) - -@app.post("/api/multi_agents") -async def run_multi_agents(): - websocket = manager.active_connections[0] if manager.active_connections else None - if websocket: - report = await run_research_task("Is AI in a hype cycle?", websocket, stream_output) - return {"report": report} - else: - return JSONResponse(status_code=400, content={"message": "No active WebSocket connection"}) - -@app.get("/getConfig") -async def get_config( - langchain_api_key: str = Header(None), - openai_api_key: str = Header(None), - tavily_api_key: str = Header(None), - google_api_key: str = Header(None), - google_cx_key: str = Header(None), - bing_api_key: str = Header(None), - searchapi_api_key: str = Header(None), - serpapi_api_key: str = Header(None), - serper_api_key: str = Header(None), - searx_url: str = Header(None) -): - config = { - "LANGCHAIN_API_KEY": langchain_api_key if langchain_api_key else os.getenv("LANGCHAIN_API_KEY", ""), - "OPENAI_API_KEY": openai_api_key if openai_api_key else os.getenv("OPENAI_API_KEY", ""), - "TAVILY_API_KEY": tavily_api_key if tavily_api_key else os.getenv("TAVILY_API_KEY", ""), - "GOOGLE_API_KEY": google_api_key if google_api_key else os.getenv("GOOGLE_API_KEY", ""), - "GOOGLE_CX_KEY": google_cx_key if google_cx_key else os.getenv("GOOGLE_CX_KEY", ""), - "BING_API_KEY": bing_api_key if bing_api_key else os.getenv("BING_API_KEY", ""), - "SEARCHAPI_API_KEY": searchapi_api_key if searchapi_api_key else os.getenv("SEARCHAPI_API_KEY", ""), - "SERPAPI_API_KEY": serpapi_api_key if serpapi_api_key else os.getenv("SERPAPI_API_KEY", ""), - "SERPER_API_KEY": serper_api_key if serper_api_key else os.getenv("SERPER_API_KEY", ""), - "SEARX_URL": searx_url if searx_url else os.getenv("SEARX_URL", ""), - "LANGCHAIN_TRACING_V2": os.getenv("LANGCHAIN_TRACING_V2", "true"), - "DOC_PATH": os.getenv("DOC_PATH", "./my-docs"), - "RETRIEVER": os.getenv("RETRIEVER", ""), - "EMBEDDING_MODEL": os.getenv("OPENAI_EMBEDDING_MODEL", "") - } - return config - -@app.post("/setConfig") -async def set_config(config: ConfigRequest): - os.environ["ANTHROPIC_API_KEY"] = config.ANTHROPIC_API_KEY - os.environ["TAVILY_API_KEY"] = config.TAVILY_API_KEY - os.environ["LANGCHAIN_TRACING_V2"] = config.LANGCHAIN_TRACING_V2 - os.environ["LANGCHAIN_API_KEY"] = config.LANGCHAIN_API_KEY - os.environ["OPENAI_API_KEY"] = config.OPENAI_API_KEY - os.environ["DOC_PATH"] = config.DOC_PATH - os.environ["RETRIEVER"] = config.RETRIEVER - os.environ["GOOGLE_API_KEY"] = config.GOOGLE_API_KEY - os.environ["GOOGLE_CX_KEY"] = config.GOOGLE_CX_KEY - os.environ["BING_API_KEY"] = config.BING_API_KEY - os.environ["SEARCHAPI_API_KEY"] = config.SEARCHAPI_API_KEY - os.environ["SERPAPI_API_KEY"] = config.SERPAPI_API_KEY - os.environ["SERPER_API_KEY"] = config.SERPER_API_KEY - os.environ["SEARX_URL"] = config.SEARX_URL - return {"message": "Config updated successfully"} - -# Enable CORS for your frontend domain (adjust accordingly) -app.add_middleware( - CORSMiddleware, - allow_origins=["http://localhost:3000"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - - -# Define DOC_PATH -DOC_PATH = os.getenv("DOC_PATH", "./my-docs") -if not os.path.exists(DOC_PATH): - os.makedirs(DOC_PATH) - - -@app.post("/upload/") -async def upload_file(file: UploadFile = File(...)): - file_path = os.path.join(DOC_PATH, file.filename) - with open(file_path, "wb") as buffer: - shutil.copyfileobj(file.file, buffer) - print(f"File uploaded to {file_path}") - - # Load documents after upload - document_loader = DocumentLoader(DOC_PATH) - await document_loader.load() - - return {"filename": file.filename, "path": file_path} - - -@app.get("/files/") -async def list_files(): - files = os.listdir(DOC_PATH) - print(f"Files in {DOC_PATH}: {files}") - return {"files": files} - -@app.delete("/files/{filename}") -async def delete_file(filename: str): - file_path = os.path.join(DOC_PATH, filename) - if os.path.exists(file_path): - os.remove(file_path) - print(f"File deleted: {file_path}") - return {"message": "File deleted successfully"} - else: - print(f"File not found: {file_path}") - return JSONResponse(status_code=404, content={"message": "File not found"}) diff --git a/backend/server/__init__.py b/backend/server/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/server/server.py b/backend/server/server.py new file mode 100644 index 000000000..832525faf --- /dev/null +++ b/backend/server/server.py @@ -0,0 +1,129 @@ +import json +import os +from typing import Dict, List + +from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect, File, UploadFile, Header +from fastapi.responses import JSONResponse +from fastapi.middleware.cors import CORSMiddleware +from fastapi.staticfiles import StaticFiles +from fastapi.templating import Jinja2Templates +from pydantic import BaseModel + +from backend.server.server_utils import generate_report_files +from backend.server.websocket_manager import WebSocketManager +from multi_agents.main import run_research_task +from gpt_researcher.document.document import DocumentLoader +from gpt_researcher.master.actions import stream_output +from backend.server.server_utils import ( + sanitize_filename, handle_start_command, handle_human_feedback, + generate_report_files, send_file_paths, get_config_dict, + update_environment_variables, handle_file_upload, handle_file_deletion, + execute_multi_agents, handle_websocket_communication, extract_command_data +) + +# Models +class ResearchRequest(BaseModel): + task: str + report_type: str + agent: str + +class ConfigRequest(BaseModel): + ANTHROPIC_API_KEY: str + TAVILY_API_KEY: str + LANGCHAIN_TRACING_V2: str + LANGCHAIN_API_KEY: str + OPENAI_API_KEY: str + DOC_PATH: str + RETRIEVER: str + GOOGLE_API_KEY: str = '' + GOOGLE_CX_KEY: str = '' + BING_API_KEY: str = '' + SEARCHAPI_API_KEY: str = '' + SERPAPI_API_KEY: str = '' + SERPER_API_KEY: str = '' + SEARX_URL: str = '' + +# App initialization +app = FastAPI() + +# Static files and templates +app.mount("/site", StaticFiles(directory="./frontend"), name="site") +app.mount("/static", StaticFiles(directory="./frontend/static"), name="static") +templates = Jinja2Templates(directory="./frontend") + +# WebSocket manager +manager = WebSocketManager() + +# Middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["http://localhost:3000"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# Constants +DOC_PATH = os.getenv("DOC_PATH", "./my-docs") + +# Startup event +@app.on_event("startup") +def startup_event(): + os.makedirs("outputs", exist_ok=True) + app.mount("/outputs", StaticFiles(directory="outputs"), name="outputs") + os.makedirs(DOC_PATH, exist_ok=True) + +# Routes +@app.get("/") +async def read_root(request: Request): + return templates.TemplateResponse("index.html", {"request": request, "report": None}) + +@app.get("/getConfig") +async def get_config( + langchain_api_key: str = Header(None), + openai_api_key: str = Header(None), + tavily_api_key: str = Header(None), + google_api_key: str = Header(None), + google_cx_key: str = Header(None), + bing_api_key: str = Header(None), + searchapi_api_key: str = Header(None), + serpapi_api_key: str = Header(None), + serper_api_key: str = Header(None), + searx_url: str = Header(None) +): + return get_config_dict( + langchain_api_key, openai_api_key, tavily_api_key, + google_api_key, google_cx_key, bing_api_key, + searchapi_api_key, serpapi_api_key, serper_api_key, searx_url + ) + +@app.get("/files/") +async def list_files(): + files = os.listdir(DOC_PATH) + print(f"Files in {DOC_PATH}: {files}") + return {"files": files} + +@app.post("/api/multi_agents") +async def run_multi_agents(): + return await execute_multi_agents(manager) + +@app.post("/setConfig") +async def set_config(config: ConfigRequest): + update_environment_variables(config.dict()) + return {"message": "Config updated successfully"} + +@app.post("/upload/") +async def upload_file(file: UploadFile = File(...)): + return await handle_file_upload(file, DOC_PATH) + +@app.delete("/files/{filename}") +async def delete_file(filename: str): + return await handle_file_deletion(filename, DOC_PATH) + +@app.websocket("/ws") +async def websocket_endpoint(websocket: WebSocket): + await manager.connect(websocket) + try: + await handle_websocket_communication(websocket, manager) + except WebSocketDisconnect: + await manager.disconnect(websocket) \ No newline at end of file diff --git a/backend/server/server_utils.py b/backend/server/server_utils.py new file mode 100644 index 000000000..f48ad6052 --- /dev/null +++ b/backend/server/server_utils.py @@ -0,0 +1,120 @@ +import json +import os +import re +import time +import shutil +from typing import Dict, List +from fastapi.responses import JSONResponse +from gpt_researcher.document.document import DocumentLoader +from backend.utils import write_md_to_pdf, write_md_to_word, write_text_to_md # Add this import + + +def sanitize_filename(filename: str) -> str: + return re.sub(r"[^\w\s-]", "", filename).strip() + +async def handle_start_command(websocket, data: str, manager): + json_data = json.loads(data[6:]) + task, report_type, source_urls, tone, headers, report_source = extract_command_data(json_data) + + if not task or not report_type: + print("Error: Missing task or report_type") + return + + sanitized_filename = sanitize_filename(f"task_{int(time.time())}_{task}") + + report = await manager.start_streaming( + task, report_type, report_source, source_urls, tone, websocket, headers + ) + report = str(report) + + file_paths = await generate_report_files(report, sanitized_filename) + await send_file_paths(websocket, file_paths) + +async def handle_human_feedback(data: str): + feedback_data = json.loads(data[14:]) # Remove "human_feedback" prefix + print(f"Received human feedback: {feedback_data}") + # TODO: Add logic to forward the feedback to the appropriate agent or update the research state + +async def generate_report_files(report: str, filename: str) -> Dict[str, str]: + pdf_path = await write_md_to_pdf(report, filename) + docx_path = await write_md_to_word(report, filename) + md_path = await write_text_to_md(report, filename) + return {"pdf": pdf_path, "docx": docx_path, "md": md_path} + +async def send_file_paths(websocket, file_paths: Dict[str, str]): + await websocket.send_json({"type": "path", "output": file_paths}) + +def get_config_dict( + langchain_api_key: str, openai_api_key: str, tavily_api_key: str, + google_api_key: str, google_cx_key: str, bing_api_key: str, + searchapi_api_key: str, serpapi_api_key: str, serper_api_key: str, searx_url: str +) -> Dict[str, str]: + return { + "LANGCHAIN_API_KEY": langchain_api_key or os.getenv("LANGCHAIN_API_KEY", ""), + "OPENAI_API_KEY": openai_api_key or os.getenv("OPENAI_API_KEY", ""), + "TAVILY_API_KEY": tavily_api_key or os.getenv("TAVILY_API_KEY", ""), + "GOOGLE_API_KEY": google_api_key or os.getenv("GOOGLE_API_KEY", ""), + "GOOGLE_CX_KEY": google_cx_key or os.getenv("GOOGLE_CX_KEY", ""), + "BING_API_KEY": bing_api_key or os.getenv("BING_API_KEY", ""), + "SEARCHAPI_API_KEY": searchapi_api_key or os.getenv("SEARCHAPI_API_KEY", ""), + "SERPAPI_API_KEY": serpapi_api_key or os.getenv("SERPAPI_API_KEY", ""), + "SERPER_API_KEY": serper_api_key or os.getenv("SERPER_API_KEY", ""), + "SEARX_URL": searx_url or os.getenv("SEARX_URL", ""), + "LANGCHAIN_TRACING_V2": os.getenv("LANGCHAIN_TRACING_V2", "true"), + "DOC_PATH": os.getenv("DOC_PATH", "./my-docs"), + "RETRIEVER": os.getenv("RETRIEVER", ""), + "EMBEDDING_MODEL": os.getenv("OPENAI_EMBEDDING_MODEL", "") + } + +def update_environment_variables(config: Dict[str, str]): + for key, value in config.items(): + os.environ[key] = value + +async def handle_file_upload(file, DOC_PATH: str) -> Dict[str, str]: + file_path = os.path.join(DOC_PATH, file.filename) + with open(file_path, "wb") as buffer: + shutil.copyfileobj(file.file, buffer) + print(f"File uploaded to {file_path}") + + document_loader = DocumentLoader(DOC_PATH) + await document_loader.load() + + return {"filename": file.filename, "path": file_path} + +async def handle_file_deletion(filename: str, DOC_PATH: str) -> JSONResponse: + file_path = os.path.join(DOC_PATH, filename) + if os.path.exists(file_path): + os.remove(file_path) + print(f"File deleted: {file_path}") + return JSONResponse(content={"message": "File deleted successfully"}) + else: + print(f"File not found: {file_path}") + return JSONResponse(status_code=404, content={"message": "File not found"}) + +async def execute_multi_agents(manager) -> Dict[str, str]: + websocket = manager.active_connections[0] if manager.active_connections else None + if websocket: + report = await run_research_task("Is AI in a hype cycle?", websocket, stream_output) + return {"report": report} + else: + return JSONResponse(status_code=400, content={"message": "No active WebSocket connection"}) + +async def handle_websocket_communication(websocket, manager): + while True: + data = await websocket.receive_text() + if data.startswith("start"): + await handle_start_command(websocket, data, manager) + elif data.startswith("human_feedback"): + await handle_human_feedback(data) + else: + print("Error: Unknown command or not enough parameters provided.") + +def extract_command_data(json_data: Dict) -> tuple: + return ( + json_data.get("task"), + json_data.get("report_type"), + json_data.get("source_urls"), + json_data.get("tone"), + json_data.get("headers", {}), + json_data.get("report_source") + ) \ No newline at end of file diff --git a/backend/websocket_manager.py b/backend/server/websocket_manager.py similarity index 100% rename from backend/websocket_manager.py rename to backend/server/websocket_manager.py diff --git a/main.py b/main.py index 85490af26..0f48c2cba 100644 --- a/main.py +++ b/main.py @@ -2,7 +2,7 @@ load_dotenv() -from backend.server import app +from backend.server.server import app if __name__ == "__main__": import uvicorn diff --git a/multi_agents/main.py b/multi_agents/main.py index 91208f41e..8fda20b99 100644 --- a/multi_agents/main.py +++ b/multi_agents/main.py @@ -24,7 +24,7 @@ def open_task(): task = json.load(f) if not task: - raise Exception("No task provided. Please include a task.json file in the multi_agents directory.") + raise Exception("No task found. Please ensure a valid task.json file is present in the multi_agents directory and contains the necessary task information.") return task