Skip to content

Commit 5803583

Browse files
committed
Support 1 or 3 args in generate_series() UDTF
1 parent 9f530dd commit 5803583

File tree

2 files changed

+149
-75
lines changed

2 files changed

+149
-75
lines changed

datafusion/functions-table/src/generate_series.rs

Lines changed: 94 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -22,22 +22,25 @@ use async_trait::async_trait;
2222
use datafusion_catalog::Session;
2323
use datafusion_catalog::TableFunctionImpl;
2424
use datafusion_catalog::TableProvider;
25-
use datafusion_common::{not_impl_err, plan_err, Result, ScalarValue};
25+
use datafusion_common::{plan_err, Result, ScalarValue};
2626
use datafusion_expr::{Expr, TableType};
2727
use datafusion_physical_plan::memory::{LazyBatchGenerator, LazyMemoryExec};
2828
use datafusion_physical_plan::ExecutionPlan;
2929
use parking_lot::RwLock;
3030
use std::fmt;
3131
use std::sync::Arc;
3232

33+
#[derive(Debug, Clone)]
34+
enum GenSeriesArgs {
35+
ContainsNull,
36+
AllNotNullArgs { start: i64, end: i64, step: i64 },
37+
}
38+
3339
/// Table that generates a series of integers from `start`(inclusive) to `end`(inclusive)
3440
#[derive(Debug, Clone)]
3541
struct GenerateSeriesTable {
3642
schema: SchemaRef,
37-
// None if input is Null
38-
start: Option<i64>,
39-
// None if input is Null
40-
end: Option<i64>,
43+
args: GenSeriesArgs,
4144
}
4245

4346
/// Table state that generates a series of integers from `start`(inclusive) to `end`(inclusive)
@@ -46,12 +49,23 @@ struct GenerateSeriesState {
4649
schema: SchemaRef,
4750
start: i64, // Kept for display
4851
end: i64,
52+
step: i64,
4953
batch_size: usize,
5054

5155
/// Tracks current position when generating table
5256
current: i64,
5357
}
5458

59+
impl GenerateSeriesState {
60+
fn reach_end(&self, val: i64) -> bool {
61+
if self.step > 0 {
62+
return val > self.end;
63+
}
64+
65+
val < self.end
66+
}
67+
}
68+
5569
/// Detail to display for 'Explain' plan
5670
impl fmt::Display for GenerateSeriesState {
5771
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
@@ -65,19 +79,19 @@ impl fmt::Display for GenerateSeriesState {
6579

6680
impl LazyBatchGenerator for GenerateSeriesState {
6781
fn generate_next_batch(&mut self) -> Result<Option<RecordBatch>> {
68-
// Check if we've reached the end
69-
if self.current > self.end {
82+
let mut buf = Vec::with_capacity(self.batch_size);
83+
while buf.len() < self.batch_size && !self.reach_end(self.current) {
84+
buf.push(self.current);
85+
self.current += self.step;
86+
}
87+
let array = Int64Array::from(buf);
88+
89+
if array.is_empty() {
7090
return Ok(None);
7191
}
7292

73-
// Construct batch
74-
let batch_end = (self.current + self.batch_size as i64 - 1).min(self.end);
75-
let array = Int64Array::from_iter_values(self.current..=batch_end);
7693
let batch = RecordBatch::try_new(self.schema.clone(), vec![Arc::new(array)])?;
7794

78-
// Update current position for next batch
79-
self.current = batch_end + 1;
80-
8195
Ok(Some(batch))
8296
}
8397
}
@@ -104,77 +118,90 @@ impl TableProvider for GenerateSeriesTable {
104118
_limit: Option<usize>,
105119
) -> Result<Arc<dyn ExecutionPlan>> {
106120
let batch_size = state.config_options().execution.batch_size;
107-
match (self.start, self.end) {
108-
(Some(start), Some(end)) => {
109-
if start > end {
110-
return plan_err!(
111-
"End value must be greater than or equal to start value"
112-
);
113-
}
114-
115-
Ok(Arc::new(LazyMemoryExec::try_new(
116-
self.schema.clone(),
117-
vec![Arc::new(RwLock::new(GenerateSeriesState {
118-
schema: self.schema.clone(),
119-
start,
120-
end,
121-
current: start,
122-
batch_size,
123-
}))],
124-
)?))
125-
}
126-
_ => {
127-
// Either start or end is None, return a generator that outputs 0 rows
128-
Ok(Arc::new(LazyMemoryExec::try_new(
129-
self.schema.clone(),
130-
vec![Arc::new(RwLock::new(GenerateSeriesState {
131-
schema: self.schema.clone(),
132-
start: 0,
133-
end: 0,
134-
current: 1,
135-
batch_size,
136-
}))],
137-
)?))
138-
}
139-
}
121+
122+
let state = match self.args {
123+
// if args have null, then return 0 row
124+
GenSeriesArgs::ContainsNull => GenerateSeriesState {
125+
schema: self.schema.clone(),
126+
start: 0,
127+
end: 0,
128+
step: 1,
129+
current: 1,
130+
batch_size,
131+
},
132+
GenSeriesArgs::AllNotNullArgs { start, end, step } => GenerateSeriesState {
133+
schema: self.schema.clone(),
134+
start,
135+
end,
136+
step,
137+
current: start,
138+
batch_size,
139+
},
140+
};
141+
142+
Ok(Arc::new(LazyMemoryExec::try_new(
143+
self.schema.clone(),
144+
vec![Arc::new(RwLock::new(state))],
145+
)?))
140146
}
141147
}
142148

143149
#[derive(Debug)]
144150
pub struct GenerateSeriesFunc {}
145151

146152
impl TableFunctionImpl for GenerateSeriesFunc {
147-
// Check input `exprs` type and number. Input validity check (e.g. start <= end)
148-
// will be performed in `TableProvider::scan`
149153
fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
150-
// TODO: support 1 or 3 arguments following DuckDB:
151-
// <https://duckdb.org/docs/sql/functions/list#generate_series>
152-
if exprs.len() == 3 || exprs.len() == 1 {
153-
return not_impl_err!("generate_series does not support 1 or 3 arguments");
154+
if exprs.is_empty() || exprs.len() > 3 {
155+
return plan_err!("generate_series function requires 1 to 3 arguments");
154156
}
155157

156-
if exprs.len() != 2 {
157-
return plan_err!("generate_series expects 2 arguments");
158+
let mut normalize_args = Vec::new();
159+
for expr in exprs {
160+
match expr {
161+
Expr::Literal(ScalarValue::Null) => {}
162+
Expr::Literal(ScalarValue::Int64(Some(n))) => normalize_args.push(*n),
163+
_ => return plan_err!("First argument must be an integer literal"),
164+
};
158165
}
159166

160-
let start = match &exprs[0] {
161-
Expr::Literal(ScalarValue::Null) => None,
162-
Expr::Literal(ScalarValue::Int64(Some(n))) => Some(*n),
163-
_ => return plan_err!("First argument must be an integer literal"),
164-
};
165-
166-
let end = match &exprs[1] {
167-
Expr::Literal(ScalarValue::Null) => None,
168-
Expr::Literal(ScalarValue::Int64(Some(n))) => Some(*n),
169-
_ => return plan_err!("Second argument must be an integer literal"),
170-
};
171-
172167
let schema = Arc::new(Schema::new(vec![Field::new(
173168
"value",
174169
DataType::Int64,
175170
false,
176171
)]));
177172

178-
Ok(Arc::new(GenerateSeriesTable { schema, start, end }))
173+
if normalize_args.len() != exprs.len() {
174+
// contain null
175+
return Ok(Arc::new(GenerateSeriesTable {
176+
schema,
177+
args: GenSeriesArgs::ContainsNull,
178+
}));
179+
}
180+
181+
let (start, end, step) = match &normalize_args[..] {
182+
[end] => (0, *end, 1),
183+
[start, end] => (*start, *end, 1),
184+
[start, end, step] => (*start, *end, *step),
185+
_ => {
186+
return plan_err!("generate_series function requires 1 to 3 arguments");
187+
}
188+
};
189+
190+
if start > end && step > 0 {
191+
return plan_err!("start is bigger than end, but increment is positive: cannot generate infinite series");
192+
}
193+
194+
if start < end && step < 0 {
195+
return plan_err!("start is smaller than end, but increment is negative: cannot generate infinite series");
196+
}
197+
198+
if step == 0 {
199+
return plan_err!("step cannot be zero");
200+
}
201+
202+
Ok(Arc::new(GenerateSeriesTable {
203+
schema,
204+
args: GenSeriesArgs::AllNotNullArgs { start, end, step },
205+
}))
179206
}
180207
}

datafusion/sqllogictest/test_files/table_functions.slt

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,18 @@
1616
# under the License.
1717

1818
# Test generate_series table function
19+
query I
20+
SELECT * FROM generate_series(6)
21+
----
22+
0
23+
1
24+
2
25+
3
26+
4
27+
5
28+
6
29+
30+
1931

2032
query I rowsort
2133
SELECT * FROM generate_series(1, 5)
@@ -39,11 +51,35 @@ SELECT * FROM generate_series(3, 6)
3951
5
4052
6
4153

54+
# #generated_data > batch_size
55+
query I
56+
SELECT count(v1) FROM generate_series(-66666,66666) t1(v1)
57+
----
58+
133333
59+
60+
61+
62+
4263
query I rowsort
4364
SELECT SUM(v1) FROM generate_series(1, 5) t1(v1)
4465
----
4566
15
4667

68+
query I
69+
SELECT * FROM generate_series(6, -1, -2)
70+
----
71+
6
72+
4
73+
2
74+
0
75+
76+
query I
77+
SELECT * FROM generate_series(6, 66, 666)
78+
----
79+
6
80+
81+
82+
4783
# Test generate_series with WHERE clause
4884
query I rowsort
4985
SELECT * FROM generate_series(1, 10) t1(v1) WHERE v1 % 2 = 0
@@ -93,6 +129,10 @@ ON a.v1 = b.v1 - 1
93129
2 3
94130
3 4
95131

132+
#
133+
# Test generate_series with null arguments
134+
#
135+
96136
query I
97137
SELECT * FROM generate_series(NULL, 5)
98138
----
@@ -105,6 +145,11 @@ query I
105145
SELECT * FROM generate_series(NULL, NULL)
106146
----
107147

148+
query I
149+
SELECT * FROM generate_series(1, 5, NULL)
150+
----
151+
152+
108153
query TT
109154
EXPLAIN SELECT * FROM generate_series(1, 5)
110155
----
@@ -115,20 +160,22 @@ physical_plan LazyMemoryExec: partitions=1, batch_generators=[generate_series: s
115160
# Test generate_series with invalid arguments
116161
#
117162

118-
query error DataFusion error: Error during planning: End value must be greater than or equal to start value
163+
query error DataFusion error: Error during planning: start is bigger than end, but increment is positive: cannot generate infinite series
119164
SELECT * FROM generate_series(5, 1)
120165

121-
statement error DataFusion error: This feature is not implemented: generate_series does not support 1 or 3 arguments
122-
SELECT * FROM generate_series(1, 5, NULL)
166+
query error DataFusion error: Error during planning: start is smaller than end, but increment is negative: cannot generate infinite series
167+
SELECT * FROM generate_series(-6, 6, -1)
168+
169+
query error DataFusion error: Error during planning: step cannot be zero
170+
SELECT * FROM generate_series(-6, 6, 0)
171+
172+
query error DataFusion error: Error during planning: start is bigger than end, but increment is positive: cannot generate infinite series
173+
SELECT * FROM generate_series(6, -6, 1)
123174

124-
statement error DataFusion error: This feature is not implemented: generate_series does not support 1 or 3 arguments
125-
SELECT * FROM generate_series(1)
126175

127-
statement error DataFusion error: Error during planning: generate_series expects 2 arguments
176+
statement error DataFusion error: Error during planning: generate_series function requires 1 to 3 arguments
128177
SELECT * FROM generate_series(1, 2, 3, 4)
129178

130-
statement error DataFusion error: Error during planning: Second argument must be an integer literal
131-
SELECT * FROM generate_series(1, '2')
132179

133180
statement error DataFusion error: Error during planning: First argument must be an integer literal
134181
SELECT * FROM generate_series('foo', 'bar')

0 commit comments

Comments
 (0)