Skip to content

Commit

Permalink
Merge pull request #663 from iorisa/feature/unittest
Browse files Browse the repository at this point in the history
feat: +unit test
  • Loading branch information
geekan authored Jan 2, 2024
2 parents ea64e6a + b7d74c6 commit d2260a5
Show file tree
Hide file tree
Showing 12 changed files with 306 additions and 66 deletions.
33 changes: 0 additions & 33 deletions metagpt/roles/role.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,16 +372,6 @@ async def _act(self) -> Message:

return msg

def _find_news(self, observed: list[Message], existed: list[Message]) -> list[Message]:
news = []
# Warning, remove `id` here to make it work for recover
observed_pure = [msg.dict(exclude={"id": True}) for msg in observed]
existed_pure = [msg.dict(exclude={"id": True}) for msg in existed]
for idx, new in enumerate(observed_pure):
if (new["cause_by"] in self.rc.watch or self.name in new["send_to"]) and new not in existed_pure:
news.append(observed[idx])
return news

async def _observe(self, ignore_memory=False) -> int:
"""Prepare new messages for processing from the message buffer and other sources."""
# Read unprocessed messages from the msg buffer.
Expand All @@ -407,29 +397,6 @@ async def _observe(self, ignore_memory=False) -> int:
logger.debug(f"{self._setting} observed: {news_text}")
return len(self.rc.news)

# async def _observe(self, ignore_memory=False) -> int:
# """Prepare new messages for processing from the message buffer and other sources."""
# # Read unprocessed messages from the msg buffer.
# news = self.rc.msg_buffer.pop_all()
# if self.recovered:
# news = [self.latest_observed_msg] if self.latest_observed_msg else []
# else:
# self.latest_observed_msg = news[-1] if len(news) > 0 else None # record the latest observed msg
#
# # Store the read messages in your own memory to prevent duplicate processing.
# old_messages = [] if ignore_memory else self.rc.memory.get()
# self.rc.memory.add_batch(news)
# # Filter out messages of interest.
# self.rc.news = self._find_news(news, old_messages)
#
# # Design Rules:
# # If you need to further categorize Message objects, you can do so using the Message.set_meta function.
# # msg_buffer is a receiving buffer, avoid adding message data and operations to msg_buffer.
# news_text = [f"{i.role}: {i.content[:20]}..." for i in self.rc.news]
# if news_text:
# logger.debug(f"{self._setting} observed: {news_text}")
# return len(self.rc.news)

def publish_message(self, msg):
"""If the role belongs to env, then the role's messages will be broadcast to env"""
if not msg:
Expand Down
16 changes: 14 additions & 2 deletions metagpt/utils/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
@Author : mashenquan
@File : redis.py
"""
from __future__ import annotations

import traceback
from datetime import timedelta
Expand All @@ -22,7 +23,7 @@ def __init__(self):
async def _connect(self, force=False):
if self._client and not force:
return True
if not CONFIG.REDIS_HOST or not CONFIG.REDIS_PORT or CONFIG.REDIS_DB is None or CONFIG.REDIS_PASSWORD is None:
if not self.is_configured:
return False

try:
Expand All @@ -37,7 +38,7 @@ async def _connect(self, force=False):
logger.warning(f"Redis initialization has failed:{e}")
return False

async def get(self, key: str) -> bytes:
async def get(self, key: str) -> bytes | None:
if not await self._connect() or not key:
return None
try:
Expand Down Expand Up @@ -65,3 +66,14 @@ async def close(self):
@property
def is_valid(self) -> bool:
return self._client is not None

@property
def is_configured(self) -> bool:
return bool(
CONFIG.REDIS_HOST
and CONFIG.REDIS_HOST != "YOUR_REDIS_HOST"
and CONFIG.REDIS_PORT
and CONFIG.REDIS_PORT != "YOUR_REDIS_PORT"
and CONFIG.REDIS_DB is not None
and CONFIG.REDIS_PASSWORD is not None
)
25 changes: 13 additions & 12 deletions metagpt/utils/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,16 +154,17 @@ async def cache(self, data: str, file_ext: str, format: str = "") -> str:

@property
def is_valid(self):
is_invalid = (
not CONFIG.S3_ACCESS_KEY
or CONFIG.S3_ACCESS_KEY == "YOUR_S3_ACCESS_KEY"
or not CONFIG.S3_SECRET_KEY
or CONFIG.S3_SECRET_KEY == "YOUR_S3_SECRET_KEY"
or not CONFIG.S3_ENDPOINT_URL
or CONFIG.S3_ENDPOINT_URL == "YOUR_S3_ENDPOINT_URL"
or not CONFIG.S3_BUCKET
or CONFIG.S3_BUCKET == "YOUR_S3_BUCKET"
return self.is_configured

@property
def is_configured(self) -> bool:
return bool(
CONFIG.S3_ACCESS_KEY
and CONFIG.S3_ACCESS_KEY != "YOUR_S3_ACCESS_KEY"
and CONFIG.S3_SECRET_KEY
and CONFIG.S3_SECRET_KEY != "YOUR_S3_SECRET_KEY"
and CONFIG.S3_ENDPOINT_URL
and CONFIG.S3_ENDPOINT_URL != "YOUR_S3_ENDPOINT_URL"
and CONFIG.S3_BUCKET
and CONFIG.S3_BUCKET != "YOUR_S3_BUCKET"
)
if is_invalid:
logger.info("S3 is invalid")
return not is_invalid
92 changes: 92 additions & 0 deletions tests/data/demo_project/game.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
## game.py

import random
from typing import List, Tuple


class Game:
def __init__(self):
self.grid: List[List[int]] = [[0 for _ in range(4)] for _ in range(4)]
self.score: int = 0
self.game_over: bool = False

def reset_game(self):
self.grid = [[0 for _ in range(4)] for _ in range(4)]
self.score = 0
self.game_over = False
self.add_new_tile()
self.add_new_tile()

def move(self, direction: str):
if direction == "up":
self._move_up()
elif direction == "down":
self._move_down()
elif direction == "left":
self._move_left()
elif direction == "right":
self._move_right()

def is_game_over(self) -> bool:
for i in range(4):
for j in range(4):
if self.grid[i][j] == 0:
return False
if j < 3 and self.grid[i][j] == self.grid[i][j + 1]:
return False
if i < 3 and self.grid[i][j] == self.grid[i + 1][j]:
return False
return True

def get_empty_cells(self) -> List[Tuple[int, int]]:
empty_cells = []
for i in range(4):
for j in range(4):
if self.grid[i][j] == 0:
empty_cells.append((i, j))
return empty_cells

def add_new_tile(self):
empty_cells = self.get_empty_cells()
if empty_cells:
x, y = random.choice(empty_cells)
self.grid[x][y] = 2 if random.random() < 0.9 else 4

def get_score(self) -> int:
return self.score

def _move_up(self):
for j in range(4):
for i in range(1, 4):
if self.grid[i][j] != 0:
for k in range(i, 0, -1):
if self.grid[k - 1][j] == 0:
self.grid[k - 1][j] = self.grid[k][j]
self.grid[k][j] = 0

def _move_down(self):
for j in range(4):
for i in range(2, -1, -1):
if self.grid[i][j] != 0:
for k in range(i, 3):
if self.grid[k + 1][j] == 0:
self.grid[k + 1][j] = self.grid[k][j]
self.grid[k][j] = 0

def _move_left(self):
for i in range(4):
for j in range(1, 4):
if self.grid[i][j] != 0:
for k in range(j, 0, -1):
if self.grid[i][k - 1] == 0:
self.grid[i][k - 1] = self.grid[i][k]
self.grid[i][k] = 0

def _move_right(self):
for i in range(4):
for j in range(2, -1, -1):
if self.grid[i][j] != 0:
for k in range(j, 3):
if self.grid[i][k + 1] == 0:
self.grid[i][k + 1] = self.grid[i][k]
self.grid[i][k] = 0
12 changes: 8 additions & 4 deletions tests/metagpt/learn/test_text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,24 @@


@pytest.mark.asyncio
async def test():
async def test_metagpt_llm():
# Prerequisites
assert CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL
assert CONFIG.OPENAI_API_KEY

data = await text_to_image("Panda emoji", size_type="512x512")
assert "base64" in data or "http" in data
key = CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL
CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL = None

# Mock session env
old_options = CONFIG.options.copy()
new_options = old_options.copy()
new_options["METAGPT_TEXT_TO_IMAGE_MODEL_URL"] = None
CONFIG.set_context(new_options)
try:
data = await text_to_image("Panda emoji", size_type="512x512")
assert "base64" in data or "http" in data
finally:
CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL = key
CONFIG.set_context(old_options)


if __name__ == "__main__":
Expand Down
9 changes: 6 additions & 3 deletions tests/metagpt/learn/test_text_to_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,16 @@ async def test_text_to_speech():
assert "base64" in data or "http" in data

# test iflytek
key = CONFIG.AZURE_TTS_SUBSCRIPTION_KEY
CONFIG.AZURE_TTS_SUBSCRIPTION_KEY = ""
## Mock session env
old_options = CONFIG.options.copy()
new_options = old_options.copy()
new_options["AZURE_TTS_SUBSCRIPTION_KEY"] = ""
CONFIG.set_context(new_options)
try:
data = await text_to_speech("panda emoji")
assert "base64" in data or "http" in data
finally:
CONFIG.AZURE_TTS_SUBSCRIPTION_KEY = key
CONFIG.set_context(old_options)


if __name__ == "__main__":
Expand Down
26 changes: 23 additions & 3 deletions tests/metagpt/roles/test_architect.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,38 @@
@Modified By: mashenquan, 2023-11-1. In accordance with Chapter 2.2.1 and 2.2.2 of RFC 116, utilize the new message
distribution feature for message handling.
"""
import uuid

import pytest

from metagpt.actions import WriteDesign, WritePRD
from metagpt.config import CONFIG
from metagpt.const import PRDS_FILE_REPO
from metagpt.logs import logger
from metagpt.roles import Architect
from metagpt.schema import Message
from metagpt.utils.common import any_to_str, awrite
from tests.metagpt.roles.mock import MockMessages


@pytest.mark.asyncio
async def test_architect():
# FIXME: make git as env? Or should we support
# Prerequisites
filename = uuid.uuid4().hex + ".json"
await awrite(CONFIG.git_repo.workdir / PRDS_FILE_REPO / filename, data=MockMessages.prd.content)

role = Architect()
role.put_message(MockMessages.req)
rsp = await role.run(MockMessages.prd)
rsp = await role.run(with_message=Message(content="", cause_by=WritePRD))
logger.info(rsp)
assert len(rsp.content) > 0
assert rsp.cause_by == any_to_str(WriteDesign)

# test update
rsp = await role.run(with_message=Message(content="", cause_by=WritePRD))
assert rsp
assert rsp.cause_by == any_to_str(WriteDesign)
assert len(rsp.content) > 0


if __name__ == "__main__":
pytest.main([__file__, "-s"])
56 changes: 56 additions & 0 deletions tests/metagpt/roles/test_qa_engineer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,59 @@
@Author : alexanderwu
@File : test_qa_engineer.py
"""
from pathlib import Path
from typing import List

import pytest
from pydantic import Field

from metagpt.actions import DebugError, RunCode, WriteTest
from metagpt.actions.summarize_code import SummarizeCode
from metagpt.config import CONFIG
from metagpt.environment import Environment
from metagpt.roles import QaEngineer
from metagpt.schema import Message
from metagpt.utils.common import any_to_str, aread, awrite


async def test_qa():
# Prerequisites
demo_path = Path(__file__).parent / "../../data/demo_project"
CONFIG.src_workspace = Path(CONFIG.git_repo.workdir) / "qa/game_2048"
data = await aread(filename=demo_path / "game.py", encoding="utf-8")
await awrite(filename=CONFIG.src_workspace / "game.py", data=data, encoding="utf-8")
await awrite(filename=Path(CONFIG.git_repo.workdir) / "requirements.txt", data="")

class MockEnv(Environment):
msgs: List[Message] = Field(default_factory=list)

def publish_message(self, message: Message, peekable: bool = True) -> bool:
self.msgs.append(message)
return True

env = MockEnv()

role = QaEngineer()
role.set_env(env)
await role.run(with_message=Message(content="", cause_by=SummarizeCode))
assert env.msgs
assert env.msgs[0].cause_by == any_to_str(WriteTest)
msg = env.msgs[0]
env.msgs.clear()
await role.run(with_message=msg)
assert env.msgs
assert env.msgs[0].cause_by == any_to_str(RunCode)
msg = env.msgs[0]
env.msgs.clear()
await role.run(with_message=msg)
assert env.msgs
assert env.msgs[0].cause_by == any_to_str(DebugError)
msg = env.msgs[0]
env.msgs.clear()
role.test_round_allowed = 1
rsp = await role.run(with_message=msg)
assert "Exceeding" in rsp.content


if __name__ == "__main__":
pytest.main([__file__, "-s"])
Loading

0 comments on commit d2260a5

Please sign in to comment.