@@ -22,22 +22,25 @@ use async_trait::async_trait;
22
22
use datafusion_catalog:: Session ;
23
23
use datafusion_catalog:: TableFunctionImpl ;
24
24
use datafusion_catalog:: TableProvider ;
25
- use datafusion_common:: { not_impl_err , plan_err, Result , ScalarValue } ;
25
+ use datafusion_common:: { plan_err, Result , ScalarValue } ;
26
26
use datafusion_expr:: { Expr , TableType } ;
27
27
use datafusion_physical_plan:: memory:: { LazyBatchGenerator , LazyMemoryExec } ;
28
28
use datafusion_physical_plan:: ExecutionPlan ;
29
29
use parking_lot:: RwLock ;
30
30
use std:: fmt;
31
31
use std:: sync:: Arc ;
32
32
33
+ #[ derive( Debug , Clone ) ]
34
+ enum GenSeriesArgs {
35
+ ContainsNull ,
36
+ AllNotNullArgs { start : i64 , end : i64 , step : i64 } ,
37
+ }
38
+
33
39
/// Table that generates a series of integers from `start`(inclusive) to `end`(inclusive)
34
40
#[ derive( Debug , Clone ) ]
35
41
struct GenerateSeriesTable {
36
42
schema : SchemaRef ,
37
- // None if input is Null
38
- start : Option < i64 > ,
39
- // None if input is Null
40
- end : Option < i64 > ,
43
+ args : GenSeriesArgs ,
41
44
}
42
45
43
46
/// Table state that generates a series of integers from `start`(inclusive) to `end`(inclusive)
@@ -46,12 +49,23 @@ struct GenerateSeriesState {
46
49
schema : SchemaRef ,
47
50
start : i64 , // Kept for display
48
51
end : i64 ,
52
+ step : i64 ,
49
53
batch_size : usize ,
50
54
51
55
/// Tracks current position when generating table
52
56
current : i64 ,
53
57
}
54
58
59
+ impl GenerateSeriesState {
60
+ fn reach_end ( & self , val : i64 ) -> bool {
61
+ if self . step > 0 {
62
+ return val > self . end ;
63
+ }
64
+
65
+ val < self . end
66
+ }
67
+ }
68
+
55
69
/// Detail to display for 'Explain' plan
56
70
impl fmt:: Display for GenerateSeriesState {
57
71
fn fmt ( & self , f : & mut fmt:: Formatter ) -> fmt:: Result {
@@ -65,19 +79,19 @@ impl fmt::Display for GenerateSeriesState {
65
79
66
80
impl LazyBatchGenerator for GenerateSeriesState {
67
81
fn generate_next_batch ( & mut self ) -> Result < Option < RecordBatch > > {
68
- // Check if we've reached the end
69
- if self . current > self . end {
82
+ let mut buf = Vec :: with_capacity ( self . batch_size ) ;
83
+ while buf. len ( ) < self . batch_size && !self . reach_end ( self . current ) {
84
+ buf. push ( self . current ) ;
85
+ self . current += self . step ;
86
+ }
87
+ let array = Int64Array :: from ( buf) ;
88
+
89
+ if array. is_empty ( ) {
70
90
return Ok ( None ) ;
71
91
}
72
92
73
- // Construct batch
74
- let batch_end = ( self . current + self . batch_size as i64 - 1 ) . min ( self . end ) ;
75
- let array = Int64Array :: from_iter_values ( self . current ..=batch_end) ;
76
93
let batch = RecordBatch :: try_new ( self . schema . clone ( ) , vec ! [ Arc :: new( array) ] ) ?;
77
94
78
- // Update current position for next batch
79
- self . current = batch_end + 1 ;
80
-
81
95
Ok ( Some ( batch) )
82
96
}
83
97
}
@@ -104,77 +118,90 @@ impl TableProvider for GenerateSeriesTable {
104
118
_limit : Option < usize > ,
105
119
) -> Result < Arc < dyn ExecutionPlan > > {
106
120
let batch_size = state. config_options ( ) . execution . batch_size ;
107
- match ( self . start , self . end ) {
108
- ( Some ( start) , Some ( end) ) => {
109
- if start > end {
110
- return plan_err ! (
111
- "End value must be greater than or equal to start value"
112
- ) ;
113
- }
114
-
115
- Ok ( Arc :: new ( LazyMemoryExec :: try_new (
116
- self . schema . clone ( ) ,
117
- vec ! [ Arc :: new( RwLock :: new( GenerateSeriesState {
118
- schema: self . schema. clone( ) ,
119
- start,
120
- end,
121
- current: start,
122
- batch_size,
123
- } ) ) ] ,
124
- ) ?) )
125
- }
126
- _ => {
127
- // Either start or end is None, return a generator that outputs 0 rows
128
- Ok ( Arc :: new ( LazyMemoryExec :: try_new (
129
- self . schema . clone ( ) ,
130
- vec ! [ Arc :: new( RwLock :: new( GenerateSeriesState {
131
- schema: self . schema. clone( ) ,
132
- start: 0 ,
133
- end: 0 ,
134
- current: 1 ,
135
- batch_size,
136
- } ) ) ] ,
137
- ) ?) )
138
- }
139
- }
121
+
122
+ let state = match self . args {
123
+ // if args have null, then return 0 row
124
+ GenSeriesArgs :: ContainsNull => GenerateSeriesState {
125
+ schema : self . schema . clone ( ) ,
126
+ start : 0 ,
127
+ end : 0 ,
128
+ step : 1 ,
129
+ current : 1 ,
130
+ batch_size,
131
+ } ,
132
+ GenSeriesArgs :: AllNotNullArgs { start, end, step } => GenerateSeriesState {
133
+ schema : self . schema . clone ( ) ,
134
+ start,
135
+ end,
136
+ step,
137
+ current : start,
138
+ batch_size,
139
+ } ,
140
+ } ;
141
+
142
+ Ok ( Arc :: new ( LazyMemoryExec :: try_new (
143
+ self . schema . clone ( ) ,
144
+ vec ! [ Arc :: new( RwLock :: new( state) ) ] ,
145
+ ) ?) )
140
146
}
141
147
}
142
148
143
149
#[ derive( Debug ) ]
144
150
pub struct GenerateSeriesFunc { }
145
151
146
152
impl TableFunctionImpl for GenerateSeriesFunc {
147
- // Check input `exprs` type and number. Input validity check (e.g. start <= end)
148
- // will be performed in `TableProvider::scan`
149
153
fn call ( & self , exprs : & [ Expr ] ) -> Result < Arc < dyn TableProvider > > {
150
- // TODO: support 1 or 3 arguments following DuckDB:
151
- // <https://duckdb.org/docs/sql/functions/list#generate_series>
152
- if exprs. len ( ) == 3 || exprs. len ( ) == 1 {
153
- return not_impl_err ! ( "generate_series does not support 1 or 3 arguments" ) ;
154
+ if exprs. is_empty ( ) || exprs. len ( ) > 3 {
155
+ return plan_err ! ( "generate_series function requires 1 to 3 arguments" ) ;
154
156
}
155
157
156
- if exprs. len ( ) != 2 {
157
- return plan_err ! ( "generate_series expects 2 arguments" ) ;
158
+ let mut normalize_args = Vec :: new ( ) ;
159
+ for expr in exprs {
160
+ match expr {
161
+ Expr :: Literal ( ScalarValue :: Null ) => { }
162
+ Expr :: Literal ( ScalarValue :: Int64 ( Some ( n) ) ) => normalize_args. push ( * n) ,
163
+ _ => return plan_err ! ( "First argument must be an integer literal" ) ,
164
+ } ;
158
165
}
159
166
160
- let start = match & exprs[ 0 ] {
161
- Expr :: Literal ( ScalarValue :: Null ) => None ,
162
- Expr :: Literal ( ScalarValue :: Int64 ( Some ( n) ) ) => Some ( * n) ,
163
- _ => return plan_err ! ( "First argument must be an integer literal" ) ,
164
- } ;
165
-
166
- let end = match & exprs[ 1 ] {
167
- Expr :: Literal ( ScalarValue :: Null ) => None ,
168
- Expr :: Literal ( ScalarValue :: Int64 ( Some ( n) ) ) => Some ( * n) ,
169
- _ => return plan_err ! ( "Second argument must be an integer literal" ) ,
170
- } ;
171
-
172
167
let schema = Arc :: new ( Schema :: new ( vec ! [ Field :: new(
173
168
"value" ,
174
169
DataType :: Int64 ,
175
170
false ,
176
171
) ] ) ) ;
177
172
178
- Ok ( Arc :: new ( GenerateSeriesTable { schema, start, end } ) )
173
+ if normalize_args. len ( ) != exprs. len ( ) {
174
+ // contain null
175
+ return Ok ( Arc :: new ( GenerateSeriesTable {
176
+ schema,
177
+ args : GenSeriesArgs :: ContainsNull ,
178
+ } ) ) ;
179
+ }
180
+
181
+ let ( start, end, step) = match & normalize_args[ ..] {
182
+ [ end] => ( 0 , * end, 1 ) ,
183
+ [ start, end] => ( * start, * end, 1 ) ,
184
+ [ start, end, step] => ( * start, * end, * step) ,
185
+ _ => {
186
+ return plan_err ! ( "generate_series function requires 1 to 3 arguments" ) ;
187
+ }
188
+ } ;
189
+
190
+ if start > end && step > 0 {
191
+ return plan_err ! ( "start is bigger than end, but increment is positive: cannot generate infinite series" ) ;
192
+ }
193
+
194
+ if start < end && step < 0 {
195
+ return plan_err ! ( "start is smaller than end, but increment is negative: cannot generate infinite series" ) ;
196
+ }
197
+
198
+ if step == 0 {
199
+ return plan_err ! ( "step cannot be zero" ) ;
200
+ }
201
+
202
+ Ok ( Arc :: new ( GenerateSeriesTable {
203
+ schema,
204
+ args : GenSeriesArgs :: AllNotNullArgs { start, end, step } ,
205
+ } ) )
179
206
}
180
207
}
0 commit comments