Skip to content

Commit

Permalink
Add distinct_on to dataframe api apache#11011
Browse files Browse the repository at this point in the history
  • Loading branch information
Omega359 committed Jun 19, 2024
1 parent 0f80b92 commit 4629f02
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 0 deletions.
116 changes: 116 additions & 0 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,37 @@ impl DataFrame {
})
}

/// Return a new `DataFrame` with duplicated rows removed as per the specified expression list
/// according to the provided sorting expressions grouped by the `DISTINCT ON` clause
/// expressions.
///
/// # Example
/// ```
/// # use datafusion::prelude::*;
/// # use datafusion::error::Result;
/// # #[tokio::main]
/// # async fn main() -> Result<()> {
/// let ctx = SessionContext::new();
/// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?;
/// let df = df.distinct_on(vec![col("a")], vec![col("a"), col("b")], None)?;
/// # Ok(())
/// # }
/// ```
pub fn distinct_on(
self,
on_expr: Vec<Expr>,
select_expr: Vec<Expr>,
sort_expr: Option<Vec<Expr>>,
) -> Result<DataFrame> {
let plan = LogicalPlanBuilder::from(self.plan)
.distinct_on(on_expr, select_expr, sort_expr)?
.build()?;
Ok(DataFrame {
session_state: self.session_state,
plan,
})
}

/// Return a new `DataFrame` that has statistics for a DataFrame.
///
/// Only summarizes numeric datatypes at the moment and returns nulls for
Expand Down Expand Up @@ -2190,6 +2221,91 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn test_distinct_on() -> Result<()> {
let t = test_table().await?;
let plan = t
.distinct_on(vec![col("c1")], vec![col("aggregate_test_100.c1")], None)
.unwrap();

let sql_plan =
create_plan("select distinct on (c1) c1 from aggregate_test_100").await?;

assert_same_plan(&plan.plan.clone(), &sql_plan);

let df_results = plan.clone().collect().await?;

#[rustfmt::skip]
assert_batches_sorted_eq!(
["+----+",
"| c1 |",
"+----+",
"| a |",
"| b |",
"| c |",
"| d |",
"| e |",
"+----+"],
&df_results
);

Ok(())
}

#[tokio::test]
async fn test_distinct_on_sort_by() -> Result<()> {
let t = test_table().await?;
let plan = t
.select(vec![col("c1")])
.unwrap()
.distinct_on(
vec![col("c1")],
vec![col("c1")],
Some(vec![col("c1").sort(true, true)]),
)
.unwrap()
.sort(vec![col("c1").sort(true, true)])
.unwrap();

let df_results = plan.clone().collect().await?;

#[rustfmt::skip]
assert_batches_sorted_eq!(
["+----+",
"| c1 |",
"+----+",
"| a |",
"| b |",
"| c |",
"| d |",
"| e |",
"+----+"],
&df_results
);

Ok(())
}

#[tokio::test]
async fn test_distinct_on_sort_by_unprojected() -> Result<()> {
let t = test_table().await?;
let err = t
.select(vec![col("c1")])
.unwrap()
.distinct_on(
vec![col("c1")],
vec![col("c1")],
Some(vec![col("c1").sort(true, true)]),
)
.unwrap()
// try to sort on some value not present in input to distinct
.sort(vec![col("c2").sort(true, true)])
.unwrap_err();
assert_eq!(err.strip_backtrace(), "Error during planning: For SELECT DISTINCT, ORDER BY expressions c2 must appear in select list");

Ok(())
}

#[tokio::test]
async fn join() -> Result<()> {
let left = test_table().await?.select_columns(&["c1", "c2"])?;
Expand Down
1 change: 1 addition & 0 deletions docs/source/user-guide/dataframe.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ execution. The plan is evaluated (executed) when an action method is invoked, su
| ------------------- | ------------------------------------------------------------------------------------------------------------------------------------------ |
| aggregate | Perform an aggregate query with optional grouping expressions. |
| distinct | Filter out duplicate rows. |
| distinct_on | Filter out duplicate rows based on provided expressions. |
| except | Calculate the exception of two DataFrames. The two DataFrames must have exactly the same schema |
| filter | Filter a DataFrame to only include rows that match the specified filter expression. |
| intersect | Calculate the intersection of two DataFrames. The two DataFrames must have exactly the same schema |
Expand Down

0 comments on commit 4629f02

Please sign in to comment.