Skip to content

Commit

Permalink
Update (#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
kesamet authored Jun 12, 2024
1 parent fdb3917 commit a295b7f
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 27 deletions.
3 changes: 0 additions & 3 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from streamlit_app.vision_assistant import vision_assistant
# from phoenix.trace.langchain import LangChainInstrumentor

# from streamlit_app.financial_assistant import financial_assistant

# Setup tracing
# LangChainInstrumentor().instrument()

Expand All @@ -19,7 +17,6 @@ def main():
dict_pages = {
"ReAct Agent": agent_react,
"Gemini Functions Agent": agent_gemini_functions,
# "Financial Assistant": financial_assistant,
"Chatbot Playground": chatbot,
"Vision Assistant": vision_assistant,
"Code Assistant": code_assistant,
Expand Down
29 changes: 29 additions & 0 deletions src/chains.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from langchain.llms.base import LLM
from langchain.prompts import BasePromptTemplate, PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import Runnable, RunnableBranch


def condense_question_chain(
llm: LLM, prompt: BasePromptTemplate | None = None
) -> Runnable:
"""Builds a chain that condenses question and chat history to create a standalone question."""
if prompt is None:
template = (
"Given the following chat history and a follow up question, "
"rephrase the follow up question to be a standalone question, in its original language.\n\n"
"Chat History:\n{chat_history}\n\nFollow Up Question: {question}\nStandalone question:"
)
prompt = PromptTemplate.from_template(template)

chain = RunnableBranch(
(
# Both empty string and empty list evaluate to False
lambda x: not x.get("chat_history", False),
# If no chat history, then we just pass input
(lambda x: x["question"]),
),
# If chat history, then we pass inputs to LLM chain
prompt | llm | StrOutputParser(),
)
return chain
8 changes: 4 additions & 4 deletions src/tools/calculator_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from pydantic import BaseModel, Field
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate
from langchain_core.runnables import Runnable
from langchain.callbacks.manager import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain.chains.llm import LLMChain
from langchain.chains.llm_math.prompt import PROMPT
from langchain.tools import BaseTool

Expand All @@ -20,7 +20,7 @@ class CalculatorTool(BaseTool):
name = "Calculator"
description = "A useful tool for answering simple questions about math."
args_schema: Type[BaseModel] = CalculatorInput
llm_chain: LLMChain
llm_chain: Runnable

def _run(
self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None
Expand All @@ -40,8 +40,8 @@ def from_llm(
llm: BaseLanguageModel,
prompt: BasePromptTemplate = PROMPT,
**kwargs: Any,
) -> LLMChain:
llm_chain = LLMChain(llm=llm, prompt=prompt)
) -> Runnable:
llm_chain = prompt | llm
return cls(llm_chain=llm_chain, **kwargs)


Expand Down
7 changes: 4 additions & 3 deletions src/tools/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
llm = GoogleGenerativeAI(model="gemini-pro", temperature=0.0)

# Web Search Tool
search = TavilySearchAPIWrapper()
_search = TavilySearchAPIWrapper()
description = (
"A search engine optimized for comprehensive, accurate, "
"and trusted results. Useful for when you need to answer questions "
Expand All @@ -20,7 +20,7 @@
"If the user is asking about something that you don't know about, "
"you should probably use this tool to see if that can provide any information."
)
tavily_tool = TavilySearchResults(api_wrapper=search, description=description)
tavily_tool = TavilySearchResults(api_wrapper=_search, description=description)

# Wikipedia Tool
_wikipedia = WikipediaAPIWrapper()
Expand All @@ -29,7 +29,8 @@
func=_wikipedia.run,
description=(
"A wrapper around Wikipedia. Useful for when you need to answer general questions about "
"people, places, companies, facts, historical events, or other subjects."
"people, places, companies, facts, historical events, or other subjects. "
"Input should be a search query."
),
)

Expand Down
56 changes: 39 additions & 17 deletions streamlit_app/agent_react.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import streamlit as st
from langchain_community.callbacks.streamlit import StreamlitCallbackHandler
from langchain_core.prompts import PromptTemplate
from langchain.agents import AgentExecutor
from langchain_google_genai import GoogleGenerativeAI

from src.chains import condense_question_chain
from src.agents import create_react_agent
from src.tools import (
tavily_tool,
Expand Down Expand Up @@ -65,25 +67,20 @@
agent=agent,
tools=tools,
return_intermediate_steps=True,
handle_parsing_errors=True,
max_execution_time=60,
verbose=True,
)

_chain = condense_question_chain(llm)


def init_messages() -> None:
clear_button = st.sidebar.button("Clear Chat", key="react_agent")
if clear_button or "react_messages" not in st.session_state:
st.session_state.react_messages = []


def get_response(user_input: str) -> str:
try:
return agent_executor.invoke({"input": user_input})
except Exception as e:
st.error(e)
return ""


def agent_react():
st.sidebar.title("ReAct Agent")
st.sidebar.info(
Expand All @@ -94,26 +91,51 @@ def agent_react():
init_messages()

# Display chat history
for human, ai in st.session_state.react_messages:
for human, ai, intermediate_steps in st.session_state.react_messages:
with st.chat_message("user"):
st.markdown(human)
with st.chat_message("assistant"):
st.markdown(ai)

with st.expander("Thoughts and Actions"):
for action, content in intermediate_steps:
st.markdown(f"`{action.log}`")
st.info(content)
st.markdown("---")

if user_input := st.chat_input("Your input"):
with st.chat_message("user"):
st.markdown(user_input)

chat_history = [x[:2] for x in st.session_state.react_messages]
with st.chat_message("assistant"):
with st.spinner("Thinking ..."):
response = get_response(user_input)
st_callback = StreamlitCallbackHandler(
parent_container=st.container(),
expand_new_thoughts=True,
collapse_completed_thoughts=True,
)
query = _chain.invoke(
{
"question": user_input,
"chat_history": chat_history,
},
)

response = agent_executor.invoke(
{"input": query},
config={"callbacks": [st_callback]},
)

output = response["output"].replace("$", r"\$")
st.markdown(response["output"])
intermediate_steps = response["intermediate_steps"]
# for r in intermediate_steps:
# r[1] = r[1].replace("$", r"\$")

with st.expander("Thoughts and Actions"):
for action, content in response["intermediate_steps"]:
st.markdown(f"`{action.log}`")
st.markdown(content)
st.markdown(output)

with st.expander("Sources"):
for _, content in intermediate_steps:
st.info(content)
st.markdown("---")

st.session_state.react_messages.append((user_input, output))
st.session_state.react_messages.append((user_input, output, intermediate_steps))

0 comments on commit a295b7f

Please sign in to comment.