Skip to content

Commit 58f0892

Browse files
authored
Support static graph code-gen for scalar and int_array (#48792)
* add suppport_tensor for code_gen to static graph * support code-gen for int_array * polish code * fix bug of data_type
1 parent ff8b2cb commit 58f0892

File tree

16 files changed

+294
-727
lines changed

16 files changed

+294
-727
lines changed

paddle/fluid/operators/crop_tensor_op.cc

Lines changed: 0 additions & 320 deletions
This file was deleted.

paddle/fluid/operators/generator/CMakeLists.txt

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,18 +108,23 @@ execute_process(
108108
--op_compat_yaml_path ${CMAKE_SOURCE_DIR}/paddle/phi/api/yaml/op_compat.yaml
109109
--output_op_path "${generated_op_path}.tmp" --output_arg_map_path
110110
"${generated_argument_mapping_path}.tmp"
111+
RESULT_VARIABLE _result)
112+
if(${_result})
113+
message(FATAL_ERROR "operator codegen failed, exiting.")
114+
endif()
115+
116+
execute_process(
117+
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/fluid/operators/generator
111118
COMMAND
112119
${PYTHON_EXECUTABLE} generate_sparse_op.py --ops_yaml_path
113120
./parsed_ops/sparse_ops.parsed.yaml --backward_ops_yaml_path
114121
./parsed_ops/sparse_backward.parsed.yaml --output_op_path
115122
"${generated_sparse_ops_path}.tmp" --output_arg_map_path
116123
"${generated_sparse_argument_mapping_path}.tmp"
117-
RESULT_VARIABLE _results)
118-
foreach(_result in ${_results})
119-
if(${_result})
120-
message(FATAL_ERROR "operator codegen failed, exiting.")
121-
endif()
122-
endforeach()
124+
RESULT_VARIABLE _result)
125+
if(${_result})
126+
message(FATAL_ERROR "sparse operator codegen failed, exiting.")
127+
endif()
123128

124129
if(EXISTS "${generated_op_path}.tmp" AND EXISTS "${generated_op_path}")
125130
execute_process(COMMAND ${CMAKE_COMMAND} -E copy_if_different

paddle/fluid/operators/generator/filters.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,17 +114,44 @@ def to_input_name(s):
114114
return match.group(2)
115115

116116

117+
def to_scalar_tensor_name(attr):
118+
if 'tensor_name' in attr:
119+
return attr['tensor_name']
120+
return to_pascal_case(attr['name']) + 'Tensor'
121+
122+
123+
def to_int_array_tensor_name(attr):
124+
if 'tensor_name' in attr:
125+
return attr['tensor_name']
126+
return to_pascal_case(attr['name']) + 'Tensor'
127+
128+
129+
def to_int_array_tensors_name(attr):
130+
if 'tensors_name' in attr:
131+
return attr['tensors_name']
132+
return to_pascal_case(attr['name']) + 'TensorList'
133+
134+
117135
def cartesian_prod_attrs(attrs):
118136
items = []
119137
for attr in attrs:
120138
type_name = attr["typename"]
121139
name = attr["name"]
122140
if type_name == "Scalar":
123-
items.append((name, "{}Tensor".format(name)))
141+
items.append((name, to_scalar_tensor_name(attr)))
124142
elif type_name == "IntArray":
125-
items.append(
126-
(name, "{}Tensor".format(name), "{}TensorList".format(name))
127-
)
143+
if 'tensor_name' not in attr and 'manual_flag' in attr:
144+
items.append((name, to_int_array_tensors_name(attr)))
145+
elif 'tensors_name' not in attr and 'manual_flag' in attr:
146+
items.append((name, to_int_array_tensor_name(attr)))
147+
else:
148+
items.append(
149+
(
150+
name,
151+
to_int_array_tensor_name(attr),
152+
to_int_array_tensors_name(attr),
153+
)
154+
)
128155
else:
129156
items.append((name,))
130157

0 commit comments

Comments
 (0)