diff --git a/lark/tree.py b/lark/tree.py index 9dccadd7..daf834a9 100644 --- a/lark/tree.py +++ b/lark/tree.py @@ -70,6 +70,8 @@ def meta(self) -> Meta: def __repr__(self): return 'Tree(%r, %r)' % (self.data, self.children) + __match_args__ = ("data", "children") + def _pretty_label(self): return self.data diff --git a/tests/test_pattern_matching.py b/tests/test_pattern_matching.py index b86bd543..ef66ae37 100644 --- a/tests/test_pattern_matching.py +++ b/tests/test_pattern_matching.py @@ -1,6 +1,6 @@ from unittest import TestCase, main -from lark import Token +from lark import Token, Tree class TestPatternMatching(TestCase): @@ -46,6 +46,51 @@ def test_matches_with_bad_token_type(self): case _: pass + def test_match_on_tree(self): + tree1 = Tree('a', [Tree(x, y) for x, y in zip('bcd', 'xyz')]) + tree2 = Tree('a', [ + Tree('b', [Token('T', 'x')]), + Tree('c', [Token('T', 'y')]), + Tree('d', [Tree('z', [Token('T', 'zz'), Tree('zzz', 'zzz')])]), + ]) + + match tree1: + case Tree('X', []): + assert False + case Tree('a', []): + assert False + case Tree(_, 'b'): + assert False + case Tree('X', _): + assert False + tree = Tree('q', [Token('T', 'x')]) + match tree: + case Tree('q', [Token('T', 'x')]): + pass + case _: + assert False + tr = Tree('a', [Tree('b', [Token('T', 'a')])]) + match tr: + case Tree('a', [Tree('b', [Token('T', 'a')])]): + pass + case _: + assert False + # test nested trees + match tree2: + case Tree('a', [ + Tree('b', [Token('T', 'x')]), + Tree('c', [Token('T', 'y')]), + Tree('d', [ + Tree('z', [ + Token('T', 'zz'), + Tree('zzz', 'zzz') + ]) + ]) + ]): + pass + case _: + assert False + if __name__ == '__main__':