Skip to content

Commit

Permalink
feat: implement stdvar_over_time function (#1291)
Browse files Browse the repository at this point in the history
* feat: implement stdvar_over_time function

* feat: add more test for stdvar_over_time

* feat: add stdvar_over_time to functions.rs
  • Loading branch information
haohuaijin authored Apr 3, 2023
1 parent 48c2841 commit a82f1f5
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/promql/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ mod test_util;

pub use aggr_over_time::{
AbsentOverTime, AvgOverTime, CountOverTime, LastOverTime, MaxOverTime, MinOverTime,
PresentOverTime, StddevOverTime, SumOverTime,
PresentOverTime, StddevOverTime, StdvarOverTime, SumOverTime,
};
use datafusion::arrow::array::ArrayRef;
use datafusion::error::DataFusionError;
Expand Down
79 changes: 78 additions & 1 deletion src/promql/src/functions/aggr_over_time.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,36 @@ pub fn present_over_time(_: &TimestampMillisecondArray, values: &Float64Array) -
}
}

// TODO(ruihang): support quantile_over_time, and stdvar_over_time
/// the population standard variance of the values in the specified interval.
/// DataFusion's implementation:
/// https://github.com/apache/arrow-datafusion/blob/292eb954fc0bad3a1febc597233ba26cb60bda3e/datafusion/physical-expr/src/aggregate/variance.rs#L224-#L241
#[range_fn(
name = "StdvarOverTime",
ret = "Float64Array",
display_name = "prom_stdvar_over_time"
)]
pub fn stdvar_over_time(_: &TimestampMillisecondArray, values: &Float64Array) -> Option<f64> {
if values.is_empty() {
None
} else {
let mut count = 0;
let mut mean: f64 = 0.0;
let mut result: f64 = 0.0;
for value in values {
let value = value.unwrap();
let new_count = count + 1;
let delta1 = value - mean;
let new_mean = delta1 / new_count as f64 + mean;
let delta2 = value - new_mean;
let new_result = result + delta1 * delta2;

count += 1;
mean = new_mean;
result = new_result;
}
Some(result / count as f64)
}
}

/// the population standard deviation of the values in the specified interval.
/// Prometheus's implementation: https://github.com/prometheus/prometheus/blob/f55ab2217984770aa1eecd0f2d5f54580029b1c0/promql/functions.go#L556-L569
Expand Down Expand Up @@ -154,6 +183,8 @@ pub fn stddev_over_time(_: &TimestampMillisecondArray, values: &Float64Array) ->
}
}

// TODO(ruihang): support quantile_over_time

#[cfg(test)]
mod test {
use super::*;
Expand Down Expand Up @@ -368,6 +399,52 @@ mod test {
);
}

#[test]
fn calculate_stdvar_over_time() {
let (ts_array, value_array) = build_test_range_arrays();
simple_range_udf_runner(
StdvarOverTime::scalar_udf(),
ts_array,
value_array,
vec![
Some(1417.8479276253622),
Some(808.999919713209),
Some(0.0),
None,
None,
Some(328.3638826418587),
Some(143.5964181766362),
Some(130.91830542386285),
Some(0.0),
None,
],
);

// add more assertions
let ts_array = Arc::new(TimestampMillisecondArray::from_iter(
[1000i64, 3000, 5000, 7000, 9000, 11000, 13000, 15000]
.into_iter()
.map(Some),
));
let values_array = Arc::new(Float64Array::from_iter([
1.5990505637277868,
1.5990505637277868,
1.5990505637277868,
0.0,
8.0,
8.0,
2.0,
3.0,
]));
let ranges = [(0, 3), (3, 5)];
simple_range_udf_runner(
StdvarOverTime::scalar_udf(),
RangeArray::from_ranges(ts_array, ranges).unwrap(),
RangeArray::from_ranges(values_array, ranges).unwrap(),
vec![Some(0.0), Some(10.559999999999999)],
);
}

#[test]
fn calculate_std_dev_over_time() {
let (ts_array, value_array) = build_test_range_arrays();
Expand Down

0 comments on commit a82f1f5

Please sign in to comment.