Skip to content

Commit

Permalink
fix(block): Updated model_version to prevent conflicts with pydantic …
Browse files Browse the repository at this point in the history
…naming (#8729)

changed model_version name to avoid conflicts
  • Loading branch information
Swiftyos authored Nov 20, 2024
1 parent 5fa5b71 commit d84ddfc
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions autogpt_platform/backend/backend/blocks/ai_music_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class Input(BlockSchema):
placeholder="e.g., 'An upbeat electronic dance track with heavy bass'",
title="Prompt",
)
model_version: MusicGenModelVersion = SchemaField(
music_gen_model_version: MusicGenModelVersion = SchemaField(
description="Model to use for generation",
default=MusicGenModelVersion.STEREO_LARGE,
title="Model Version",
Expand Down Expand Up @@ -118,7 +118,7 @@ def __init__(self):
test_input={
"credentials": TEST_CREDENTIALS_INPUT,
"prompt": "An upbeat electronic dance track with heavy bass",
"model_version": MusicGenModelVersion.STEREO_LARGE,
"music_gen_model_version": MusicGenModelVersion.STEREO_LARGE,
"duration": 8,
"temperature": 1.0,
"top_k": 250,
Expand All @@ -134,7 +134,7 @@ def __init__(self):
),
],
test_mock={
"run_model": lambda api_key, model_version, prompt, duration, temperature, top_k, top_p, classifier_free_guidance, output_format, normalization_strategy: "https://replicate.com/output/generated-audio-url.wav",
"run_model": lambda api_key, music_gen_model_version, prompt, duration, temperature, top_k, top_p, classifier_free_guidance, output_format, normalization_strategy: "https://replicate.com/output/generated-audio-url.wav",
},
test_credentials=TEST_CREDENTIALS,
)
Expand All @@ -153,7 +153,7 @@ def run(
)
result = self.run_model(
api_key=credentials.api_key,
model_version=input_data.model_version,
music_gen_model_version=input_data.music_gen_model_version,
prompt=input_data.prompt,
duration=input_data.duration,
temperature=input_data.temperature,
Expand Down Expand Up @@ -182,7 +182,7 @@ def run(
def run_model(
self,
api_key: SecretStr,
model_version: MusicGenModelVersion,
music_gen_model_version: MusicGenModelVersion,
prompt: str,
duration: int,
temperature: float,
Expand All @@ -200,7 +200,7 @@ def run_model(
"meta/musicgen:671ac645ce5e552cc63a54a2bbff63fcf798043055d2dac5fc9e36a837eedcfb",
input={
"prompt": prompt,
"model_version": model_version,
"music_gen_model_version": music_gen_model_version,
"duration": duration,
"temperature": temperature,
"top_k": top_k,
Expand Down

0 comments on commit d84ddfc

Please sign in to comment.