Skip to content

Commit

Permalink
Added helper function to pull metrics in HTTPAPIServer to pull metric…
Browse files Browse the repository at this point in the history
…s for use in HandleGenerate to add kv_utilization and max_token_capacity to the inference request response header.
  • Loading branch information
BenjaminBraunDev committed Dec 9, 2024
1 parent e0f0734 commit 713c8de
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 0 deletions.
115 changes: 115 additions & 0 deletions src/http_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3225,6 +3225,34 @@ HTTPAPIServer::HandleGenerate(
req, RestrictedCategory::INFERENCE, restricted_apis_);

AddContentTypeHeader(req, "application/json");

// logic to add kv_cache metrics to reponse header
// Get the metrics in Prometheus format
TRITONSERVER_Metrics* metrics = nullptr;
TRITONSERVER_Error* err = TRITONSERVER_ServerMetrics(server_.get(), &metrics);
if (err == nullptr) {
const char* base;
size_t byte_size;
err = TRITONSERVER_MetricsFormatted(
metrics, TRITONSERVER_METRIC_PROMETHEUS, &base, &byte_size);
if (err == nullptr) {
std::string kv_utilization(base, byte_size);
// Extract the KV utilization metrics from the Prometheus formatted string.
std::string extracted_kv_metrics = ExtractKVMetrics(kv_utilization);
if (!extracted_kv_metrics.empty()) {
evhtp_headers_add_header(
req->headers_out,
evhtp_header_new("endpoint-load-metrics", extracted_kv_metrics.c_str(), 1, 1));
}
}
}
TRITONSERVER_MetricsDelete(metrics);
// Handle potential errors
if (err != nullptr) {
LOG_ERROR << "Failed to get KV metrics: " << TRITONSERVER_ErrorMessage(err);
TRITONSERVER_ErrorDelete(err);
}

if (req->method != htp_method_POST) {
RETURN_AND_RESPOND_WITH_ERR(
req, EVHTP_RES_METHNALLOWED, "Method Not Allowed");
Expand Down Expand Up @@ -3381,6 +3409,93 @@ HTTPAPIServer::HandleGenerate(
request_release_payload.release();
}

// TODO: Add and example and how it's used.
std::vector<HTTPAPIServer::PromMetric> HTTPAPIServer::MetricFamilyExtractor(
const std::string& input, const std::string& metricFamily)
{
std::vector<PromMetric> metrics;
// Construct the regex pattern using the provided metricFamily
std::string patternStr = metricFamily + R"((?:{(.*?)})?\s+(\d+(?:\.\d+)?))";
re2::RE2 pattern(patternStr);
re2::StringPiece inputPiece(input);

std::string labelString;
std::string metric_value;

while (re2::RE2::FindAndConsume(&inputPiece, pattern, &labelString, &metric_value)) {
PromMetric metric;

// Extract labels if they exist
if (!labelString.empty()) {
re2::RE2 labelPattern(R"((\w+)=\"([^\"]+)\")");
re2::StringPiece labelPiece(labelString);
std::string key, value;
while (re2::RE2::FindAndConsume(&labelPiece, labelPattern, &key, &value)) {
metric.labels[key] = value;
}
}

// Assign the value
metric.value = stod(metric_value);
metrics.push_back(metric);
}

return metrics;
}

std::string HTTPAPIServer::ExtractKVMetrics(
const std::string& prometheus_metrics)
{
std::string metric_family = "nv_trt_llm_kv_cache_block_metrics";
std::vector<PromMetric> kv_cache_metrics = MetricFamilyExtractor(prometheus_metrics, metric_family);

double tokens_per_block = -1;
double used_blocks = -1;
double max_blocks = -1;

for (const auto& metric : kv_cache_metrics) {
if (metric.labels.count("kv_cache_block_type") > 0) {
std::string type = metric.labels.at("kv_cache_block_type");
if (type == "tokens_per") {
tokens_per_block = metric.value;
} else if (type == "used") {
used_blocks = metric.value;
} else if (type == "max") {
max_blocks = metric.value;
}
}
}

// One or more of the kv metrics was not found or invalid.
if (tokens_per_block < 0 || used_blocks < 0 || max_blocks < 0) {
return "";
}

// Calculate derived metrics
double kv_cache_utilization = 0;
if (max_blocks > 0) {
kv_cache_utilization = used_blocks / max_blocks;
}
uint64_t max_token_capacity = static_cast<uint64_t>(max_blocks * tokens_per_block);

// Format the metrics according to the ORCA protocol
triton::common::TritonJson::Value orca_metrics(
triton::common::TritonJson::ValueType::OBJECT);
triton::common::TritonJson::Value named_metrics(
orca_metrics, triton::common::TritonJson::ValueType::OBJECT);

named_metrics.AddDouble("kv_cache_utilization", kv_cache_utilization);
named_metrics.AddUInt("max_token_capacity", max_token_capacity);

// TODO: Import and make this an actual proto.
orca_metrics.Add("named_metrics", std::move(named_metrics));

triton::common::TritonJson::WriteBuffer buffer;
orca_metrics.Write(&buffer);

return std::string("JSON ") + buffer.Contents();
}

TRITONSERVER_Error*
HTTPAPIServer::ModelInputMetadata(
const std::string& model_name, const int64_t model_version,
Expand Down
15 changes: 15 additions & 0 deletions src/http_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,12 @@ class HTTPAPIServer : public HTTPServer {
evbuffer* buffer_ = nullptr;
};

private:
struct PromMetric {
std::unordered_map<std::string, std::string> labels;
double value;
};

protected:
explicit HTTPAPIServer(
const std::shared_ptr<TRITONSERVER_Server>& server,
Expand Down Expand Up @@ -559,6 +565,15 @@ class HTTPAPIServer : public HTTPServer {
evhtp_request_t* req, const std::string& model_name,
const std::string& model_version_str, bool streaming);

// Helper function to set get the KV-cache utilization metrics for the
// infer response header
std::string ExtractKVMetrics(
const std::string& prometheus_metrics);

// Generates a metric struct for a given family with a map of labels and a value
std::vector<PromMetric> MetricFamilyExtractor(
const std::string& input, const std::string& metricFamily);

// 'meta_data_root' is the root JSON document for 'input_metadata'.
// In TritonJson, the Value objects are references to the root document.
// Therefore the document must stay valid.
Expand Down

0 comments on commit 713c8de

Please sign in to comment.