Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix example server and add use gpu instruction in readme #185

Merged
merged 2 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ In order to build bark.cpp you must use `CMake`:
```bash
mkdir build
cd build
# To enable nvidia gpu, use the following option
# cmake -DGGML_CUBLAS=ON ..
cmake ..
cmake --build . --config Release
```
Expand Down
2 changes: 1 addition & 1 deletion examples/server/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ set(TARGET server)
add_executable(${TARGET} server.cpp httplib.h json.hpp)

install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE bark ${CMAKE_THREAD_LIBS_INIT})
target_link_libraries(${TARGET} PRIVATE bark common ${CMAKE_THREAD_LIBS_INIT})

if (WIN32)
target_link_libraries(${TARGET} PRIVATE ws2_32)
Expand Down
78 changes: 43 additions & 35 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <vector>

#include "bark.h"
#include "common.h"
#include "httplib.h"
#include "json.hpp"

Expand All @@ -29,20 +30,7 @@ struct server_params {
int32_t write_timeout = 600;
};

struct bark_params {
int32_t n_threads = std::min(1, static_cast<int32_t>(std::thread::hardware_concurrency()));

// user prompt
std::string prompt = "this is an audio";

// paths
std::string model_path = "./ggml_weights";

int32_t seed = 0;
server_params sparams;
};

void bark_print_usage(char **argv, const bark_params &params) {
void bark_print_usage(char **argv, const bark_params &params, const server_params &server_params) {
fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
Expand All @@ -54,7 +42,7 @@ void bark_print_usage(char **argv, const bark_params &params) {
fprintf(stderr, "\n");
}

void bark_params_parse(int argc, char **argv, bark_params &params) {
void bark_params_parse(int argc, char **argv, bark_params &params, server_params &server_params) {
bool model_req = false;
for (int i = 1; i < argc; i++) {
std::string arg = argv[i];
Expand All @@ -66,33 +54,54 @@ void bark_params_parse(int argc, char **argv, bark_params &params) {
} else if (arg == "-t" || arg == "--thread") {
params.n_threads = std::stoi(argv[++i]);
} else if (arg == "-p" || arg == "--port") {
params.sparams.port = std::stoi(argv[++i]);
server_params.port = std::stoi(argv[++i]);
} else if (arg == "-a" || arg == "--address") {
params.sparams.hostname = argv[++i];
server_params.hostname = argv[++i];
} else if (arg == "-h" || arg == "--help") {
bark_print_usage(argv, params);
bark_print_usage(argv, params, server_params);
exit(0);
} else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
bark_print_usage(argv, params);
bark_print_usage(argv, params, server_params);
exit(1);
}
}
if (!model_req) {
fprintf(stderr, "error: no model path specified\n");
bark_print_usage(argv, params);
bark_print_usage(argv, params, server_params);
exit(1);
}
}

bool generate_audio(int n_threads, bark_context *bctx, std::string text, std::string dest) {
if (!bark_generate_audio(bctx, text.c_str(), n_threads)) {
fprintf(stderr, "%s: An error occured. If the problem persists, feel free to open an issue to report it.\n", __func__);
return false;
}

const float *audio_data = bark_get_audio_data(bctx);
if (audio_data == NULL) {
fprintf(stderr, "%s: Could not get audio data\n", __func__);
return false;
}

const int audio_arr_size = bark_get_audio_data_size(bctx);

std::vector<float> audio_arr(audio_data, audio_data + audio_arr_size);

write_wav_on_disk(audio_arr, dest);
return true;
}

int main(int argc, char **argv) {
ggml_time_init();
const int64_t t_main_start_us = ggml_time_us();

bark_params params;
server_params server_params;
bark_verbosity_level verbosity = bark_verbosity_level::LOW;

bark_params_parse(argc, argv, params);
bark_params_parse(argc, argv, params, server_params);

struct bark_context_params ctx_params = bark_context_default_params();
ctx_params.verbosity = verbosity;
Expand All @@ -114,8 +123,7 @@ int main(int argc, char **argv) {
// this is only called if no index.html is found in the public --path
svr.Get("/", [&default_content](const Request &, Response &res) {
res.set_content(default_content.c_str(), default_content.size(), "text/html");
return false;
});
return false; });

svr.Post("/bark", [&](const Request &req, Response &res) {
// aquire bark model mutex lock
Expand All @@ -124,14 +132,15 @@ int main(int argc, char **argv) {
json jreq = json::parse(req.body);
std::string text = jreq.at("text");

// generate audio
std::string dest_wav_path = "/tmp/bark_tmp.wav";
bark_generate_audio(bctx, text.c_str(), params.n_threads);

// generate audio
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rm comments

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed all these as requested.

bool generated = generate_audio(params.n_threads, bctx, text, dest_wav_path);

// read audio as binary
std::ifstream wav_file("/tmp/bark_tmp.wav", std::ios::binary);
std::ifstream wav_file(dest_wav_path, std::ios::binary);

if (wav_file.is_open()) {
if (generated && wav_file.is_open()) {
// Read the contents of the WAV file
std::string wav_contents((std::istreambuf_iterator<char>(wav_file)),
std::istreambuf_iterator<char>());
Expand All @@ -151,24 +160,23 @@ int main(int argc, char **argv) {
std::remove("/tmp/bark_tmp.wav");

// return bark model mutex lock
bark_mutex.unlock();
});
bark_mutex.unlock(); });

svr.set_read_timeout(params.sparams.read_timeout);
svr.set_write_timeout(params.sparams.write_timeout);
svr.set_read_timeout(server_params.read_timeout);
svr.set_write_timeout(server_params.write_timeout);

if (!svr.bind_to_port(params.sparams.hostname, params.sparams.port)) {
if (!svr.bind_to_port(server_params.hostname, server_params.port)) {
fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n",
params.sparams.hostname.c_str(), params.sparams.port);
server_params.hostname.c_str(), server_params.port);
return 1;
}

// Set the base directory for serving static files
svr.set_base_dir(params.sparams.public_path);
svr.set_base_dir(server_params.public_path);

// to make it ctrl+clickable:
printf("\nbark server listening at http://%s:%d\n\n",
params.sparams.hostname.c_str(), params.sparams.port);
server_params.hostname.c_str(), server_params.port);

if (!svr.listen_after_bind()) {
return 1;
Expand Down
Loading