Skip to content

Commit

Permalink
Merge pull request #919 from assafelovic/feature/improved_search_queries
Browse files Browse the repository at this point in the history
Feature/improved search queries using prior web search
  • Loading branch information
assafelovic authored Oct 17, 2024
2 parents b492ba5 + 830aadf commit e2c7d48
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 33 deletions.
2 changes: 1 addition & 1 deletion gpt_researcher/config/variables/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class DefaultConfig(BaseConfig):
"MEMORY_BACKEND": "local",
"TOTAL_WORDS": 900,
"REPORT_FORMAT": "APA",
"MAX_ITERATIONS": 3,
"MAX_ITERATIONS": 4,
"AGENT_ROLE": None,
"SCRAPER": "bs",
"MAX_SUBTOPICS": 3,
Expand Down
15 changes: 11 additions & 4 deletions gpt_researcher/master/actions/query_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json_repair
from ...utils.llm import create_chat_completion
from ..prompts import auto_agent_instructions, generate_search_queries_prompt
from typing import Any


async def choose_agent(
Expand Down Expand Up @@ -77,6 +78,7 @@ def extract_json_with_regex(response):

async def get_sub_queries(
query: str,
retriever: Any,
agent_role_prompt: str,
cfg,
parent_query: str,
Expand All @@ -87,16 +89,20 @@ async def get_sub_queries(
Gets the sub queries
Args:
query: original query
retriever: retriever instance
agent_role_prompt: agent role prompt
cfg: Config
parent_query:
report_type:
cost_callback:
parent_query: parent query
report_type: report type
cost_callback: callback for cost calculation
Returns:
sub_queries: List of sub queries
"""
# Get web search results prior to generating subqueries for improved context around real time data tasks
search_retriever = retriever(query)
search_results = search_retriever.search()

max_research_iterations = cfg.max_iterations if cfg.max_iterations else 1
response = await create_chat_completion(
model=cfg.smart_llm_model,
Expand All @@ -109,6 +115,7 @@ async def get_sub_queries(
parent_query,
report_type,
max_iterations=max_research_iterations,
context=search_results
),
},
],
Expand Down
15 changes: 2 additions & 13 deletions gpt_researcher/master/agent/context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ async def __get_context_from_langchain_documents(self):
return await self.__get_context_by_search(self.researcher.query, langchain_documents_data)

async def __get_context_by_vectorstore(self, query, filter: Optional[dict] = None):
sub_queries = await self.__get_sub_queries(query)
sub_queries = await self.researcher.research_conductor.get_sub_queries(query)
if self.researcher.report_type != "subtopic_report":
sub_queries.append(query)

Expand All @@ -92,7 +92,7 @@ async def __get_context_by_vectorstore(self, query, filter: Optional[dict] = Non
return context

async def __get_context_by_search(self, query, scraped_data: list = []):
sub_queries = await self.__get_sub_queries(query)
sub_queries = await self.researcher.research_conductor.get_sub_queries(query)
if self.researcher.report_type != "subtopic_report":
sub_queries.append(query)

Expand Down Expand Up @@ -214,17 +214,6 @@ async def get_similar_content_by_query(self, query, pages):
query=query, max_results=10, cost_callback=self.researcher.add_costs
)

async def __get_sub_queries(self, query):
from gpt_researcher.master.actions import get_sub_queries
return await get_sub_queries(
query=query,
agent_role_prompt=self.researcher.role,
cfg=self.researcher.cfg,
parent_query=self.researcher.parent_query,
report_type=self.researcher.report_type,
cost_callback=self.researcher.add_costs,
)

async def get_similar_written_contents_by_draft_section_titles(
self,
current_subtopic: str,
Expand Down
10 changes: 6 additions & 4 deletions gpt_researcher/master/agent/master.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
from ...memory import Memory
from ...utils.enum import ReportSource, ReportType, Tone
from ...llm_provider import GenericLLMProvider
from ..agent.researcher import ResearchConductor
from ..agent.scraper import ReportScraper
from ..agent.writer import ReportGenerator
from ..agent.context_manager import ContextManager
from ..actions import get_retrievers, choose_agent
from ...vector_store import VectorStoreWrapper

# Research agents
from .researcher import ResearchConductor
from .scraper import ReportScraper
from .writer import ReportGenerator
from .context_manager import ContextManager


class GPTResearcher:
def __init__(
Expand Down
18 changes: 13 additions & 5 deletions gpt_researcher/master/agent/researcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from typing import Dict, Optional

from ..actions.utils import stream_output
from ..actions import get_sub_queries, scrape_urls
from ..actions import scrape_urls
from ..actions.query_processing import get_sub_queries
from ...document import DocumentLoader, LangChainDocumentLoader
from ...utils.enum import ReportSource, ReportType, Tone

Expand Down Expand Up @@ -110,7 +111,7 @@ async def __get_context_by_vectorstore(self, query, filter: Optional[dict] = Non
"""
context = []
# Generate Sub-Queries including original query
sub_queries = await self.__get_sub_queries(query)
sub_queries = await self.get_sub_queries(query)
# If this is not part of a sub researcher, add original query to research for better results
if self.researcher.report_type != "subtopic_report":
sub_queries.append(query)
Expand Down Expand Up @@ -142,7 +143,7 @@ async def __get_context_by_search(self, query, scraped_data: list = []):
"""
context = []
# Generate Sub-Queries including original query
sub_queries = await self.__get_sub_queries(query)
sub_queries = await self.get_sub_queries(query)
# If this is not part of a sub researcher, add original query to research for better results
if self.researcher.report_type != "subtopic_report":
sub_queries.append(query)
Expand Down Expand Up @@ -306,10 +307,17 @@ async def __scrape_data_by_query(self, sub_query):

return scraped_content_results

async def __get_sub_queries(self, query):
# Generate Sub-Queries including original query
async def get_sub_queries(self, query):
await stream_output(
"logs",
"planning_research",
f"🌐 Browsing the web and planning research for query: {query}...",
self.researcher.websocket,
)

return await get_sub_queries(
query=query,
retriever=self.researcher.retrievers[0],
agent_role_prompt=self.researcher.role,
cfg=self.researcher.cfg,
parent_query=self.researcher.parent_query,
Expand Down
25 changes: 19 additions & 6 deletions gpt_researcher/master/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ def generate_search_queries_prompt(
parent_query: str,
report_type: str,
max_iterations: int = 3,
context: str = ""
):
"""Generates the search queries prompt for the given question.
Args:
question (str): The question to generate the search queries prompt for
parent_query (str): The main question (only relevant for detailed reports)
report_type (str): The report type
max_iterations (int): The maximum number of search queries to generate
context (str): Context for better understanding of the task with realtime web information
Returns: str: The search queries prompt for the given question
"""
Expand All @@ -28,12 +30,22 @@ def generate_search_queries_prompt(
else:
task = question

return (
f'Write {max_iterations} google search queries to search online that form an objective opinion from the following task: "{task}"\n'
f"Assume the current date is {datetime.now(timezone.utc).strftime('%B %d, %Y')} if required.\n"
f'You must respond with a list of strings in the following format: ["query 1", "query 2", "query 3"].\n'
f"The response should contain ONLY the list."
)
context_prompt = f"""
Context: {context}
Use this context to inform and refine your search queries. The context provides real-time web information that can help you generate more specific and relevant queries. Consider any current events, recent developments, or specific details mentioned in the context that could enhance the search queries.
""" if context else ""

dynamic_example = ", ".join([f'"query {i+1}"' for i in range(max_iterations)])

return f"""Write {max_iterations} google search queries to search online that form an objective opinion from the following task: "{task}"
Assume the current date is {datetime.now(timezone.utc).strftime('%B %d, %Y')} if required.
{context_prompt}
You must respond with a list of strings in the following format: [{dynamic_example}].
The response should contain ONLY the list.
"""


def generate_report_prompt(
Expand Down Expand Up @@ -404,3 +416,4 @@ def get_prompt_by_report_type(report_type):
)
prompt_by_type = report_type_mapping.get(default_report_type)
return prompt_by_type

0 comments on commit e2c7d48

Please sign in to comment.