Skip to content

Commit

Permalink
Manage context length; Allow configuration for RAG and generation
Browse files Browse the repository at this point in the history
  • Loading branch information
gallegi committed Nov 27, 2024
1 parent d3a3056 commit 8b85dda
Show file tree
Hide file tree
Showing 10 changed files with 388 additions and 90 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/publish-to-pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,4 @@ jobs:
path: dist/

- name: Publish release distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
uses: pypa/gh-action-pypi-publish@release/v1
204 changes: 162 additions & 42 deletions llama_assistant/agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import List, Set, Optional
from collections import defaultdict
from typing import List, Set, Optional, Dict

from llama_cpp import Llama
from llama_index.core import VectorStoreIndex
Expand All @@ -11,30 +10,71 @@

from llama_index.core.workflow import Event, StartEvent, StopEvent, Workflow, step


def convert_message_list_to_str(messages):
chat_history_str = ""
for message in messages:
if type(message["content"]) is str:
chat_history_str += message["role"] + ": " + message["content"] + "\n"
else:
chat_history_str += message["role"] + ": " + message["content"]["text"] + "\n"

return chat_history_str


class SetupEvent(Event):
pass


class CondenseQueryEvent(Event):
condensed_query_str: str


class RetrievalEvent(Event):
nodes: List[NodeWithScore]


class ChatHistory:
def __init__(self, max_history_size: int):
self.max_history_size = max_history_size
self.total_size = 0
self.chat_history = []

def add_message(self, message: dict):
if "content" in message and type(message["content"]) is list:
# multimodal model's message format
new_msg_size = len(message["content"][0]["text"].split())
else:
# text-only model's message format
new_msg_size = len(message["content"].split())

self.total_size += new_msg_size
self.chat_history.append(message)
while self.total_size > self.max_history_size:
oldest_msg = self.chat_history.pop(0)
if "content" in oldest_msg and type(oldest_msg["content"]) is list:
# multimodal model's message format
len_oldest_msg = len(oldest_msg["content"][0]["text"].split())
else:
# text-only model's message format
len_oldest_msg = len(oldest_msg["content"].split())
self.total_size -= len_oldest_msg

def get_chat_history(self):
return self.chat_history

def clear(self):
self.chat_history = []
self.total_size = 0

def __len__(self):
return len(self.chat_history)


class RAGAgent(Workflow):
SUMMARY_TEMPLATE = (
"Given the chat history:\n"
"Given the converstation:\n"
"'''{chat_history_str}'''\n\n"
"And the user asked the following question:{query_str}\n"
"Rewrite to a standalone question:\n"
)

Expand All @@ -45,49 +85,111 @@ class RAGAgent(Workflow):
"-----\n"
"Please write a response to the following question, using the above information if relevant:\n"
"{query_str}\n"
)
def __init__(self, embed_model_name: str, llm: Llama, timeout: int = 60, verbose: bool = False):
)
SYSTEM_PROMPT = {"role": "system", "content": "Generate short and simple response:"}

def __init__(
self,
generation_setting: Dict,
rag_setting: Dict,
llm: Llama,
timeout: int = 60,
verbose: bool = False,
):
super().__init__(timeout=timeout, verbose=verbose)
self.k = 3
self.generation_setting = generation_setting
self.context_len = generation_setting["context_len"]
# we want the retrieved context = our set value but not more than the context_len
self.retrieval_top_k = (self.context_len - rag_setting["chunk_size"]) // rag_setting[
"chunk_overlap"
]
self.retrieval_top_k = min(max(1, self.retrieval_top_k), rag_setting["max_retrieval_top_k"])
self.search_index = None
self.retriever = None
self.chat_history = []
# 1 token ~ 3/4 words, so we multiply the context_len by 0.7 to get the max number of words
# why not 0.75? because we want to be a bit more conservative
self.chat_history = ChatHistory(max_history_size=self.context_len * 0.7)
self.lookup_files = set()

self.embed_model = HuggingFaceEmbedding(model_name=embed_model_name)
self.embed_model = HuggingFaceEmbedding(model_name=rag_setting["embed_model_name"])
Settings.embed_model = self.embed_model
self.node_processor = SimilarityPostprocessor(similarity_cutoff=0.3)
Settings.chunk_size = rag_setting["chunk_size"]
Settings.chunk_overlap = rag_setting["chunk_overlap"]
self.node_processor = SimilarityPostprocessor(
similarity_cutoff=rag_setting["similarity_threshold"]
)
self.llm = llm

def udpate_index(self, files: Optional[Set[str] ] = set()):
def update_index(self, files: Optional[Set[str]] = set()):
if not files:
print("No lookup files provided, clearing index...")
self.retriever = None
self.search_index = None
return

print("Indexing documents...")
documents = SimpleDirectoryReader(input_files=files, recursive=True).load_data(show_progress=True, num_workers=1)
page_num_tracker = defaultdict(int)
for doc in documents:
key = doc.metadata['file_path']
doc.metadata['page_index'] = page_num_tracker[key]
page_num_tracker[key] += 1
documents = SimpleDirectoryReader(input_files=files, recursive=True).load_data(
show_progress=True, num_workers=1
)

if self.search_index is None:
self.search_index = VectorStoreIndex.from_documents(documents, embed_model=self.embed_model)
self.search_index = VectorStoreIndex.from_documents(
documents, embed_model=self.embed_model
)
else:
for doc in documents:
self.search_index.insert(doc) # Add the new document to the index
self.search_index.insert(doc) # Add the new document to the index

self.retriever = self.search_index.as_retriever(similarity_top_k=self.retrieval_top_k)

def update_rag_setting(self, rag_setting: Dict):
if self.node_processor.similarity_cutoff != rag_setting["similarity_threshold"]:
self.node_processor = SimilarityPostprocessor(
similarity_cutoff=rag_setting["similarity_threshold"]
)

new_top_k = (self.context_len - rag_setting["chunk_size"]) // rag_setting["chunk_overlap"]
new_top_k = min(max(1, new_top_k), rag_setting["max_retrieval_top_k"])
if self.retrieval_top_k != new_top_k:
self.retrieval_top_k = new_top_k
if self.retriever:
self.retriever = self.search_index.as_retriever(
similarity_top_k=self.retrieval_top_k
)

if (
self.embed_model.model_name != rag_setting["embed_model_name"]
or Settings.chunk_size != rag_setting["chunk_size"]
or Settings.chunk_overlap != rag_setting["chunk_overlap"]
):
self.embed_model = HuggingFaceEmbedding(model_name=rag_setting["embed_model_name"])
Settings.embed_model = self.embed_model
Settings.chunk_size = rag_setting["chunk_size"]
Settings.chunk_overlap = rag_setting["chunk_overlap"]

# reindex since those are the settings that affect the index
if self.lookup_files:
print("Re-indexing documents since the rag settings have changed...")
documents = SimpleDirectoryReader(
input_files=self.lookup_files, recursive=True
).load_data(show_progress=True, num_workers=1)
self.search_index = VectorStoreIndex.from_documents(
documents, embed_model=self.embed_model
)
self.retriever = self.search_index.as_retriever(
similarity_top_k=self.retrieval_top_k
)

def update_generation_setting(self, generation_setting):
self.generation_setting = generation_setting
self.context_len = generation_setting["context_len"]

self.retriever = self.search_index.as_retriever(similarity_top_k=self.k)

@step
async def setup(self, ctx: Context, ev: StartEvent) -> SetupEvent:
# set frequetly used variables to context
# set frequently used variables to context
query_str = ev.query_str
image = ev.image
lookup_files = ev.lookup_files
lookup_files = ev.lookup_files if ev.lookup_files else set()
streaming = ev.streaming
await ctx.set("query_str", query_str)
await ctx.set("image", image)
Expand All @@ -96,7 +198,7 @@ async def setup(self, ctx: Context, ev: StartEvent) -> SetupEvent:
# update index if needed
if lookup_files != self.lookup_files:
print("Different lookup files, updating index...")
self.udpate_index(lookup_files)
self.update_index(lookup_files)

self.lookup_files = lookup_files.copy()

Expand All @@ -105,36 +207,48 @@ async def setup(self, ctx: Context, ev: StartEvent) -> SetupEvent:
@step
async def condense_history_to_query(self, ctx: Context, ev: SetupEvent) -> CondenseQueryEvent:
"""
Condense the chat history and the query into a single query. Only used for retrieval.
Condense the chat history and the query into a single query. Only used for retrieval.
"""
query_str = await ctx.get("query_str")

formated_query = ""

if len(self.chat_history) > 0 or self.retriever is not None:
chat_history_str = convert_message_list_to_str(self.chat_history)
formated_query = self.SUMMARY_TEMPLATE.format(chat_history_str=chat_history_str, query_str=query_str)
self.chat_history.add_message({"role": "user", "content": query_str})
chat_history_str = convert_message_list_to_str(self.chat_history.get_chat_history())
# use llm to summarize the chat history into a single query,
# which is used to retrieve context from the documents
formated_query = self.SUMMARY_TEMPLATE.format(chat_history_str=chat_history_str)

history_summary = self.llm.create_chat_completion(
messages=[{"role": "user", "content": formated_query}], stream=False
messages=[{"role": "user", "content": formated_query}],
stream=False,
top_k=self.generation_setting["top_k"],
top_p=self.generation_setting["top_p"],
temperature=self.generation_setting["temperature"],
)["choices"][0]["message"]["content"]

condensed_query = "Context:\n" + history_summary + "\nQuestion: " + query_str
# remove the last user message from the chat history
# later we will add the query with retrieved context (RAG)
self.chat_history.chat_history.pop()
else:
# if there is no history or no need for retrieval, return the query as is
condensed_query = query_str

return CondenseQueryEvent(condensed_query_str=condensed_query)

@step
async def retrieve(self, ctx: Context, ev: CondenseQueryEvent) -> RetrievalEvent:
# retrieve from context
if not self.retriever:
return RetrievalEvent(nodes=[])

# retrieve from dropped documents
condensed_query_str = ev.condensed_query_str
nodes = await self.retriever.aretrieve(condensed_query_str)
nodes = self.node_processor.postprocess_nodes(nodes)
return RetrievalEvent(nodes=nodes)

def _prepare_query_with_context(
self,
query_str: str,
Expand All @@ -144,19 +258,19 @@ def _prepare_query_with_context(

if len(nodes) == 0:
return query_str

for idx, node in enumerate(nodes):
node_text = node.get_content(metadata_mode="llm")
node_context += f"\n{node_text}\n\n"

formatted_query = self.CONTEXT_PROMPT_TEMPLATE.format(
node_context=node_context, query_str=query_str
)

return formatted_query

@step
async def llm_response(self, ctx: Context, retrieval_ev: RetrievalEvent) -> StopEvent:
async def llm_response(self, ctx: Context, retrieval_ev: RetrievalEvent) -> StopEvent:
nodes = retrieval_ev.nodes
query_str = await ctx.get("query_str")
image = await ctx.get("image")
Expand All @@ -174,12 +288,18 @@ async def llm_response(self, ctx: Context, retrieval_ev: RetrievalEvent) -> Sto
else:
formated_message = {"role": "user", "content": query_with_ctx}

self.chat_history.add_message(formated_message)

response = self.llm.create_chat_completion(
messages=self.chat_history+[formated_message], stream=streaming
messages=[self.SYSTEM_PROMPT] + self.chat_history.get_chat_history(),
stream=streaming,
top_k=self.generation_setting["top_k"],
top_p=self.generation_setting["top_p"],
temperature=self.generation_setting["temperature"],
)
self.chat_history.append({"role": "user", "content": query_str})
# remove the query with context from the chat history,
self.chat_history.chat_history.pop()
# add the short query (without context) instead -> not to clutter the chat history
self.chat_history.add_message({"role": "user", "content": query_str})

return StopEvent(result=response)



11 changes: 10 additions & 1 deletion llama_assistant/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@
},
]

# generation setting
context_len = 2048
top_k = 40
top_p = 0.95
temperature = 0.2

home_dir = Path.home()
llama_assistant_dir = home_dir / "llama_assistant"
Expand All @@ -119,8 +124,12 @@
settings_file = llama_assistant_dir / "settings.json"
document_icon = "llama_assistant/resources/document_icon.png"

# for RAG pipeline
# RAG setting
embed_model_name = "BAAI/bge-base-en-v1.5"
chunk_size = 256
chunk_overlap = 128
max_retrieval_top_k = 3
similarity_threshold = 0.6

if custom_models_file.exists():
with open(custom_models_file, "r") as f:
Expand Down
Loading

0 comments on commit 8b85dda

Please sign in to comment.