Skip to content

Commit

Permalink
fixbug: azure openai
Browse files Browse the repository at this point in the history
  • Loading branch information
莘权 马 committed Jan 2, 2024
1 parent d9c5809 commit b7d74c6
Show file tree
Hide file tree
Showing 8 changed files with 83 additions and 44 deletions.
2 changes: 1 addition & 1 deletion metagpt/provider/azure_openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class AzureOpenAILLM(OpenAILLM):
def _init_client(self):
kwargs = self._make_client_kwargs()
# https://learn.microsoft.com/zh-cn/azure/ai-services/openai/how-to/migration?tabs=python-new%2Cdalle-fix
self.async_client = AsyncAzureOpenAI(**kwargs)
self.aclient = AsyncAzureOpenAI(**kwargs)
self.model = self.config.DEPLOYMENT_NAME # Used in _calc_usage & _cons_kwargs

def _make_client_kwargs(self) -> dict:
Expand Down
21 changes: 12 additions & 9 deletions metagpt/utils/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,7 @@ def __init__(self):
async def _connect(self, force=False):
if self._client and not force:
return True
is_ready = (
CONFIG.REDIS_HOST
and CONFIG.REDIS_HOST != "YOUR_REDIS_HOST"
and CONFIG.REDIS_PORT
and CONFIG.REDIS_PORT != "YOUR_REDIS_PORT"
and CONFIG.REDIS_DB is not None
and CONFIG.REDIS_PASSWORD is not None
)
if not is_ready:
if not self.is_configured:
return False

try:
Expand Down Expand Up @@ -74,3 +66,14 @@ async def close(self):
@property
def is_valid(self) -> bool:
return self._client is not None

@property
def is_configured(self) -> bool:
return bool(
CONFIG.REDIS_HOST
and CONFIG.REDIS_HOST != "YOUR_REDIS_HOST"
and CONFIG.REDIS_PORT
and CONFIG.REDIS_PORT != "YOUR_REDIS_PORT"
and CONFIG.REDIS_DB is not None
and CONFIG.REDIS_PASSWORD is not None
)
25 changes: 13 additions & 12 deletions metagpt/utils/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,16 +154,17 @@ async def cache(self, data: str, file_ext: str, format: str = "") -> str:

@property
def is_valid(self):
is_invalid = (
not CONFIG.S3_ACCESS_KEY
or CONFIG.S3_ACCESS_KEY == "YOUR_S3_ACCESS_KEY"
or not CONFIG.S3_SECRET_KEY
or CONFIG.S3_SECRET_KEY == "YOUR_S3_SECRET_KEY"
or not CONFIG.S3_ENDPOINT_URL
or CONFIG.S3_ENDPOINT_URL == "YOUR_S3_ENDPOINT_URL"
or not CONFIG.S3_BUCKET
or CONFIG.S3_BUCKET == "YOUR_S3_BUCKET"
return self.is_configured

@property
def is_configured(self) -> bool:
return bool(
CONFIG.S3_ACCESS_KEY
and CONFIG.S3_ACCESS_KEY != "YOUR_S3_ACCESS_KEY"
and CONFIG.S3_SECRET_KEY
and CONFIG.S3_SECRET_KEY != "YOUR_S3_SECRET_KEY"
and CONFIG.S3_ENDPOINT_URL
and CONFIG.S3_ENDPOINT_URL != "YOUR_S3_ENDPOINT_URL"
and CONFIG.S3_BUCKET
and CONFIG.S3_BUCKET != "YOUR_S3_BUCKET"
)
if is_invalid:
logger.info("S3 is invalid")
return not is_invalid
12 changes: 8 additions & 4 deletions tests/metagpt/learn/test_text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,24 @@


@pytest.mark.asyncio
async def test():
async def test_metagpt_llm():
# Prerequisites
assert CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL
assert CONFIG.OPENAI_API_KEY

data = await text_to_image("Panda emoji", size_type="512x512")
assert "base64" in data or "http" in data
key = CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL
CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL = None

# Mock session env
old_options = CONFIG.options.copy()
new_options = old_options.copy()
new_options["METAGPT_TEXT_TO_IMAGE_MODEL_URL"] = None
CONFIG.set_context(new_options)
try:
data = await text_to_image("Panda emoji", size_type="512x512")
assert "base64" in data or "http" in data
finally:
CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL = key
CONFIG.set_context(old_options)


if __name__ == "__main__":
Expand Down
9 changes: 6 additions & 3 deletions tests/metagpt/learn/test_text_to_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,16 @@ async def test_text_to_speech():
assert "base64" in data or "http" in data

# test iflytek
key = CONFIG.AZURE_TTS_SUBSCRIPTION_KEY
CONFIG.AZURE_TTS_SUBSCRIPTION_KEY = ""
## Mock session env
old_options = CONFIG.options.copy()
new_options = old_options.copy()
new_options["AZURE_TTS_SUBSCRIPTION_KEY"] = ""
CONFIG.set_context(new_options)
try:
data = await text_to_speech("panda emoji")
assert "base64" in data or "http" in data
finally:
CONFIG.AZURE_TTS_SUBSCRIPTION_KEY = key
CONFIG.set_context(old_options)


if __name__ == "__main__":
Expand Down
26 changes: 23 additions & 3 deletions tests/metagpt/roles/test_architect.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,38 @@
@Modified By: mashenquan, 2023-11-1. In accordance with Chapter 2.2.1 and 2.2.2 of RFC 116, utilize the new message
distribution feature for message handling.
"""
import uuid

import pytest

from metagpt.actions import WriteDesign, WritePRD
from metagpt.config import CONFIG
from metagpt.const import PRDS_FILE_REPO
from metagpt.logs import logger
from metagpt.roles import Architect
from metagpt.schema import Message
from metagpt.utils.common import any_to_str, awrite
from tests.metagpt.roles.mock import MockMessages


@pytest.mark.asyncio
async def test_architect():
# FIXME: make git as env? Or should we support
# Prerequisites
filename = uuid.uuid4().hex + ".json"
await awrite(CONFIG.git_repo.workdir / PRDS_FILE_REPO / filename, data=MockMessages.prd.content)

role = Architect()
role.put_message(MockMessages.req)
rsp = await role.run(MockMessages.prd)
rsp = await role.run(with_message=Message(content="", cause_by=WritePRD))
logger.info(rsp)
assert len(rsp.content) > 0
assert rsp.cause_by == any_to_str(WriteDesign)

# test update
rsp = await role.run(with_message=Message(content="", cause_by=WritePRD))
assert rsp
assert rsp.cause_by == any_to_str(WriteDesign)
assert len(rsp.content) > 0


if __name__ == "__main__":
pytest.main([__file__, "-s"])
19 changes: 12 additions & 7 deletions tests/metagpt/utils/test_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,18 @@ async def test_redis():
assert await conn.get("test") == b"test"
await conn.close()

key = CONFIG.REDIS_HOST
CONFIG.REDIS_HOST = "YOUR_REDIS_HOST"
conn = Redis()
await conn.set("test", "test", timeout_sec=0)
assert not await conn.get("test") == b"test"
CONFIG.REDIS_HOST = key
await conn.close()
# Mock session env
old_options = CONFIG.options.copy()
new_options = old_options.copy()
new_options["REDIS_HOST"] = "YOUR_REDIS_HOST"
CONFIG.set_context(new_options)
try:
conn = Redis()
await conn.set("test", "test", timeout_sec=0)
assert not await conn.get("test") == b"test"
await conn.close()
finally:
CONFIG.set_context(old_options)


if __name__ == "__main__":
Expand Down
13 changes: 8 additions & 5 deletions tests/metagpt/utils/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,18 @@ async def test_s3():
res = await conn.cache(data, ".bak", "script")
assert "http" in res

key = CONFIG.S3_ACCESS_KEY
CONFIG.S3_ACCESS_KEY = "YOUR_S3_ACCESS_KEY"
conn = S3()
assert not conn.is_valid
# Mock session env
old_options = CONFIG.options.copy()
new_options = old_options.copy()
new_options["S3_ACCESS_KEY"] = "YOUR_S3_ACCESS_KEY"
CONFIG.set_context(new_options)
try:
conn = S3()
assert not conn.is_valid
res = await conn.cache("ABC", ".bak", "script")
assert not res
finally:
CONFIG.S3_ACCESS_KEY = key
CONFIG.set_context(old_options)


if __name__ == "__main__":
Expand Down

0 comments on commit b7d74c6

Please sign in to comment.