diff --git a/src/ethereum_spec_tools/new_fork/codemod/constant.py b/src/ethereum_spec_tools/new_fork/codemod/constant.py index 1a502d2a98..ce3204b90d 100644 --- a/src/ethereum_spec_tools/new_fork/codemod/constant.py +++ b/src/ethereum_spec_tools/new_fork/codemod/constant.py @@ -124,6 +124,45 @@ def leave_Assign( # noqa: D102 return updated_node.with_changes(value=self.value.deep_clone()) + @override + def visit_AnnAssign_target(self, node: cst.AnnAssign) -> None: # noqa: D102 + if self._in_assign_target: + raise Exception("already in assign target") + self._in_assign_target = True + + @override + def leave_AnnAssign_target(self, node: cst.AnnAssign) -> None: # noqa: D102 + if not self._in_assign_target: + raise Exception("not in assign target") + self._in_assign_target = False + + @override + def visit_AnnAssign(self, node: cst.AnnAssign) -> None: # noqa: D102 + if self._matches or self._in_assign_target: + raise Exception("nested assign") + + @override + def leave_AnnAssign( # noqa: D102 + self, original_node: cst.AnnAssign, updated_node: cst.AnnAssign + ) -> cst.AnnAssign: + if self._in_assign_target: + raise Exception("still in assign target") + + if not self._matches: + return updated_node + + self._matches = False + + for module, identifier in self.imports: + AddImportsVisitor.add_needed_import( + self.context, module, identifier + ) + RemoveImportsVisitor.remove_unused_import( + self.context, module, identifier + ) + + return updated_node.with_changes(value=self.value.deep_clone()) + @override def visit_Name(self, node: cst.Name) -> None: # noqa: D102 if not self._in_assign_target: diff --git a/tests/json_infra/test_tools_new_fork.py b/tests/json_infra/test_tools_new_fork.py index c500c87c5e..12d0c2fbd2 100644 --- a/tests/json_infra/test_tools_new_fork.py +++ b/tests/json_infra/test_tools_new_fork.py @@ -62,7 +62,7 @@ def test_end_to_end(template_fork: str) -> None: source = f.read() assert '"""' not in source[:20] - assert "FORK_CRITERIA = ByTimestamp(7)" in source + assert "FORK_CRITERIA: ForkCriteria = ByTimestamp(7)" in source assert template_fork.capitalize() not in source with (fork_dir / "utils" / "hexadecimal.py").open("r") as f: