Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: ORCA Format KV Cache Utilization in Inference Response Header #7839

Open
wants to merge 2 commits into
base: r24.10
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 224 additions & 0 deletions src/http_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3381,6 +3381,7 @@ HTTPAPIServer::HandleGenerate(
request_release_payload.release();
}


TRITONSERVER_Error*
HTTPAPIServer::ModelInputMetadata(
const std::string& model_name, const int64_t model_version,
Expand Down Expand Up @@ -4226,6 +4227,62 @@ HTTPAPIServer::GenerateRequestClass::StartResponse(
return;
}


#ifdef TRITON_ENABLE_METRICS
// logic to add kv_cache metrics to response header
// Get the metrics in Prometheus format

// "ORCA_METRIC_FORMAT" is an environment variable that specifies which load
// report format `endpoint-load-metrics` will be in. If left unset the header
// will not be written and the feature is disabled.
//
// When set, the valid values for "ORCA_METRIC_FORMAT" are:
//
// "http"
// "json"
//
// Any other value will have behavior equivalent to being unset while also
// logging and error.
//
// For specifics on the different formats for the load reporting formats, see:
// https://docs.google.com/document/d/1C1ybMmDKJIVlrbOLbywhu9iRYo4rilR-cT50OTtOFTs/edit?tab=t.0#heading=h.do9yfa1wlpk8
auto server = infer_request->EvHtpServer();
if (std::getenv("ORCA_METRIC_FORMAT") != nullptr && server != nullptr) {
const std::string orca_type = std::getenv("ORCA_METRIC_FORMAT");
TRITONSERVER_Metrics* metrics = nullptr;
TRITONSERVER_Error* err = TRITONSERVER_ServerMetrics(server, &metrics);
if (err == nullptr) {
const char* base;
size_t byte_size;
err = TRITONSERVER_MetricsFormatted(
metrics, TRITONSERVER_METRIC_PROMETHEUS, &base, &byte_size);
if (err == nullptr) {
std::string formatted_metrics(base, byte_size);
// Extract the KV utilization metrics from the Prometheus formatted
// string.
std::string extracted_kv_metrics =
ExtractKVMetrics(formatted_metrics, orca_type);
if (!extracted_kv_metrics.empty()) {
evhtp_headers_add_header(
req->headers_out,
evhtp_header_new(
"endpoint-load-metrics", extracted_kv_metrics.c_str(), 1, 1));
} else {
LOG_ERROR << "ORCA_METRIC_FORMAT is set but extracted_kv_metrics is "
"empty, no header written. orca_type="
<< orca_type;
}
}
} else {
// Handle potential errors
LOG_ERROR << "Failed to get KV metrics: "
<< TRITONSERVER_ErrorMessage(err);
TRITONSERVER_ErrorDelete(err);
}
TRITONSERVER_MetricsDelete(metrics);
}
#endif // TRITON_ENABLE_METRICS

if (infer_request->streaming_) {
AddContentTypeHeader(req, "text/event-stream; charset=utf-8");
} else {
Expand All @@ -4235,6 +4292,173 @@ HTTPAPIServer::GenerateRequestClass::StartResponse(
evhtp_request_resume(req);
}


#ifdef TRITON_ENABLE_METRICS
std::vector<HTTPAPIServer::GenerateRequestClass::PromMetric>
HTTPAPIServer::GenerateRequestClass::MetricFamilyExtractor(
const std::string& input, const std::string& metricFamily)
{
std::vector<PromMetric> metrics;
// Construct the regex pattern using the provided metricFamily.

// `labelGroup` is a capturing group that captures all characters within curly
// braces, excluding line breaks.
std::string labelGroup = "(?:{(.*?)})";

// `valueGroup` is a capturing group that captures a number with its
// decimals if any.
std::string valueGroup = R"((\d+(?:\.\d+)?))";

// `patternStr` matches on lines starting with `metricFamily` then captures
// its labels if any, then (optionally) matches any whitespace, then captures
// its numeric double value.
//
// For example, `patternStr` would match on input:
// `nv_trt_llm_kv_cache_block_metrics{kv_cache_block_type="used",model="tensorrt_llm",version="1"}
// 3`
//
// with 2 capturing groups:
// 1. `kv_cache_block_type="used",model="tensorrt_llm",version="1"`
// 2. `3`
std::string patternStr = metricFamily + labelGroup + R"(?\s*)" + valueGroup;
re2::RE2 pattern(patternStr);
re2::StringPiece inputPiece(input);

std::string labelString;
std::string metric_value;

while (re2::RE2::FindAndConsume(
&inputPiece, pattern, &labelString, &metric_value)) {
PromMetric metric;

// Extract labels if they exist
if (!labelString.empty()) {
// `labelPattern` captures any alphanumeric sequence that precedes an '='
// character, then captures the following quoted character sequence. These
// groups are exahstive given the prometheus data model:
// https://prometheus.io/docs/concepts/data_model/#metric-names-and-labels
//
// For example, calling FindAndConsume() with `labelPattern` on input:
// `kv_cache_block_type="used",model="tensorrt_llm",version="1"`
//
// matches 3 times with 2 capturing groups each:
//
// Match #1
// 1. `kv_cache_block_type`
// 2. `used`
//
// Match #2
// 1. `model`
// 2. `tensorrt_llm`
//
// Match #3
// 1. `version`
// 2. `1`
re2::RE2 labelPattern(R"((\w+)=\"([^\"]*)\")");
re2::StringPiece labelPiece(labelString);
std::string key, value;
while (
re2::RE2::FindAndConsume(&labelPiece, labelPattern, &key, &value)) {
// Populate the metric's labels map
metric.labels[key] = value;
}
}

// Assign the metric its value and add it to the family list
metric.value = stod(metric_value);
metrics.push_back(metric);
}

return metrics;
}

std::string
HTTPAPIServer::GenerateRequestClass::ExtractKVMetrics(
const std::string& prometheus_metrics, const std::string& orca_type)
{
std::string metric_family = "nv_trt_llm_kv_cache_block_metrics";
std::vector<PromMetric> kv_cache_metrics =
MetricFamilyExtractor(prometheus_metrics, metric_family);

double tokens_per_block = -1;
double used_blocks = -1;
double max_blocks = -1;

for (const auto& metric : kv_cache_metrics) {
if (metric.labels.count("kv_cache_block_type") > 0) {
std::string type = metric.labels.at("kv_cache_block_type");
if (type == "tokens_per") {
tokens_per_block = metric.value;
} else if (type == "used") {
used_blocks = metric.value;
} else if (type == "max") {
max_blocks = metric.value;
}
}
}

// Return early if not all kv metrics are found and set.
if (tokens_per_block < 0 || used_blocks < 0 || max_blocks < 0) {
LOG_ERROR << "One or more of the kv metrics was not found or invalid.";
return "";
}

// Calculate derived metrics
double kv_cache_utilization = 0;
if (max_blocks > 0) {
kv_cache_utilization = used_blocks / max_blocks;
}
uint64_t max_token_capacity =
static_cast<uint64_t>(max_blocks * tokens_per_block);

return OrcaKVMetricHeader(
orca_type, kv_cache_utilization, max_token_capacity);
}

std::string
HTTPAPIServer::GenerateRequestClass::OrcaKVMetricHeader(
const std::string& orca_type, const double kv_cache_utilization,
const uint64_t max_token_capacity)
{
// Logic to construct and format response header
std::string header_contents = "";
const std::string named_metrics_key = "named_metrics";
const std::string kv_util_key = "kv_cache_utilization";
const std::string max_token_key = "max_token_capacity";

if (orca_type == "json") {
// Format the metrics according to the ORCA protocol as JSON.
triton::common::TritonJson::Value orca_metrics(
triton::common::TritonJson::ValueType::OBJECT);
triton::common::TritonJson::Value named_metrics(
orca_metrics, triton::common::TritonJson::ValueType::OBJECT);

named_metrics.AddDouble(kv_util_key.c_str(), kv_cache_utilization);
named_metrics.AddUInt(max_token_key.c_str(), max_token_capacity);
orca_metrics.Add(named_metrics_key.c_str(), std::move(named_metrics));

triton::common::TritonJson::WriteBuffer buffer;
orca_metrics.Write(&buffer);
header_contents = std::string("JSON ") + buffer.Contents();

} else if (orca_type == "http") {
// Format the metrics according to the ORCA protocol as Native HTTP
// (comma separated list).
const std::string prefix = named_metrics_key + ".";

header_contents = "TEXT ";
header_contents += prefix + kv_util_key + "=" +
std::to_string(kv_cache_utilization) + ", ";
header_contents +=
prefix + max_token_key + "=" + std::to_string(max_token_capacity);
} else {
LOG_ERROR << "orca_type is set to an invalid type: " << orca_type;
}

return header_contents;
}
#endif // TRITON_ENABLE_METRICS

void
HTTPAPIServer::GenerateRequestClass::ChunkResponseCallback(
evthr_t* thr, void* arg, void* shared)
Expand Down
27 changes: 26 additions & 1 deletion src/http_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,8 @@ class HTTPAPIServer : public HTTPServer {
}
virtual ~GenerateRequestClass();

TRITONSERVER_Server* EvHtpServer() const { return server_; }

// [FIXME] Specialize response complete function for now, should have
// been a dispatcher and call into object specific response function.
static void InferResponseComplete(
Expand Down Expand Up @@ -393,6 +395,12 @@ class HTTPAPIServer : public HTTPServer {
const MappingSchema* ResponseSchema() { return response_schema_; }

private:
#ifdef TRITON_ENABLE_METRICS
struct PromMetric {
std::unordered_map<std::string, std::string> labels;
double value;
};
#endif // TRITON_ENABLE_METRICS
struct TritonOutput {
enum class Type { RESERVED, TENSOR, PARAMETER };
TritonOutput(Type t, const std::string& val) : type(t), value(val) {}
Expand All @@ -403,6 +411,23 @@ class HTTPAPIServer : public HTTPServer {
// TENSOR, PARAMETER type
uint32_t index;
};

#ifdef TRITON_ENABLE_METRICS
// Helper function to get the KV-cache utilization metrics for the
// inference response header
static std::string ExtractKVMetrics(
const std::string& prometheus_metrics, const std::string& orca_type);
// Generates a metric struct for a given family with a map of labels and a
// value
static std::vector<PromMetric> MetricFamilyExtractor(
const std::string& input, const std::string& metricFamily);
// Creates a header string in the the proper reporting format for provided
// KV-cache metrics.
static std::string OrcaKVMetricHeader(
const std::string& reporting_format, const double kv_cache_utilization,
const uint64_t max_token_capacity);
#endif // TRITON_ENABLE_METRICS

TRITONSERVER_Error* ExactMappingInput(
const std::string& name, triton::common::TritonJson::Value& value,
std::map<std::string, triton::common::TritonJson::Value>&
Expand Down Expand Up @@ -455,6 +480,7 @@ class HTTPAPIServer : public HTTPServer {
evbuffer* buffer_ = nullptr;
};


protected:
explicit HTTPAPIServer(
const std::shared_ptr<TRITONSERVER_Server>& server,
Expand Down Expand Up @@ -558,7 +584,6 @@ class HTTPAPIServer : public HTTPServer {
void HandleGenerate(
evhtp_request_t* req, const std::string& model_name,
const std::string& model_version_str, bool streaming);

// 'meta_data_root' is the root JSON document for 'input_metadata'.
// In TritonJson, the Value objects are references to the root document.
// Therefore the document must stay valid.
Expand Down
Loading