diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 6aeb00685..95e670093 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -251,7 +251,7 @@ pub enum Expr { }, MapAccess { column: Box, - key: String, + keys: Vec, }, /// Scalar function call e.g. `LEFT(foo, 5)` Function(Function), @@ -280,7 +280,17 @@ impl fmt::Display for Expr { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Expr::Identifier(s) => write!(f, "{}", s), - Expr::MapAccess { column, key } => write!(f, "{}[\"{}\"]", column, key), + Expr::MapAccess { column, keys } => { + write!(f, "{}", column)?; + for k in keys { + match k { + k @ Value::Number(_, _) => write!(f, "[{}]", k)?, + Value::SingleQuotedString(s) => write!(f, "[\"{}\"]", s)?, + _ => write!(f, "[{}]", k)?, + } + } + Ok(()) + } Expr::Wildcard => f.write_str("*"), Expr::QualifiedWildcard(q) => write!(f, "{}.*", display_separated(q, ".")), Expr::CompoundIdentifier(s) => write!(f, "{}", display_separated(s, ".")), diff --git a/src/parser.rs b/src/parser.rs index 63bbc8ffd..84a82b121 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -951,13 +951,20 @@ impl<'a> Parser<'a> { } pub fn parse_map_access(&mut self, expr: Expr) -> Result { - let key = self.parse_literal_string()?; + let key = self.parse_map_key()?; let tok = self.consume_token(&Token::RBracket); debug!("Tok: {}", tok); + let mut key_parts: Vec = vec![key]; + while self.consume_token(&Token::LBracket) { + let key = self.parse_map_key()?; + let tok = self.consume_token(&Token::RBracket); + debug!("Tok: {}", tok); + key_parts.push(key); + } match expr { e @ Expr::Identifier(_) | e @ Expr::CompoundIdentifier(_) => Ok(Expr::MapAccess { column: Box::new(e), - key, + keys: key_parts, }), _ => Ok(expr), } @@ -1995,6 +2002,21 @@ impl<'a> Parser<'a> { } } + /// Parse a map key string + pub fn parse_map_key(&mut self) -> Result { + match self.next_token() { + Token::Word(Word { value, keyword, .. }) if keyword == Keyword::NoKeyword => { + Ok(Value::SingleQuotedString(value)) + } + Token::SingleQuotedString(s) => Ok(Value::SingleQuotedString(s)), + #[cfg(not(feature = "bigdecimal"))] + Token::Number(s, _) => Ok(Value::Number(s, false)), + #[cfg(feature = "bigdecimal")] + Token::Number(s, _) => Ok(Value::Number(s.parse().unwrap(), false)), + unexpected => self.expected("literal string or number", unexpected), + } + } + /// Parse a SQL datatype (in the context of a CREATE TABLE statement for example) pub fn parse_data_type(&mut self) -> Result { match self.next_token() { diff --git a/tests/sqlparser_hive.rs b/tests/sqlparser_hive.rs index 585be989b..d933f0f25 100644 --- a/tests/sqlparser_hive.rs +++ b/tests/sqlparser_hive.rs @@ -194,7 +194,7 @@ fn rename_table() { #[test] fn map_access() { - let rename = "SELECT a.b[\"asdf\"] FROM db.table WHERE a = 2"; + let rename = r#"SELECT a.b["asdf"] FROM db.table WHERE a = 2"#; hive().verified_stmt(rename); } diff --git a/tests/sqlparser_postgres.rs b/tests/sqlparser_postgres.rs index 2e66d313b..43baeb5a5 100644 --- a/tests/sqlparser_postgres.rs +++ b/tests/sqlparser_postgres.rs @@ -18,6 +18,9 @@ mod test_utils; use test_utils::*; +#[cfg(feature = "bigdecimal")] +use bigdecimal::BigDecimal; +use sqlparser::ast::Expr::{Identifier, MapAccess}; use sqlparser::ast::*; use sqlparser::dialect::{GenericDialect, PostgreSqlDialect}; use sqlparser::parser::ParserError; @@ -669,6 +672,57 @@ fn parse_pg_regex_match_ops() { } } +#[test] +fn parse_map_access_expr() { + #[cfg(not(feature = "bigdecimal"))] + let zero = "0".to_string(); + #[cfg(feature = "bigdecimal")] + let zero = BigDecimal::parse_bytes(b"0", 10).unwrap(); + let sql = "SELECT foo[0] FROM foos"; + let select = pg_and_generic().verified_only_select(sql); + assert_eq!( + &MapAccess { + column: Box::new(Identifier(Ident { + value: "foo".to_string(), + quote_style: None + })), + keys: vec![Value::Number(zero.clone(), false)] + }, + expr_from_projection(only(&select.projection)), + ); + let sql = "SELECT foo[0][0] FROM foos"; + let select = pg_and_generic().verified_only_select(sql); + assert_eq!( + &MapAccess { + column: Box::new(Identifier(Ident { + value: "foo".to_string(), + quote_style: None + })), + keys: vec![ + Value::Number(zero.clone(), false), + Value::Number(zero.clone(), false) + ] + }, + expr_from_projection(only(&select.projection)), + ); + let sql = r#"SELECT bar[0]["baz"]["fooz"] FROM foos"#; + let select = pg_and_generic().verified_only_select(sql); + assert_eq!( + &MapAccess { + column: Box::new(Identifier(Ident { + value: "bar".to_string(), + quote_style: None + })), + keys: vec![ + Value::Number(zero, false), + Value::SingleQuotedString("baz".to_string()), + Value::SingleQuotedString("fooz".to_string()) + ] + }, + expr_from_projection(only(&select.projection)), + ); +} + fn pg() -> TestedDialects { TestedDialects { dialects: vec![Box::new(PostgreSqlDialect {})],