@@ -782,8 +782,7 @@ def _transform_module(self, module: KFlatModule) -> KFlatModule: ...
782782class RulePass (SingleModulePass , ABC ):
783783 @final
784784 def _transform_module (self , module : KFlatModule ) -> KFlatModule :
785- sentences = tuple (self ._transform_rule (sent ) if isinstance (sent , KRule ) else sent for sent in module .sentences )
786- return module .let (sentences = sentences )
785+ return module .map_sentences (self ._transform_rule , of_type = KRule )
787786
788787 @abstractmethod
789788 def _transform_rule (self , rule : KRule ) -> KRule : ...
@@ -880,8 +879,7 @@ def _transform_module(self, module: KFlatModule) -> KFlatModule:
880879 Atts .UNIT in att for _ , att in concat_atts .items ()
881880 ) # TODO Could be saved with a different attribute structure: concat(Element, Unit)
882881
883- sentences = tuple (self ._update (sent , concat_atts ) if isinstance (sent , KSyntaxSort ) else sent for sent in module )
884- return module .let (sentences = sentences )
882+ return module .map_sentences (lambda syntax_sort : self ._update (syntax_sort , concat_atts ), of_type = KSyntaxSort )
885883
886884 @staticmethod
887885 def _update (syntax_sort : KSyntaxSort , concat_atts : Mapping [KSort , KAtt ]) -> KSyntaxSort :
@@ -914,14 +912,13 @@ class AddDomainValueAtts(SingleModulePass):
914912
915913 def _transform_module (self , module : KFlatModule ) -> KFlatModule :
916914 token_sorts = self ._token_sorts (module )
917- sentences = tuple (self ._update (sent , token_sorts ) if isinstance (sent , KSyntaxSort ) else sent for sent in module )
918- return module .let (sentences = sentences )
919915
920- @staticmethod
921- def _update (syntax_sort : KSyntaxSort , token_sorts : set [KSort ]) -> KSyntaxSort :
922- if syntax_sort .sort not in token_sorts :
923- return syntax_sort
924- return syntax_sort .let (att = syntax_sort .att .update ([Atts .HAS_DOMAIN_VALUES (None )]))
916+ def update (syntax_sort : KSyntaxSort ) -> KSyntaxSort :
917+ if syntax_sort .sort not in token_sorts :
918+ return syntax_sort
919+ return syntax_sort .let (att = syntax_sort .att .update ([Atts .HAS_DOMAIN_VALUES (None )]))
920+
921+ return module .map_sentences (update , of_type = KSyntaxSort )
925922
926923 @staticmethod
927924 def _token_sorts (module : KFlatModule ) -> set [KSort ]:
@@ -959,13 +956,23 @@ def execute(self, definition: KDefinition) -> KDefinition:
959956 if len (definition .modules ) > 1 :
960957 raise ValueError ('Expected a single module' )
961958 module = definition .modules [0 ]
962-
963959 rules = self ._rules_by_klabel (module )
964960
965- sentences = tuple (
966- self ._update (sent , definition , rules ) if isinstance (sent , KProduction ) else sent for sent in module
967- )
968- module = module .let (sentences = sentences )
961+ def update (production : KProduction ) -> KProduction :
962+ if not production .klabel :
963+ return production
964+
965+ klabel = production .klabel
966+
967+ if any (Atts .ANYWHERE in rule .att for rule in rules .get (klabel , [])):
968+ return production .let (att = production .att .update ([Atts .ANYWHERE (None )]))
969+
970+ if klabel .name in definition .overloads :
971+ return production .let (att = production .att .update ([Atts .ANYWHERE (None )]))
972+
973+ return production
974+
975+ module = module .map_sentences (update , of_type = KProduction )
969976 return KDefinition (module .name , (module ,))
970977
971978 @staticmethod
@@ -984,21 +991,6 @@ def _rules_by_klabel(module: KFlatModule) -> dict[KLabel, list[KRule]]:
984991 res .setdefault (label , []).append (rule )
985992 return res
986993
987- @staticmethod
988- def _update (production : KProduction , definition : KDefinition , rules : Mapping [KLabel , list [KRule ]]) -> KProduction :
989- if not production .klabel :
990- return production
991-
992- klabel = production .klabel
993-
994- if any (Atts .ANYWHERE in rule .att for rule in rules .get (klabel , [])):
995- return production .let (att = production .att .update ([Atts .ANYWHERE (None )]))
996-
997- if klabel .name in definition .overloads :
998- return production .let (att = production .att .update ([Atts .ANYWHERE (None )]))
999-
1000- return production
1001-
1002994
1003995@dataclass
1004996class AddSymbolAtts (SingleModulePass ):
@@ -1008,9 +1000,7 @@ class AddSymbolAtts(SingleModulePass):
10081000 pred : Callable [[KAtt ], bool ]
10091001
10101002 def _transform_module (self , module : KFlatModule ) -> KFlatModule :
1011- return module .let (
1012- sentences = tuple (self ._update (sent ) if isinstance (sent , KProduction ) else sent for sent in module )
1013- )
1003+ return module .map_sentences (self ._update , of_type = KProduction )
10141004
10151005 def _update (self , production : KProduction ) -> KProduction :
10161006 if not production .klabel : # filter for symbol productions
@@ -1090,8 +1080,7 @@ def __init__(self, hook_namespaces: Iterable[str] = ()):
10901080 self .active_prefixes = tuple (f'{ namespace } .' for namespace in namespaces )
10911081
10921082 def _transform_module (self , module : KFlatModule ) -> KFlatModule :
1093- sentences = tuple (self ._update (sent ) if isinstance (sent , KProduction ) else sent for sent in module )
1094- return module .let (sentences = sentences )
1083+ return module .map_sentences (self ._update , of_type = KProduction )
10951084
10961085 def _update (self , production : KProduction ) -> KProduction :
10971086 if not production .klabel :
@@ -1120,8 +1109,7 @@ def __init__(self, keys: Iterable[AttKey]):
11201109 self .keys = frozenset (keys )
11211110
11221111 def _transform_module (self , module : KFlatModule ) -> KFlatModule :
1123- sentences = tuple (self ._update (sent ) if isinstance (sent , KProduction ) else sent for sent in module )
1124- return module .let (sentences = sentences )
1112+ return module .map_sentences (self ._update , of_type = KProduction )
11251113
11261114 def _update (self , production : KProduction ) -> KProduction :
11271115 if not production .klabel :
@@ -1135,8 +1123,7 @@ class AddDefaultFormatAtts(SingleModulePass):
11351123 """Add a default format attribute value to each symbol profuction missing one."""
11361124
11371125 def _transform_module (self , module : KFlatModule ) -> KFlatModule :
1138- sentences = tuple (self ._update (sent ) if isinstance (sent , KProduction ) else sent for sent in module )
1139- return module .let (sentences = sentences )
1126+ return module .map_sentences (self ._update , of_type = KProduction )
11401127
11411128 @staticmethod
11421129 def _update (production : KProduction ) -> KProduction :
@@ -1154,8 +1141,7 @@ class DiscardFormatAtts(SingleModulePass):
11541141 """Remove format attributes from symbol productions with items other than terminals and non-terminals."""
11551142
11561143 def _transform_module (self , module : KFlatModule ) -> KFlatModule :
1157- sentences = tuple (self ._update (sent ) if isinstance (sent , KProduction ) else sent for sent in module )
1158- return module .let (sentences = sentences )
1144+ return module .map_sentences (self ._update , of_type = KProduction )
11591145
11601146 @staticmethod
11611147 def _update (production : KProduction ) -> KProduction :
@@ -1173,8 +1159,7 @@ class InlineFormatTerminals(SingleModulePass):
11731159 """For a terminal `"foo"` change `%i` to `%cfoo%r`. For a non-terminal, decrease the index."""
11741160
11751161 def _transform_module (self , module : KFlatModule ) -> KFlatModule :
1176- sentences = tuple (self ._update (sent ) if isinstance (sent , KProduction ) else sent for sent in module )
1177- return module .let (sentences = sentences )
1162+ return module .map_sentences (self ._update , of_type = KProduction )
11781163
11791164 @staticmethod
11801165 def _update (production : KProduction ) -> KProduction :
@@ -1229,8 +1214,7 @@ def _inline_terminals(formatt: Format, production: KProduction) -> Format:
12291214@dataclass
12301215class AddColorAtts (SingleModulePass ):
12311216 def _transform_module (self , module : KFlatModule ) -> KFlatModule :
1232- sentences = tuple (self ._update (sent ) if isinstance (sent , KProduction ) else sent for sent in module )
1233- return module .let (sentences = sentences )
1217+ return module .map_sentences (self ._update , of_type = KProduction )
12341218
12351219 @staticmethod
12361220 def _update (production : KProduction ) -> KProduction :
@@ -1254,8 +1238,7 @@ def _update(production: KProduction) -> KProduction:
12541238@dataclass
12551239class AddTerminalAtts (SingleModulePass ):
12561240 def _transform_module (self , module : KFlatModule ) -> KFlatModule :
1257- sentences = tuple (self ._update (sent ) if isinstance (sent , KProduction ) else sent for sent in module )
1258- return module .let (sentences = sentences )
1241+ return module .map_sentences (self ._update , of_type = KProduction )
12591242
12601243 @staticmethod
12611244 def _update (production : KProduction ) -> KProduction :
@@ -1276,34 +1259,41 @@ def execute(self, definition: KDefinition) -> KDefinition:
12761259 raise ValueError ('Expected a single module' )
12771260 module = definition .modules [0 ]
12781261
1279- sentences = tuple ( self . _update ( sent , definition ) if isinstance ( sent , KProduction ) else sent for sent in module )
1280- module = module . let ( sentences = sentences )
1281- return KDefinition ( module . name , ( module ,))
1262+ def update ( production : KProduction ) -> KProduction :
1263+ if not production . klabel :
1264+ return production
12821265
1283- @staticmethod
1284- def _update (production : KProduction , definition : KDefinition ) -> KSentence :
1285- if not production .klabel :
1286- return production
1266+ if Atts .FORMAT not in production .att :
1267+ return production
12871268
1288- if Atts .FORMAT not in production .att :
1289- return production
1269+ tags = sorted (definition .priorities .get (production .klabel .name , []))
1270+ priorities = tuple (
1271+ KApply (tag ).to_dict () for tag in tags if tag not in BUILTIN_LABELS
1272+ ) # TODO Add KType to pyk.kast.att
1273+ return production .let (att = production .att .update ([Atts .PRIORITIES (priorities )]))
12901274
1291- tags = sorted (definition .priorities .get (production .klabel .name , []))
1292- priorities = tuple (
1293- KApply (tag ).to_dict () for tag in tags if tag not in BUILTIN_LABELS
1294- ) # TODO Add KType to pyk.kast.att
1295- return production .let (att = production .att .update ([Atts .PRIORITIES (priorities )]))
1275+ module = module .map_sentences (update , of_type = KProduction )
1276+ return KDefinition (module .name , (module ,))
12961277
12971278
12981279@dataclass
12991280class AddAssocAtts (SingleModulePass ):
13001281 def _transform_module (self , module : KFlatModule ) -> KFlatModule :
13011282 left_assocs = self ._assocs (module , KAssoc .LEFT )
13021283 right_assocs = self ._assocs (module , KAssoc .RIGHT )
1303- sentences = tuple (
1304- self ._update (sent , left_assocs , right_assocs ) if isinstance (sent , KProduction ) else sent for sent in module
1305- )
1306- return module .let (sentences = sentences )
1284+
1285+ def update (production : KProduction ) -> KProduction :
1286+ if not production .klabel :
1287+ return production
1288+
1289+ if Atts .FORMAT not in production .att :
1290+ return production
1291+
1292+ left = tuple (KApply (tag ).to_dict () for tag in sorted (left_assocs .get (production .klabel .name , [])))
1293+ right = tuple (KApply (tag ).to_dict () for tag in sorted (right_assocs .get (production .klabel .name , [])))
1294+ return production .let (att = production .att .update ([Atts .LEFT (left ), Atts .RIGHT (right )]))
1295+
1296+ return module .map_sentences (update , of_type = KProduction )
13071297
13081298 @staticmethod
13091299 def _assocs (module : KFlatModule , assoc : KAssoc ) -> dict [str , set [str ]]:
@@ -1319,19 +1309,3 @@ def insert(dct: dict[str, set[str]], *, key: str, value: str) -> dict[str, set[s
13191309 return dct
13201310
13211311 return reduce (lambda res , pair : insert (res , key = pair [0 ], value = pair [1 ]), pairs , {})
1322-
1323- @staticmethod
1324- def _update (
1325- production : KProduction ,
1326- left_assocs : Mapping [str , set [str ]],
1327- right_assocs : Mapping [str , set [str ]],
1328- ) -> KProduction :
1329- if not production .klabel :
1330- return production
1331-
1332- if Atts .FORMAT not in production .att :
1333- return production
1334-
1335- left = tuple (KApply (tag ).to_dict () for tag in sorted (left_assocs .get (production .klabel .name , [])))
1336- right = tuple (KApply (tag ).to_dict () for tag in sorted (right_assocs .get (production .klabel .name , [])))
1337- return production .let (att = production .att .update ([Atts .LEFT (left ), Atts .RIGHT (right )]))
0 commit comments