Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions python/paddle/jit/sot/opcode_translator/executor/function_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
NullVariable,
PaddleLayerVariable,
ParameterVariable,
SymbolicIntVariable,
SymbolicVariable,
TensorVariable,
VariableBase,
VariableFactory,
Expand All @@ -89,7 +89,7 @@ def convert_to_meta(inputs: Any):
"""

def func(x):
if isinstance(x, (TensorVariable, SymbolicIntVariable)):
if isinstance(x, (TensorVariable, SymbolicVariable)):
return x.meta
if isinstance(x, VariableBase):
return x.get_py_value()
Expand All @@ -104,7 +104,7 @@ def convert_to_symbol(inputs: Any):
"""

def func(x):
if isinstance(x, (TensorVariable, SymbolicIntVariable)):
if isinstance(x, (TensorVariable, SymbolicVariable)):
return x.get_symbol()
if isinstance(x, VariableBase):
return x.get_py_value()
Expand All @@ -119,7 +119,7 @@ def record_symbols(SIR, *args, **kwargs):
non_params = set()

def fn(value):
if isinstance(value, (TensorVariable, SymbolicIntVariable)):
if isinstance(value, (TensorVariable, SymbolicVariable)):
symbol_meta_map[value.get_symbol()] = value.meta
if isinstance(value, ParameterVariable):
params.add(value.get_symbol())
Expand Down Expand Up @@ -415,16 +415,16 @@ def start_compile(self, *ret_vars: VariableBase):
found = False
for variable in self.input_variables:
if (
isinstance(variable, (TensorVariable, SymbolicIntVariable))
isinstance(variable, (TensorVariable, SymbolicVariable))
and variable.get_symbol().name == name
):
if isinstance(variable, SymbolicIntVariable):
if isinstance(variable, SymbolicVariable):
self.pycode_gen.gen_load_object(
paddle.to_tensor, "___paddle_to_tensor"
)
variable.tracker.gen_instructions(self.pycode_gen)
found = True
if isinstance(variable, SymbolicIntVariable):
if isinstance(variable, SymbolicVariable):
self.pycode_gen.gen_call_function(1)
break
assert found, f"can't find input {name} in SIR."
Expand Down Expand Up @@ -619,7 +619,7 @@ def symbolic_call(

log(3, f" inputs : {inputs_symbols}", "\n")

var_cls = SymbolicIntVariable if is_symbolic_int else TensorVariable
var_cls = SymbolicVariable if is_symbolic_int else TensorVariable
outputs = map_if(
out_metas,
pred=lambda x: isinstance(x, MetaInfo),
Expand Down Expand Up @@ -714,7 +714,7 @@ def remove_global_guarded_variable(self, variable: VariableBase):

def _find_tensor_outputs(
self, outputs: list[VariableBase]
) -> OrderedSet[TensorVariable | SymbolicIntVariable]:
) -> OrderedSet[TensorVariable | SymbolicVariable]:
"""
Return all TensorVariable. find TensorVariables participating in networking from the output Variables

Expand All @@ -724,9 +724,9 @@ def _find_tensor_outputs(

def is_graph_output(
var,
) -> TypeGuard[TensorVariable | SymbolicIntVariable]:
) -> TypeGuard[TensorVariable | SymbolicVariable]:
return isinstance(var.tracker, DummyTracker) and isinstance(
var, (TensorVariable, SymbolicIntVariable)
var, (TensorVariable, SymbolicVariable)
)

def collect_related_dummy_tensor(var):
Expand All @@ -741,7 +741,7 @@ def collect_related_dummy_tensor(var):
return []

output_tensors: OrderedSet[
TensorVariable | SymbolicIntVariable
TensorVariable | SymbolicVariable
] = OrderedSet()
# Find Tensor Variables from outputs.
for output in outputs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
NumpyVariable,
RangeVariable,
SliceVariable,
SymbolicIntVariable,
SymbolicVariable,
TupleVariable,
VariableBase,
VariableFactory,
Expand Down Expand Up @@ -886,7 +886,7 @@ def is_not_func(var: VariableBase, other: VariableBase):
if unary_fn in fallback_tensor_unary_method:
Dispatcher.register(
unary_fn,
("TensorVariable | SymbolicIntVariable",),
("TensorVariable | SymbolicVariable",),
raise_break_graph_fn,
)
continue
Expand All @@ -912,7 +912,7 @@ def is_not_func(var: VariableBase, other: VariableBase):
)
Dispatcher.register(
unary_fn,
("SymbolicIntVariable",),
("SymbolicVariable",),
partial(
lambda magic_name, var: var.graph.call_symbolic_method(
magic_name, var
Expand All @@ -932,7 +932,7 @@ def is_not_func(var: VariableBase, other: VariableBase):
binary_fn,
(
"TensorVariable",
"TensorVariable | SymbolicIntVariable | ConstantVariable | NumpyVariable",
"TensorVariable | SymbolicVariable | ConstantVariable | NumpyVariable",
),
partial(
lambda magic_name, var, other: var.graph.call_tensor_method(
Expand All @@ -944,8 +944,8 @@ def is_not_func(var: VariableBase, other: VariableBase):
Dispatcher.register(
binary_fn,
(
"SymbolicIntVariable",
"ConstantVariable | SymbolicIntVariable",
"SymbolicVariable",
"ConstantVariable | SymbolicVariable",
),
partial(
lambda magic_name, var, other: var.graph.call_symbolic_method(
Expand All @@ -960,7 +960,7 @@ def is_not_func(var: VariableBase, other: VariableBase):

@Dispatcher.register_decorator(operator.mod)
def tensor_mod_dispatcher(
var: ConstantVariable | SymbolicIntVariable,
var: ConstantVariable | SymbolicVariable,
other: TensorVariable,
):
if var.get_py_type() is str:
Expand All @@ -973,7 +973,7 @@ def tensor_mod_dispatcher(
Dispatcher.register(
binary_fn,
(
"SymbolicIntVariable | ConstantVariable | NumpyVariable",
"SymbolicVariable | ConstantVariable | NumpyVariable",
"TensorVariable",
),
partial(
Expand All @@ -986,7 +986,7 @@ def tensor_mod_dispatcher(

Dispatcher.register(
binary_fn,
("ConstantVariable", "SymbolicIntVariable"),
("ConstantVariable", "SymbolicVariable"),
partial(
lambda magic_name, var, other: var.graph.call_symbolic_method(
magic_name, var, other
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
ObjectVariable,
ParameterVariable,
SliceVariable,
SymbolicIntVariable,
SymbolicVariable,
TensorVariable,
)
from .callable import ( # noqa: F401
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ def from_value(value: Any, graph: FunctionGraph, tracker: Tracker):
return None


class SymbolicIntVariable(VariableBase):
class SymbolicVariable(VariableBase):
"""
TODO
"""
Expand Down Expand Up @@ -684,7 +684,7 @@ def from_value(value: Any, graph: FunctionGraph, tracker: Tracker):
symbolic_input.setdefault(value, 0)
symbolic_input[value] += 1
# TODO(zrr1999): determine frequency
return SymbolicIntVariable(value, graph, tracker)
return SymbolicVariable(value, graph, tracker)
return None


Expand Down