-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #941 from assafelovic/feature/strategic_llm
Feature/strategic llm
- Loading branch information
Showing
10 changed files
with
130 additions
and
59 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,58 +1,113 @@ | ||
import json_repair | ||
from ..utils.llm import create_chat_completion | ||
from ..prompts import generate_search_queries_prompt | ||
from typing import Any | ||
from typing import Any, List, Dict | ||
from ..config import Config | ||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
async def get_sub_queries( | ||
async def get_search_results(query: str, retriever: Any) -> List[Dict[str, Any]]: | ||
""" | ||
Get web search results for a given query. | ||
Args: | ||
query: The search query | ||
retriever: The retriever instance | ||
Returns: | ||
A list of search results | ||
""" | ||
search_retriever = retriever(query) | ||
return search_retriever.search() | ||
|
||
async def generate_sub_queries( | ||
query: str, | ||
retriever: Any, | ||
parent_query: str, | ||
report_type: str, | ||
context: List[Dict[str, Any]], | ||
cfg: Config, | ||
cost_callback: callable = None | ||
) -> List[str]: | ||
""" | ||
Generate sub-queries using the specified LLM model. | ||
Args: | ||
query: The original query | ||
parent_query: The parent query | ||
report_type: The type of report | ||
max_iterations: Maximum number of research iterations | ||
context: Search results context | ||
cfg: Configuration object | ||
cost_callback: Callback for cost calculation | ||
Returns: | ||
A list of sub-queries | ||
""" | ||
gen_queries_prompt = generate_search_queries_prompt( | ||
query, | ||
parent_query, | ||
report_type, | ||
max_iterations=cfg.max_iterations or 1, | ||
context=context | ||
) | ||
|
||
try: | ||
response = await create_chat_completion( | ||
model=cfg.strategic_llm_model, | ||
messages=[{"role": "user", "content": gen_queries_prompt}], | ||
temperature=1, | ||
llm_provider=cfg.strategic_llm_provider, | ||
max_tokens=None, | ||
llm_kwargs=cfg.llm_kwargs, | ||
cost_callback=cost_callback, | ||
) | ||
except Exception as e: | ||
logger.warning(f"Error with strategic LLM: {e}. Falling back to smart LLM.") | ||
response = await create_chat_completion( | ||
model=cfg.smart_llm_model, | ||
messages=[{"role": "user", "content": gen_queries_prompt}], | ||
temperature=cfg.temperature, | ||
max_tokens=cfg.smart_token_limit, | ||
llm_provider=cfg.smart_llm_provider, | ||
llm_kwargs=cfg.llm_kwargs, | ||
cost_callback=cost_callback, | ||
) | ||
|
||
return json_repair.loads(response) | ||
|
||
async def plan_research_outline( | ||
query: str, | ||
search_results: List[Dict[str, Any]], | ||
agent_role_prompt: str, | ||
cfg, | ||
cfg: Config, | ||
parent_query: str, | ||
report_type: str, | ||
cost_callback: callable = None, | ||
): | ||
) -> List[str]: | ||
""" | ||
Gets the sub queries | ||
Plan the research outline by generating sub-queries. | ||
Args: | ||
query: original query | ||
retriever: retriever instance | ||
agent_role_prompt: agent role prompt | ||
cfg: Config | ||
parent_query: parent query | ||
report_type: report type | ||
cost_callback: callback for cost calculation | ||
query: Original query | ||
retriever: Retriever instance | ||
agent_role_prompt: Agent role prompt | ||
cfg: Configuration object | ||
parent_query: Parent query | ||
report_type: Report type | ||
cost_callback: Callback for cost calculation | ||
Returns: | ||
sub_queries: List of sub queries | ||
A list of sub-queries | ||
""" | ||
# Get web search results prior to generating subqueries for improved context around real time data tasks | ||
search_retriever = retriever(query) | ||
search_results = search_retriever.search() | ||
|
||
max_research_iterations = cfg.max_iterations if cfg.max_iterations else 1 | ||
response = await create_chat_completion( | ||
model=cfg.smart_llm_model, | ||
messages=[ | ||
{"role": "system", "content": f"{agent_role_prompt}"}, | ||
{ | ||
"role": "user", | ||
"content": generate_search_queries_prompt( | ||
query, | ||
parent_query, | ||
report_type, | ||
max_iterations=max_research_iterations, | ||
context=search_results | ||
), | ||
}, | ||
], | ||
temperature=0.1, | ||
llm_provider=cfg.smart_llm_provider, | ||
llm_kwargs=cfg.llm_kwargs, | ||
cost_callback=cost_callback, | ||
|
||
sub_queries = await generate_sub_queries( | ||
query, | ||
parent_query, | ||
report_type, | ||
search_results, | ||
cfg, | ||
cost_callback | ||
) | ||
|
||
sub_queries = json_repair.loads(response) | ||
|
||
return sub_queries |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters