Skip to content

Commit

Permalink
Merge pull request #1114 from better629/feat_memory
Browse files Browse the repository at this point in the history
Feat rag add more features
  • Loading branch information
geekan authored Mar 27, 2024
2 parents 6434503 + 90e1b62 commit 6450a09
Show file tree
Hide file tree
Showing 15 changed files with 325 additions and 37 deletions.
61 changes: 49 additions & 12 deletions examples/rag_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,13 @@
BM25RetrieverConfig,
ChromaIndexConfig,
ChromaRetrieverConfig,
ElasticsearchIndexConfig,
ElasticsearchRetrieverConfig,
ElasticsearchStoreConfig,
FAISSRetrieverConfig,
LLMRankerConfig,
)
from metagpt.utils.exceptions import handle_exception

DOC_PATH = EXAMPLE_DATA_PATH / "rag/writer.txt"
QUESTION = "What are key qualities to be a good writer?"
Expand All @@ -39,12 +43,22 @@ def rag_key(self) -> str:
class RAGExample:
"""Show how to use RAG."""

def __init__(self):
self.engine = SimpleEngine.from_docs(
input_files=[DOC_PATH],
retriever_configs=[FAISSRetrieverConfig(), BM25RetrieverConfig()],
ranker_configs=[LLMRankerConfig()],
)
def __init__(self, engine: SimpleEngine = None):
self._engine = engine

@property
def engine(self):
if not self._engine:
self._engine = SimpleEngine.from_docs(
input_files=[DOC_PATH],
retriever_configs=[FAISSRetrieverConfig(), BM25RetrieverConfig()],
ranker_configs=[LLMRankerConfig()],
)
return self._engine

@engine.setter
def engine(self, value: SimpleEngine):
self._engine = value

async def run_pipeline(self, question=QUESTION, print_title=True):
"""This example run rag pipeline, use faiss&bm25 retriever and llm ranker, will print something like:
Expand Down Expand Up @@ -97,6 +111,7 @@ async def add_docs(self):
self.engine.add_docs([travel_filepath])
await self.run_pipeline(question=travel_question, print_title=False)

@handle_exception
async def add_objects(self, print_title=True):
"""This example show how to add objects.
Expand Down Expand Up @@ -154,20 +169,41 @@ async def init_and_query_chromadb(self):
"""
self._print_title("Init And Query ChromaDB")

# save index
# 1. save index
output_dir = DATA_PATH / "rag"
SimpleEngine.from_docs(
input_files=[TRAVEL_DOC_PATH],
retriever_configs=[ChromaRetrieverConfig(persist_path=output_dir)],
)

# load index
engine = SimpleEngine.from_index(
index_config=ChromaIndexConfig(persist_path=output_dir),
# 2. load index
engine = SimpleEngine.from_index(index_config=ChromaIndexConfig(persist_path=output_dir))

# 3. query
answer = await engine.aquery(TRAVEL_QUESTION)
self._print_query_result(answer)

@handle_exception
async def init_and_query_es(self):
"""This example show how to use es. how to save and load index. will print something like:
Query Result:
Bob likes traveling.
"""
self._print_title("Init And Query Elasticsearch")

# 1. create es index and save docs
store_config = ElasticsearchStoreConfig(index_name="travel", es_url="http://127.0.0.1:9200")
engine = SimpleEngine.from_docs(
input_files=[TRAVEL_DOC_PATH],
retriever_configs=[ElasticsearchRetrieverConfig(store_config=store_config)],
)

# query
answer = engine.query(TRAVEL_QUESTION)
# 2. load index
engine = SimpleEngine.from_index(index_config=ElasticsearchIndexConfig(store_config=store_config))

# 3. query
answer = await engine.aquery(TRAVEL_QUESTION)
self._print_query_result(answer)

@staticmethod
Expand Down Expand Up @@ -205,6 +241,7 @@ async def main():
await e.add_objects()
await e.init_objects()
await e.init_and_query_chromadb()
await e.init_and_query_es()


if __name__ == "__main__":
Expand Down
3 changes: 2 additions & 1 deletion metagpt/rag/engines/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Engines init"""

from metagpt.rag.engines.simple import SimpleEngine
from metagpt.rag.engines.flare import FLAREEngine

__all__ = ["SimpleEngine"]
__all__ = ["SimpleEngine", "FLAREEngine"]
9 changes: 9 additions & 0 deletions metagpt/rag/engines/flare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""FLARE Engine.
Use llamaindex's FLAREInstructQueryEngine as FLAREEngine, which accepts other engines as parameters.
For example, Create a simple engine, and then pass it to FLAREEngine.
"""

from llama_index.core.query_engine import ( # noqa: F401
FLAREInstructQueryEngine as FLAREEngine,
)
4 changes: 3 additions & 1 deletion metagpt/rag/engines/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,12 @@ def from_objs(
retriever_configs: Configuration for retrievers. If more than one config, will use SimpleHybridRetriever.
ranker_configs: Configuration for rankers.
"""
objs = objs or []
retriever_configs = retriever_configs or []

if not objs and any(isinstance(config, BM25RetrieverConfig) for config in retriever_configs):
raise ValueError("In BM25RetrieverConfig, Objs must not be empty.")

objs = objs or []
nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
index = VectorStoreIndex(
nodes=nodes,
Expand Down
2 changes: 1 addition & 1 deletion metagpt/rag/factories/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def get_instance(self, key: Any, **kwargs) -> Any:
if creator:
return creator(key, **kwargs)

raise ValueError(f"Unknown config: {key}")
raise ValueError(f"Unknown config: `{type(key)}`, {key}")

@staticmethod
def _val_from_config_or_kwargs(key: str, config: object = None, **kwargs) -> Any:
Expand Down
49 changes: 34 additions & 15 deletions metagpt/rag/factories/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@
from llama_index.core import StorageContext, VectorStoreIndex, load_index_from_storage
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.indices.base import BaseIndex
from llama_index.core.vector_stores.types import BasePydanticVectorStore
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
from llama_index.vector_stores.faiss import FaissVectorStore

from metagpt.rag.factories.base import ConfigBasedFactory
from metagpt.rag.schema import (
BaseIndexConfig,
BM25IndexConfig,
ChromaIndexConfig,
ElasticsearchIndexConfig,
ElasticsearchKeywordIndexConfig,
FAISSIndexConfig,
)
from metagpt.rag.vector_stores.chroma import ChromaVectorStore
Expand All @@ -22,6 +26,8 @@ def __init__(self):
FAISSIndexConfig: self._create_faiss,
ChromaIndexConfig: self._create_chroma,
BM25IndexConfig: self._create_bm25,
ElasticsearchIndexConfig: self._create_es,
ElasticsearchKeywordIndexConfig: self._create_es,
}
super().__init__(creators)

Expand All @@ -30,31 +36,44 @@ def get_index(self, config: BaseIndexConfig, **kwargs) -> BaseIndex:
return super().get_instance(config, **kwargs)

def _create_faiss(self, config: FAISSIndexConfig, **kwargs) -> VectorStoreIndex:
embed_model = self._extract_embed_model(config, **kwargs)

vector_store = FaissVectorStore.from_persist_dir(str(config.persist_path))
storage_context = StorageContext.from_defaults(vector_store=vector_store, persist_dir=config.persist_path)
index = load_index_from_storage(storage_context=storage_context, embed_model=embed_model)
return index

def _create_chroma(self, config: ChromaIndexConfig, **kwargs) -> VectorStoreIndex:
embed_model = self._extract_embed_model(config, **kwargs)
return self._index_from_storage(storage_context=storage_context, config=config, **kwargs)

def _create_bm25(self, config: BM25IndexConfig, **kwargs) -> VectorStoreIndex:
storage_context = StorageContext.from_defaults(persist_dir=config.persist_path)

return self._index_from_storage(storage_context=storage_context, config=config, **kwargs)

def _create_chroma(self, config: ChromaIndexConfig, **kwargs) -> VectorStoreIndex:
db = chromadb.PersistentClient(str(config.persist_path))
chroma_collection = db.get_or_create_collection(config.collection_name)
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
index = VectorStoreIndex.from_vector_store(
vector_store,
embed_model=embed_model,
)
return index

def _create_bm25(self, config: BM25IndexConfig, **kwargs) -> VectorStoreIndex:
return self._index_from_vector_store(vector_store=vector_store, config=config, **kwargs)

def _create_es(self, config: ElasticsearchIndexConfig, **kwargs) -> VectorStoreIndex:
vector_store = ElasticsearchStore(**config.store_config.model_dump())

return self._index_from_vector_store(vector_store=vector_store, config=config, **kwargs)

def _index_from_storage(
self, storage_context: StorageContext, config: BaseIndexConfig, **kwargs
) -> VectorStoreIndex:
embed_model = self._extract_embed_model(config, **kwargs)

storage_context = StorageContext.from_defaults(persist_dir=config.persist_path)
index = load_index_from_storage(storage_context=storage_context, embed_model=embed_model)
return index
return load_index_from_storage(storage_context=storage_context, embed_model=embed_model)

def _index_from_vector_store(
self, vector_store: BasePydanticVectorStore, config: BaseIndexConfig, **kwargs
) -> VectorStoreIndex:
embed_model = self._extract_embed_model(config, **kwargs)

return VectorStoreIndex.from_vector_store(
vector_store=vector_store,
embed_model=embed_model,
)

def _extract_embed_model(self, config, **kwargs) -> BaseEmbedding:
return self._val_from_config_or_kwargs("embed_model", config, **kwargs)
Expand Down
4 changes: 3 additions & 1 deletion metagpt/rag/factories/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ class RAGLLM(CustomLLM):
@property
def metadata(self) -> LLMMetadata:
"""Get LLM metadata."""
return LLMMetadata(context_window=self.context_window, num_output=self.num_output, model_name=self.model_name)
return LLMMetadata(
context_window=self.context_window, num_output=self.num_output, model_name=self.model_name or "unknown"
)

@llm_completion_callback()
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
Expand Down
17 changes: 16 additions & 1 deletion metagpt/rag/factories/ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,16 @@
from llama_index.core.llms import LLM
from llama_index.core.postprocessor import LLMRerank
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.postprocessor.colbert_rerank import ColbertRerank

from metagpt.rag.factories.base import ConfigBasedFactory
from metagpt.rag.schema import BaseRankerConfig, LLMRankerConfig
from metagpt.rag.rankers.object_ranker import ObjectSortPostprocessor
from metagpt.rag.schema import (
BaseRankerConfig,
ColbertRerankConfig,
LLMRankerConfig,
ObjectRankerConfig,
)


class RankerFactory(ConfigBasedFactory):
Expand All @@ -14,6 +21,8 @@ class RankerFactory(ConfigBasedFactory):
def __init__(self):
creators = {
LLMRankerConfig: self._create_llm_ranker,
ColbertRerankConfig: self._create_colbert_ranker,
ObjectRankerConfig: self._create_object_ranker,
}
super().__init__(creators)

Expand All @@ -28,6 +37,12 @@ def _create_llm_ranker(self, config: LLMRankerConfig, **kwargs) -> LLMRerank:
config.llm = self._extract_llm(config, **kwargs)
return LLMRerank(**config.model_dump())

def _create_colbert_ranker(self, config: ColbertRerankConfig, **kwargs) -> LLMRerank:
return ColbertRerank(**config.model_dump())

def _create_object_ranker(self, config: ObjectRankerConfig, **kwargs) -> LLMRerank:
return ObjectSortPostprocessor(**config.model_dump())

def _extract_llm(self, config: BaseRankerConfig = None, **kwargs) -> LLM:
return self._val_from_config_or_kwargs("llm", config, **kwargs)

Expand Down
19 changes: 17 additions & 2 deletions metagpt/rag/factories/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,22 @@
import faiss
from llama_index.core import StorageContext, VectorStoreIndex
from llama_index.core.vector_stores.types import BasePydanticVectorStore
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
from llama_index.vector_stores.faiss import FaissVectorStore

from metagpt.rag.factories.base import ConfigBasedFactory
from metagpt.rag.retrievers.base import RAGRetriever
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
from metagpt.rag.retrievers.chroma_retriever import ChromaRetriever
from metagpt.rag.retrievers.es_retriever import ElasticsearchRetriever
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.schema import (
BaseRetrieverConfig,
BM25RetrieverConfig,
ChromaRetrieverConfig,
ElasticsearchKeywordRetrieverConfig,
ElasticsearchRetrieverConfig,
FAISSRetrieverConfig,
IndexRetrieverConfig,
)
Expand All @@ -32,6 +36,8 @@ def __init__(self):
FAISSRetrieverConfig: self._create_faiss_retriever,
BM25RetrieverConfig: self._create_bm25_retriever,
ChromaRetrieverConfig: self._create_chroma_retriever,
ElasticsearchRetrieverConfig: self._create_es_retriever,
ElasticsearchKeywordRetrieverConfig: self._create_es_retriever,
}
super().__init__(creators)

Expand All @@ -53,20 +59,29 @@ def _create_default(self, **kwargs) -> RAGRetriever:
def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever:
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)

return FAISSRetriever(**config.model_dump())

def _create_bm25_retriever(self, config: BM25RetrieverConfig, **kwargs) -> DynamicBM25Retriever:
config.index = copy.deepcopy(self._extract_index(config, **kwargs))
nodes = list(config.index.docstore.docs.values())
return DynamicBM25Retriever(nodes=nodes, **config.model_dump())

return DynamicBM25Retriever(nodes=list(config.index.docstore.docs.values()), **config.model_dump())

def _create_chroma_retriever(self, config: ChromaRetrieverConfig, **kwargs) -> ChromaRetriever:
db = chromadb.PersistentClient(path=str(config.persist_path))
chroma_collection = db.get_or_create_collection(config.collection_name)

vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)

return ChromaRetriever(**config.model_dump())

def _create_es_retriever(self, config: ElasticsearchRetrieverConfig, **kwargs) -> ElasticsearchRetriever:
vector_store = ElasticsearchStore(**config.store_config.model_dump())
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)

return ElasticsearchRetriever(**config.model_dump())

def _extract_index(self, config: BaseRetrieverConfig = None, **kwargs) -> VectorStoreIndex:
return self._val_from_config_or_kwargs("index", config, **kwargs)

Expand Down
Loading

0 comments on commit 6450a09

Please sign in to comment.