Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 39 additions & 9 deletions dissect/cstruct/types/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,13 +828,13 @@ def _generate_structure__init__(fields: list[Field]) -> FunctionType:
Args:
fields: List of field names.
"""
field_names = [field._name for field in fields]
mapping = _generate_co_mapping(fields)

template: FunctionType = _make_structure__init__(len(field_names))
template: FunctionType = _make_structure__init__(len(fields))
return type(template)(
template.__code__.replace(
co_names=tuple(chain.from_iterable(zip((f"__{name}_default__" for name in field_names), field_names))),
co_varnames=("self", *field_names),
co_names=_remap_co_values(template.__code__.co_names, mapping),
co_varnames=_remap_co_values(template.__code__.co_varnames, mapping),
),
template.__globals__ | {f"__{field._name}_default__": field.type.__default__() for field in fields},
argdefs=template.__defaults__,
Expand All @@ -847,20 +847,50 @@ def _generate_union__init__(fields: list[Field]) -> FunctionType:
Args:
fields: List of field names.
"""
field_names = [field._name for field in fields]
mapping = _generate_co_mapping(fields)

template: FunctionType = _make_union__init__(len(field_names))
template: FunctionType = _make_union__init__(len(fields))
return type(template)(
template.__code__.replace(
co_consts=(None, *field_names),
co_names=("object", "__setattr__", *(f"__{name}_default__" for name in field_names)),
co_varnames=("self", *field_names),
co_consts=_remap_co_values(template.__code__.co_consts, mapping),
co_names=_remap_co_values(template.__code__.co_names, mapping),
co_varnames=_remap_co_values(template.__code__.co_varnames, mapping),
),
template.__globals__ | {f"__{field._name}_default__": field.type.__default__() for field in fields},
argdefs=template.__defaults__,
)


def _generate_co_mapping(fields: list[Field]) -> dict[str, str]:
"""Generates a mapping of generated code object names to field names.

The generated code uses names like ``_0``, ``_1``, etc. for fields, and ``_0_default``, ``_1_default``, etc.
for default initializer values. Return a mapping of these names to the actual field names.

Args:
fields: List of field names.
"""
return {
key: value
for i, field in enumerate(fields)
for key, value in [(f"_{i}", field._name), (f"_{i}_default", f"__{field._name}_default__")]
}


def _remap_co_values(value: tuple[Any, ...], mapping: dict[str, str]) -> tuple[Any, ...]:
"""Remaps code object values using a mapping.

This is used to replace generated code object names with actual field names.

Args:
value: The original code object values.
mapping: A mapping of generated code object names to field names.
"""
# Only attempt to remap if the value is a string, otherwise return it as is
# This is to avoid issues with trying to remap non-hashable types, and we only need to replace strings anyway
return tuple(mapping.get(v, v) if isinstance(v, str) else v for v in value)


def _generate__eq__(fields: list[str]) -> FunctionType:
"""Generates an ``__eq__`` method for a class with the specified fields.

Expand Down