diff --git a/metagpt/roles/tutorial_assistant.py b/metagpt/roles/tutorial_assistant.py index 5d1323371..bedf8b3be 100644 --- a/metagpt/roles/tutorial_assistant.py +++ b/metagpt/roles/tutorial_assistant.py @@ -90,4 +90,5 @@ async def react(self) -> Message: msg = await super().react() root_path = TUTORIAL_PATH / datetime.now().strftime("%Y-%m-%d_%H-%M-%S") await File.write(root_path, f"{self.main_title}.md", self.total_content.encode("utf-8")) + msg.content = str(root_path / f"{self.main_title}.md") return msg diff --git a/tests/metagpt/actions/test_invoice_ocr.py b/tests/metagpt/actions/test_invoice_ocr.py index 7f16aa9a4..b3b93cf9f 100644 --- a/tests/metagpt/actions/test_invoice_ocr.py +++ b/tests/metagpt/actions/test_invoice_ocr.py @@ -6,7 +6,7 @@ @Author : Stitch-z @File : test_invoice_ocr.py """ - +import json import os from pathlib import Path @@ -34,7 +34,10 @@ async def test_invoice_ocr(invoice_path: str): @pytest.mark.parametrize( ("invoice_path", "expected_result"), [ - ("../../data/invoices/invoice-1.pdf", [{"收款人": "小明", "城市": "深圳市", "总费用/元": "412.00", "开票日期": "2023年02月03日"}]), + ( + "../../data/invoices/invoice-1.pdf", + [{"收款人": "小明", "城市": "深圳市", "总费用/元": "412.00", "开票日期": "2023年02月03日"}] + ), ], ) async def test_generate_table(invoice_path: str, expected_result: list[dict]): @@ -42,7 +45,7 @@ async def test_generate_table(invoice_path: str, expected_result: list[dict]): filename = os.path.basename(invoice_path) ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename) table_data = await GenerateTable().run(ocr_results=ocr_result, filename=filename) - assert table_data == expected_result + assert json.dumps(table_data) == json.dumps(expected_result) @pytest.mark.asyncio diff --git a/tests/metagpt/roles/test_invoice_ocr_assistant.py b/tests/metagpt/roles/test_invoice_ocr_assistant.py index ab3092004..48abb9eb8 100644 --- a/tests/metagpt/roles/test_invoice_ocr_assistant.py +++ b/tests/metagpt/roles/test_invoice_ocr_assistant.py @@ -7,7 +7,6 @@ @File : test_invoice_ocr_assistant.py """ -import json from pathlib import Path import pandas as pd @@ -25,39 +24,36 @@ "Invoicing date", Path("../../data/invoices/invoice-1.pdf"), Path("../../../data/invoice_table/invoice-1.xlsx"), - [{"收款人": "小明", "城市": "深圳市", "总费用/元": 412.00, "开票日期": "2023年02月03日"}], + {"收款人": "小明", "城市": "深圳", "总费用/元": 412.00, "开票日期": "2023年02月03日"}, ), ( "Invoicing date", Path("../../data/invoices/invoice-2.png"), Path("../../../data/invoice_table/invoice-2.xlsx"), - [{"收款人": "铁头", "城市": "广州市", "总费用/元": 898.00, "开票日期": "2023年03月17日"}], + {"收款人": "铁头", "城市": "广州", "总费用/元": 898.00, "开票日期": "2023年03月17日"}, ), ( "Invoicing date", Path("../../data/invoices/invoice-3.jpg"), Path("../../../data/invoice_table/invoice-3.xlsx"), - [{"收款人": "夏天", "城市": "福州市", "总费用/元": 2462.00, "开票日期": "2023年08月26日"}], - ), - ( - "Invoicing date", - Path("../../data/invoices/invoice-4.zip"), - Path("../../../data/invoice_table/invoice-4.xlsx"), - [ - {"收款人": "小明", "城市": "深圳市", "总费用/元": 412.00, "开票日期": "2023年02月03日"}, - {"收款人": "铁头", "城市": "广州市", "总费用/元": 898.00, "开票日期": "2023年03月17日"}, - {"收款人": "夏天", "城市": "福州市", "总费用/元": 2462.00, "开票日期": "2023年08月26日"}, - ], - ), + {"收款人": "夏天", "城市": "福州", "总费用/元": 2462.00, "开票日期": "2023年08月26日"}, + ) ], ) async def test_invoice_ocr_assistant( - query: str, invoice_path: Path, invoice_table_path: Path, expected_result: list[dict] + query: str, invoice_path: Path, invoice_table_path: Path, expected_result: dict ): invoice_path = Path.cwd() / invoice_path role = InvoiceOCRAssistant() await role.run(Message(content=query, instruct_content=InvoicePath(file_path=invoice_path))) invoice_table_path = Path.cwd() / invoice_table_path df = pd.read_excel(invoice_table_path) - dict_result = df.to_dict(orient="records") - assert json.dumps(dict_result) == json.dumps(expected_result) + resp = df.to_dict(orient="records") + assert isinstance(resp, list) + assert len(resp) == 1 + resp = resp[0] + assert expected_result["收款人"] == resp["收款人"] + assert expected_result["城市"] in resp["城市"] + assert int(expected_result["总费用/元"]) == int(resp["总费用/元"]) + assert expected_result["开票日期"] == resp["开票日期"] + diff --git a/tests/metagpt/roles/test_tutorial_assistant.py b/tests/metagpt/roles/test_tutorial_assistant.py index 3158a5fc1..4455e1bf6 100644 --- a/tests/metagpt/roles/test_tutorial_assistant.py +++ b/tests/metagpt/roles/test_tutorial_assistant.py @@ -6,7 +6,7 @@ @File : test_tutorial_assistant.py """ import shutil - +import aiofiles import pytest from metagpt.const import TUTORIAL_PATH @@ -14,20 +14,17 @@ @pytest.mark.asyncio -@pytest.mark.parametrize(("language", "topic"), [("Chinese", "Write a tutorial about Python")]) +@pytest.mark.parametrize(("language", "topic"), [("Chinese", "Write a tutorial about pip")]) async def test_tutorial_assistant(language: str, topic: str): shutil.rmtree(path=TUTORIAL_PATH, ignore_errors=True) - topic = "Write a tutorial about MySQL" role = TutorialAssistant(language=language) msg = await role.run(topic) - assert "MySQL" in msg.content assert TUTORIAL_PATH.exists() - # filename = msg.content - # title = filename.split("/")[-1].split(".")[0] - # async with aiofiles.open(filename, mode="r") as reader: - # content = await reader.read() - # assert content.startswith(f"# {title}") + filename = msg.content + async with aiofiles.open(filename, mode="r", encoding="utf-8") as reader: + content = await reader.read() + assert "pip" in content if __name__ == "__main__":