From 2e7e223ddf881e689b37a070b545dd09331ec1cd Mon Sep 17 00:00:00 2001 From: goriri Date: Mon, 15 Jul 2024 02:21:15 +0000 Subject: [PATCH 1/2] Fix example server and add use gpu instruction in readme --- README.md | 2 + examples/server/CMakeLists.txt | 2 +- examples/server/server.cpp | 157 +++++++++++++++++++++++---------- 3 files changed, 113 insertions(+), 48 deletions(-) diff --git a/README.md b/README.md index 699b301..968af1e 100644 --- a/README.md +++ b/README.md @@ -99,6 +99,8 @@ In order to build bark.cpp you must use `CMake`: ```bash mkdir build cd build +# To enable nvidia gpu, use the following option +# cmake -DGGML_CUBLAS=ON .. cmake .. cmake --build . --config Release ``` diff --git a/examples/server/CMakeLists.txt b/examples/server/CMakeLists.txt index f0000a6..457547d 100644 --- a/examples/server/CMakeLists.txt +++ b/examples/server/CMakeLists.txt @@ -2,7 +2,7 @@ set(TARGET server) add_executable(${TARGET} server.cpp httplib.h json.hpp) install(TARGETS ${TARGET} RUNTIME) -target_link_libraries(${TARGET} PRIVATE bark ${CMAKE_THREAD_LIBS_INIT}) +target_link_libraries(${TARGET} PRIVATE bark common ${CMAKE_THREAD_LIBS_INIT}) if (WIN32) target_link_libraries(${TARGET} PRIVATE ws2_32) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 2a2e6f7..957055b 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -9,9 +9,10 @@ #include "bark.h" #include "httplib.h" #include "json.hpp" +#include "common.h" #if defined(_MSC_VER) -#pragma warning(disable : 4244 4267) // possible loss of data +#pragma warning(disable : 4244 4267) // possible loss of data #endif #ifndef SERVER_VERBOSE @@ -21,7 +22,8 @@ using namespace httplib; using json = nlohmann::json; -struct server_params { +struct server_params +{ std::string hostname = "127.0.0.1"; std::string public_path = "examples/server/public"; int32_t port = 1337; @@ -29,20 +31,21 @@ struct server_params { int32_t write_timeout = 600; }; -struct bark_params { - int32_t n_threads = std::min(1, static_cast(std::thread::hardware_concurrency())); +// struct bark_params { +// int32_t n_threads = std::min(1, static_cast(std::thread::hardware_concurrency())); - // user prompt - std::string prompt = "this is an audio"; +// // user prompt +// std::string prompt = "this is an audio"; - // paths - std::string model_path = "./ggml_weights"; +// // paths +// std::string model_path = "./ggml_weights"; - int32_t seed = 0; - server_params sparams; -}; +// int32_t seed = 0; +// server_params sparams; +// }; -void bark_print_usage(char **argv, const bark_params ¶ms) { +void bark_print_usage(char **argv, const bark_params ¶ms, const server_params &server_params) +{ fprintf(stderr, "usage: %s [options]\n", argv[0]); fprintf(stderr, "\n"); fprintf(stderr, "options:\n"); @@ -54,51 +57,93 @@ void bark_print_usage(char **argv, const bark_params ¶ms) { fprintf(stderr, "\n"); } -void bark_params_parse(int argc, char **argv, bark_params ¶ms) { +void bark_params_parse(int argc, char **argv, bark_params ¶ms, server_params &server_params) +{ bool model_req = false; - for (int i = 1; i < argc; i++) { + for (int i = 1; i < argc; i++) + { std::string arg = argv[i]; - if (arg == "-t" || arg == "--threads") { + if (arg == "-t" || arg == "--threads") + { params.n_threads = std::stoi(argv[++i]); - } else if (arg == "-m" || arg == "--model") { + } + else if (arg == "-m" || arg == "--model") + { params.model_path = argv[++i]; model_req = true; - } else if (arg == "-t" || arg == "--thread") { + } + else if (arg == "-t" || arg == "--thread") + { params.n_threads = std::stoi(argv[++i]); - } else if (arg == "-p" || arg == "--port") { - params.sparams.port = std::stoi(argv[++i]); - } else if (arg == "-a" || arg == "--address") { - params.sparams.hostname = argv[++i]; - } else if (arg == "-h" || arg == "--help") { - bark_print_usage(argv, params); + } + else if (arg == "-p" || arg == "--port") + { + server_params.port = std::stoi(argv[++i]); + } + else if (arg == "-a" || arg == "--address") + { + server_params.hostname = argv[++i]; + } + else if (arg == "-h" || arg == "--help") + { + bark_print_usage(argv, params, server_params); exit(0); - } else { + } + else + { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); - bark_print_usage(argv, params); + bark_print_usage(argv, params, server_params); exit(1); } } - if (!model_req) { + if (!model_req) + { fprintf(stderr, "error: no model path specified\n"); - bark_print_usage(argv, params); + bark_print_usage(argv, params, server_params); exit(1); } } -int main(int argc, char **argv) { +bool generate_audio(int n_threads, bark_context *bctx, std::string text, std::string dest) +{ + if (!bark_generate_audio(bctx, text.c_str(), n_threads)) + { + fprintf(stderr, "%s: An error occured. If the problem persists, feel free to open an issue to report it.\n", __func__); + return false; + } + + const float *audio_data = bark_get_audio_data(bctx); + if (audio_data == NULL) + { + fprintf(stderr, "%s: Could not get audio data\n", __func__); + return false; + } + + const int audio_arr_size = bark_get_audio_data_size(bctx); + + std::vector audio_arr(audio_data, audio_data + audio_arr_size); + + write_wav_on_disk(audio_arr, dest); + return true; +} + +int main(int argc, char **argv) +{ ggml_time_init(); const int64_t t_main_start_us = ggml_time_us(); bark_params params; + server_params server_params; bark_verbosity_level verbosity = bark_verbosity_level::LOW; - bark_params_parse(argc, argv, params); + bark_params_parse(argc, argv, params, server_params); struct bark_context_params ctx_params = bark_context_default_params(); ctx_params.verbosity = verbosity; struct bark_context *bctx = bark_load_model(params.model_path.c_str(), ctx_params, params.seed); - if (!bctx) { + if (!bctx) + { fprintf(stderr, "%s: Could not load model\n", __func__); return 1; } @@ -112,26 +157,43 @@ int main(int argc, char **argv) { std::string default_content = "hello"; // this is only called if no index.html is found in the public --path - svr.Get("/", [&default_content](const Request &, Response &res) { + svr.Get("/", [&default_content](const Request &, Response &res) + { res.set_content(default_content.c_str(), default_content.size(), "text/html"); - return false; - }); + return false; }); - svr.Post("/bark", [&](const Request &req, Response &res) { + svr.Post("/bark", [&](const Request &req, Response &res) + { // aquire bark model mutex lock bark_mutex.lock(); json jreq = json::parse(req.body); std::string text = jreq.at("text"); - // generate audio std::string dest_wav_path = "/tmp/bark_tmp.wav"; - bark_generate_audio(bctx, text.c_str(), params.n_threads); + // generate audio + // if (!bark_generate_audio(bctx, text.c_str(), params.n_threads)) { + // fprintf(stderr, "%s: An error occured. If the problem persists, feel free to open an issue to report it.\n", __func__); + // exit(1); + // } + + // const float *audio_data = bark_get_audio_data(bctx); + // if (audio_data == NULL) { + // fprintf(stderr, "%s: Could not get audio data\n", __func__); + // exit(1); + // } + + // const int audio_arr_size = bark_get_audio_data_size(bctx); + + // std::vector audio_arr(audio_data, audio_data + audio_arr_size); + + // write_wav_on_disk(audio_arr, dest_wav_path); + bool generated = generate_audio(params.n_threads, bctx, text, dest_wav_path); // read audio as binary - std::ifstream wav_file("/tmp/bark_tmp.wav", std::ios::binary); + std::ifstream wav_file(dest_wav_path, std::ios::binary); - if (wav_file.is_open()) { + if (generated && wav_file.is_open()) { // Read the contents of the WAV file std::string wav_contents((std::istreambuf_iterator(wav_file)), std::istreambuf_iterator()); @@ -151,26 +213,27 @@ int main(int argc, char **argv) { std::remove("/tmp/bark_tmp.wav"); // return bark model mutex lock - bark_mutex.unlock(); - }); + bark_mutex.unlock(); }); - svr.set_read_timeout(params.sparams.read_timeout); - svr.set_write_timeout(params.sparams.write_timeout); + svr.set_read_timeout(server_params.read_timeout); + svr.set_write_timeout(server_params.write_timeout); - if (!svr.bind_to_port(params.sparams.hostname, params.sparams.port)) { + if (!svr.bind_to_port(server_params.hostname, server_params.port)) + { fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", - params.sparams.hostname.c_str(), params.sparams.port); + server_params.hostname.c_str(), server_params.port); return 1; } // Set the base directory for serving static files - svr.set_base_dir(params.sparams.public_path); + svr.set_base_dir(server_params.public_path); // to make it ctrl+clickable: printf("\nbark server listening at http://%s:%d\n\n", - params.sparams.hostname.c_str(), params.sparams.port); + server_params.hostname.c_str(), server_params.port); - if (!svr.listen_after_bind()) { + if (!svr.listen_after_bind()) + { return 1; } From f202eb1a8143f692f838eb55e61dd7de29c057ef Mon Sep 17 00:00:00 2001 From: goriri Date: Wed, 17 Jul 2024 12:11:54 +0800 Subject: [PATCH 2/2] rm unused comments and change formatting --- examples/server/server.cpp | 103 +++++++++---------------------------- 1 file changed, 24 insertions(+), 79 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 957055b..268a92a 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -7,12 +7,12 @@ #include #include "bark.h" +#include "common.h" #include "httplib.h" #include "json.hpp" -#include "common.h" #if defined(_MSC_VER) -#pragma warning(disable : 4244 4267) // possible loss of data +#pragma warning(disable : 4244 4267) // possible loss of data #endif #ifndef SERVER_VERBOSE @@ -22,8 +22,7 @@ using namespace httplib; using json = nlohmann::json; -struct server_params -{ +struct server_params { std::string hostname = "127.0.0.1"; std::string public_path = "examples/server/public"; int32_t port = 1337; @@ -31,21 +30,7 @@ struct server_params int32_t write_timeout = 600; }; -// struct bark_params { -// int32_t n_threads = std::min(1, static_cast(std::thread::hardware_concurrency())); - -// // user prompt -// std::string prompt = "this is an audio"; - -// // paths -// std::string model_path = "./ggml_weights"; - -// int32_t seed = 0; -// server_params sparams; -// }; - -void bark_print_usage(char **argv, const bark_params ¶ms, const server_params &server_params) -{ +void bark_print_usage(char **argv, const bark_params ¶ms, const server_params &server_params) { fprintf(stderr, "usage: %s [options]\n", argv[0]); fprintf(stderr, "\n"); fprintf(stderr, "options:\n"); @@ -57,64 +42,45 @@ void bark_print_usage(char **argv, const bark_params ¶ms, const server_param fprintf(stderr, "\n"); } -void bark_params_parse(int argc, char **argv, bark_params ¶ms, server_params &server_params) -{ +void bark_params_parse(int argc, char **argv, bark_params ¶ms, server_params &server_params) { bool model_req = false; - for (int i = 1; i < argc; i++) - { + for (int i = 1; i < argc; i++) { std::string arg = argv[i]; - if (arg == "-t" || arg == "--threads") - { + if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); - } - else if (arg == "-m" || arg == "--model") - { + } else if (arg == "-m" || arg == "--model") { params.model_path = argv[++i]; model_req = true; - } - else if (arg == "-t" || arg == "--thread") - { + } else if (arg == "-t" || arg == "--thread") { params.n_threads = std::stoi(argv[++i]); - } - else if (arg == "-p" || arg == "--port") - { + } else if (arg == "-p" || arg == "--port") { server_params.port = std::stoi(argv[++i]); - } - else if (arg == "-a" || arg == "--address") - { + } else if (arg == "-a" || arg == "--address") { server_params.hostname = argv[++i]; - } - else if (arg == "-h" || arg == "--help") - { + } else if (arg == "-h" || arg == "--help") { bark_print_usage(argv, params, server_params); exit(0); - } - else - { + } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); bark_print_usage(argv, params, server_params); exit(1); } } - if (!model_req) - { + if (!model_req) { fprintf(stderr, "error: no model path specified\n"); bark_print_usage(argv, params, server_params); exit(1); } } -bool generate_audio(int n_threads, bark_context *bctx, std::string text, std::string dest) -{ - if (!bark_generate_audio(bctx, text.c_str(), n_threads)) - { +bool generate_audio(int n_threads, bark_context *bctx, std::string text, std::string dest) { + if (!bark_generate_audio(bctx, text.c_str(), n_threads)) { fprintf(stderr, "%s: An error occured. If the problem persists, feel free to open an issue to report it.\n", __func__); return false; } const float *audio_data = bark_get_audio_data(bctx); - if (audio_data == NULL) - { + if (audio_data == NULL) { fprintf(stderr, "%s: Could not get audio data\n", __func__); return false; } @@ -127,8 +93,7 @@ bool generate_audio(int n_threads, bark_context *bctx, std::string text, std::st return true; } -int main(int argc, char **argv) -{ +int main(int argc, char **argv) { ggml_time_init(); const int64_t t_main_start_us = ggml_time_us(); @@ -142,8 +107,7 @@ int main(int argc, char **argv) ctx_params.verbosity = verbosity; struct bark_context *bctx = bark_load_model(params.model_path.c_str(), ctx_params, params.seed); - if (!bctx) - { + if (!bctx) { fprintf(stderr, "%s: Could not load model\n", __func__); return 1; } @@ -157,13 +121,11 @@ int main(int argc, char **argv) std::string default_content = "hello"; // this is only called if no index.html is found in the public --path - svr.Get("/", [&default_content](const Request &, Response &res) - { + svr.Get("/", [&default_content](const Request &, Response &res) { res.set_content(default_content.c_str(), default_content.size(), "text/html"); return false; }); - svr.Post("/bark", [&](const Request &req, Response &res) - { + svr.Post("/bark", [&](const Request &req, Response &res) { // aquire bark model mutex lock bark_mutex.lock(); @@ -171,23 +133,8 @@ int main(int argc, char **argv) std::string text = jreq.at("text"); std::string dest_wav_path = "/tmp/bark_tmp.wav"; + // generate audio - // if (!bark_generate_audio(bctx, text.c_str(), params.n_threads)) { - // fprintf(stderr, "%s: An error occured. If the problem persists, feel free to open an issue to report it.\n", __func__); - // exit(1); - // } - - // const float *audio_data = bark_get_audio_data(bctx); - // if (audio_data == NULL) { - // fprintf(stderr, "%s: Could not get audio data\n", __func__); - // exit(1); - // } - - // const int audio_arr_size = bark_get_audio_data_size(bctx); - - // std::vector audio_arr(audio_data, audio_data + audio_arr_size); - - // write_wav_on_disk(audio_arr, dest_wav_path); bool generated = generate_audio(params.n_threads, bctx, text, dest_wav_path); // read audio as binary @@ -218,8 +165,7 @@ int main(int argc, char **argv) svr.set_read_timeout(server_params.read_timeout); svr.set_write_timeout(server_params.write_timeout); - if (!svr.bind_to_port(server_params.hostname, server_params.port)) - { + if (!svr.bind_to_port(server_params.hostname, server_params.port)) { fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", server_params.hostname.c_str(), server_params.port); return 1; @@ -232,8 +178,7 @@ int main(int argc, char **argv) printf("\nbark server listening at http://%s:%d\n\n", server_params.hostname.c_str(), server_params.port); - if (!svr.listen_after_bind()) - { + if (!svr.listen_after_bind()) { return 1; }