Skip to content

Commit

Permalink
Made regex matching cleaner for orca kv_cache metrics.
Browse files Browse the repository at this point in the history
  • Loading branch information
BenjaminBraunDev committed Dec 9, 2024
1 parent ef2f7c6 commit 43a1b18
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 26 deletions.
91 changes: 65 additions & 26 deletions src/http_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3239,9 +3239,11 @@ HTTPAPIServer::HandleGenerate(
std::string kv_utilization(base, byte_size);
// Extract the KV utilization metrics from the Prometheus formatted string.
std::string extracted_kv_metrics = ExtractKVMetrics(kv_utilization);
evhtp_headers_add_header(
req->headers_out,
evhtp_header_new("endpoint-load-metrics", extracted_kv_metrics.c_str(), 1, 1));
if (!extracted_kv_metrics.empty()) {
evhtp_headers_add_header(
req->headers_out,
evhtp_header_new("endpoint-load-metrics", extracted_kv_metrics.c_str(), 1, 1));
}
}
}
TRITONSERVER_MetricsDelete(metrics);
Expand Down Expand Up @@ -3407,38 +3409,74 @@ HTTPAPIServer::HandleGenerate(
request_release_payload.release();
}

std::string HTTPAPIServer::ExtractKVMetrics(
const std::string& prometheus_metrics) {
uint64_t tokens_per_block = 0;
uint64_t used_blocks = 0;
uint64_t max_blocks = 0;


const RE2 kv_cache_block_regex(
R"(nv_trt_llm_kv_cache_block_metrics{kv_cache_block_type=\"(?P<type>\w+)\",model=\"(?P<model>.*?)\",version=\"1\"}\s+(?P<value>\d+))");

re2::StringPiece input(prometheus_metrics);
std::string type, model, value;
// TODO: Add and example and how it's used.
std::vector<HTTPAPIServer::PromMetric> HTTPAPIServer::MetricFamilyExtractor(
const std::string& input, const std::string& metricFamily)
{
std::vector<PromMetric> metrics;
// Construct the regex pattern using the provided metricFamily
std::string patternStr = metricFamily + R"((?:{(.*?)})?\s+(\d+(?:\.\d+)?))";
re2::RE2 pattern(patternStr);
re2::StringPiece inputPiece(input);

std::string labelString;
std::string metric_value;

while (re2::RE2::FindAndConsume(&inputPiece, pattern, &labelString, &metric_value)) {
PromMetric metric;

// Extract labels if they exist
if (!labelString.empty()) {
re2::RE2 labelPattern(R"((\w+)=\"([^\"]+)\")");
re2::StringPiece labelPiece(labelString);
std::string key, value;
while (re2::RE2::FindAndConsume(&labelPiece, labelPattern, &key, &value)) {
metric.labels[key] = value;
}
}

while (RE2::FindAndConsume(&input, kv_cache_block_regex, &type, &model, &value)) {
// Assign the value
metric.value = stod(metric_value);
metrics.push_back(metric);
}

uint64_t numeric_value = std::stoull(value);
return metrics;
}

if (type == "tokens_per") {
tokens_per_block = numeric_value;
} else if (type == "used") {
used_blocks = numeric_value;
} else if (type == "max") {
max_blocks = numeric_value;
std::string HTTPAPIServer::ExtractKVMetrics(
const std::string& prometheus_metrics)
{
std::string metric_family = "nv_trt_llm_kv_cache_block_metrics";
std::vector<PromMetric> kv_cache_metrics = MetricFamilyExtractor(prometheus_metrics, metric_family);

double tokens_per_block = -1;
double used_blocks = -1;
double max_blocks = -1;

for (const auto& metric : kv_cache_metrics) {
if (metric.labels.count("kv_cache_block_type") > 0) {
std::string type = metric.labels.at("kv_cache_block_type");
if (type == "tokens_per") {
tokens_per_block = metric.value;
} else if (type == "used") {
used_blocks = metric.value;
} else if (type == "max") {
max_blocks = metric.value;
}
}
}

// One or more of the kv metrics was not found or invalid.
if (tokens_per_block < 0 || used_blocks < 0 || max_blocks < 0) {
return "";
}

// Calculate derived metrics
double kv_cache_utilization = 0.0;
double kv_cache_utilization = 0;
if (max_blocks > 0) {
kv_cache_utilization = (double)used_blocks / max_blocks;
kv_cache_utilization = used_blocks / max_blocks;
}
uint64_t max_token_capacity = max_blocks * tokens_per_block;
uint64_t max_token_capacity = static_cast<uint64_t>(max_blocks * tokens_per_block);

// Format the metrics according to the ORCA protocol
triton::common::TritonJson::Value orca_metrics(
Expand All @@ -3449,6 +3487,7 @@ std::string HTTPAPIServer::ExtractKVMetrics(
named_metrics.AddDouble("kv_cache_utilization", kv_cache_utilization);
named_metrics.AddUInt("max_token_capacity", max_token_capacity);

// TODO: Import and make this an actual proto.
orca_metrics.Add("named_metrics", std::move(named_metrics));

triton::common::TritonJson::WriteBuffer buffer;
Expand Down
10 changes: 10 additions & 0 deletions src/http_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,12 @@ class HTTPAPIServer : public HTTPServer {
evbuffer* buffer_ = nullptr;
};

private:
struct PromMetric {
std::unordered_map<std::string, std::string> labels;
double value;
};

protected:
explicit HTTPAPIServer(
const std::shared_ptr<TRITONSERVER_Server>& server,
Expand Down Expand Up @@ -564,6 +570,10 @@ class HTTPAPIServer : public HTTPServer {
std::string ExtractKVMetrics(
const std::string& prometheus_metrics);

// Generates a metric struct for a given family with a map of labels and a value
std::vector<PromMetric> MetricFamilyExtractor(
const std::string& input, const std::string& metricFamily);

// 'meta_data_root' is the root JSON document for 'input_metadata'.
// In TritonJson, the Value objects are references to the root document.
// Therefore the document must stay valid.
Expand Down

0 comments on commit 43a1b18

Please sign in to comment.