Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions python/paddle/jit/dy2static/convert_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from paddle.base.dygraph.base import _convert_into_variable, in_to_static_mode
from paddle.base.framework import Variable, core, default_main_program
from paddle.framework import use_pir_api
from paddle.jit.utils import OrderedSet
from paddle.pir import Value
from paddle.static.amp.fp16_utils import AmpOptions
from paddle.utils import is_sequence, map_structure
Expand Down Expand Up @@ -207,9 +208,9 @@ def _run_paddle_while(
helper = GetterSetterHelper(getter, setter, return_name_ids, push_pop_names)
_convert_tensor_arrray_if_necessary(helper, push_pop_names)

union_name = (set(return_name_ids) if return_name_ids else set()) | (
set(push_pop_names) if push_pop_names else set()
)
union_name = (
OrderedSet(return_name_ids) if return_name_ids else OrderedSet()
) | (OrderedSet(push_pop_names) if push_pop_names else OrderedSet())
union_name = list(union_name)

def new_body_fn(*args):
Expand Down Expand Up @@ -444,9 +445,9 @@ def _run_paddle_cond(
if return_name_ids is None and push_pop_names is None:
union_name = None
else:
union_name = (set(return_name_ids) if return_name_ids else set()) | (
set(push_pop_names) if push_pop_names else set()
)
union_name = (
OrderedSet(return_name_ids) if return_name_ids else OrderedSet()
) | (OrderedSet(push_pop_names) if push_pop_names else OrderedSet())
union_name = list(union_name)

def new_true_fn():
Expand Down
5 changes: 3 additions & 2 deletions python/paddle/jit/dy2static/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from paddle.base.layer_helper import LayerHelper
from paddle.base.wrapped_decorator import signature_safe_contextmanager
from paddle.framework import CUDAPinnedPlace
from paddle.jit.utils import OrderedSet
from paddle.utils import flatten

from .ast_utils import ast_to_source_code
Expand Down Expand Up @@ -479,9 +480,9 @@ class GetterSetterHelper:

def __init__(self, getter_func, setter_func, *name_lists):
name_lists = ([] if x is None else x for x in name_lists)
name_sets = (set(x) for x in name_lists)
name_sets = (OrderedSet(x) for x in name_lists)
self._union = list(
functools.reduce(lambda x, y: x | y, name_sets, set())
functools.reduce(lambda x, y: x | y, name_sets, OrderedSet())
)
self._union.sort()
self.getter = getter_func
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from functools import cached_property
from typing import Any, Callable

from paddle.jit.utils import OrderedSet
from paddle.utils import flatten

from ...infer_meta import (
Expand All @@ -38,7 +39,6 @@
from ...utils import (
ENV_SHOW_TRACKERS,
NameGenerator,
OrderedSet,
inner_error_default_handler,
is_inplace_api,
is_paddle_api,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,15 @@

import opcode

from paddle.jit.utils import OrderedSet

from ...profiler import EventGuard, event_register
from ...psdb import NO_BREAKGRAPH_CODES
from ...utils import (
ENV_MIN_GRAPH_SIZE,
BreakGraphError,
FallbackError,
InnerError,
OrderedSet,
SotUndefinedVar,
get_static_function,
log,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@
import opcode

import paddle
from paddle.jit.utils import OrderedSet

from ...utils import (
FallbackError,
InnerError,
OrderedSet,
ResumeFnNameFactory,
is_clean_code,
list_contain_by_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
import dataclasses
from enum import Enum

from ...utils import InnerError, OrderedSet
from paddle.jit.utils import OrderedSet

from ...utils import InnerError
from .instruction_utils import Instruction
from .opcode_info import ALL_JUMP, HAS_FREE, HAS_LOCAL, UNCONDITIONAL_JUMP

Expand Down
3 changes: 2 additions & 1 deletion python/paddle/jit/sot/symbolic/statement_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@
from typing import Any, Callable

import paddle
from paddle.jit.utils import OrderedSet
from paddle.utils import flatten, map_structure

from ..utils import NameGenerator, OrderedSet, Singleton, flatten_extend
from ..utils import NameGenerator, Singleton, flatten_extend


class Reference: # to unify weak_ref and strong_ref
Expand Down
1 change: 0 additions & 1 deletion python/paddle/jit/sot/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
ConstTypes,
GraphLogger,
NameGenerator,
OrderedSet,
ResumeFnNameFactory,
Singleton,
SotUndefinedVar,
Expand Down
220 changes: 1 addition & 219 deletions python/paddle/jit/sot/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from collections import OrderedDict
from contextlib import contextmanager
from enum import Enum
from typing import Any, Generic, Iterable, Iterator, TypeVar
from typing import Any, Generic, TypeVar
from weakref import WeakValueDictionary

import numpy as np
Expand Down Expand Up @@ -384,224 +384,6 @@ def hashable(obj):
return False


class OrderedSet(Generic[T]):
"""
A set that preserves the order of insertion.
"""

_data: dict[T, None]

def __init__(self, items: Iterable[T] | None = None):
"""
Examples:
>>> s = OrderedSet([1, 2, 3])
>>> s
OrderedSet(1, 2, 3)
>>> s = OrderedSet()
>>> s
OrderedSet()
"""
self._data = dict.fromkeys(items) if items is not None else {}

def __iter__(self) -> Iterator[T]:
"""
Examples:
>>> s = OrderedSet([1, 2, 3])
>>> for item in s:
... print(item)
1
2
3
"""
return iter(self._data)

def __or__(self, other: OrderedSet[T]) -> OrderedSet[T]:
"""
Union two sets.

Args:
other: Another set to be unioned.

Returns:
The union of two sets.

Examples:
>>> s1 = OrderedSet([1, 2, 3])
>>> s2 = OrderedSet([2, 3, 4])
>>> s1 | s2
OrderedSet(1, 2, 3, 4)
"""
return OrderedSet(list(self) + list(other))

def __ior__(self, other: OrderedSet[T]):
"""
Union two sets in place.

Args:
other: Another set to be unioned.

Examples:
>>> s1 = OrderedSet([1, 2, 3])
>>> s2 = OrderedSet([2, 3, 4])
>>> s1 |= s2
>>> s1
OrderedSet(1, 2, 3, 4)
"""
self._data.update(dict.fromkeys(other))
return self

def __and__(self, other: OrderedSet[T]) -> OrderedSet[T]:
"""
Intersect two sets.

Args:
other: Another set to be intersected.

Returns:
The intersection of two sets.

Examples:
>>> s1 = OrderedSet([1, 2, 3])
>>> s2 = OrderedSet([2, 3, 4])
>>> s1 & s2
OrderedSet(2, 3)
"""
return OrderedSet([item for item in self if item in other])

def __iand__(self, other: OrderedSet[T]):
"""
Intersect two sets in place.

Args:
other: Another set to be intersected.

Examples:
>>> s1 = OrderedSet([1, 2, 3])
>>> s2 = OrderedSet([2, 3, 4])
>>> s1 &= s2
>>> s1
OrderedSet(2, 3)
"""
self._data = {item: None for item in self if item in other}
return self

def __sub__(self, other: OrderedSet[T]) -> OrderedSet[T]:
"""
Subtract two sets.

Args:
other: Another set to be subtracted.

Returns:
The subtraction of two sets.

Examples:
>>> s1 = OrderedSet([1, 2, 3])
>>> s2 = OrderedSet([2, 3, 4])
>>> s1 - s2
OrderedSet(1)
"""
return OrderedSet([item for item in self if item not in other])

def __isub__(self, other: OrderedSet[T]):
"""
Subtract two sets in place.

Args:
other: Another set to be subtracted.

Examples:
>>> s1 = OrderedSet([1, 2, 3])
>>> s2 = OrderedSet([2, 3, 4])
>>> s1 -= s2
>>> s1
OrderedSet(1)
"""
self._data = {item: None for item in self if item not in other}
return self

def add(self, item: T):
"""
Add an item to the set.

Args:
item: The item to be added.

Examples:
>>> s = OrderedSet([1, 2, 3])
>>> s.add(4)
>>> s
OrderedSet(1, 2, 3, 4)
"""
self._data.setdefault(item)

def remove(self, item: T):
"""
Remove an item from the set.

Args:
item: The item to be removed.

Examples:
>>> s = OrderedSet([1, 2, 3])
>>> s.remove(2)
>>> s
OrderedSet(1, 3)
"""
del self._data[item]

def __contains__(self, item: T) -> bool:
"""
Examples:
>>> s = OrderedSet([1, 2, 3])
>>> 1 in s
True
>>> 4 in s
False
"""
return item in self._data

def __len__(self) -> int:
"""
Examples:
>>> s = OrderedSet([1, 2, 3])
>>> len(s)
3
"""
return len(self._data)

def __bool__(self) -> bool:
"""
Examples:
>>> s = OrderedSet([1, 2, 3])
>>> bool(s)
True
>>> s = OrderedSet()
>>> bool(s)
False
"""
return bool(self._data)

def __eq__(self, other: object) -> bool:
"""
Examples:
>>> s1 = OrderedSet([1, 2, 3])
>>> s2 = OrderedSet([1, 2, 3])
>>> s1 == s2
True
>>> s3 = OrderedSet([3, 2, 1])
>>> s1 == s3
False
"""
if not isinstance(other, OrderedSet):
return NotImplemented
return list(self) == list(other)

def __repr__(self) -> str:
data_repr = ", ".join(map(repr, self._data))
return f"OrderedSet({data_repr})"


class StepState(Enum):
COLLECT_INFO = 1
RUN_SOT = 2
Expand Down
Loading