-
Notifications
You must be signed in to change notification settings - Fork 1
/
rag_chain.py
90 lines (71 loc) · 2.81 KB
/
rag_chain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
from dotenv import load_dotenv
from langchain import hub
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.messages.base import BaseMessage
from basic_chain import basic_chain, get_model
from remote_loader import get_wiki_docs
from splitter import split_documents
from vector_store import create_vector_db
def find_similar(vs, query):
docs = vs.similarity_search(query)
return docs
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
def get_question(input):
if not input:
return None
elif isinstance(input,str):
return input
elif isinstance(input,dict) and 'question' in input:
return input['question']
elif isinstance(input,BaseMessage):
return input.content
else:
raise Exception("string or dict with 'question' key expected as RAG chain input.")
def make_rag_chain(model, retriever, rag_prompt = None):
# We will use a prompt template from langchain hub.
if not rag_prompt:
rag_prompt = hub.pull("rlm/rag-prompt")
# And we will use the LangChain RunnablePassthrough to add some custom processing into our chain.
rag_chain = (
{
"context": RunnableLambda(get_question) | retriever | format_docs,
"question": RunnablePassthrough()
}
| rag_prompt
| model
)
return rag_chain
def main():
load_dotenv()
model = get_model("ChatGPT")
docs = get_wiki_docs(query="Bertrand Russell", load_max_docs=5)
texts = split_documents(docs)
vs = create_vector_db(texts)
prompt = ChatPromptTemplate.from_messages([
("system", "You are a professor who teaches philosophical concepts to beginners."),
("user", "{input}")
])
# Besides similarly search, you can also use maximal marginal relevance (MMR) for selecting results.
# retriever = vs.as_retriever(search_type="mmr")
retriever = vs.as_retriever()
output_parser = StrOutputParser()
chain = basic_chain(model, prompt)
base_chain = chain | output_parser
rag_chain = make_rag_chain(model, retriever) | output_parser
questions = [
"What were the most important contributions of Bertrand Russell to philosophy?",
"What was the first book Bertrand Russell published?",
"What was most notable about \"An Essay on the Foundations of Geometry\"?",
]
for q in questions:
print("\n--- QUESTION: ", q)
print("* BASE:\n", base_chain.invoke({"input": q}))
print("* RAG:\n", rag_chain.invoke(q))
if __name__ == '__main__':
# this is to quite parallel tokenizers warning.
os.environ["TOKENIZERS_PARALLELISM"] = "false"
main()