diff --git a/src/sqlparser.rs b/src/sqlparser.rs index a9a78aea7..c7733b572 100644 --- a/src/sqlparser.rs +++ b/src/sqlparser.rs @@ -95,15 +95,7 @@ impl Parser { "INSERT" => Ok(self.parse_insert()?), "ALTER" => Ok(self.parse_alter()?), "COPY" => Ok(self.parse_copy()?), - "TRUE" => { - self.prev_token(); - self.parse_sql_value() - } - "FALSE" => { - self.prev_token(); - self.parse_sql_value() - } - "NULL" => { + "TRUE" | "FALSE" | "NULL" => { self.prev_token(); self.parse_sql_value() } @@ -116,7 +108,7 @@ impl Parser { self.parse_cast_expression() } else { match self.peek_token() { - Some(Token::LParen) => self.parse_function_or_pg_cast(&id), + Some(Token::LParen) => self.parse_function(&id), Some(Token::Period) => { let mut id_parts: Vec = vec![id]; while self.peek_token() == Some(Token::Period) { @@ -136,19 +128,10 @@ impl Parser { } } } - Token::Number(_) => { - self.prev_token(); - self.parse_sql_value() - } - Token::String(_) => { - self.prev_token(); - self.parse_sql_value() - } - Token::SingleQuotedString(_) => { - self.prev_token(); - self.parse_sql_value() - } - Token::DoubleQuotedString(_) => { + Token::Number(_) + | Token::String(_) + | Token::SingleQuotedString(_) + | Token::DoubleQuotedString(_) => { self.prev_token(); self.parse_sql_value() } @@ -168,15 +151,6 @@ impl Parser { } } - pub fn parse_function_or_pg_cast(&mut self, id: &str) -> Result { - let func = self.parse_function(&id)?; - if let Some(Token::DoubleColon) = self.peek_token() { - self.parse_pg_cast(func) - } else { - Ok(func) - } - } - pub fn parse_function(&mut self, id: &str) -> Result { self.consume_token(&Token::LParen)?; if let Ok(true) = self.consume_token(&Token::RParen) { @@ -241,25 +215,13 @@ impl Parser { }) } - /// Parse a postgresql casting style which is in the form or expr::datatype + /// Parse a postgresql casting style which is in the form of `expr::datatype` pub fn parse_pg_cast(&mut self, expr: ASTNode) -> Result { let _ = self.consume_token(&Token::DoubleColon)?; - let datatype = if let Ok(data_type) = self.parse_data_type() { - Ok(data_type) - } else if let Ok(table_name) = self.parse_tablename() { - Ok(SQLType::Custom(table_name)) - } else { - parser_err!("Expecting datatype or identifier") - }; - let pg_cast = ASTNode::SQLCast { + Ok(ASTNode::SQLCast { expr: Box::new(expr), - data_type: datatype?, - }; - if let Some(Token::DoubleColon) = self.peek_token() { - self.parse_pg_cast(pg_cast) - } else { - Ok(pg_cast) - } + data_type: self.parse_data_type()?, + }) } /// Parse an expression infix (typically an operator) @@ -362,11 +324,17 @@ impl Parser { } } + /// Return first non-whitespace token that has not yet been processed pub fn peek_token(&self) -> Option { - self.peek_token_skip_whitespace() + if let Some(n) = self.til_non_whitespace() { + self.token_at(n) + } else { + None + } } - pub fn skip_whitespace(&mut self) -> Option { + /// Get the next token skipping whitespace and increment the token index + pub fn next_token(&mut self) -> Option { loop { match self.next_token_no_skip() { Some(Token::Whitespace(_)) => { @@ -406,19 +374,6 @@ impl Parser { } } - pub fn peek_token_skip_whitespace(&self) -> Option { - if let Some(n) = self.til_non_whitespace() { - self.token_at(n) - } else { - None - } - } - - /// Get the next token skipping whitespace and increment the token index - pub fn next_token(&mut self) -> Option { - self.skip_whitespace() - } - pub fn next_token_no_skip(&mut self) -> Option { if self.index < self.tokens.len() { self.index = self.index + 1; @@ -428,9 +383,9 @@ impl Parser { } } - /// if prev token is whitespace skip it - /// if prev token is not whitespace skipt it as well - pub fn prev_token_skip_whitespace(&mut self) -> Option { + /// Push back the last one non-whitespace token + pub fn prev_token(&mut self) -> Option { + // TODO: returned value is unused (available via peek_token) loop { match self.prev_token_no_skip() { Some(Token::Whitespace(_)) => { @@ -443,12 +398,8 @@ impl Parser { } } - pub fn prev_token(&mut self) -> Option { - self.prev_token_skip_whitespace() - } - /// Get the previous token and decrement the token index - pub fn prev_token_no_skip(&mut self) -> Option { + fn prev_token_no_skip(&mut self) -> Option { if self.index > 0 { self.index = self.index - 1; Some(self.tokens[self.index].clone()) @@ -731,30 +682,13 @@ impl Parser { "NULL" => Ok(Value::Null), _ => return parser_err!(format!("No value parser for keyword {}", k)), }, - //TODO: parse the timestamp here + //TODO: parse the timestamp here (see parse_timestamp_value()) Token::Number(ref n) if n.contains(".") => match n.parse::() { Ok(n) => Ok(Value::Double(n)), - Err(e) => { - let index = self.index; - self.prev_token(); - if let Ok(timestamp) = self.parse_timestamp_value() { - println!("timstamp: {:?}", timestamp); - Ok(timestamp) - } else { - self.index = index; - parser_err!(format!("Could not parse '{}' as i64: {}", n, e)) - } - } + Err(e) => parser_err!(format!("Could not parse '{}' as i64: {}", n, e)), }, Token::Number(ref n) => match n.parse::() { - Ok(n) => { - // if let Some(Token::Minus) = self.peek_token() { - // self.prev_token(); - // self.parse_timestamp_value() - // } else { - Ok(Value::Long(n)) - // } - } + Ok(n) => Ok(Value::Long(n)), Err(e) => parser_err!(format!("Could not parse '{}' as i64: {}", n, e)), }, Token::Identifier(id) => Ok(Value::String(id.to_string())), @@ -782,13 +716,13 @@ impl Parser { } } - /// Parse a literal integer/long + /// Parse a literal double pub fn parse_literal_double(&mut self) -> Result { match self.next_token() { Some(Token::Number(s)) => s.parse::().map_err(|e| { - ParserError::ParserError(format!("Could not parse '{}' as i64: {}", s, e)) + ParserError::ParserError(format!("Could not parse '{}' as f64: {}", s, e)) }), - other => parser_err!(format!("Expected literal int, found {:?}", other)), + other => parser_err!(format!("Expected literal number, found {:?}", other)), } } @@ -869,19 +803,17 @@ impl Parser { self.consume_token(&Token::Colon)?; let min = self.parse_literal_int()?; self.consume_token(&Token::Colon)?; + // On one hand, the SQL specs defines ::= , + // so it would be more correct to parse it as such let sec = self.parse_literal_double()?; - let _ = (sec.fract() * 1000.0).round(); - if let Ok(true) = self.consume_token(&Token::Period) { - let ms = self.parse_literal_int()?; - Ok(NaiveTime::from_hms_milli( - hour as u32, - min as u32, - sec as u32, - ms as u32, - )) - } else { - Ok(NaiveTime::from_hms(hour as u32, min as u32, sec as u32)) - } + // On the other, chrono only supports nanoseconds, which should(?) fit in seconds-as-f64... + let nanos = (sec.fract() * 1_000_000_000.0).round(); + Ok(NaiveTime::from_hms_nano( + hour as u32, + min as u32, + sec as u32, + nanos as u32, + )) } /// Parse a SQL datatype (in the context of a CREATE TABLE statement for example) @@ -973,13 +905,10 @@ impl Parser { } _ => parser_err!(format!("Invalid data type '{:?}'", k)), }, - Some(Token::Identifier(id)) => { - if let Ok(true) = self.consume_token(&Token::Period) { - let ids = self.parse_tablename()?; - Ok(SQLType::Custom(format!("{}.{}", id, ids))) - } else { - Ok(SQLType::Custom(id)) - } + Some(Token::Identifier(_)) => { + self.prev_token(); + let type_name = self.parse_tablename()?; // TODO: this actually reads a possibly schema-qualified name of a (custom) type + Ok(SQLType::Custom(type_name)) } other => parser_err!(format!("Invalid data type: '{:?}'", other)), } diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index e731e56f7..343705839 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -517,20 +517,40 @@ fn parse_create_table_from_pg_dump() { ASTNode::SQLCreateTable { name, columns } => { assert_eq!("public.customer", name); - let c_name = &columns[0]; - assert_eq!("customer_id", c_name.name); - assert_eq!(SQLType::Int, c_name.data_type); - assert_eq!(false, c_name.allow_null); - - let c_lat = &columns[1]; - assert_eq!("store_id", c_lat.name); - assert_eq!(SQLType::SmallInt, c_lat.data_type); - assert_eq!(false, c_lat.allow_null); + let c_customer_id = &columns[0]; + assert_eq!("customer_id", c_customer_id.name); + assert_eq!(SQLType::Int, c_customer_id.data_type); + assert_eq!(false, c_customer_id.allow_null); + + let c_store_id = &columns[1]; + assert_eq!("store_id", c_store_id.name); + assert_eq!(SQLType::SmallInt, c_store_id.data_type); + assert_eq!(false, c_store_id.allow_null); + + let c_first_name = &columns[2]; + assert_eq!("first_name", c_first_name.name); + assert_eq!(SQLType::Varchar(Some(45)), c_first_name.data_type); + assert_eq!(false, c_first_name.allow_null); + + let c_create_date1 = &columns[8]; + assert_eq!( + Some(Box::new(ASTNode::SQLCast { + expr: Box::new(ASTNode::SQLCast { + expr: Box::new(ASTNode::SQLValue(Value::SingleQuotedString( + "now".to_string() + ))), + data_type: SQLType::Text + }), + data_type: SQLType::Date + })), + c_create_date1.default + ); - let c_lng = &columns[2]; - assert_eq!("first_name", c_lng.name); - assert_eq!(SQLType::Varchar(Some(45)), c_lng.data_type); - assert_eq!(false, c_lng.allow_null); + let c_release_year = &columns[10]; + assert_eq!( + SQLType::Custom("public.year".to_string()), + c_release_year.data_type + ); } _ => assert!(false), } @@ -637,6 +657,7 @@ fn parse_timestamps_example() { let sql = "2016-02-15 09:43:33"; let _ = parse_sql(sql); //TODO add assertion + //assert_eq!(sql, ast.to_string()); } #[test] @@ -644,6 +665,7 @@ fn parse_timestamps_with_millis_example() { let sql = "2017-11-02 19:15:42.308637"; let _ = parse_sql(sql); //TODO add assertion + //assert_eq!(sql, ast.to_string()); } #[test]