diff --git a/compiler/noirc_frontend/src/ast/expression.rs b/compiler/noirc_frontend/src/ast/expression.rs index d36966e2efe..9c9c0ded867 100644 --- a/compiler/noirc_frontend/src/ast/expression.rs +++ b/compiler/noirc_frontend/src/ast/expression.rs @@ -31,6 +31,7 @@ pub enum ExpressionKind { Cast(Box), Infix(Box), If(Box), + Match(Box), Variable(Path), Tuple(Vec), Lambda(Box), @@ -465,6 +466,12 @@ pub struct IfExpression { pub alternative: Option, } +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct MatchExpression { + pub expression: Expression, + pub rules: Vec<(/*pattern*/ Expression, /*branch*/ Expression)>, +} + #[derive(Debug, PartialEq, Eq, Clone)] pub struct Lambda { pub parameters: Vec<(Pattern, UnresolvedType)>, @@ -612,6 +619,7 @@ impl Display for ExpressionKind { Cast(cast) => cast.fmt(f), Infix(infix) => infix.fmt(f), If(if_expr) => if_expr.fmt(f), + Match(match_expr) => match_expr.fmt(f), Variable(path) => path.fmt(f), Constructor(constructor) => constructor.fmt(f), MemberAccess(access) => access.fmt(f), @@ -790,6 +798,16 @@ impl Display for IfExpression { } } +impl Display for MatchExpression { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "match {} {{", self.expression)?; + for (pattern, branch) in &self.rules { + writeln!(f, " {pattern} -> {branch},")?; + } + write!(f, "}}") + } +} + impl Display for Lambda { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let parameters = vecmap(&self.parameters, |(name, r#type)| format!("{name}: {type}")); diff --git a/compiler/noirc_frontend/src/ast/visitor.rs b/compiler/noirc_frontend/src/ast/visitor.rs index 30b8deb4925..a43bd0a5d3d 100644 --- a/compiler/noirc_frontend/src/ast/visitor.rs +++ b/compiler/noirc_frontend/src/ast/visitor.rs @@ -22,7 +22,7 @@ use crate::{ use super::{ ForBounds, FunctionReturnType, GenericTypeArgs, IntegerBitSize, ItemVisibility, - NoirEnumeration, Pattern, Signedness, TraitBound, TraitImplItemKind, TypePath, + MatchExpression, NoirEnumeration, Pattern, Signedness, TraitBound, TraitImplItemKind, TypePath, UnresolvedGenerics, UnresolvedTraitConstraint, UnresolvedType, UnresolvedTypeData, UnresolvedTypeExpression, }; @@ -222,6 +222,10 @@ pub trait Visitor { true } + fn visit_match_expression(&mut self, _: &MatchExpression, _: Span) -> bool { + true + } + fn visit_tuple(&mut self, _: &[Expression], _: Span) -> bool { true } @@ -866,6 +870,9 @@ impl Expression { ExpressionKind::If(if_expression) => { if_expression.accept(self.span, visitor); } + ExpressionKind::Match(match_expression) => { + match_expression.accept(self.span, visitor); + } ExpressionKind::Tuple(expressions) => { if visitor.visit_tuple(expressions, self.span) { visit_expressions(expressions, visitor); @@ -1073,6 +1080,22 @@ impl IfExpression { } } +impl MatchExpression { + pub fn accept(&self, span: Span, visitor: &mut impl Visitor) { + if visitor.visit_match_expression(self, span) { + self.accept_children(visitor); + } + } + + pub fn accept_children(&self, visitor: &mut impl Visitor) { + self.expression.accept(visitor); + for (pattern, branch) in &self.rules { + pattern.accept(visitor); + branch.accept(visitor); + } + } +} + impl Lambda { pub fn accept(&self, span: Span, visitor: &mut impl Visitor) { if visitor.visit_lambda(self, span) { diff --git a/compiler/noirc_frontend/src/elaborator/expressions.rs b/compiler/noirc_frontend/src/elaborator/expressions.rs index ff5ff48cbf4..16278995104 100644 --- a/compiler/noirc_frontend/src/elaborator/expressions.rs +++ b/compiler/noirc_frontend/src/elaborator/expressions.rs @@ -7,9 +7,9 @@ use crate::{ ast::{ ArrayLiteral, BlockExpression, CallExpression, CastExpression, ConstructorExpression, Expression, ExpressionKind, Ident, IfExpression, IndexExpression, InfixExpression, - ItemVisibility, Lambda, Literal, MemberAccessExpression, MethodCallExpression, Path, - PathSegment, PrefixExpression, StatementKind, UnaryOp, UnresolvedTypeData, - UnresolvedTypeExpression, + ItemVisibility, Lambda, Literal, MatchExpression, MemberAccessExpression, + MethodCallExpression, Path, PathSegment, PrefixExpression, StatementKind, UnaryOp, + UnresolvedTypeData, UnresolvedTypeExpression, }, hir::{ comptime::{self, InterpreterError}, @@ -51,6 +51,7 @@ impl<'context> Elaborator<'context> { ExpressionKind::Cast(cast) => self.elaborate_cast(*cast, expr.span), ExpressionKind::Infix(infix) => return self.elaborate_infix(*infix, expr.span), ExpressionKind::If(if_) => self.elaborate_if(*if_), + ExpressionKind::Match(match_) => self.elaborate_match(*match_), ExpressionKind::Variable(variable) => return self.elaborate_variable(variable), ExpressionKind::Tuple(tuple) => self.elaborate_tuple(tuple), ExpressionKind::Lambda(lambda) => self.elaborate_lambda(*lambda, None), @@ -926,6 +927,10 @@ impl<'context> Elaborator<'context> { (HirExpression::If(if_expr), ret_type) } + fn elaborate_match(&mut self, _match_expr: MatchExpression) -> (HirExpression, Type) { + (HirExpression::Error, Type::Error) + } + fn elaborate_tuple(&mut self, tuple: Vec) -> (HirExpression, Type) { let mut element_ids = Vec::with_capacity(tuple.len()); let mut element_types = Vec::with_capacity(tuple.len()); diff --git a/compiler/noirc_frontend/src/hir/comptime/display.rs b/compiler/noirc_frontend/src/hir/comptime/display.rs index 6be5e19577d..1be4bbe61ab 100644 --- a/compiler/noirc_frontend/src/hir/comptime/display.rs +++ b/compiler/noirc_frontend/src/hir/comptime/display.rs @@ -8,9 +8,9 @@ use crate::{ ArrayLiteral, AsTraitPath, AssignStatement, BlockExpression, CallExpression, CastExpression, ConstrainStatement, ConstructorExpression, Expression, ExpressionKind, ForBounds, ForLoopStatement, ForRange, GenericTypeArgs, IfExpression, IndexExpression, - InfixExpression, LValue, Lambda, LetStatement, Literal, MemberAccessExpression, - MethodCallExpression, Pattern, PrefixExpression, Statement, StatementKind, UnresolvedType, - UnresolvedTypeData, + InfixExpression, LValue, Lambda, LetStatement, Literal, MatchExpression, + MemberAccessExpression, MethodCallExpression, Pattern, PrefixExpression, Statement, + StatementKind, UnresolvedType, UnresolvedTypeData, }, hir_def::traits::TraitConstraint, node_interner::{InternedStatementKind, NodeInterner}, @@ -241,6 +241,7 @@ impl<'interner> TokenPrettyPrinter<'interner> { | Token::GreaterEqual | Token::Equal | Token::NotEqual + | Token::FatArrow | Token::Arrow => write!(f, " {token} "), Token::Assign => { if last_was_op { @@ -602,6 +603,14 @@ fn remove_interned_in_expression_kind( .alternative .map(|alternative| remove_interned_in_expression(interner, alternative)), })), + ExpressionKind::Match(match_expr) => ExpressionKind::Match(Box::new(MatchExpression { + expression: remove_interned_in_expression(interner, match_expr.expression), + rules: vecmap(match_expr.rules, |(pattern, branch)| { + let pattern = remove_interned_in_expression(interner, pattern); + let branch = remove_interned_in_expression(interner, branch); + (pattern, branch) + }), + })), ExpressionKind::Variable(_) => expr, ExpressionKind::Tuple(expressions) => ExpressionKind::Tuple(vecmap(expressions, |expr| { remove_interned_in_expression(interner, expr) diff --git a/compiler/noirc_frontend/src/lexer/lexer.rs b/compiler/noirc_frontend/src/lexer/lexer.rs index 0b7bd0991d9..771af3daba0 100644 --- a/compiler/noirc_frontend/src/lexer/lexer.rs +++ b/compiler/noirc_frontend/src/lexer/lexer.rs @@ -215,8 +215,19 @@ impl<'a> Lexer<'a> { Ok(prev_token.into_single_span(start)) } } + Token::Assign => { + let start = self.position; + if self.peek_char_is('=') { + self.next_char(); + Ok(Token::Equal.into_span(start, start + 1)) + } else if self.peek_char_is('>') { + self.next_char(); + Ok(Token::FatArrow.into_span(start, start + 1)) + } else { + Ok(prev_token.into_single_span(start)) + } + } Token::Bang => self.single_double_peek_token('=', prev_token, Token::NotEqual), - Token::Assign => self.single_double_peek_token('=', prev_token, Token::Equal), Token::Minus => self.single_double_peek_token('>', prev_token, Token::Arrow), Token::Colon => self.single_double_peek_token(':', prev_token, Token::DoubleColon), Token::Slash => { diff --git a/compiler/noirc_frontend/src/lexer/token.rs b/compiler/noirc_frontend/src/lexer/token.rs index 7d11b97ca16..d0a6f05e05a 100644 --- a/compiler/noirc_frontend/src/lexer/token.rs +++ b/compiler/noirc_frontend/src/lexer/token.rs @@ -91,6 +91,8 @@ pub enum BorrowedToken<'input> { RightBracket, /// -> Arrow, + /// => + FatArrow, /// | Pipe, /// # @@ -212,6 +214,8 @@ pub enum Token { RightBracket, /// -> Arrow, + /// => + FatArrow, /// | Pipe, /// # @@ -296,6 +300,7 @@ pub fn token_to_borrowed_token(token: &Token) -> BorrowedToken<'_> { Token::LeftBracket => BorrowedToken::LeftBracket, Token::RightBracket => BorrowedToken::RightBracket, Token::Arrow => BorrowedToken::Arrow, + Token::FatArrow => BorrowedToken::FatArrow, Token::Pipe => BorrowedToken::Pipe, Token::Pound => BorrowedToken::Pound, Token::Comma => BorrowedToken::Comma, @@ -473,6 +478,7 @@ impl fmt::Display for Token { Token::LeftBracket => write!(f, "["), Token::RightBracket => write!(f, "]"), Token::Arrow => write!(f, "->"), + Token::FatArrow => write!(f, "=>"), Token::Pipe => write!(f, "|"), Token::Pound => write!(f, "#"), Token::Comma => write!(f, ","), diff --git a/compiler/noirc_frontend/src/parser/parser/expression.rs b/compiler/noirc_frontend/src/parser/parser/expression.rs index 90e9e53921e..eff309154e3 100644 --- a/compiler/noirc_frontend/src/parser/parser/expression.rs +++ b/compiler/noirc_frontend/src/parser/parser/expression.rs @@ -4,7 +4,7 @@ use noirc_errors::Span; use crate::{ ast::{ ArrayLiteral, BlockExpression, CallExpression, CastExpression, ConstructorExpression, - Expression, ExpressionKind, Ident, IfExpression, IndexExpression, Literal, + Expression, ExpressionKind, Ident, IfExpression, IndexExpression, Literal, MatchExpression, MemberAccessExpression, MethodCallExpression, Statement, TypePath, UnaryOp, UnresolvedType, }, parser::{labels::ParsingRuleLabel, parser::parse_many::separated_by_comma, ParserErrorReason}, @@ -91,8 +91,7 @@ impl<'a> Parser<'a> { } /// AtomOrUnaryRightExpression - /// = Atom - /// | UnaryRightExpression + /// = Atom UnaryRightExpression* fn parse_atom_or_unary_right(&mut self, allow_constructors: bool) -> Option { let start_span = self.current_token_span; let mut atom = self.parse_atom(allow_constructors)?; @@ -311,6 +310,10 @@ impl<'a> Parser<'a> { return Some(kind); } + if let Some(kind) = self.parse_match_expr() { + return Some(kind); + } + if let Some(kind) = self.parse_lambda() { return Some(kind); } @@ -518,6 +521,49 @@ impl<'a> Parser<'a> { Some(ExpressionKind::If(Box::new(IfExpression { condition, consequence, alternative }))) } + /// MatchExpression = 'match' ExpressionExceptConstructor '{' MatchRule* '}' + pub(super) fn parse_match_expr(&mut self) -> Option { + let start_span = self.current_token_span; + if !self.eat_keyword(Keyword::Match) { + return None; + } + + let expression = self.parse_expression_except_constructor_or_error(); + + self.eat_left_brace(); + + let rules = self.parse_many( + "match cases", + without_separator().until(Token::RightBrace), + Self::parse_match_rule, + ); + + self.push_error(ParserErrorReason::ExperimentalFeature("Match expressions"), start_span); + Some(ExpressionKind::Match(Box::new(MatchExpression { expression, rules }))) + } + + /// MatchRule = Expression '->' (Block ','?) | (Expression ',') + fn parse_match_rule(&mut self) -> Option<(Expression, Expression)> { + let pattern = self.parse_expression()?; + self.eat_or_error(Token::FatArrow); + + let start_span = self.current_token_span; + let branch = match self.parse_block() { + Some(block) => { + let span = self.span_since(start_span); + let block = Expression::new(ExpressionKind::Block(block), span); + self.eat_comma(); // comma is optional if we have a block + block + } + None => { + let branch = self.parse_expression_or_error(); + self.eat_or_error(Token::Comma); + branch + } + }; + Some((pattern, branch)) + } + /// ComptimeExpression = 'comptime' Block fn parse_comptime_expr(&mut self) -> Option { if !self.eat_keyword(Keyword::Comptime) { diff --git a/compiler/noirc_frontend/src/parser/parser/statement.rs b/compiler/noirc_frontend/src/parser/parser/statement.rs index 005216b1deb..37013e91528 100644 --- a/compiler/noirc_frontend/src/parser/parser/statement.rs +++ b/compiler/noirc_frontend/src/parser/parser/statement.rs @@ -162,10 +162,13 @@ impl<'a> Parser<'a> { } if let Some(kind) = self.parse_if_expr() { - return Some(StatementKind::Expression(Expression { - kind, - span: self.span_since(start_span), - })); + let span = self.span_since(start_span); + return Some(StatementKind::Expression(Expression { kind, span })); + } + + if let Some(kind) = self.parse_match_expr() { + let span = self.span_since(start_span); + return Some(StatementKind::Expression(Expression { kind, span })); } if let Some(block) = self.parse_block() { diff --git a/tooling/lsp/src/requests/inlay_hint.rs b/tooling/lsp/src/requests/inlay_hint.rs index cbf4ed26ef9..8e091d1eb04 100644 --- a/tooling/lsp/src/requests/inlay_hint.rs +++ b/tooling/lsp/src/requests/inlay_hint.rs @@ -590,6 +590,7 @@ fn get_expression_name(expression: &Expression) -> Option { | ExpressionKind::InternedStatement(..) | ExpressionKind::Literal(..) | ExpressionKind::Unsafe(..) + | ExpressionKind::Match(_) | ExpressionKind::Error => None, } } diff --git a/tooling/nargo_fmt/src/formatter/expression.rs b/tooling/nargo_fmt/src/formatter/expression.rs index ef04276a605..98eabe10e7e 100644 --- a/tooling/nargo_fmt/src/formatter/expression.rs +++ b/tooling/nargo_fmt/src/formatter/expression.rs @@ -2,8 +2,8 @@ use noirc_frontend::{ ast::{ ArrayLiteral, BinaryOpKind, BlockExpression, CallExpression, CastExpression, ConstructorExpression, Expression, ExpressionKind, IfExpression, IndexExpression, - InfixExpression, Lambda, Literal, MemberAccessExpression, MethodCallExpression, - PrefixExpression, TypePath, UnaryOp, UnresolvedTypeData, + InfixExpression, Lambda, Literal, MatchExpression, MemberAccessExpression, + MethodCallExpression, PrefixExpression, TypePath, UnaryOp, UnresolvedTypeData, }, token::{Keyword, Token}, }; @@ -57,6 +57,9 @@ impl<'a, 'b> ChunkFormatter<'a, 'b> { false, // force multiple lines )); } + ExpressionKind::Match(match_expression) => { + group.group(self.format_match_expression(*match_expression)); + } ExpressionKind::Variable(path) => { group.text(self.chunk(|formatter| { formatter.format_path(path); @@ -895,6 +898,68 @@ impl<'a, 'b> ChunkFormatter<'a, 'b> { group } + pub(super) fn format_match_expression( + &mut self, + match_expression: MatchExpression, + ) -> ChunkGroup { + let group_tag = self.new_group_tag(); + let mut group = self.format_match_expression_with_group_tag(match_expression, group_tag); + force_if_chunks_to_multiple_lines(&mut group, group_tag); + group + } + + pub(super) fn format_match_expression_with_group_tag( + &mut self, + match_expression: MatchExpression, + group_tag: GroupTag, + ) -> ChunkGroup { + let mut group = ChunkGroup::new(); + group.tag = Some(group_tag); + group.force_multiple_lines = true; + + group.text(self.chunk(|formatter| { + formatter.write_keyword(Keyword::Match); + formatter.write_space(); + })); + + self.format_expression(match_expression.expression, &mut group); + group.trailing_comment(self.skip_comments_and_whitespace_chunk()); + group.space(self); + + group.text(self.chunk(|formatter| { + formatter.write_left_brace(); + })); + + group.increase_indentation(); + for (pattern, branch) in match_expression.rules { + group.line(); + self.format_expression(pattern, &mut group); + group.text(self.chunk(|formatter| { + formatter.write_space(); + formatter.write_token(Token::FatArrow); + formatter.write_space(); + })); + self.format_expression(branch, &mut group); + + // Add a trailing comma regardless of whether the user specified one or not + group.text(self.chunk(|formatter| { + if formatter.token == Token::Comma { + formatter.write_current_token_and_bump(); + } else { + formatter.write(","); + } + })); + } + group.decrease_indentation(); + group.line(); + + group.text(self.chunk(|formatter| { + formatter.write_right_brace(); + })); + + group + } + fn format_index_expression(&mut self, index: IndexExpression) -> ChunkGroup { let mut group = ChunkGroup::new(); self.format_expression(index.collection, &mut group); @@ -2326,4 +2391,19 @@ global y = 1; "; assert_format_with_max_width(src, expected, " Foo { a: 1 },".len() - 1); } + + #[test] + fn format_match() { + let src = "fn main() { match x { A=>B,C => {D}E=>(), } }"; + // We should remove the block on D for single expressions in the future, + // unless D is an if or match. + let expected = "fn main() { + match x { + A => B, + C => { D }, + E => (), + } +}\n"; + assert_format(src, expected); + } }