Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support 1 or 3 arg in generate_series() UDTF #13856

Merged
merged 2 commits into from
Dec 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 99 additions & 69 deletions datafusion/functions-table/src/generate_series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,36 +22,53 @@ use async_trait::async_trait;
use datafusion_catalog::Session;
use datafusion_catalog::TableFunctionImpl;
use datafusion_catalog::TableProvider;
use datafusion_common::{not_impl_err, plan_err, Result, ScalarValue};
use datafusion_common::{plan_err, Result, ScalarValue};
use datafusion_expr::{Expr, TableType};
use datafusion_physical_plan::memory::{LazyBatchGenerator, LazyMemoryExec};
use datafusion_physical_plan::ExecutionPlan;
use parking_lot::RwLock;
use std::fmt;
use std::sync::Arc;

/// Table that generates a series of integers from `start`(inclusive) to `end`(inclusive)
/// Indicates the arguments used for generating a series.
#[derive(Debug, Clone)]
enum GenSeriesArgs {
/// ContainsNull signifies that at least one argument(start, end, step) was null, thus no series will be generated.
ContainsNull,
/// AllNotNullArgs holds the start, end, and step values for generating the series when all arguments are not null.
AllNotNullArgs { start: i64, end: i64, step: i64 },
}

/// Table that generates a series of integers from `start`(inclusive) to `end`(inclusive), incrementing by step
#[derive(Debug, Clone)]
struct GenerateSeriesTable {
schema: SchemaRef,
// None if input is Null
start: Option<i64>,
// None if input is Null
end: Option<i64>,
args: GenSeriesArgs,
}

/// Table state that generates a series of integers from `start`(inclusive) to `end`(inclusive)
/// Table state that generates a series of integers from `start`(inclusive) to `end`(inclusive), incrementing by step
#[derive(Debug, Clone)]
struct GenerateSeriesState {
schema: SchemaRef,
start: i64, // Kept for display
end: i64,
step: i64,
batch_size: usize,

/// Tracks current position when generating table
current: i64,
}

impl GenerateSeriesState {
fn reach_end(&self, val: i64) -> bool {
if self.step > 0 {
return val > self.end;
}

val < self.end
}
}

/// Detail to display for 'Explain' plan
impl fmt::Display for GenerateSeriesState {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
Expand All @@ -65,19 +82,19 @@ impl fmt::Display for GenerateSeriesState {

impl LazyBatchGenerator for GenerateSeriesState {
fn generate_next_batch(&mut self) -> Result<Option<RecordBatch>> {
// Check if we've reached the end
if self.current > self.end {
let mut buf = Vec::with_capacity(self.batch_size);
while buf.len() < self.batch_size && !self.reach_end(self.current) {
buf.push(self.current);
self.current += self.step;
}
let array = Int64Array::from(buf);

if array.is_empty() {
return Ok(None);
}

// Construct batch
let batch_end = (self.current + self.batch_size as i64 - 1).min(self.end);
let array = Int64Array::from_iter_values(self.current..=batch_end);
let batch = RecordBatch::try_new(self.schema.clone(), vec![Arc::new(array)])?;

// Update current position for next batch
self.current = batch_end + 1;

Ok(Some(batch))
}
}
Expand All @@ -104,77 +121,90 @@ impl TableProvider for GenerateSeriesTable {
_limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>> {
let batch_size = state.config_options().execution.batch_size;
match (self.start, self.end) {
(Some(start), Some(end)) => {
if start > end {
return plan_err!(
"End value must be greater than or equal to start value"
);
}

Ok(Arc::new(LazyMemoryExec::try_new(
self.schema.clone(),
vec![Arc::new(RwLock::new(GenerateSeriesState {
schema: self.schema.clone(),
start,
end,
current: start,
batch_size,
}))],
)?))
}
_ => {
// Either start or end is None, return a generator that outputs 0 rows
Ok(Arc::new(LazyMemoryExec::try_new(
self.schema.clone(),
vec![Arc::new(RwLock::new(GenerateSeriesState {
schema: self.schema.clone(),
start: 0,
end: 0,
current: 1,
batch_size,
}))],
)?))
}
}

let state = match self.args {
// if args have null, then return 0 row
GenSeriesArgs::ContainsNull => GenerateSeriesState {
schema: self.schema.clone(),
start: 0,
end: 0,
step: 1,
current: 1,
batch_size,
},
GenSeriesArgs::AllNotNullArgs { start, end, step } => GenerateSeriesState {
schema: self.schema.clone(),
start,
end,
step,
current: start,
batch_size,
},
};

Ok(Arc::new(LazyMemoryExec::try_new(
self.schema.clone(),
vec![Arc::new(RwLock::new(state))],
)?))
}
}

#[derive(Debug)]
pub struct GenerateSeriesFunc {}

impl TableFunctionImpl for GenerateSeriesFunc {
// Check input `exprs` type and number. Input validity check (e.g. start <= end)
// will be performed in `TableProvider::scan`
fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
// TODO: support 1 or 3 arguments following DuckDB:
// <https://duckdb.org/docs/sql/functions/list#generate_series>
if exprs.len() == 3 || exprs.len() == 1 {
return not_impl_err!("generate_series does not support 1 or 3 arguments");
if exprs.is_empty() || exprs.len() > 3 {
return plan_err!("generate_series function requires 1 to 3 arguments");
}

if exprs.len() != 2 {
return plan_err!("generate_series expects 2 arguments");
let mut normalize_args = Vec::new();
for expr in exprs {
match expr {
Expr::Literal(ScalarValue::Null) => {}
Expr::Literal(ScalarValue::Int64(Some(n))) => normalize_args.push(*n),
_ => return plan_err!("First argument must be an integer literal"),
};
}

let start = match &exprs[0] {
Expr::Literal(ScalarValue::Null) => None,
Expr::Literal(ScalarValue::Int64(Some(n))) => Some(*n),
_ => return plan_err!("First argument must be an integer literal"),
};

let end = match &exprs[1] {
Expr::Literal(ScalarValue::Null) => None,
Expr::Literal(ScalarValue::Int64(Some(n))) => Some(*n),
_ => return plan_err!("Second argument must be an integer literal"),
};

let schema = Arc::new(Schema::new(vec![Field::new(
"value",
DataType::Int64,
false,
)]));

Ok(Arc::new(GenerateSeriesTable { schema, start, end }))
if normalize_args.len() != exprs.len() {
// contain null
return Ok(Arc::new(GenerateSeriesTable {
schema,
args: GenSeriesArgs::ContainsNull,
}));
}

let (start, end, step) = match &normalize_args[..] {
[end] => (0, *end, 1),
[start, end] => (*start, *end, 1),
[start, end, step] => (*start, *end, *step),
_ => {
return plan_err!("generate_series function requires 1 to 3 arguments");
}
};

if start > end && step > 0 {
return plan_err!("start is bigger than end, but increment is positive: cannot generate infinite series");
}

if start < end && step < 0 {
return plan_err!("start is smaller than end, but increment is negative: cannot generate infinite series");
}

if step == 0 {
return plan_err!("step cannot be zero");
}

Ok(Arc::new(GenerateSeriesTable {
schema,
args: GenSeriesArgs::AllNotNullArgs { start, end, step },
}))
}
}
63 changes: 55 additions & 8 deletions datafusion/sqllogictest/test_files/table_functions.slt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,18 @@
# under the License.

# Test generate_series table function
query I
SELECT * FROM generate_series(6)
----
0
1
2
3
4
5
6



query I rowsort
SELECT * FROM generate_series(1, 5)
Expand All @@ -39,11 +51,35 @@ SELECT * FROM generate_series(3, 6)
5
6

# #generated_data > batch_size
query I
SELECT count(v1) FROM generate_series(-66666,66666) t1(v1)
----
133333




query I rowsort
SELECT SUM(v1) FROM generate_series(1, 5) t1(v1)
----
15

query I
SELECT * FROM generate_series(6, -1, -2)
----
6
4
2
0

query I
SELECT * FROM generate_series(6, 66, 666)
----
6



# Test generate_series with WHERE clause
query I rowsort
SELECT * FROM generate_series(1, 10) t1(v1) WHERE v1 % 2 = 0
Expand Down Expand Up @@ -93,6 +129,10 @@ ON a.v1 = b.v1 - 1
2 3
3 4

#
# Test generate_series with null arguments
#

query I
SELECT * FROM generate_series(NULL, 5)
----
Expand All @@ -105,6 +145,11 @@ query I
SELECT * FROM generate_series(NULL, NULL)
----

query I
SELECT * FROM generate_series(1, 5, NULL)
----


query TT
EXPLAIN SELECT * FROM generate_series(1, 5)
----
Expand All @@ -115,20 +160,22 @@ physical_plan LazyMemoryExec: partitions=1, batch_generators=[generate_series: s
# Test generate_series with invalid arguments
#

query error DataFusion error: Error during planning: End value must be greater than or equal to start value
query error DataFusion error: Error during planning: start is bigger than end, but increment is positive: cannot generate infinite series
SELECT * FROM generate_series(5, 1)

statement error DataFusion error: This feature is not implemented: generate_series does not support 1 or 3 arguments
SELECT * FROM generate_series(1, 5, NULL)
query error DataFusion error: Error during planning: start is smaller than end, but increment is negative: cannot generate infinite series
SELECT * FROM generate_series(-6, 6, -1)

query error DataFusion error: Error during planning: step cannot be zero
SELECT * FROM generate_series(-6, 6, 0)

query error DataFusion error: Error during planning: start is bigger than end, but increment is positive: cannot generate infinite series
SELECT * FROM generate_series(6, -6, 1)

statement error DataFusion error: This feature is not implemented: generate_series does not support 1 or 3 arguments
SELECT * FROM generate_series(1)

statement error DataFusion error: Error during planning: generate_series expects 2 arguments
statement error DataFusion error: Error during planning: generate_series function requires 1 to 3 arguments
SELECT * FROM generate_series(1, 2, 3, 4)

statement error DataFusion error: Error during planning: Second argument must be an integer literal
SELECT * FROM generate_series(1, '2')

statement error DataFusion error: Error during planning: First argument must be an integer literal
SELECT * FROM generate_series('foo', 'bar')
Expand Down
Loading