Skip to content

Commit

Permalink
Merge pull request #268 from NexaAI/zack-dev
Browse files Browse the repository at this point in the history
Disable bark.cpp
  • Loading branch information
zhiyuan8 authored Nov 21, 2024
2 parents 5538d0b + 1e435f2 commit 6509ecc
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 24 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/generate-index-from-release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name: Wheels Index
on:
# Trigger on new release
workflow_run:
workflows: ["Release", "Build Wheels (CUDA)", "Build Wheels (Metal)", "Build Wheels (ROCm)", "Build Wheels (Vulkan)"]
workflows: ["Build Wheels (CPU)", "Build Wheels (CUDA)", "Build Wheels (Metal)", "Build Wheels (ROCm)", "Build Wheels (Vulkan)"]
types:
- completed

Expand Down
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@ if(LLAMA_BUILD)
endif()

# bark_cpp project
option(BARK_BUILD "Build bark.cpp" ON)
# Temporarily disabled since version v0.0.9.3
option(BARK_BUILD "Build bark.cpp" OFF)
if(BARK_BUILD)
# Filter out HIPBLAS and Vulkan options for bark.cpp since it doesn't support them
set(BARK_CMAKE_OPTIONS ${USER_DEFINED_OPTIONS})
Expand Down
6 changes: 4 additions & 2 deletions nexa/cli/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,10 @@ def run_ggml_inference(args):
from nexa.gguf.nexa_inference_voice import NexaVoiceInference
inference = NexaVoiceInference(model_path=model_path, local_path=local_path, **kwargs)
elif run_type == "TTS":
from nexa.gguf.nexa_inference_tts import NexaTTSInference
inference = NexaTTSInference(model_path=model_path, local_path=local_path, **kwargs)
# # Temporarily disabled since version v0.0.9.3
raise NotImplementedError("TTS model is not supported in CLI mode.")
# from nexa.gguf.nexa_inference_tts import NexaTTSInference
# inference = NexaTTSInference(model_path=model_path, local_path=local_path, **kwargs)
elif run_type == "AudioLM":
from nexa.gguf.nexa_inference_audio_lm import NexaAudioLMInference
inference = NexaAudioLMInference(model_path=model_path, local_path=local_path, **kwargs)
Expand Down
6 changes: 4 additions & 2 deletions nexa/gguf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
from .nexa_inference_text import NexaTextInference
from .nexa_inference_vlm import NexaVLMInference
from .nexa_inference_voice import NexaVoiceInference
from .nexa_inference_tts import NexaTTSInference

# Temporarily disabled since version v0.0.9.3
# from .nexa_inference_tts import NexaTTSInference

__all__ = [
"NexaImageInference",
"NexaTextInference",
"NexaVLMInference",
"NexaVoiceInference",
"NexaTTSInference",
#"NexaTTSInference",
"NexaAudioLMInference"
]
38 changes: 20 additions & 18 deletions tests/test_tts_generation.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
from nexa.gguf import NexaTTSInference
# Temporarily disabled since version v0.0.9.3

def test_tts_generation():
tts = NexaTTSInference(
model_path="bark-small",
local_path=None,
n_threads=4,
seed=42,
sampling_rate=24000,
verbosity=2
)
# from nexa.gguf import NexaTTSInference

# def test_tts_generation():
# tts = NexaTTSInference(
# model_path="bark-small",
# local_path=None,
# n_threads=4,
# seed=42,
# sampling_rate=24000,
# verbosity=2
# )

# Generate audio from prompt
prompt = "Hello, this is a test of the Bark text to speech system."
audio_data = tts.audio_generation(prompt)
# # Generate audio from prompt
# prompt = "Hello, this is a test of the Bark text to speech system."
# audio_data = tts.audio_generation(prompt)

# Save the generated audio
tts._save_audio(audio_data, tts.sampling_rate, "tts_output")
print("TTS generation test completed successfully!")
# # Save the generated audio
# tts._save_audio(audio_data, tts.sampling_rate, "tts_output")
# print("TTS generation test completed successfully!")

if __name__ == "__main__":
test_tts_generation()
# if __name__ == "__main__":
# test_tts_generation()

0 comments on commit 6509ecc

Please sign in to comment.