Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: +unit test #663

Merged
merged 3 commits into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion metagpt/provider/azure_openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class AzureOpenAILLM(OpenAILLM):
def _init_client(self):
kwargs = self._make_client_kwargs()
# https://learn.microsoft.com/zh-cn/azure/ai-services/openai/how-to/migration?tabs=python-new%2Cdalle-fix
self.async_client = AsyncAzureOpenAI(**kwargs)
self.aclient = AsyncAzureOpenAI(**kwargs)
self.model = self.config.DEPLOYMENT_NAME # Used in _calc_usage & _cons_kwargs

def _make_client_kwargs(self) -> dict:
Expand Down
33 changes: 0 additions & 33 deletions metagpt/roles/role.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,16 +372,6 @@ async def _act(self) -> Message:

return msg

def _find_news(self, observed: list[Message], existed: list[Message]) -> list[Message]:
news = []
# Warning, remove `id` here to make it work for recover
observed_pure = [msg.dict(exclude={"id": True}) for msg in observed]
existed_pure = [msg.dict(exclude={"id": True}) for msg in existed]
for idx, new in enumerate(observed_pure):
if (new["cause_by"] in self.rc.watch or self.name in new["send_to"]) and new not in existed_pure:
news.append(observed[idx])
return news

async def _observe(self, ignore_memory=False) -> int:
"""Prepare new messages for processing from the message buffer and other sources."""
# Read unprocessed messages from the msg buffer.
Expand All @@ -407,29 +397,6 @@ async def _observe(self, ignore_memory=False) -> int:
logger.debug(f"{self._setting} observed: {news_text}")
return len(self.rc.news)

# async def _observe(self, ignore_memory=False) -> int:
# """Prepare new messages for processing from the message buffer and other sources."""
# # Read unprocessed messages from the msg buffer.
# news = self.rc.msg_buffer.pop_all()
# if self.recovered:
# news = [self.latest_observed_msg] if self.latest_observed_msg else []
# else:
# self.latest_observed_msg = news[-1] if len(news) > 0 else None # record the latest observed msg
#
# # Store the read messages in your own memory to prevent duplicate processing.
# old_messages = [] if ignore_memory else self.rc.memory.get()
# self.rc.memory.add_batch(news)
# # Filter out messages of interest.
# self.rc.news = self._find_news(news, old_messages)
#
# # Design Rules:
# # If you need to further categorize Message objects, you can do so using the Message.set_meta function.
# # msg_buffer is a receiving buffer, avoid adding message data and operations to msg_buffer.
# news_text = [f"{i.role}: {i.content[:20]}..." for i in self.rc.news]
# if news_text:
# logger.debug(f"{self._setting} observed: {news_text}")
# return len(self.rc.news)

def publish_message(self, msg):
"""If the role belongs to env, then the role's messages will be broadcast to env"""
if not msg:
Expand Down
16 changes: 14 additions & 2 deletions metagpt/utils/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
@Author : mashenquan
@File : redis.py
"""
from __future__ import annotations

import traceback
from datetime import timedelta
Expand All @@ -22,7 +23,7 @@ def __init__(self):
async def _connect(self, force=False):
if self._client and not force:
return True
if not CONFIG.REDIS_HOST or not CONFIG.REDIS_PORT or CONFIG.REDIS_DB is None or CONFIG.REDIS_PASSWORD is None:
if not self.is_configured:
return False

try:
Expand All @@ -37,7 +38,7 @@ async def _connect(self, force=False):
logger.warning(f"Redis initialization has failed:{e}")
return False

async def get(self, key: str) -> bytes:
async def get(self, key: str) -> bytes | None:
if not await self._connect() or not key:
return None
try:
Expand Down Expand Up @@ -65,3 +66,14 @@ async def close(self):
@property
def is_valid(self) -> bool:
return self._client is not None

@property
def is_configured(self) -> bool:
return bool(
CONFIG.REDIS_HOST
and CONFIG.REDIS_HOST != "YOUR_REDIS_HOST"
and CONFIG.REDIS_PORT
and CONFIG.REDIS_PORT != "YOUR_REDIS_PORT"
and CONFIG.REDIS_DB is not None
and CONFIG.REDIS_PASSWORD is not None
)
25 changes: 13 additions & 12 deletions metagpt/utils/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,16 +154,17 @@ async def cache(self, data: str, file_ext: str, format: str = "") -> str:

@property
def is_valid(self):
is_invalid = (
not CONFIG.S3_ACCESS_KEY
or CONFIG.S3_ACCESS_KEY == "YOUR_S3_ACCESS_KEY"
or not CONFIG.S3_SECRET_KEY
or CONFIG.S3_SECRET_KEY == "YOUR_S3_SECRET_KEY"
or not CONFIG.S3_ENDPOINT_URL
or CONFIG.S3_ENDPOINT_URL == "YOUR_S3_ENDPOINT_URL"
or not CONFIG.S3_BUCKET
or CONFIG.S3_BUCKET == "YOUR_S3_BUCKET"
return self.is_configured

@property
def is_configured(self) -> bool:
return bool(
CONFIG.S3_ACCESS_KEY
and CONFIG.S3_ACCESS_KEY != "YOUR_S3_ACCESS_KEY"
and CONFIG.S3_SECRET_KEY
and CONFIG.S3_SECRET_KEY != "YOUR_S3_SECRET_KEY"
and CONFIG.S3_ENDPOINT_URL
and CONFIG.S3_ENDPOINT_URL != "YOUR_S3_ENDPOINT_URL"
and CONFIG.S3_BUCKET
and CONFIG.S3_BUCKET != "YOUR_S3_BUCKET"
)
if is_invalid:
logger.info("S3 is invalid")
return not is_invalid
92 changes: 92 additions & 0 deletions tests/data/demo_project/game.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
## game.py

import random
from typing import List, Tuple


class Game:
def __init__(self):
self.grid: List[List[int]] = [[0 for _ in range(4)] for _ in range(4)]
self.score: int = 0
self.game_over: bool = False

def reset_game(self):
self.grid = [[0 for _ in range(4)] for _ in range(4)]
self.score = 0
self.game_over = False
self.add_new_tile()
self.add_new_tile()

def move(self, direction: str):
if direction == "up":
self._move_up()
elif direction == "down":
self._move_down()
elif direction == "left":
self._move_left()
elif direction == "right":
self._move_right()

def is_game_over(self) -> bool:
for i in range(4):
for j in range(4):
if self.grid[i][j] == 0:
return False
if j < 3 and self.grid[i][j] == self.grid[i][j + 1]:
return False
if i < 3 and self.grid[i][j] == self.grid[i + 1][j]:
return False
return True

def get_empty_cells(self) -> List[Tuple[int, int]]:
empty_cells = []
for i in range(4):
for j in range(4):
if self.grid[i][j] == 0:
empty_cells.append((i, j))
return empty_cells

def add_new_tile(self):
empty_cells = self.get_empty_cells()
if empty_cells:
x, y = random.choice(empty_cells)
self.grid[x][y] = 2 if random.random() < 0.9 else 4

def get_score(self) -> int:
return self.score

def _move_up(self):
for j in range(4):
for i in range(1, 4):
if self.grid[i][j] != 0:
for k in range(i, 0, -1):
if self.grid[k - 1][j] == 0:
self.grid[k - 1][j] = self.grid[k][j]
self.grid[k][j] = 0

def _move_down(self):
for j in range(4):
for i in range(2, -1, -1):
if self.grid[i][j] != 0:
for k in range(i, 3):
if self.grid[k + 1][j] == 0:
self.grid[k + 1][j] = self.grid[k][j]
self.grid[k][j] = 0

def _move_left(self):
for i in range(4):
for j in range(1, 4):
if self.grid[i][j] != 0:
for k in range(j, 0, -1):
if self.grid[i][k - 1] == 0:
self.grid[i][k - 1] = self.grid[i][k]
self.grid[i][k] = 0

def _move_right(self):
for i in range(4):
for j in range(2, -1, -1):
if self.grid[i][j] != 0:
for k in range(j, 3):
if self.grid[i][k + 1] == 0:
self.grid[i][k + 1] = self.grid[i][k]
self.grid[i][k] = 0
12 changes: 8 additions & 4 deletions tests/metagpt/learn/test_text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,24 @@


@pytest.mark.asyncio
async def test():
async def test_metagpt_llm():
# Prerequisites
assert CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL
assert CONFIG.OPENAI_API_KEY

data = await text_to_image("Panda emoji", size_type="512x512")
assert "base64" in data or "http" in data
key = CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL
CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL = None

# Mock session env
old_options = CONFIG.options.copy()
new_options = old_options.copy()
new_options["METAGPT_TEXT_TO_IMAGE_MODEL_URL"] = None
CONFIG.set_context(new_options)
try:
data = await text_to_image("Panda emoji", size_type="512x512")
assert "base64" in data or "http" in data
finally:
CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL = key
CONFIG.set_context(old_options)


if __name__ == "__main__":
Expand Down
9 changes: 6 additions & 3 deletions tests/metagpt/learn/test_text_to_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,16 @@ async def test_text_to_speech():
assert "base64" in data or "http" in data

# test iflytek
key = CONFIG.AZURE_TTS_SUBSCRIPTION_KEY
CONFIG.AZURE_TTS_SUBSCRIPTION_KEY = ""
## Mock session env
old_options = CONFIG.options.copy()
new_options = old_options.copy()
new_options["AZURE_TTS_SUBSCRIPTION_KEY"] = ""
CONFIG.set_context(new_options)
try:
data = await text_to_speech("panda emoji")
assert "base64" in data or "http" in data
finally:
CONFIG.AZURE_TTS_SUBSCRIPTION_KEY = key
CONFIG.set_context(old_options)


if __name__ == "__main__":
Expand Down
26 changes: 23 additions & 3 deletions tests/metagpt/roles/test_architect.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,38 @@
@Modified By: mashenquan, 2023-11-1. In accordance with Chapter 2.2.1 and 2.2.2 of RFC 116, utilize the new message
distribution feature for message handling.
"""
import uuid

import pytest

from metagpt.actions import WriteDesign, WritePRD
from metagpt.config import CONFIG
from metagpt.const import PRDS_FILE_REPO
from metagpt.logs import logger
from metagpt.roles import Architect
from metagpt.schema import Message
from metagpt.utils.common import any_to_str, awrite
from tests.metagpt.roles.mock import MockMessages


@pytest.mark.asyncio
async def test_architect():
# FIXME: make git as env? Or should we support
# Prerequisites
filename = uuid.uuid4().hex + ".json"
await awrite(CONFIG.git_repo.workdir / PRDS_FILE_REPO / filename, data=MockMessages.prd.content)

role = Architect()
role.put_message(MockMessages.req)
rsp = await role.run(MockMessages.prd)
rsp = await role.run(with_message=Message(content="", cause_by=WritePRD))
logger.info(rsp)
assert len(rsp.content) > 0
assert rsp.cause_by == any_to_str(WriteDesign)

# test update
rsp = await role.run(with_message=Message(content="", cause_by=WritePRD))
assert rsp
assert rsp.cause_by == any_to_str(WriteDesign)
assert len(rsp.content) > 0


if __name__ == "__main__":
pytest.main([__file__, "-s"])
56 changes: 56 additions & 0 deletions tests/metagpt/roles/test_qa_engineer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,59 @@
@Author : alexanderwu
@File : test_qa_engineer.py
"""
from pathlib import Path
from typing import List

import pytest
from pydantic import Field

from metagpt.actions import DebugError, RunCode, WriteTest
from metagpt.actions.summarize_code import SummarizeCode
from metagpt.config import CONFIG
from metagpt.environment import Environment
from metagpt.roles import QaEngineer
from metagpt.schema import Message
from metagpt.utils.common import any_to_str, aread, awrite


async def test_qa():
# Prerequisites
demo_path = Path(__file__).parent / "../../data/demo_project"
CONFIG.src_workspace = Path(CONFIG.git_repo.workdir) / "qa/game_2048"
data = await aread(filename=demo_path / "game.py", encoding="utf-8")
await awrite(filename=CONFIG.src_workspace / "game.py", data=data, encoding="utf-8")
await awrite(filename=Path(CONFIG.git_repo.workdir) / "requirements.txt", data="")

class MockEnv(Environment):
msgs: List[Message] = Field(default_factory=list)

def publish_message(self, message: Message, peekable: bool = True) -> bool:
self.msgs.append(message)
return True

env = MockEnv()

role = QaEngineer()
role.set_env(env)
await role.run(with_message=Message(content="", cause_by=SummarizeCode))
assert env.msgs
assert env.msgs[0].cause_by == any_to_str(WriteTest)
msg = env.msgs[0]
env.msgs.clear()
await role.run(with_message=msg)
assert env.msgs
assert env.msgs[0].cause_by == any_to_str(RunCode)
msg = env.msgs[0]
env.msgs.clear()
await role.run(with_message=msg)
assert env.msgs
assert env.msgs[0].cause_by == any_to_str(DebugError)
msg = env.msgs[0]
env.msgs.clear()
role.test_round_allowed = 1
rsp = await role.run(with_message=msg)
assert "Exceeding" in rsp.content


if __name__ == "__main__":
pytest.main([__file__, "-s"])
Loading
Loading