Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions src/ethereum_spec_tools/new_fork/codemod/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,45 @@ def leave_Assign( # noqa: D102

return updated_node.with_changes(value=self.value.deep_clone())

@override
def visit_AnnAssign_target(self, node: cst.AnnAssign) -> None: # noqa: D102
if self._in_assign_target:
raise Exception("already in assign target")
self._in_assign_target = True

@override
def leave_AnnAssign_target(self, node: cst.AnnAssign) -> None: # noqa: D102
if not self._in_assign_target:
raise Exception("not in assign target")
self._in_assign_target = False

@override
def visit_AnnAssign(self, node: cst.AnnAssign) -> None: # noqa: D102
if self._matches or self._in_assign_target:
raise Exception("nested assign")

@override
def leave_AnnAssign( # noqa: D102
self, original_node: cst.AnnAssign, updated_node: cst.AnnAssign
) -> cst.AnnAssign:
if self._in_assign_target:
raise Exception("still in assign target")

if not self._matches:
return updated_node

self._matches = False

for module, identifier in self.imports:
AddImportsVisitor.add_needed_import(
self.context, module, identifier
)
RemoveImportsVisitor.remove_unused_import(
self.context, module, identifier
)

return updated_node.with_changes(value=self.value.deep_clone())

@override
def visit_Name(self, node: cst.Name) -> None: # noqa: D102
if not self._in_assign_target:
Expand Down
2 changes: 1 addition & 1 deletion tests/json_infra/test_tools_new_fork.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_end_to_end(template_fork: str) -> None:
source = f.read()

assert '"""' not in source[:20]
assert "FORK_CRITERIA = ByTimestamp(7)" in source
assert "FORK_CRITERIA: ForkCriteria = ByTimestamp(7)" in source
assert template_fork.capitalize() not in source

with (fork_dir / "utils" / "hexadecimal.py").open("r") as f:
Expand Down
Loading