Skip to content

Commit

Permalink
Merge pull request #941 from assafelovic/feature/strategic_llm
Browse files Browse the repository at this point in the history
Feature/strategic llm
  • Loading branch information
assafelovic authored Oct 23, 2024
2 parents e6ef61c + b643d71 commit b0aa661
Show file tree
Hide file tree
Showing 10 changed files with 130 additions and 59 deletions.
1 change: 1 addition & 0 deletions docs/docs/gpt-researcher/gptr/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Below is a list of current supported options:
- **`EMBEDDING`**: Embedding model. Defaults to `openai:text-embedding-3-small`. Options: `ollama`, `huggingface`, `azure_openai`, `custom`.
- **`FAST_LLM`**: Model name for fast LLM operations such summaries. Defaults to `openai:gpt-4o-mini`.
- **`SMART_LLM`**: Model name for smart operations like generating research reports and reasoning. Defaults to `openai:gpt-4o`.
- **`STRATEGIC_LLM`**: Model name for strategic operations like generating research plans and strategies. Defaults to `openai:o1-preview`.
- **`FAST_TOKEN_LIMIT`**: Maximum token limit for fast LLM responses. Defaults to `2000`.
- **`SMART_TOKEN_LIMIT`**: Maximum token limit for smart LLM responses. Defaults to `4000`.
- **`BROWSE_CHUNK_MAX_LENGTH`**: Maximum length of text chunks to browse in web sources. Defaults to `8192`.
Expand Down
4 changes: 2 additions & 2 deletions gpt_researcher/actions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .retriever import get_retriever, get_retrievers
from .query_processing import get_sub_queries
from .query_processing import plan_research_outline
from .agent_creator import extract_json_with_regex, choose_agent
from .web_scraping import scrape_urls
from .report_generation import write_conclusion, summarize_url, generate_draft_section_titles, generate_report, write_report_introduction
Expand All @@ -9,7 +9,7 @@
__all__ = [
"get_retriever",
"get_retrievers",
"get_sub_queries",
"plan_research_outline",
"extract_json_with_regex",
"scrape_urls",
"write_conclusion",
Expand Down
137 changes: 96 additions & 41 deletions gpt_researcher/actions/query_processing.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,113 @@
import json_repair
from ..utils.llm import create_chat_completion
from ..prompts import generate_search_queries_prompt
from typing import Any
from typing import Any, List, Dict
from ..config import Config
import logging

logger = logging.getLogger(__name__)

async def get_sub_queries(
async def get_search_results(query: str, retriever: Any) -> List[Dict[str, Any]]:
"""
Get web search results for a given query.
Args:
query: The search query
retriever: The retriever instance
Returns:
A list of search results
"""
search_retriever = retriever(query)
return search_retriever.search()

async def generate_sub_queries(
query: str,
retriever: Any,
parent_query: str,
report_type: str,
context: List[Dict[str, Any]],
cfg: Config,
cost_callback: callable = None
) -> List[str]:
"""
Generate sub-queries using the specified LLM model.
Args:
query: The original query
parent_query: The parent query
report_type: The type of report
max_iterations: Maximum number of research iterations
context: Search results context
cfg: Configuration object
cost_callback: Callback for cost calculation
Returns:
A list of sub-queries
"""
gen_queries_prompt = generate_search_queries_prompt(
query,
parent_query,
report_type,
max_iterations=cfg.max_iterations or 1,
context=context
)

try:
response = await create_chat_completion(
model=cfg.strategic_llm_model,
messages=[{"role": "user", "content": gen_queries_prompt}],
temperature=1,
llm_provider=cfg.strategic_llm_provider,
max_tokens=None,
llm_kwargs=cfg.llm_kwargs,
cost_callback=cost_callback,
)
except Exception as e:
logger.warning(f"Error with strategic LLM: {e}. Falling back to smart LLM.")
response = await create_chat_completion(
model=cfg.smart_llm_model,
messages=[{"role": "user", "content": gen_queries_prompt}],
temperature=cfg.temperature,
max_tokens=cfg.smart_token_limit,
llm_provider=cfg.smart_llm_provider,
llm_kwargs=cfg.llm_kwargs,
cost_callback=cost_callback,
)

return json_repair.loads(response)

async def plan_research_outline(
query: str,
search_results: List[Dict[str, Any]],
agent_role_prompt: str,
cfg,
cfg: Config,
parent_query: str,
report_type: str,
cost_callback: callable = None,
):
) -> List[str]:
"""
Gets the sub queries
Plan the research outline by generating sub-queries.
Args:
query: original query
retriever: retriever instance
agent_role_prompt: agent role prompt
cfg: Config
parent_query: parent query
report_type: report type
cost_callback: callback for cost calculation
query: Original query
retriever: Retriever instance
agent_role_prompt: Agent role prompt
cfg: Configuration object
parent_query: Parent query
report_type: Report type
cost_callback: Callback for cost calculation
Returns:
sub_queries: List of sub queries
A list of sub-queries
"""
# Get web search results prior to generating subqueries for improved context around real time data tasks
search_retriever = retriever(query)
search_results = search_retriever.search()

max_research_iterations = cfg.max_iterations if cfg.max_iterations else 1
response = await create_chat_completion(
model=cfg.smart_llm_model,
messages=[
{"role": "system", "content": f"{agent_role_prompt}"},
{
"role": "user",
"content": generate_search_queries_prompt(
query,
parent_query,
report_type,
max_iterations=max_research_iterations,
context=search_results
),
},
],
temperature=0.1,
llm_provider=cfg.smart_llm_provider,
llm_kwargs=cfg.llm_kwargs,
cost_callback=cost_callback,

sub_queries = await generate_sub_queries(
query,
parent_query,
report_type,
search_results,
cfg,
cost_callback
)

sub_queries = json_repair.loads(response)

return sub_queries
8 changes: 5 additions & 3 deletions gpt_researcher/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def _set_embedding_attributes(self) -> None:
def _set_llm_attributes(self) -> None:
self.fast_llm_provider, self.fast_llm_model = self.parse_llm(self.fast_llm)
self.smart_llm_provider, self.smart_llm_model = self.parse_llm(self.smart_llm)
self.strategic_llm_provider, self.strategic_llm_model = self.parse_llm(self.strategic_llm)

def _handle_deprecated_attributes(self) -> None:
if os.getenv("EMBEDDING_PROVIDER") is not None:
Expand Down Expand Up @@ -110,9 +111,10 @@ def load_config(cls, config_path: str | None) -> Dict[str, Any]:

# config_path = os.path.join(cls.CONFIG_DIR, config_path)
if not os.path.exists(config_path):
print(f"Warning: Configuration not found at '{config_path}'. Using default configuration.")
if not config_path.endswith(".json"):
print(f"Do you mean '{config_path}.json'?")
if config_path:
print(f"Warning: Configuration not found at '{config_path}'. Using default configuration.")
if not config_path.endswith(".json"):
print(f"Do you mean '{config_path}.json'?")
return DEFAULT_CONFIG

with open(config_path, "r") as f:
Expand Down
1 change: 1 addition & 0 deletions gpt_researcher/config/variables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class BaseConfig(TypedDict):
SIMILARITY_THRESHOLD: float
FAST_LLM: str
SMART_LLM: str
STRATEGIC_LLM: str
FAST_TOKEN_LIMIT: int
SMART_TOKEN_LIMIT: int
BROWSE_CHUNK_MAX_LENGTH: int
Expand Down
1 change: 1 addition & 0 deletions gpt_researcher/config/variables/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"SIMILARITY_THRESHOLD": 0.42,
"FAST_LLM": "openai:gpt-4o-mini",
"SMART_LLM": "openai:gpt-4o-2024-08-06",
"STRATEGIC_LLM": "openai:o1-preview",
"FAST_TOKEN_LIMIT": 2000,
"SMART_TOKEN_LIMIT": 4000,
"BROWSE_CHUNK_MAX_LENGTH": 8192,
Expand Down
6 changes: 4 additions & 2 deletions gpt_researcher/prompts.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import warnings
from datetime import date, datetime, timezone

from gpt_researcher.utils.enum import ReportSource, ReportType, Tone
from .utils.enum import ReportSource, ReportType, Tone
from typing import List, Dict, Any


def generate_search_queries_prompt(
question: str,
parent_query: str,
report_type: str,
max_iterations: int = 3,
context: str = ""
context: List[Dict[str, Any]] = [],
):
"""Generates the search queries prompt for the given question.
Args:
Expand All @@ -31,6 +32,7 @@ def generate_search_queries_prompt(
task = question

context_prompt = f"""
You are a seasoned research assistant tasked with generating search queries to find relevant information for the following task: "{task}".
Context: {context}
Use this context to inform and refine your search queries. The context provides real-time web information that can help you generate more specific and relevant queries. Consider any current events, recent developments, or specific details mentioned in the context that could enhance the search queries.
Expand Down
23 changes: 16 additions & 7 deletions gpt_researcher/skills/researcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Dict, Optional

from ..actions.utils import stream_output
from ..actions.query_processing import get_sub_queries
from ..actions.query_processing import plan_research_outline, get_search_results
from ..document import DocumentLoader, LangChainDocumentLoader
from ..utils.enum import ReportSource, ReportType, Tone

Expand Down Expand Up @@ -111,7 +111,7 @@ async def __get_context_by_vectorstore(self, query, filter: Optional[dict] = Non
"""
context = []
# Generate Sub-Queries including original query
sub_queries = await self.get_sub_queries(query)
sub_queries = await self.plan_research(query)
# If this is not part of a sub researcher, add original query to research for better results
if self.researcher.report_type != "subtopic_report":
sub_queries.append(query)
Expand Down Expand Up @@ -143,7 +143,7 @@ async def __get_context_by_search(self, query, scraped_data: list = []):
"""
context = []
# Generate Sub-Queries including original query
sub_queries = await self.get_sub_queries(query)
sub_queries = await self.plan_research(query)
# If this is not part of a sub researcher, add original query to research for better results
if self.researcher.report_type != "subtopic_report":
sub_queries.append(query)
Expand Down Expand Up @@ -305,17 +305,26 @@ async def __scrape_data_by_query(self, sub_query):

return scraped_content

async def get_sub_queries(self, query):
async def plan_research(self, query):
await stream_output(
"logs",
"planning_research",
f"🌐 Browsing the web and planning research for query: {query}...",
f"🌐 Browsing the web to learn more about the task: {query}...",
self.researcher.websocket,
)

return await get_sub_queries(
search_results = await get_search_results(query, self.researcher.retrievers[0])

await stream_output(
"logs",
"planning_research",
f"🤔 Planning the research strategy and subtasks...",
self.researcher.websocket,
)

return await plan_research_outline(
query=query,
retriever=self.researcher.retrievers[0],
search_results=search_results,
agent_role_prompt=self.researcher.role,
cfg=self.researcher.cfg,
parent_query=self.researcher.parent_query,
Expand Down
2 changes: 1 addition & 1 deletion gpt_researcher/utils/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def get_llm(llm_provider, **kwargs):
async def create_chat_completion(
messages: list, # type: ignore
model: Optional[str] = None,
temperature: float = 0.4,
temperature: Optional[float] = 0.4,
max_tokens: Optional[int] = 4000,
llm_provider: Optional[str] = None,
stream: Optional[bool] = False,
Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ pydantic
fastapi
python-multipart
markdown
langchain>=0.2,<0.3
langchain_community>=0.2,<0.3
langchain-openai>=0.1,<0.2
langchain>=0.2,<0.4
langchain_community>=0.2,<0.4
langchain-openai>=0.1,<0.4
langgraph
tiktoken
gpt-researcher
Expand Down

0 comments on commit b0aa661

Please sign in to comment.