@@ -46,6 +46,17 @@ def lam_sub(grammar: Grammar, node: RawNode) -> NL:
4646 return Node (type = node [0 ], children = node [3 ], context = node [2 ])
4747
4848
49+ # A placeholder node, used when parser is backtracking.
50+ FAKE_NODE = (- 1 , None , None , None )
51+
52+
53+ def stack_copy (
54+ stack : List [Tuple [DFAS , int , RawNode ]]
55+ ) -> List [Tuple [DFAS , int , RawNode ]]:
56+ """Nodeless stack copy."""
57+ return [(copy .deepcopy (dfa ), label , FAKE_NODE ) for dfa , label , _ in stack ]
58+
59+
4960class Recorder :
5061 def __init__ (self , parser : "Parser" , ilabels : List [int ], context : Context ) -> None :
5162 self .parser = parser
@@ -54,21 +65,45 @@ def __init__(self, parser: "Parser", ilabels: List[int], context: Context) -> No
5465
5566 self ._dead_ilabels : Set [int ] = set ()
5667 self ._start_point = self .parser .stack
57- self ._points = {ilabel : copy . deepcopy (self ._start_point ) for ilabel in ilabels }
68+ self ._points = {ilabel : stack_copy (self ._start_point ) for ilabel in ilabels }
5869
5970 @property
6071 def ilabels (self ) -> Set [int ]:
6172 return self ._dead_ilabels .symmetric_difference (self ._ilabels )
6273
6374 @contextmanager
6475 def switch_to (self , ilabel : int ) -> Iterator [None ]:
65- self .parser .stack = self ._points [ilabel ]
76+ with self .patch ():
77+ self .parser .stack = self ._points [ilabel ]
78+ try :
79+ yield
80+ except ParseError :
81+ self ._dead_ilabels .add (ilabel )
82+ finally :
83+ self .parser .stack = self ._start_point
84+
85+ @contextmanager
86+ def patch (self ) -> Iterator [None ]:
87+ """
88+ Patch basic state operations (push/pop/shift) with node-level
89+ immutable variants. These still will operate on the stack; but
90+ they won't create any new nodes, or modify the contents of any
91+ other existing nodes.
92+
93+ This saves us a ton of time when we are backtracking, since we
94+ want to restore to the initial state as quick as possible, which
95+ can only be done by having as little mutatations as possible.
96+ """
97+ original_functions = {}
98+ for name in self .parser .STATE_OPERATIONS :
99+ original_functions [name ] = getattr (self .parser , name )
100+ safe_variant = getattr (self .parser , name + "_safe" )
101+ setattr (self .parser , name , safe_variant )
66102 try :
67103 yield
68- except ParseError :
69- self ._dead_ilabels .add (ilabel )
70104 finally :
71- self .parser .stack = self ._start_point
105+ for name , func in original_functions .items ():
106+ setattr (self .parser , name , func )
72107
73108 def add_token (self , tok_type : int , tok_val : Text , raw : bool = False ) -> None :
74109 func : Callable [..., Any ]
@@ -317,6 +352,8 @@ def classify(self, type: int, value: Text, context: Context) -> List[int]:
317352 raise ParseError ("bad token" , type , value , context )
318353 return [ilabel ]
319354
355+ STATE_OPERATIONS = ["shift" , "push" , "pop" ]
356+
320357 def shift (self , type : int , value : Text , newstate : int , context : Context ) -> None :
321358 """Shift a token. (Internal)"""
322359 dfa , state , node = self .stack [- 1 ]
@@ -344,3 +381,22 @@ def pop(self) -> None:
344381 else :
345382 self .rootnode = newnode
346383 self .rootnode .used_names = self .used_names
384+
385+ def shift_safe (
386+ self , type : int , value : Text , newstate : int , context : Context
387+ ) -> None :
388+ """Immutable (node-level) version of shift()"""
389+ dfa , state , _ = self .stack [- 1 ]
390+ self .stack [- 1 ] = (dfa , newstate , FAKE_NODE )
391+
392+ def push_safe (
393+ self , type : int , newdfa : DFAS , newstate : int , context : Context
394+ ) -> None :
395+ """Immutable (node-level) version of push()"""
396+ dfa , state , _ = self .stack [- 1 ]
397+ self .stack [- 1 ] = (dfa , newstate , FAKE_NODE )
398+ self .stack .append ((newdfa , 0 , FAKE_NODE ))
399+
400+ def pop_safe (self ) -> None :
401+ """Immutable (node-level) version of pop()"""
402+ self .stack .pop ()
0 commit comments