Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat rag add more features #1114

Merged
merged 13 commits into from
Mar 27, 2024
74 changes: 62 additions & 12 deletions examples/rag_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""RAG pipeline"""

import asyncio
from functools import wraps

from pydantic import BaseModel

Expand All @@ -11,6 +12,9 @@
BM25RetrieverConfig,
ChromaIndexConfig,
ChromaRetrieverConfig,
ElasticsearchIndexConfig,
ElasticsearchRetrieverConfig,
ElasticsearchStoreConfig,
FAISSRetrieverConfig,
LLMRankerConfig,
)
Expand All @@ -24,6 +28,17 @@
LLM_TIP = "If you not sure, just answer I don't know."


def catch_exception(func):
@wraps(func)
async def wrapper(*args, **kwargs):
try:
return await func(*args, **kwargs)
except Exception as e:
logger.error(f"{func.__name__} exception: {e}")

return wrapper
better629 marked this conversation as resolved.
Show resolved Hide resolved


class Player(BaseModel):
"""To demonstrate rag add objs."""

Expand All @@ -39,12 +54,22 @@ def rag_key(self) -> str:
class RAGExample:
"""Show how to use RAG."""

def __init__(self):
self.engine = SimpleEngine.from_docs(
input_files=[DOC_PATH],
retriever_configs=[FAISSRetrieverConfig(), BM25RetrieverConfig()],
ranker_configs=[LLMRankerConfig()],
)
def __init__(self, engine: SimpleEngine = None):
self._engine = engine

@property
def engine(self):
if not self._engine:
self._engine = SimpleEngine.from_docs(
input_files=[DOC_PATH],
retriever_configs=[FAISSRetrieverConfig(), BM25RetrieverConfig()],
ranker_configs=[LLMRankerConfig()],
)
return self._engine

@engine.setter
def engine(self, value: SimpleEngine):
self._engine = value

async def run_pipeline(self, question=QUESTION, print_title=True):
"""This example run rag pipeline, use faiss&bm25 retriever and llm ranker, will print something like:
Expand Down Expand Up @@ -97,6 +122,7 @@ async def add_docs(self):
self.engine.add_docs([travel_filepath])
await self.run_pipeline(question=travel_question, print_title=False)

@catch_exception
async def add_objects(self, print_title=True):
"""This example show how to add objects.

Expand Down Expand Up @@ -154,20 +180,43 @@ async def init_and_query_chromadb(self):
"""
self._print_title("Init And Query ChromaDB")

# save index
# 1.save index
better629 marked this conversation as resolved.
Show resolved Hide resolved
output_dir = DATA_PATH / "rag"
SimpleEngine.from_docs(
input_files=[TRAVEL_DOC_PATH],
retriever_configs=[ChromaRetrieverConfig(persist_path=output_dir)],
)

# load index
engine = SimpleEngine.from_index(
index_config=ChromaIndexConfig(persist_path=output_dir),
# 2.load index
engine = SimpleEngine.from_index(index_config=ChromaIndexConfig(persist_path=output_dir))

# 3.query
answer = await engine.aquery(TRAVEL_QUESTION)
self._print_query_result(answer)

@catch_exception
async def init_and_query_es(self):
"""This example show how to use es. how to save and load index. will print something like:

Query Result:
Bob likes traveling.

If `Unclosed client session`, it's llamaindex elasticsearch problem, maybe fixed later.
"""
self._print_title("Init And Query Elasticsearch")

# 1.create es index and save docs
store_config = ElasticsearchStoreConfig(index_name="travel", es_url="http://127.0.0.1:9200")
engine = SimpleEngine.from_docs(
input_files=[TRAVEL_DOC_PATH],
retriever_configs=[ElasticsearchRetrieverConfig(store_config=store_config)],
)

# query
answer = engine.query(TRAVEL_QUESTION)
# 2.load index
engine = SimpleEngine.from_index(index_config=ElasticsearchIndexConfig(store_config=store_config))

# 3.query
answer = await engine.aquery(TRAVEL_QUESTION)
self._print_query_result(answer)

@staticmethod
Expand Down Expand Up @@ -205,6 +254,7 @@ async def main():
await e.add_objects()
await e.init_objects()
await e.init_and_query_chromadb()
await e.init_and_query_es()


if __name__ == "__main__":
Expand Down
3 changes: 2 additions & 1 deletion metagpt/rag/engines/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Engines init"""

from metagpt.rag.engines.simple import SimpleEngine
from metagpt.rag.engines.flare import FLAREEngine

__all__ = ["SimpleEngine"]
__all__ = ["SimpleEngine", "FLAREEngine"]
9 changes: 9 additions & 0 deletions metagpt/rag/engines/flare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""FLARE Engine.

Use llamaindex's FLAREInstructQueryEngine as FLAREEngine, which accepts other engines as parameters.
For example, Create a simple engine, and then pass it to FLAREEngine.
"""

from llama_index.core.query_engine import ( # noqa: F401
FLAREInstructQueryEngine as FLAREEngine,
)
4 changes: 3 additions & 1 deletion metagpt/rag/engines/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,12 @@ def from_objs(
retriever_configs: Configuration for retrievers. If more than one config, will use SimpleHybridRetriever.
ranker_configs: Configuration for rankers.
"""
objs = objs or []
retriever_configs = retriever_configs or []

if not objs and any(isinstance(config, BM25RetrieverConfig) for config in retriever_configs):
raise ValueError("In BM25RetrieverConfig, Objs must not be empty.")

objs = objs or []
nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
index = VectorStoreIndex(
nodes=nodes,
Expand Down
2 changes: 1 addition & 1 deletion metagpt/rag/factories/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def get_instance(self, key: Any, **kwargs) -> Any:
if creator:
return creator(key, **kwargs)

raise ValueError(f"Unknown config: {key}")
raise ValueError(f"Unknown config: `{type(key)}`, {key}")

@staticmethod
def _val_from_config_or_kwargs(key: str, config: object = None, **kwargs) -> Any:
Expand Down
49 changes: 34 additions & 15 deletions metagpt/rag/factories/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@
from llama_index.core import StorageContext, VectorStoreIndex, load_index_from_storage
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.indices.base import BaseIndex
from llama_index.core.vector_stores.types import BasePydanticVectorStore
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
from llama_index.vector_stores.faiss import FaissVectorStore

from metagpt.rag.factories.base import ConfigBasedFactory
from metagpt.rag.schema import (
BaseIndexConfig,
BM25IndexConfig,
ChromaIndexConfig,
ElasticsearchIndexConfig,
ElasticsearchKeywordIndexConfig,
FAISSIndexConfig,
)
from metagpt.rag.vector_stores.chroma import ChromaVectorStore
Expand All @@ -22,6 +26,8 @@ def __init__(self):
FAISSIndexConfig: self._create_faiss,
ChromaIndexConfig: self._create_chroma,
BM25IndexConfig: self._create_bm25,
ElasticsearchIndexConfig: self._create_es,
ElasticsearchKeywordIndexConfig: self._create_es,
}
super().__init__(creators)

Expand All @@ -30,31 +36,44 @@ def get_index(self, config: BaseIndexConfig, **kwargs) -> BaseIndex:
return super().get_instance(config, **kwargs)

def _create_faiss(self, config: FAISSIndexConfig, **kwargs) -> VectorStoreIndex:
embed_model = self._extract_embed_model(config, **kwargs)

vector_store = FaissVectorStore.from_persist_dir(str(config.persist_path))
storage_context = StorageContext.from_defaults(vector_store=vector_store, persist_dir=config.persist_path)
index = load_index_from_storage(storage_context=storage_context, embed_model=embed_model)
return index

def _create_chroma(self, config: ChromaIndexConfig, **kwargs) -> VectorStoreIndex:
embed_model = self._extract_embed_model(config, **kwargs)
return self._index_from_storage(storage_context=storage_context, config=config, **kwargs)

def _create_bm25(self, config: BM25IndexConfig, **kwargs) -> VectorStoreIndex:
storage_context = StorageContext.from_defaults(persist_dir=config.persist_path)

return self._index_from_storage(storage_context=storage_context, config=config, **kwargs)

def _create_chroma(self, config: ChromaIndexConfig, **kwargs) -> VectorStoreIndex:
db = chromadb.PersistentClient(str(config.persist_path))
chroma_collection = db.get_or_create_collection(config.collection_name)
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
index = VectorStoreIndex.from_vector_store(
vector_store,
embed_model=embed_model,
)
return index

def _create_bm25(self, config: BM25IndexConfig, **kwargs) -> VectorStoreIndex:
return self._index_from_vector_store(vector_store=vector_store, config=config, **kwargs)

def _create_es(self, config: ElasticsearchIndexConfig, **kwargs) -> VectorStoreIndex:
vector_store = ElasticsearchStore(**config.store_config.model_dump())

return self._index_from_vector_store(vector_store=vector_store, config=config, **kwargs)

def _index_from_storage(
self, storage_context: StorageContext, config: BaseIndexConfig, **kwargs
) -> VectorStoreIndex:
embed_model = self._extract_embed_model(config, **kwargs)

storage_context = StorageContext.from_defaults(persist_dir=config.persist_path)
index = load_index_from_storage(storage_context=storage_context, embed_model=embed_model)
return index
return load_index_from_storage(storage_context=storage_context, embed_model=embed_model)

def _index_from_vector_store(
self, vector_store: BasePydanticVectorStore, config: BaseIndexConfig, **kwargs
) -> VectorStoreIndex:
embed_model = self._extract_embed_model(config, **kwargs)

return VectorStoreIndex.from_vector_store(
vector_store=vector_store,
embed_model=embed_model,
)

def _extract_embed_model(self, config, **kwargs) -> BaseEmbedding:
return self._val_from_config_or_kwargs("embed_model", config, **kwargs)
Expand Down
4 changes: 3 additions & 1 deletion metagpt/rag/factories/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ class RAGLLM(CustomLLM):
@property
def metadata(self) -> LLMMetadata:
"""Get LLM metadata."""
return LLMMetadata(context_window=self.context_window, num_output=self.num_output, model_name=self.model_name)
return LLMMetadata(
context_window=self.context_window, num_output=self.num_output, model_name=self.model_name or "unknown"
)

@llm_completion_callback()
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
Expand Down
17 changes: 16 additions & 1 deletion metagpt/rag/factories/ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,16 @@
from llama_index.core.llms import LLM
from llama_index.core.postprocessor import LLMRerank
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.postprocessor.colbert_rerank import ColbertRerank

from metagpt.rag.factories.base import ConfigBasedFactory
from metagpt.rag.schema import BaseRankerConfig, LLMRankerConfig
from metagpt.rag.rankers.object_ranker import ObjectSortPostprocessor
from metagpt.rag.schema import (
BaseRankerConfig,
ColbertRerankConfig,
LLMRankerConfig,
ObjectRankerConfig,
)


class RankerFactory(ConfigBasedFactory):
Expand All @@ -14,6 +21,8 @@ class RankerFactory(ConfigBasedFactory):
def __init__(self):
creators = {
LLMRankerConfig: self._create_llm_ranker,
ColbertRerankConfig: self._create_colbert_ranker,
ObjectRankerConfig: self._create_object_ranker,
}
super().__init__(creators)

Expand All @@ -28,6 +37,12 @@ def _create_llm_ranker(self, config: LLMRankerConfig, **kwargs) -> LLMRerank:
config.llm = self._extract_llm(config, **kwargs)
return LLMRerank(**config.model_dump())

def _create_colbert_ranker(self, config: ColbertRerankConfig, **kwargs) -> LLMRerank:
return ColbertRerank(**config.model_dump())

def _create_object_ranker(self, config: ObjectRankerConfig, **kwargs) -> LLMRerank:
return ObjectSortPostprocessor(**config.model_dump())

def _extract_llm(self, config: BaseRankerConfig = None, **kwargs) -> LLM:
return self._val_from_config_or_kwargs("llm", config, **kwargs)

Expand Down
19 changes: 17 additions & 2 deletions metagpt/rag/factories/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,22 @@
import faiss
from llama_index.core import StorageContext, VectorStoreIndex
from llama_index.core.vector_stores.types import BasePydanticVectorStore
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
from llama_index.vector_stores.faiss import FaissVectorStore

from metagpt.rag.factories.base import ConfigBasedFactory
from metagpt.rag.retrievers.base import RAGRetriever
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
from metagpt.rag.retrievers.chroma_retriever import ChromaRetriever
from metagpt.rag.retrievers.es_retriever import ElasticsearchRetriever
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.schema import (
BaseRetrieverConfig,
BM25RetrieverConfig,
ChromaRetrieverConfig,
ElasticsearchKeywordRetrieverConfig,
ElasticsearchRetrieverConfig,
FAISSRetrieverConfig,
IndexRetrieverConfig,
)
Expand All @@ -32,6 +36,8 @@ def __init__(self):
FAISSRetrieverConfig: self._create_faiss_retriever,
BM25RetrieverConfig: self._create_bm25_retriever,
ChromaRetrieverConfig: self._create_chroma_retriever,
ElasticsearchRetrieverConfig: self._create_es_retriever,
ElasticsearchKeywordRetrieverConfig: self._create_es_retriever,
}
super().__init__(creators)

Expand All @@ -53,20 +59,29 @@ def _create_default(self, **kwargs) -> RAGRetriever:
def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever:
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)

return FAISSRetriever(**config.model_dump())

def _create_bm25_retriever(self, config: BM25RetrieverConfig, **kwargs) -> DynamicBM25Retriever:
config.index = copy.deepcopy(self._extract_index(config, **kwargs))
nodes = list(config.index.docstore.docs.values())
return DynamicBM25Retriever(nodes=nodes, **config.model_dump())

return DynamicBM25Retriever(nodes=list(config.index.docstore.docs.values()), **config.model_dump())

def _create_chroma_retriever(self, config: ChromaRetrieverConfig, **kwargs) -> ChromaRetriever:
db = chromadb.PersistentClient(path=str(config.persist_path))
chroma_collection = db.get_or_create_collection(config.collection_name)

vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)

return ChromaRetriever(**config.model_dump())

def _create_es_retriever(self, config: ElasticsearchRetrieverConfig, **kwargs) -> ElasticsearchRetriever:
vector_store = ElasticsearchStore(**config.store_config.model_dump())
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)

return ElasticsearchRetriever(**config.model_dump())

def _extract_index(self, config: BaseRetrieverConfig = None, **kwargs) -> VectorStoreIndex:
return self._val_from_config_or_kwargs("index", config, **kwargs)

Expand Down
Loading
Loading