Skip to content

Miscellaneous fixes #34

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 6 commits into from
Jan 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 42 additions & 113 deletions src/sqlparser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,7 @@ impl Parser {
"INSERT" => Ok(self.parse_insert()?),
"ALTER" => Ok(self.parse_alter()?),
"COPY" => Ok(self.parse_copy()?),
"TRUE" => {
self.prev_token();
self.parse_sql_value()
}
"FALSE" => {
self.prev_token();
self.parse_sql_value()
}
"NULL" => {
"TRUE" | "FALSE" | "NULL" => {
self.prev_token();
self.parse_sql_value()
}
Expand All @@ -116,7 +108,7 @@ impl Parser {
self.parse_cast_expression()
} else {
match self.peek_token() {
Some(Token::LParen) => self.parse_function_or_pg_cast(&id),
Some(Token::LParen) => self.parse_function(&id),
Some(Token::Period) => {
let mut id_parts: Vec<String> = vec![id];
while self.peek_token() == Some(Token::Period) {
Expand All @@ -136,19 +128,10 @@ impl Parser {
}
}
}
Token::Number(_) => {
self.prev_token();
self.parse_sql_value()
}
Token::String(_) => {
self.prev_token();
self.parse_sql_value()
}
Token::SingleQuotedString(_) => {
self.prev_token();
self.parse_sql_value()
}
Token::DoubleQuotedString(_) => {
Token::Number(_)
| Token::String(_)
| Token::SingleQuotedString(_)
| Token::DoubleQuotedString(_) => {
self.prev_token();
self.parse_sql_value()
}
Expand All @@ -168,15 +151,6 @@ impl Parser {
}
}

pub fn parse_function_or_pg_cast(&mut self, id: &str) -> Result<ASTNode, ParserError> {
let func = self.parse_function(&id)?;
if let Some(Token::DoubleColon) = self.peek_token() {
self.parse_pg_cast(func)
} else {
Ok(func)
}
}

pub fn parse_function(&mut self, id: &str) -> Result<ASTNode, ParserError> {
self.consume_token(&Token::LParen)?;
if let Ok(true) = self.consume_token(&Token::RParen) {
Expand Down Expand Up @@ -241,25 +215,13 @@ impl Parser {
})
}

/// Parse a postgresql casting style which is in the form or expr::datatype
/// Parse a postgresql casting style which is in the form of `expr::datatype`
pub fn parse_pg_cast(&mut self, expr: ASTNode) -> Result<ASTNode, ParserError> {
let _ = self.consume_token(&Token::DoubleColon)?;
let datatype = if let Ok(data_type) = self.parse_data_type() {
Ok(data_type)
} else if let Ok(table_name) = self.parse_tablename() {
Ok(SQLType::Custom(table_name))
} else {
parser_err!("Expecting datatype or identifier")
};
let pg_cast = ASTNode::SQLCast {
Ok(ASTNode::SQLCast {
expr: Box::new(expr),
data_type: datatype?,
};
if let Some(Token::DoubleColon) = self.peek_token() {
self.parse_pg_cast(pg_cast)
} else {
Ok(pg_cast)
}
data_type: self.parse_data_type()?,
})
}

/// Parse an expression infix (typically an operator)
Expand Down Expand Up @@ -362,11 +324,17 @@ impl Parser {
}
}

/// Return first non-whitespace token that has not yet been processed
pub fn peek_token(&self) -> Option<Token> {
self.peek_token_skip_whitespace()
if let Some(n) = self.til_non_whitespace() {
self.token_at(n)
} else {
None
}
}

pub fn skip_whitespace(&mut self) -> Option<Token> {
/// Get the next token skipping whitespace and increment the token index
pub fn next_token(&mut self) -> Option<Token> {
loop {
match self.next_token_no_skip() {
Some(Token::Whitespace(_)) => {
Expand Down Expand Up @@ -406,19 +374,6 @@ impl Parser {
}
}

pub fn peek_token_skip_whitespace(&self) -> Option<Token> {
if let Some(n) = self.til_non_whitespace() {
self.token_at(n)
} else {
None
}
}

/// Get the next token skipping whitespace and increment the token index
pub fn next_token(&mut self) -> Option<Token> {
self.skip_whitespace()
}

pub fn next_token_no_skip(&mut self) -> Option<Token> {
if self.index < self.tokens.len() {
self.index = self.index + 1;
Expand All @@ -428,9 +383,9 @@ impl Parser {
}
}

/// if prev token is whitespace skip it
/// if prev token is not whitespace skipt it as well
pub fn prev_token_skip_whitespace(&mut self) -> Option<Token> {
/// Push back the last one non-whitespace token
pub fn prev_token(&mut self) -> Option<Token> {
// TODO: returned value is unused (available via peek_token)
loop {
match self.prev_token_no_skip() {
Some(Token::Whitespace(_)) => {
Expand All @@ -443,12 +398,8 @@ impl Parser {
}
}

pub fn prev_token(&mut self) -> Option<Token> {
self.prev_token_skip_whitespace()
}

/// Get the previous token and decrement the token index
pub fn prev_token_no_skip(&mut self) -> Option<Token> {
fn prev_token_no_skip(&mut self) -> Option<Token> {
if self.index > 0 {
self.index = self.index - 1;
Some(self.tokens[self.index].clone())
Expand Down Expand Up @@ -731,30 +682,13 @@ impl Parser {
"NULL" => Ok(Value::Null),
_ => return parser_err!(format!("No value parser for keyword {}", k)),
},
//TODO: parse the timestamp here
//TODO: parse the timestamp here (see parse_timestamp_value())
Token::Number(ref n) if n.contains(".") => match n.parse::<f64>() {
Ok(n) => Ok(Value::Double(n)),
Err(e) => {
let index = self.index;
self.prev_token();
if let Ok(timestamp) = self.parse_timestamp_value() {
println!("timstamp: {:?}", timestamp);
Ok(timestamp)
} else {
self.index = index;
parser_err!(format!("Could not parse '{}' as i64: {}", n, e))
}
}
Err(e) => parser_err!(format!("Could not parse '{}' as i64: {}", n, e)),
},
Token::Number(ref n) => match n.parse::<i64>() {
Ok(n) => {
// if let Some(Token::Minus) = self.peek_token() {
// self.prev_token();
// self.parse_timestamp_value()
// } else {
Ok(Value::Long(n))
// }
}
Ok(n) => Ok(Value::Long(n)),
Err(e) => parser_err!(format!("Could not parse '{}' as i64: {}", n, e)),
},
Token::Identifier(id) => Ok(Value::String(id.to_string())),
Expand Down Expand Up @@ -782,13 +716,13 @@ impl Parser {
}
}

/// Parse a literal integer/long
/// Parse a literal double
pub fn parse_literal_double(&mut self) -> Result<f64, ParserError> {
match self.next_token() {
Some(Token::Number(s)) => s.parse::<f64>().map_err(|e| {
ParserError::ParserError(format!("Could not parse '{}' as i64: {}", s, e))
ParserError::ParserError(format!("Could not parse '{}' as f64: {}", s, e))
}),
other => parser_err!(format!("Expected literal int, found {:?}", other)),
other => parser_err!(format!("Expected literal number, found {:?}", other)),
}
}

Expand Down Expand Up @@ -869,19 +803,17 @@ impl Parser {
self.consume_token(&Token::Colon)?;
let min = self.parse_literal_int()?;
self.consume_token(&Token::Colon)?;
// On one hand, the SQL specs defines <seconds fraction> ::= <unsigned integer>,
// so it would be more correct to parse it as such
let sec = self.parse_literal_double()?;
let _ = (sec.fract() * 1000.0).round();
if let Ok(true) = self.consume_token(&Token::Period) {
let ms = self.parse_literal_int()?;
Ok(NaiveTime::from_hms_milli(
hour as u32,
min as u32,
sec as u32,
ms as u32,
))
} else {
Ok(NaiveTime::from_hms(hour as u32, min as u32, sec as u32))
}
// On the other, chrono only supports nanoseconds, which should(?) fit in seconds-as-f64...
let nanos = (sec.fract() * 1_000_000_000.0).round();
Ok(NaiveTime::from_hms_nano(
hour as u32,
min as u32,
sec as u32,
nanos as u32,
))
}

/// Parse a SQL datatype (in the context of a CREATE TABLE statement for example)
Expand Down Expand Up @@ -973,13 +905,10 @@ impl Parser {
}
_ => parser_err!(format!("Invalid data type '{:?}'", k)),
},
Some(Token::Identifier(id)) => {
if let Ok(true) = self.consume_token(&Token::Period) {
let ids = self.parse_tablename()?;
Ok(SQLType::Custom(format!("{}.{}", id, ids)))
} else {
Ok(SQLType::Custom(id))
}
Some(Token::Identifier(_)) => {
self.prev_token();
let type_name = self.parse_tablename()?; // TODO: this actually reads a possibly schema-qualified name of a (custom) type
Ok(SQLType::Custom(type_name))
}
other => parser_err!(format!("Invalid data type: '{:?}'", other)),
}
Expand Down
48 changes: 35 additions & 13 deletions tests/sqlparser_postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -517,20 +517,40 @@ fn parse_create_table_from_pg_dump() {
ASTNode::SQLCreateTable { name, columns } => {
assert_eq!("public.customer", name);

let c_name = &columns[0];
assert_eq!("customer_id", c_name.name);
assert_eq!(SQLType::Int, c_name.data_type);
assert_eq!(false, c_name.allow_null);

let c_lat = &columns[1];
assert_eq!("store_id", c_lat.name);
assert_eq!(SQLType::SmallInt, c_lat.data_type);
assert_eq!(false, c_lat.allow_null);
let c_customer_id = &columns[0];
assert_eq!("customer_id", c_customer_id.name);
assert_eq!(SQLType::Int, c_customer_id.data_type);
assert_eq!(false, c_customer_id.allow_null);

let c_store_id = &columns[1];
assert_eq!("store_id", c_store_id.name);
assert_eq!(SQLType::SmallInt, c_store_id.data_type);
assert_eq!(false, c_store_id.allow_null);

let c_first_name = &columns[2];
assert_eq!("first_name", c_first_name.name);
assert_eq!(SQLType::Varchar(Some(45)), c_first_name.data_type);
assert_eq!(false, c_first_name.allow_null);

let c_create_date1 = &columns[8];
assert_eq!(
Some(Box::new(ASTNode::SQLCast {
expr: Box::new(ASTNode::SQLCast {
expr: Box::new(ASTNode::SQLValue(Value::SingleQuotedString(
"now".to_string()
))),
data_type: SQLType::Text
}),
data_type: SQLType::Date
})),
c_create_date1.default
);

let c_lng = &columns[2];
assert_eq!("first_name", c_lng.name);
assert_eq!(SQLType::Varchar(Some(45)), c_lng.data_type);
assert_eq!(false, c_lng.allow_null);
let c_release_year = &columns[10];
assert_eq!(
SQLType::Custom("public.year".to_string()),
c_release_year.data_type
);
}
_ => assert!(false),
}
Expand Down Expand Up @@ -637,13 +657,15 @@ fn parse_timestamps_example() {
let sql = "2016-02-15 09:43:33";
let _ = parse_sql(sql);
//TODO add assertion
//assert_eq!(sql, ast.to_string());
}

#[test]
fn parse_timestamps_with_millis_example() {
let sql = "2017-11-02 19:15:42.308637";
let _ = parse_sql(sql);
//TODO add assertion
//assert_eq!(sql, ast.to_string());
}

#[test]
Expand Down