diff --git a/python/paddle/jit/dy2static/convert_operators.py b/python/paddle/jit/dy2static/convert_operators.py index 2ed81d8a632895..f47b7613bfaf0f 100644 --- a/python/paddle/jit/dy2static/convert_operators.py +++ b/python/paddle/jit/dy2static/convert_operators.py @@ -24,6 +24,7 @@ from paddle.base.dygraph.base import _convert_into_variable, in_to_static_mode from paddle.base.framework import Variable, core, default_main_program from paddle.framework import use_pir_api +from paddle.jit.utils import OrderedSet from paddle.pir import Value from paddle.static.amp.fp16_utils import AmpOptions from paddle.utils import is_sequence, map_structure @@ -207,9 +208,9 @@ def _run_paddle_while( helper = GetterSetterHelper(getter, setter, return_name_ids, push_pop_names) _convert_tensor_arrray_if_necessary(helper, push_pop_names) - union_name = (set(return_name_ids) if return_name_ids else set()) | ( - set(push_pop_names) if push_pop_names else set() - ) + union_name = ( + OrderedSet(return_name_ids) if return_name_ids else OrderedSet() + ) | (OrderedSet(push_pop_names) if push_pop_names else OrderedSet()) union_name = list(union_name) def new_body_fn(*args): @@ -444,9 +445,9 @@ def _run_paddle_cond( if return_name_ids is None and push_pop_names is None: union_name = None else: - union_name = (set(return_name_ids) if return_name_ids else set()) | ( - set(push_pop_names) if push_pop_names else set() - ) + union_name = ( + OrderedSet(return_name_ids) if return_name_ids else OrderedSet() + ) | (OrderedSet(push_pop_names) if push_pop_names else OrderedSet()) union_name = list(union_name) def new_true_fn(): diff --git a/python/paddle/jit/dy2static/utils.py b/python/paddle/jit/dy2static/utils.py index a6ff64f25303f8..7a188f522dba64 100644 --- a/python/paddle/jit/dy2static/utils.py +++ b/python/paddle/jit/dy2static/utils.py @@ -35,6 +35,7 @@ from paddle.base.layer_helper import LayerHelper from paddle.base.wrapped_decorator import signature_safe_contextmanager from paddle.framework import CUDAPinnedPlace +from paddle.jit.utils import OrderedSet from paddle.utils import flatten from .ast_utils import ast_to_source_code @@ -479,9 +480,9 @@ class GetterSetterHelper: def __init__(self, getter_func, setter_func, *name_lists): name_lists = ([] if x is None else x for x in name_lists) - name_sets = (set(x) for x in name_lists) + name_sets = (OrderedSet(x) for x in name_lists) self._union = list( - functools.reduce(lambda x, y: x | y, name_sets, set()) + functools.reduce(lambda x, y: x | y, name_sets, OrderedSet()) ) self._union.sort() self.getter = getter_func diff --git a/python/paddle/jit/sot/opcode_translator/executor/function_graph.py b/python/paddle/jit/sot/opcode_translator/executor/function_graph.py index b7b3efc6fe65d1..8f87e19cd4d288 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/function_graph.py +++ b/python/paddle/jit/sot/opcode_translator/executor/function_graph.py @@ -24,6 +24,7 @@ from functools import cached_property from typing import Any, Callable +from paddle.jit.utils import OrderedSet from paddle.utils import flatten from ...infer_meta import ( @@ -38,7 +39,6 @@ from ...utils import ( ENV_SHOW_TRACKERS, NameGenerator, - OrderedSet, inner_error_default_handler, is_inplace_api, is_paddle_api, diff --git a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py index 55177f0601b9fc..5691b2c25b4863 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py +++ b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py @@ -27,6 +27,8 @@ import opcode +from paddle.jit.utils import OrderedSet + from ...profiler import EventGuard, event_register from ...psdb import NO_BREAKGRAPH_CODES from ...utils import ( @@ -34,7 +36,6 @@ BreakGraphError, FallbackError, InnerError, - OrderedSet, SotUndefinedVar, get_static_function, log, diff --git a/python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py b/python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py index bf89950b7d858b..a675b6065de9e5 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py +++ b/python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py @@ -27,11 +27,11 @@ import opcode import paddle +from paddle.jit.utils import OrderedSet from ...utils import ( FallbackError, InnerError, - OrderedSet, ResumeFnNameFactory, is_clean_code, list_contain_by_id, diff --git a/python/paddle/jit/sot/opcode_translator/instruction_utils/opcode_analysis.py b/python/paddle/jit/sot/opcode_translator/instruction_utils/opcode_analysis.py index dcda7558e5a395..f0211167f44498 100644 --- a/python/paddle/jit/sot/opcode_translator/instruction_utils/opcode_analysis.py +++ b/python/paddle/jit/sot/opcode_translator/instruction_utils/opcode_analysis.py @@ -17,7 +17,9 @@ import dataclasses from enum import Enum -from ...utils import InnerError, OrderedSet +from paddle.jit.utils import OrderedSet + +from ...utils import InnerError from .instruction_utils import Instruction from .opcode_info import ALL_JUMP, HAS_FREE, HAS_LOCAL, UNCONDITIONAL_JUMP diff --git a/python/paddle/jit/sot/symbolic/statement_ir.py b/python/paddle/jit/sot/symbolic/statement_ir.py index 8bd68c533770c5..8be61aaf522e94 100644 --- a/python/paddle/jit/sot/symbolic/statement_ir.py +++ b/python/paddle/jit/sot/symbolic/statement_ir.py @@ -24,9 +24,10 @@ from typing import Any, Callable import paddle +from paddle.jit.utils import OrderedSet from paddle.utils import flatten, map_structure -from ..utils import NameGenerator, OrderedSet, Singleton, flatten_extend +from ..utils import NameGenerator, Singleton, flatten_extend class Reference: # to unify weak_ref and strong_ref diff --git a/python/paddle/jit/sot/utils/__init__.py b/python/paddle/jit/sot/utils/__init__.py index 4b2f77823088ec..8e74d65aa99e6d 100644 --- a/python/paddle/jit/sot/utils/__init__.py +++ b/python/paddle/jit/sot/utils/__init__.py @@ -47,7 +47,6 @@ ConstTypes, GraphLogger, NameGenerator, - OrderedSet, ResumeFnNameFactory, Singleton, SotUndefinedVar, diff --git a/python/paddle/jit/sot/utils/utils.py b/python/paddle/jit/sot/utils/utils.py index 3e5caa10f9ef9d..b51272c0e38e5e 100644 --- a/python/paddle/jit/sot/utils/utils.py +++ b/python/paddle/jit/sot/utils/utils.py @@ -22,7 +22,7 @@ from collections import OrderedDict from contextlib import contextmanager from enum import Enum -from typing import Any, Generic, Iterable, Iterator, TypeVar +from typing import Any, Generic, TypeVar from weakref import WeakValueDictionary import numpy as np @@ -384,224 +384,6 @@ def hashable(obj): return False -class OrderedSet(Generic[T]): - """ - A set that preserves the order of insertion. - """ - - _data: dict[T, None] - - def __init__(self, items: Iterable[T] | None = None): - """ - Examples: - >>> s = OrderedSet([1, 2, 3]) - >>> s - OrderedSet(1, 2, 3) - >>> s = OrderedSet() - >>> s - OrderedSet() - """ - self._data = dict.fromkeys(items) if items is not None else {} - - def __iter__(self) -> Iterator[T]: - """ - Examples: - >>> s = OrderedSet([1, 2, 3]) - >>> for item in s: - ... print(item) - 1 - 2 - 3 - """ - return iter(self._data) - - def __or__(self, other: OrderedSet[T]) -> OrderedSet[T]: - """ - Union two sets. - - Args: - other: Another set to be unioned. - - Returns: - The union of two sets. - - Examples: - >>> s1 = OrderedSet([1, 2, 3]) - >>> s2 = OrderedSet([2, 3, 4]) - >>> s1 | s2 - OrderedSet(1, 2, 3, 4) - """ - return OrderedSet(list(self) + list(other)) - - def __ior__(self, other: OrderedSet[T]): - """ - Union two sets in place. - - Args: - other: Another set to be unioned. - - Examples: - >>> s1 = OrderedSet([1, 2, 3]) - >>> s2 = OrderedSet([2, 3, 4]) - >>> s1 |= s2 - >>> s1 - OrderedSet(1, 2, 3, 4) - """ - self._data.update(dict.fromkeys(other)) - return self - - def __and__(self, other: OrderedSet[T]) -> OrderedSet[T]: - """ - Intersect two sets. - - Args: - other: Another set to be intersected. - - Returns: - The intersection of two sets. - - Examples: - >>> s1 = OrderedSet([1, 2, 3]) - >>> s2 = OrderedSet([2, 3, 4]) - >>> s1 & s2 - OrderedSet(2, 3) - """ - return OrderedSet([item for item in self if item in other]) - - def __iand__(self, other: OrderedSet[T]): - """ - Intersect two sets in place. - - Args: - other: Another set to be intersected. - - Examples: - >>> s1 = OrderedSet([1, 2, 3]) - >>> s2 = OrderedSet([2, 3, 4]) - >>> s1 &= s2 - >>> s1 - OrderedSet(2, 3) - """ - self._data = {item: None for item in self if item in other} - return self - - def __sub__(self, other: OrderedSet[T]) -> OrderedSet[T]: - """ - Subtract two sets. - - Args: - other: Another set to be subtracted. - - Returns: - The subtraction of two sets. - - Examples: - >>> s1 = OrderedSet([1, 2, 3]) - >>> s2 = OrderedSet([2, 3, 4]) - >>> s1 - s2 - OrderedSet(1) - """ - return OrderedSet([item for item in self if item not in other]) - - def __isub__(self, other: OrderedSet[T]): - """ - Subtract two sets in place. - - Args: - other: Another set to be subtracted. - - Examples: - >>> s1 = OrderedSet([1, 2, 3]) - >>> s2 = OrderedSet([2, 3, 4]) - >>> s1 -= s2 - >>> s1 - OrderedSet(1) - """ - self._data = {item: None for item in self if item not in other} - return self - - def add(self, item: T): - """ - Add an item to the set. - - Args: - item: The item to be added. - - Examples: - >>> s = OrderedSet([1, 2, 3]) - >>> s.add(4) - >>> s - OrderedSet(1, 2, 3, 4) - """ - self._data.setdefault(item) - - def remove(self, item: T): - """ - Remove an item from the set. - - Args: - item: The item to be removed. - - Examples: - >>> s = OrderedSet([1, 2, 3]) - >>> s.remove(2) - >>> s - OrderedSet(1, 3) - """ - del self._data[item] - - def __contains__(self, item: T) -> bool: - """ - Examples: - >>> s = OrderedSet([1, 2, 3]) - >>> 1 in s - True - >>> 4 in s - False - """ - return item in self._data - - def __len__(self) -> int: - """ - Examples: - >>> s = OrderedSet([1, 2, 3]) - >>> len(s) - 3 - """ - return len(self._data) - - def __bool__(self) -> bool: - """ - Examples: - >>> s = OrderedSet([1, 2, 3]) - >>> bool(s) - True - >>> s = OrderedSet() - >>> bool(s) - False - """ - return bool(self._data) - - def __eq__(self, other: object) -> bool: - """ - Examples: - >>> s1 = OrderedSet([1, 2, 3]) - >>> s2 = OrderedSet([1, 2, 3]) - >>> s1 == s2 - True - >>> s3 = OrderedSet([3, 2, 1]) - >>> s1 == s3 - False - """ - if not isinstance(other, OrderedSet): - return NotImplemented - return list(self) == list(other) - - def __repr__(self) -> str: - data_repr = ", ".join(map(repr, self._data)) - return f"OrderedSet({data_repr})" - - class StepState(Enum): COLLECT_INFO = 1 RUN_SOT = 2 diff --git a/python/paddle/jit/utils.py b/python/paddle/jit/utils.py new file mode 100644 index 00000000000000..00cc32e5b62626 --- /dev/null +++ b/python/paddle/jit/utils.py @@ -0,0 +1,282 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Generic, Iterable, Iterator, TypeVar + +T = TypeVar("T") + + +class OrderedSet(Generic[T]): + """ + A set that preserves the order of insertion. + """ + + _data: dict[T, None] + + def __init__(self, items: Iterable[T] | None = None): + """ + Examples: + >>> s = OrderedSet([1, 2, 3]) + >>> s + OrderedSet(1, 2, 3) + >>> s = OrderedSet() + >>> s + OrderedSet() + """ + self._data = dict.fromkeys(items) if items is not None else {} + + def __iter__(self) -> Iterator[T]: + """ + Examples: + >>> s = OrderedSet([1, 2, 3]) + >>> for item in s: + ... print(item) + 1 + 2 + 3 + """ + return iter(self._data) + + def __or__(self, other: OrderedSet[T]) -> OrderedSet[T]: + """ + Union two sets. + + Args: + other: Another set to be unioned. + + Returns: + The union of two sets. + + Examples: + >>> s1 = OrderedSet([1, 2, 3]) + >>> s2 = OrderedSet([2, 3, 4]) + >>> s1 | s2 + OrderedSet(1, 2, 3, 4) + """ + return OrderedSet(list(self) + list(other)) + + def __ior__(self, other: OrderedSet[T]): + """ + Union two sets in place. + + Args: + other: Another set to be unioned. + + Examples: + >>> s1 = OrderedSet([1, 2, 3]) + >>> s2 = OrderedSet([2, 3, 4]) + >>> s1 |= s2 + >>> s1 + OrderedSet(1, 2, 3, 4) + """ + self._data.update(dict.fromkeys(other)) + return self + + def __and__(self, other: OrderedSet[T]) -> OrderedSet[T]: + """ + Intersect two sets. + + Args: + other: Another set to be intersected. + + Returns: + The intersection of two sets. + + Examples: + >>> s1 = OrderedSet([1, 2, 3]) + >>> s2 = OrderedSet([2, 3, 4]) + >>> s1 & s2 + OrderedSet(2, 3) + """ + return OrderedSet([item for item in self if item in other]) + + def __iand__(self, other: OrderedSet[T]): + """ + Intersect two sets in place. + + Args: + other: Another set to be intersected. + + Examples: + >>> s1 = OrderedSet([1, 2, 3]) + >>> s2 = OrderedSet([2, 3, 4]) + >>> s1 &= s2 + >>> s1 + OrderedSet(2, 3) + """ + self._data = {item: None for item in self if item in other} + return self + + def __sub__(self, other: OrderedSet[T]) -> OrderedSet[T]: + """ + Subtract two sets. + + Args: + other: Another set to be subtracted. + + Returns: + The subtraction of two sets. + + Examples: + >>> s1 = OrderedSet([1, 2, 3]) + >>> s2 = OrderedSet([2, 3, 4]) + >>> s1 - s2 + OrderedSet(1) + """ + return OrderedSet([item for item in self if item not in other]) + + def __isub__(self, other: OrderedSet[T]): + """ + Subtract two sets in place. + + Args: + other: Another set to be subtracted. + + Examples: + >>> s1 = OrderedSet([1, 2, 3]) + >>> s2 = OrderedSet([2, 3, 4]) + >>> s1 -= s2 + >>> s1 + OrderedSet(1) + """ + self._data = {item: None for item in self if item not in other} + return self + + def __xor__(self, other: OrderedSet[T]) -> OrderedSet[T]: + """ + Symmetric difference of two sets. + + Args: + other: Another set to be xor'ed. + + Returns: + The symmetric difference of two sets. + + Examples: + >>> s1 = OrderedSet([1, 2, 3]) + >>> s2 = OrderedSet([2, 3, 4]) + >>> s1 ^ s2 + OrderedSet(1, 4) + """ + return OrderedSet( + [item for item in self if item not in other] + ) | OrderedSet([item for item in other if item not in self]) + + def __ixor__(self, other: OrderedSet[T]): + """ + Symmetric difference of two sets in place. + + Args: + other: Another set to be xor'ed. + + Examples: + >>> s1 = OrderedSet([1, 2, 3]) + >>> s2 = OrderedSet([2, 3, 4]) + >>> s1 ^= s2 + >>> s1 + OrderedSet(1, 4) + """ + # TODO(Python3.8-cleanup): Use dict union syntax when Python 3.9 is + # minimum supported version. + # self._data = {item: None for item in self if item not in other} | { + # item: None for item in other if item not in self + # } + self._data = { + **{item: None for item in self if item not in other}, + **{item: None for item in other if item not in self}, + } + return self + + def add(self, item: T): + """ + Add an item to the set. + + Args: + item: The item to be added. + + Examples: + >>> s = OrderedSet([1, 2, 3]) + >>> s.add(4) + >>> s + OrderedSet(1, 2, 3, 4) + """ + self._data.setdefault(item) + + def remove(self, item: T): + """ + Remove an item from the set. + + Args: + item: The item to be removed. + + Examples: + >>> s = OrderedSet([1, 2, 3]) + >>> s.remove(2) + >>> s + OrderedSet(1, 3) + """ + del self._data[item] + + def __contains__(self, item: T) -> bool: + """ + Examples: + >>> s = OrderedSet([1, 2, 3]) + >>> 1 in s + True + >>> 4 in s + False + """ + return item in self._data + + def __len__(self) -> int: + """ + Examples: + >>> s = OrderedSet([1, 2, 3]) + >>> len(s) + 3 + """ + return len(self._data) + + def __bool__(self) -> bool: + """ + Examples: + >>> s = OrderedSet([1, 2, 3]) + >>> bool(s) + True + >>> s = OrderedSet() + >>> bool(s) + False + """ + return bool(self._data) + + def __eq__(self, other: object) -> bool: + """ + Examples: + >>> s1 = OrderedSet([1, 2, 3]) + >>> s2 = OrderedSet([1, 2, 3]) + >>> s1 == s2 + True + >>> s3 = OrderedSet([3, 2, 1]) + >>> s1 == s3 + False + """ + if not isinstance(other, OrderedSet): + return NotImplemented + return list(self) == list(other) + + def __repr__(self) -> str: + data_repr = ", ".join(map(repr, self._data)) + return f"OrderedSet({data_repr})" diff --git a/test/dygraph_to_static/test_ordered_set.py b/test/dygraph_to_static/test_ordered_set.py new file mode 100644 index 00000000000000..db72da87af27e7 --- /dev/null +++ b/test/dygraph_to_static/test_ordered_set.py @@ -0,0 +1,107 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from paddle.jit.utils import OrderedSet + + +class TestOrderedSet(unittest.TestCase): + def test_iter(self): + s = OrderedSet([1, 2, 3]) + self.assertEqual(list(s), [1, 2, 3]) + + def test_or(self): + s1 = OrderedSet([1, 2, 3]) + s2 = OrderedSet([2, 3, 4]) + self.assertEqual(s1 | s2, OrderedSet([1, 2, 3, 4])) + + def test_ior(self): + s1 = OrderedSet([1, 2, 3]) + s2 = OrderedSet([2, 3, 4]) + s1 |= s2 + self.assertEqual(s1, OrderedSet([1, 2, 3, 4])) + + def test_and(self): + s1 = OrderedSet([1, 2, 3]) + s2 = OrderedSet([2, 3, 4]) + self.assertEqual(s1 & s2, OrderedSet([2, 3])) + + def test_iand(self): + s1 = OrderedSet([1, 2, 3]) + s2 = OrderedSet([2, 3, 4]) + s1 &= s2 + self.assertEqual(s1, OrderedSet([2, 3])) + + def test_sub(self): + s1 = OrderedSet([1, 2, 3]) + s2 = OrderedSet([2, 3, 4]) + self.assertEqual(s1 - s2, OrderedSet([1])) + + def test_isub(self): + s1 = OrderedSet([1, 2, 3]) + s2 = OrderedSet([2, 3, 4]) + s1 -= s2 + self.assertEqual(s1, OrderedSet([1])) + + def test_xor(self): + s1 = OrderedSet([1, 2, 3]) + s2 = OrderedSet([2, 3, 4]) + self.assertEqual(s1 ^ s2, OrderedSet([1, 4])) + + def test_ixor(self): + s1 = OrderedSet([1, 2, 3]) + s2 = OrderedSet([2, 3, 4]) + s1 ^= s2 + self.assertEqual(s1, OrderedSet([1, 4])) + + def test_add(self): + s = OrderedSet([1, 2, 3]) + s.add(4) + self.assertEqual(s, OrderedSet([1, 2, 3, 4])) + + def test_remove(self): + s = OrderedSet([1, 2, 3]) + s.remove(2) + self.assertEqual(s, OrderedSet([1, 3])) + + def test_contains(self): + s = OrderedSet([1, 2, 3]) + self.assertTrue(2 in s) + self.assertFalse(4 in s) + + def test_len(self): + s = OrderedSet([1, 2, 3]) + self.assertEqual(len(s), 3) + + def test_bool(self): + s = OrderedSet([1, 2, 3]) + self.assertTrue(bool(s)) + s = OrderedSet() + self.assertFalse(bool(s)) + + def test_eq(self): + s1 = OrderedSet([1, 2, 3]) + s2 = OrderedSet([1, 2, 3]) + self.assertEqual(s1, s2) + s3 = OrderedSet([3, 2, 1]) + self.assertNotEqual(s1, s3) + + def test_repr(self): + s = OrderedSet([1, 2, 3]) + self.assertEqual(repr(s), "OrderedSet(1, 2, 3)") + + +if __name__ == '__main__': + unittest.main()