Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,8 @@ MARKDOWN_FILES = $(CURDIR)/README.md \
lint: pyspec
@$(MDFORMAT_VENV) --number --wrap=80 $(MARKDOWN_FILES)
@$(CODESPELL_VENV) . --skip "./.git,$(VENV),$(PYSPEC_DIR)/.mypy_cache" -I .codespell-whitelist
@$(PYTHON_VENV) -m isort --quiet $(CURDIR)/tests $(CURDIR)/pysetup $(CURDIR)/setup.py
@$(PYTHON_VENV) -m black --quiet $(CURDIR)/tests $(CURDIR)/pysetup $(CURDIR)/setup.py
@$(PYTHON_VENV) -m pylint --rcfile $(PYLINT_CONFIG) $(PYLINT_SCOPE)
@$(PYTHON_VENV) -m ruff check --fix --quiet $(CURDIR)/tests $(CURDIR)/pysetup $(CURDIR)/setup.py
@$(PYTHON_VENV) -m ruff format --quiet $(CURDIR)/tests $(CURDIR)/pysetup $(CURDIR)/setup.py
@$(PYTHON_VENV) -m mypy --config-file $(MYPY_CONFIG) $(MYPY_SCOPE)

###############################################################################
Expand Down
37 changes: 24 additions & 13 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,13 @@ test = [
"pytest==8.4.0",
]
lint = [
"black==25.1.0",
"codespell==2.4.1",
"isort==6.0.1",
"mdformat-gfm-alerts==1.0.2",
"mdformat-gfm==0.4.1",
"mdformat-toc==0.3.0",
"mdformat==0.7.22",
"mypy==1.16.0",
"pylint==3.3.7",
"ruff==0.11.12",
]
generator = [
"filelock==3.18.0",
Expand All @@ -61,16 +59,29 @@ docs = [
"mkdocs==1.6.1",
]

[tool.black]
[tool.ruff]
line-length = 100

[tool.isort]
profile = "black"
line_length = 100
combine_as_imports = true
known_first_party = ["eth2spec"]
order_by_type = false
skip_glob = [
"tests/core/pyspec/eth2spec/*/mainnet.py",
"tests/core/pyspec/eth2spec/*/minimal.py",
[tool.ruff.lint]
select = [
"F", # https://docs.astral.sh/ruff/rules/#pyflakes-f
"I", # https://docs.astral.sh/ruff/rules/#isort-i
"PL", # https://docs.astral.sh/ruff/rules/#pylint-pl
"UP", # https://docs.astral.sh/ruff/rules/#pyupgrade-up
]
ignore = [
"PLR0911", # https://docs.astral.sh/ruff/rules/too-many-return-statements/
"PLR0912", # https://docs.astral.sh/ruff/rules/too-many-branches/
"PLR0913", # https://docs.astral.sh/ruff/rules/too-many-arguments/
"PLR0915", # https://docs.astral.sh/ruff/rules/too-many-statements/
"PLR1714", # https://docs.astral.sh/ruff/rules/repeated-equality-comparison/
"PLR2004", # https://docs.astral.sh/ruff/rules/magic-value-comparison/
"PLW0128", # https://docs.astral.sh/ruff/rules/redeclared-assigned-name/
"PLW0603", # https://docs.astral.sh/ruff/rules/global-statement/
"PLW2901", # https://docs.astral.sh/ruff/rules/redefined-loop-name/
]

[tool.ruff.lint.isort]
combine-as-imports = true
known-first-party = ["eth2spec"]
order-by-type = false
36 changes: 17 additions & 19 deletions pysetup/helpers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
import textwrap
from functools import reduce
from typing import Dict, List, TypeVar, Union
from typing import TypeVar

from .constants import CONSTANT_DEP_SUNDRY_CONSTANTS_FUNCTIONS
from .md_doc_paths import PREVIOUS_FORK_OF
Expand All @@ -23,8 +23,8 @@ def collect_prev_forks(fork: str) -> list[str]:


def requires_mypy_type_ignore(value: str) -> bool:
return value.startswith(("ByteVector")) or (
value.startswith(("Vector")) and any(k in value for k in ["ceillog2", "floorlog2"])
return value.startswith("ByteVector") or (
value.startswith("Vector") and any(k in value for k in ["ceillog2", "floorlog2"])
)


Expand All @@ -34,13 +34,13 @@ def make_function_abstract(protocol_def: ProtocolDefinition, key: str):


def objects_to_spec(
preset_name: str, spec_object: SpecObject, fork: str, ordered_class_objects: Dict[str, str]
preset_name: str, spec_object: SpecObject, fork: str, ordered_class_objects: dict[str, str]
) -> str:
"""
Given all the objects that constitute a spec, combine them into a single pyfile.
"""

def gen_new_type_definitions(custom_types: Dict[str, str]) -> str:
def gen_new_type_definitions(custom_types: dict[str, str]) -> str:
return "\n\n".join(
[
(
Expand Down Expand Up @@ -89,11 +89,9 @@ def format_protocol(protocol_name: str, protocol_def: ProtocolDefinition) -> str
# Access global dict of config vars for runtime configurables
# Ignore variable between quotes and doubles quotes
for name in spec_object.config_vars.keys():
functions_spec = re.sub(
r"(?<!['\"])\b%s\b(?!['\"])" % name, "config." + name, functions_spec
)
functions_spec = re.sub(rf"(?<!['\"])\b{name}\b(?!['\"])", "config." + name, functions_spec)
ordered_class_objects_spec = re.sub(
r"(?<!['\"])\b%s\b(?!['\"])" % name, "config." + name, ordered_class_objects_spec
rf"(?<!['\"])\b{name}\b(?!['\"])", "config." + name, ordered_class_objects_spec
)

def format_config_var(name: str, vardef) -> str:
Expand Down Expand Up @@ -202,17 +200,17 @@ def format_constant(name: str, vardef: VariableDefinition) -> str:
format_constant(k, v) for k, v in spec_object.preset_vars.items()
)
ssz_dep_constants = "\n".join(
map(lambda x: "%s = %s" % (x, hardcoded_ssz_dep_constants[x]), hardcoded_ssz_dep_constants)
map(lambda x: f"{x} = {hardcoded_ssz_dep_constants[x]}", hardcoded_ssz_dep_constants)
)
ssz_dep_constants_verification = "\n".join(
map(
lambda x: "assert %s == %s" % (x, spec_object.ssz_dep_constants[x]),
lambda x: f"assert {x} == {spec_object.ssz_dep_constants[x]}",
filtered_ssz_dep_constants,
)
)
func_dep_presets_verification = "\n".join(
map(
lambda x: "assert %s == %s # noqa: E501" % (x, spec_object.func_dep_presets[x]),
lambda x: f"assert {x} == {spec_object.func_dep_presets[x]} # noqa: E501",
filtered_hardcoded_func_dep_presets,
)
)
Expand Down Expand Up @@ -247,8 +245,8 @@ def format_constant(name: str, vardef: VariableDefinition) -> str:


def combine_protocols(
old_protocols: Dict[str, ProtocolDefinition], new_protocols: Dict[str, ProtocolDefinition]
) -> Dict[str, ProtocolDefinition]:
old_protocols: dict[str, ProtocolDefinition], new_protocols: dict[str, ProtocolDefinition]
) -> dict[str, ProtocolDefinition]:
for key, value in new_protocols.items():
if key not in old_protocols:
old_protocols[key] = value
Expand All @@ -261,7 +259,7 @@ def combine_protocols(
T = TypeVar("T")


def combine_dicts(old_dict: Dict[str, T], new_dict: Dict[str, T]) -> Dict[str, T]:
def combine_dicts(old_dict: dict[str, T], new_dict: dict[str, T]) -> dict[str, T]:
return {**old_dict, **new_dict}


Expand Down Expand Up @@ -305,7 +303,7 @@ def combine_dicts(old_dict: Dict[str, T], new_dict: Dict[str, T]) -> Dict[str, T
]


def dependency_order_class_objects(objects: Dict[str, str], custom_types: Dict[str, str]) -> None:
def dependency_order_class_objects(objects: dict[str, str], custom_types: dict[str, str]) -> None:
"""
Determines which SSZ Object is dependent on which other and orders them appropriately
"""
Expand All @@ -332,7 +330,7 @@ def dependency_order_class_objects(objects: Dict[str, str], custom_types: Dict[s
objects[item] = objects.pop(item)


def combine_ssz_objects(old_objects: Dict[str, str], new_objects: Dict[str, str]) -> Dict[str, str]:
def combine_ssz_objects(old_objects: dict[str, str], new_objects: dict[str, str]) -> dict[str, str]:
"""
Takes in old spec and new spec ssz objects, combines them,
and returns the newer versions of the objects in dependency order.
Expand Down Expand Up @@ -378,11 +376,11 @@ def combine_spec_objects(spec0: SpecObject, spec1: SpecObject) -> SpecObject:
)


def parse_config_vars(conf: Dict[str, str]) -> Dict[str, Union[str, List[Dict[str, str]]]]:
def parse_config_vars(conf: dict[str, str]) -> dict[str, str | list[dict[str, str]]]:
"""
Parses a dict of basic str/int/list types into a dict for insertion into the spec code.
"""
out: Dict[str, Union[str, List[Dict[str, str]]]] = dict()
out: dict[str, str | list[dict[str, str]]] = dict()
for k, v in conf.items():
if isinstance(v, list):
# A special case for list of records
Expand Down
2 changes: 1 addition & 1 deletion pysetup/md_doc_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def is_post_fork(a, b) -> bool:
prev_fork = PREVIOUS_FORK_OF[a]
if prev_fork == b:
return True
elif prev_fork == None:
elif prev_fork is None:
return False
else:
return is_post_fork(prev_fork, b)
Expand Down
59 changes: 29 additions & 30 deletions pysetup/md_to_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import json
import re
import string
from functools import lru_cache
from collections.abc import Iterator, Mapping
from functools import cache
from pathlib import Path
from typing import cast, Dict, Iterator, Mapping, Optional, Tuple
from typing import cast

from marko.block import BlankLine, Document, FencedCode, Heading, HTMLBlock
from marko.element import Element
Expand All @@ -31,7 +32,7 @@ def __init__(
self.preset_name = preset_name

self.document_iterator: Iterator[Element] = self._parse_document(file_name)
self.all_custom_types: Dict[str, str] = {}
self.all_custom_types: dict[str, str] = {}
self.current_heading_name: str | None = None

# Use a single dict to hold all SpecObject fields
Expand Down Expand Up @@ -59,7 +60,7 @@ def run(self) -> SpecObject:
self._finalize_types()
return self._build_spec_object()

def _get_next_element(self) -> Optional[Element]:
def _get_next_element(self) -> Element | None:
"""
Returns the next non-blank element in the document.
"""
Expand Down Expand Up @@ -240,7 +241,7 @@ def _process_table(self, table: Table) -> None:
self.spec["constant_vars"][name] = value_def

@staticmethod
def _get_table_row_fields(row: TableRow) -> tuple[str, str, Optional[str]]:
def _get_table_row_fields(row: TableRow) -> tuple[str, str, str | None]:
"""
Extracts the name, value, and description fields from a table row element.
"""
Expand Down Expand Up @@ -294,9 +295,9 @@ def _process_list_of_records_table(self, table: Table, list_of_records_name: str
# For mainnet, check that the spec config & file config are the same
# For minimal, we expect this to be different; just use the file config
if self.preset_name == "mainnet":
assert (
list_of_records_spec == list_of_records_config_file
), f"list of records mismatch: {list_of_records_spec} vs {list_of_records_config_file}"
assert list_of_records_spec == list_of_records_config_file, (
f"list of records mismatch: {list_of_records_spec} vs {list_of_records_config_file}"
)

# Set the config variable
self.spec["config_vars"][list_of_records_name] = list_of_records_config_file
Expand Down Expand Up @@ -435,21 +436,21 @@ def _build_spec_object(self) -> SpecObject:
)


@lru_cache(maxsize=None)
def _get_name_from_heading(heading: Heading) -> Optional[str]:
@cache
def _get_name_from_heading(heading: Heading) -> str | None:
last_child = heading.children[-1]
if isinstance(last_child, CodeSpan):
return last_child.children
return None


@lru_cache(maxsize=None)
@cache
def _get_source_from_code_block(block: FencedCode) -> str:
return block.children[0].children.strip()


@lru_cache(maxsize=None)
def _get_self_type_from_source(fn: ast.FunctionDef) -> Optional[str]:
@cache
def _get_self_type_from_source(fn: ast.FunctionDef) -> str | None:
args = fn.args.args
if len(args) == 0:
return None
Expand All @@ -460,8 +461,8 @@ def _get_self_type_from_source(fn: ast.FunctionDef) -> Optional[str]:
return args[0].annotation.id


@lru_cache(maxsize=None)
def _get_class_info_from_ast(cls: ast.ClassDef) -> Tuple[str, Optional[str]]:
@cache
def _get_class_info_from_ast(cls: ast.ClassDef) -> tuple[str, str | None]:
base = cls.bases[0]
if isinstance(base, ast.Name):
parent_class = base.id
Expand All @@ -475,7 +476,7 @@ def _get_class_info_from_ast(cls: ast.ClassDef) -> Tuple[str, Optional[str]]:
return cls.name, parent_class


@lru_cache(maxsize=None)
@cache
def _is_constant_id(name: str) -> bool:
"""
Checks if the given name follows the convention for constant identifiers.
Expand All @@ -485,16 +486,16 @@ def _is_constant_id(name: str) -> bool:
return all(map(lambda c: c in string.ascii_uppercase + "_" + string.digits, name[1:]))


@lru_cache(maxsize=None)
def _load_kzg_trusted_setups(preset_name: str) -> Tuple[list[str], list[str], list[str]]:
@cache
def _load_kzg_trusted_setups(preset_name: str) -> tuple[list[str], list[str], list[str]]:
trusted_setups_file_path = (
str(Path(__file__).parent.parent)
+ "/presets/"
+ preset_name
+ "/trusted_setups/trusted_setup_4096.json"
)

with open(trusted_setups_file_path, "r") as f:
with open(trusted_setups_file_path) as f:
json_data = json.load(f)
trusted_setup_G1_monomial = json_data["g1_monomial"]
trusted_setup_G1_lagrange = json_data["g1_lagrange"]
Expand All @@ -503,8 +504,8 @@ def _load_kzg_trusted_setups(preset_name: str) -> Tuple[list[str], list[str], li
return trusted_setup_G1_monomial, trusted_setup_G1_lagrange, trusted_setup_G2_monomial


@lru_cache(maxsize=None)
def _load_curdleproofs_crs(preset_name: str) -> Dict[str, list[str]]:
@cache
def _load_curdleproofs_crs(preset_name: str) -> dict[str, list[str]]:
"""
NOTE: File generated from https://github.com/asn-d6/curdleproofs/blob/8e8bf6d4191fb6a844002f75666fb7009716319b/tests/crs.rs#L53-L67
"""
Expand All @@ -515,7 +516,7 @@ def _load_curdleproofs_crs(preset_name: str) -> Dict[str, list[str]]:
+ "/trusted_setups/curdleproofs_crs.json"
)

with open(file_path, "r") as f:
with open(file_path) as f:
json_data = json.load(f)

return json_data
Expand All @@ -532,10 +533,8 @@ def _load_curdleproofs_crs(preset_name: str) -> Dict[str, list[str]]:
}


@lru_cache(maxsize=None)
def _parse_value(
name: str, typed_value: str, type_hint: Optional[str] = None
) -> VariableDefinition:
@cache
def _parse_value(name: str, typed_value: str, type_hint: str | None = None) -> VariableDefinition:
comment = None
if name in ("ROOT_OF_UNITY_EXTENDED", "ROOTS_OF_UNITY_EXTENDED", "ROOTS_OF_UNITY_REDUCED"):
comment = "noqa: E501"
Expand Down Expand Up @@ -585,7 +584,7 @@ def _update_constant_vars_with_curdleproofs_crs(
)


@lru_cache(maxsize=None)
@cache
def parse_markdown(content: str) -> Document:
return gfm.parse(content)

Expand Down Expand Up @@ -613,9 +612,9 @@ def check_yaml_matches_spec(
else:
raise ValueError(f"Variable {var} should be a string in the yaml file.")
try:
assert yaml[var_name] == repr(
eval(updated_value)
), f"mismatch for {var_name}: {yaml[var_name]} vs {eval(updated_value)}"
assert yaml[var_name] == repr(eval(updated_value)), (
f"mismatch for {var_name}: {yaml[var_name]} vs {eval(updated_value)}"
)
except NameError:
# Okay it's probably something more serious, let's ignore
pass
Expand Down
6 changes: 2 additions & 4 deletions pysetup/spec_builders/altair.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Dict

from ..constants import ALTAIR, OPTIMIZED_BLS_AGGREGATE_PUBKEYS
from .base import BaseSpecBuilder

Expand Down Expand Up @@ -39,15 +37,15 @@ def compute_merkle_proof(object: SSZObject,
return build_proof(object.get_backing(), index)"""

@classmethod
def hardcoded_ssz_dep_constants(cls) -> Dict[str, str]:
def hardcoded_ssz_dep_constants(cls) -> dict[str, str]:
return {
"FINALIZED_ROOT_GINDEX": "GeneralizedIndex(105)",
"CURRENT_SYNC_COMMITTEE_GINDEX": "GeneralizedIndex(54)",
"NEXT_SYNC_COMMITTEE_GINDEX": "GeneralizedIndex(55)",
}

@classmethod
def implement_optimizations(cls, functions: Dict[str, str]) -> Dict[str, str]:
def implement_optimizations(cls, functions: dict[str, str]) -> dict[str, str]:
if "eth_aggregate_pubkeys" in functions:
functions["eth_aggregate_pubkeys"] = OPTIMIZED_BLS_AGGREGATE_PUBKEYS.strip()
return functions
Loading
Loading