Skip to content

Commit

Permalink
Merge pull request #633 from Stitch-z/update-unit-test
Browse files Browse the repository at this point in the history
Update: improve the unit testing of technical tutorial assistants and OCR assistants.
  • Loading branch information
geekan authored Dec 26, 2023
2 parents 5c7cdf5 + bf0f6bd commit d244c64
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 30 deletions.
1 change: 1 addition & 0 deletions metagpt/roles/tutorial_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,5 @@ async def react(self) -> Message:
msg = await super().react()
root_path = TUTORIAL_PATH / datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
await File.write(root_path, f"{self.main_title}.md", self.total_content.encode("utf-8"))
msg.content = str(root_path / f"{self.main_title}.md")
return msg
9 changes: 6 additions & 3 deletions tests/metagpt/actions/test_invoice_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
@Author : Stitch-z
@File : test_invoice_ocr.py
"""

import json
import os
from pathlib import Path

Expand Down Expand Up @@ -34,15 +34,18 @@ async def test_invoice_ocr(invoice_path: str):
@pytest.mark.parametrize(
("invoice_path", "expected_result"),
[
("../../data/invoices/invoice-1.pdf", [{"收款人": "小明", "城市": "深圳市", "总费用/元": "412.00", "开票日期": "2023年02月03日"}]),
(
"../../data/invoices/invoice-1.pdf",
[{"收款人": "小明", "城市": "深圳市", "总费用/元": "412.00", "开票日期": "2023年02月03日"}]
),
],
)
async def test_generate_table(invoice_path: str, expected_result: list[dict]):
invoice_path = os.path.abspath(os.path.join(os.getcwd(), invoice_path))
filename = os.path.basename(invoice_path)
ocr_result = await InvoiceOCR().run(file_path=Path(invoice_path), filename=filename)
table_data = await GenerateTable().run(ocr_results=ocr_result, filename=filename)
assert table_data == expected_result
assert json.dumps(table_data) == json.dumps(expected_result)


@pytest.mark.asyncio
Expand Down
32 changes: 14 additions & 18 deletions tests/metagpt/roles/test_invoice_ocr_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
@File : test_invoice_ocr_assistant.py
"""

import json
from pathlib import Path

import pandas as pd
Expand All @@ -25,39 +24,36 @@
"Invoicing date",
Path("../../data/invoices/invoice-1.pdf"),
Path("../../../data/invoice_table/invoice-1.xlsx"),
[{"收款人": "小明", "城市": "深圳市", "总费用/元": 412.00, "开票日期": "2023年02月03日"}],
{"收款人": "小明", "城市": "深圳", "总费用/元": 412.00, "开票日期": "2023年02月03日"},
),
(
"Invoicing date",
Path("../../data/invoices/invoice-2.png"),
Path("../../../data/invoice_table/invoice-2.xlsx"),
[{"收款人": "铁头", "城市": "广州市", "总费用/元": 898.00, "开票日期": "2023年03月17日"}],
{"收款人": "铁头", "城市": "广州", "总费用/元": 898.00, "开票日期": "2023年03月17日"},
),
(
"Invoicing date",
Path("../../data/invoices/invoice-3.jpg"),
Path("../../../data/invoice_table/invoice-3.xlsx"),
[{"收款人": "夏天", "城市": "福州市", "总费用/元": 2462.00, "开票日期": "2023年08月26日"}],
),
(
"Invoicing date",
Path("../../data/invoices/invoice-4.zip"),
Path("../../../data/invoice_table/invoice-4.xlsx"),
[
{"收款人": "小明", "城市": "深圳市", "总费用/元": 412.00, "开票日期": "2023年02月03日"},
{"收款人": "铁头", "城市": "广州市", "总费用/元": 898.00, "开票日期": "2023年03月17日"},
{"收款人": "夏天", "城市": "福州市", "总费用/元": 2462.00, "开票日期": "2023年08月26日"},
],
),
{"收款人": "夏天", "城市": "福州", "总费用/元": 2462.00, "开票日期": "2023年08月26日"},
)
],
)
async def test_invoice_ocr_assistant(
query: str, invoice_path: Path, invoice_table_path: Path, expected_result: list[dict]
query: str, invoice_path: Path, invoice_table_path: Path, expected_result: dict
):
invoice_path = Path.cwd() / invoice_path
role = InvoiceOCRAssistant()
await role.run(Message(content=query, instruct_content=InvoicePath(file_path=invoice_path)))
invoice_table_path = Path.cwd() / invoice_table_path
df = pd.read_excel(invoice_table_path)
dict_result = df.to_dict(orient="records")
assert json.dumps(dict_result) == json.dumps(expected_result)
resp = df.to_dict(orient="records")
assert isinstance(resp, list)
assert len(resp) == 1
resp = resp[0]
assert expected_result["收款人"] == resp["收款人"]
assert expected_result["城市"] in resp["城市"]
assert int(expected_result["总费用/元"]) == int(resp["总费用/元"])
assert expected_result["开票日期"] == resp["开票日期"]

15 changes: 6 additions & 9 deletions tests/metagpt/roles/test_tutorial_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,25 @@
@File : test_tutorial_assistant.py
"""
import shutil

import aiofiles
import pytest

from metagpt.const import TUTORIAL_PATH
from metagpt.roles.tutorial_assistant import TutorialAssistant


@pytest.mark.asyncio
@pytest.mark.parametrize(("language", "topic"), [("Chinese", "Write a tutorial about Python")])
@pytest.mark.parametrize(("language", "topic"), [("Chinese", "Write a tutorial about pip")])
async def test_tutorial_assistant(language: str, topic: str):
shutil.rmtree(path=TUTORIAL_PATH, ignore_errors=True)

topic = "Write a tutorial about MySQL"
role = TutorialAssistant(language=language)
msg = await role.run(topic)
assert "MySQL" in msg.content
assert TUTORIAL_PATH.exists()
# filename = msg.content
# title = filename.split("/")[-1].split(".")[0]
# async with aiofiles.open(filename, mode="r") as reader:
# content = await reader.read()
# assert content.startswith(f"# {title}")
filename = msg.content
async with aiofiles.open(filename, mode="r", encoding="utf-8") as reader:
content = await reader.read()
assert "pip" in content


if __name__ == "__main__":
Expand Down

0 comments on commit d244c64

Please sign in to comment.