Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Huggingface retry generate_image with delay #2745

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
0785547
Retry image_gen.py for huggingface with delay.
primaryobjects Apr 21, 2023
a0c9ade
Formatting.
primaryobjects Apr 21, 2023
afacb5f
isort
primaryobjects Apr 21, 2023
8aae747
Skip retry on error.
primaryobjects Apr 21, 2023
985eb87
Merge branch 'master' into imagegen_delay_retry_huggingface
primaryobjects Apr 21, 2023
e97d589
Merge branch 'master' into imagegen_delay_retry_huggingface
primaryobjects Apr 23, 2023
ebf3404
Merge branch 'master' into imagegen_delay_retry_huggingface
primaryobjects Apr 23, 2023
d2e67f8
Added unit tests.
primaryobjects Apr 23, 2023
ce4c27a
Formatting.
primaryobjects Apr 23, 2023
87427ae
Merge branch 'master' into imagegen_delay_retry_huggingface
primaryobjects Apr 23, 2023
e7edf58
Fix unit test for successful hg.
primaryobjects Apr 23, 2023
7ee9f03
Moved error cases to separate unit test class.
primaryobjects Apr 23, 2023
0183073
fix
primaryobjects Apr 23, 2023
c66a6ca
Merge branch 'master' into imagegen_delay_retry_huggingface
primaryobjects Apr 23, 2023
58a4209
Mock hf api key.
primaryobjects Apr 24, 2023
b593cc8
Merge branch 'master' into imagegen_delay_retry_huggingface
primaryobjects Apr 24, 2023
a9795bc
Merge branch 'master' into imagegen_delay_retry_huggingface
primaryobjects Apr 24, 2023
3be6ffd
Added unit test coverage.
primaryobjects Apr 24, 2023
6501f38
Merge branch 'master' into imagegen_delay_retry_huggingface
primaryobjects Apr 24, 2023
4336694
Merge branch 'master' into imagegen_delay_retry_huggingface
ntindle Apr 25, 2023
34d8dec
Merge branch 'master' into imagegen_delay_retry_huggingface
ntindle May 12, 2023
93c37c3
Update image_gen.py
ntindle May 12, 2023
8a0178d
remove unittest changes
ntindle May 12, 2023
e1dd84d
Merge branch 'master' into imagegen_delay_retry_huggingface
ntindle May 13, 2023
2d0ec19
Merge branch 'master' into imagegen_delay_retry_huggingface
ntindle May 13, 2023
4b8a5b6
Merge branch 'master' into imagegen_delay_retry_huggingface
primaryobjects May 14, 2023
147c6f0
Merge branch 'master' into imagegen_delay_retry_huggingface
primaryobjects May 14, 2023
6aca77e
Merge branch 'master' into imagegen_delay_retry_huggingface
ntindle May 14, 2023
981e7c5
Merge branch 'master' into imagegen_delay_retry_huggingface
k-boikov May 14, 2023
7844a8c
feat: tests from @rihp
ntindle May 14, 2023
54a4752
Merge branch 'master' into imagegen_delay_retry_huggingface
k-boikov May 15, 2023
dd788ea
isort
primaryobjects May 15, 2023
d73b4d5
Merge branch 'master' into imagegen_delay_retry_huggingface
primaryobjects May 15, 2023
03cb399
Merge branch 'master' into imagegen_delay_retry_huggingface
waynehamadi May 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 37 additions & 13 deletions autogpt/commands/image_gen.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
""" Image Generation Module for AutoGPT."""
import io
import json
import time
import uuid
from base64 import b64decode

Expand Down Expand Up @@ -61,20 +63,42 @@ def generate_image_with_hf(prompt: str, filename: str) -> str:
"X-Use-Cache": "false",
}

response = requests.post(
API_URL,
headers=headers,
json={
"inputs": prompt,
},
)

image = Image.open(io.BytesIO(response.content))
logger.info(f"Image Generated for prompt:{prompt}")

image.save(filename)
retry_count = 0
while retry_count < 10:
response = requests.post(
API_URL,
headers=headers,
json={
"inputs": prompt,
},
)

return f"Saved to disk:{filename}"
if response.ok:
try:
image = Image.open(io.BytesIO(response.content))
logger.info(f"Image Generated for prompt:{prompt}")
image.save(filename)
return f"Saved to disk:{filename}"
except Exception as e:
logger.error(e)
break
else:
try:
error = json.loads(response.text)
if "estimated_time" in error:
ntindle marked this conversation as resolved.
Show resolved Hide resolved
delay = error["estimated_time"]
logger.debug(response.text)
logger.info("Retrying in", delay)
time.sleep(delay)
ntindle marked this conversation as resolved.
Show resolved Hide resolved
else:
break
except Exception as e:
logger.error(e)
break

retry_count += 1

return f"Error creating image."


def generate_image_with_dalle(prompt: str, filename: str, size: int) -> str:
Expand Down
112 changes: 108 additions & 4 deletions tests/test_image_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from PIL import Image

from autogpt.commands.image_gen import generate_image, generate_image_with_sd_webui
from autogpt.config import Config
from tests.utils import requires_api_key


Expand All @@ -19,7 +20,7 @@ def image_size(request):
reason="The image is too big to be put in a cassette for a CI pipeline. We're looking into a solution."
)
@requires_api_key("OPENAI_API_KEY")
def test_dalle(config, workspace, image_size, patched_api_requestor):
def test_dalle(config, workspace, image_size):
"""Test DALL-E image generation."""
generate_and_validate(
config,
Expand Down Expand Up @@ -48,18 +49,18 @@ def test_huggingface(config, workspace, image_size, image_model):
)


@pytest.mark.skip(reason="External SD WebUI may not be available.")
@pytest.mark.xfail(reason="SD WebUI call does not work.")
def test_sd_webui(config, workspace, image_size):
"""Test SD WebUI image generation."""
generate_and_validate(
config,
workspace,
image_provider="sdwebui",
image_provider="sd_webui",
image_size=image_size,
)


@pytest.mark.skip(reason="External SD WebUI may not be available.")
@pytest.mark.xfail(reason="SD WebUI call does not work.")
def test_sd_webui_negative_prompt(config, workspace, image_size):
gen_image = functools.partial(
generate_image_with_sd_webui,
Expand Down Expand Up @@ -103,3 +104,106 @@ def generate_and_validate(
assert image_path.exists()
with Image.open(image_path) as img:
assert img.size == (image_size, image_size)


def test_huggingface_fail_request_with_delay(mocker):
config = Config()
config.huggingface_api_token = "1"

# Mock requests.post
mock_post = mocker.patch("requests.post")
mock_post.return_value.status_code = 500
mock_post.return_value.ok = False
mock_post.return_value.text = '{"error":"Model CompVis/stable-diffusion-v1-4 is currently loading","estimated_time":0}'

# Mock time.sleep
mock_sleep = mocker.patch("time.sleep")

config.image_provider = "huggingface"
config.huggingface_image_model = "CompVis/stable-diffusion-v1-4"

result = generate_image("astronaut riding a horse", 512)

assert result == "Error creating image."

# Verify retry was called with delay.
mock_sleep.assert_called_with(0)


def test_huggingface_fail_request_no_delay(mocker):
config = Config()
config.huggingface_api_token = "1"

# Mock requests.post
mock_post = mocker.patch("requests.post")
mock_post.return_value.status_code = 500
mock_post.return_value.ok = False
mock_post.return_value.text = (
'{"error":"Model CompVis/stable-diffusion-v1-4 is currently loading"}'
)

# Mock time.sleep
mock_sleep = mocker.patch("time.sleep")

config.image_provider = "huggingface"
config.huggingface_image_model = "CompVis/stable-diffusion-v1-4"

result = generate_image("astronaut riding a horse", 512)

assert result == "Error creating image."

# Verify retry was not called.
mock_sleep.assert_not_called()


def test_huggingface_fail_request_bad_json(mocker):
config = Config()
config.huggingface_api_token = "1"

# Mock requests.post
mock_post = mocker.patch("requests.post")
mock_post.return_value.status_code = 500
mock_post.return_value.ok = False
mock_post.return_value.text = '{"error:}'

# Mock time.sleep
mock_sleep = mocker.patch("time.sleep")

config.image_provider = "huggingface"
config.huggingface_image_model = "CompVis/stable-diffusion-v1-4"

result = generate_image("astronaut riding a horse", 512)

assert result == "Error creating image."

# Verify retry was not called.
mock_sleep.assert_not_called()


def test_huggingface_fail_request_bad_image(mocker):
config = Config()
config.huggingface_api_token = "1"

# Mock requests.post
mock_post = mocker.patch("requests.post")
mock_post.return_value.status_code = 200

config.image_provider = "huggingface"
config.huggingface_image_model = "CompVis/stable-diffusion-v1-4"

result = generate_image("astronaut riding a horse", 512)

assert result == "Error creating image."


def test_huggingface_fail_missing_api_token(mocker):
config = Config()
config.image_provider = "huggingface"
config.huggingface_image_model = "CompVis/stable-diffusion-v1-4"

# Mock requests.post to raise ValueError
mock_post = mocker.patch("requests.post", side_effect=ValueError)

# Verify request raises an error.
with pytest.raises(ValueError):
generate_image("astronaut riding a horse", 512)