Skip to content

Commit 00f108c

Browse files
wanghuancoderLuckycheng222
authored andcommitted
checkout pir not support out (PaddlePaddle#74685)
* checkout pir not support out * refine
1 parent 50e889f commit 00f108c

File tree

3 files changed

+38
-0
lines changed

3 files changed

+38
-0
lines changed

paddle/fluid/pir/dialect/op_generator/python_c_gen.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@
7373
// Parse Attributes
7474
{attrs}
7575
76+
// Parse input_out if needed
77+
{input_out}
78+
7679
// Check Reminding Params validity if needed
7780
{check_remaining_params_valid}
7881
// Call Pre_Process before calling dygraph function if needed
@@ -166,6 +169,9 @@
166169
// Parse Attributes
167170
{attrs_py_obj}
168171
172+
// Parse input_out if needed
173+
{input_out}
174+
169175
// Check for mutable attrs
170176
{init_attrs}
171177
{cast_attrs}
@@ -646,6 +652,13 @@ def _gen_one_impl(self, op_info, op_name):
646652
args=', '.join(input_name_list + attr_name_list),
647653
)
648654
elif len(mutable_attr_name_list) > 0:
655+
get_input_out_str = ""
656+
if (
657+
not op_name[-1:] == "_"
658+
and not op_name[-4:] == "grad"
659+
and "sparse" not in op_name
660+
):
661+
get_input_out_str = "Check_PIR_not_support_out(kwargs);"
649662
ret = MUTABLE_ATTR_API_IMPL_TEMPLATE.format(
650663
api_name=op_name,
651664
check_params_count=self._gen_check_params_count(
@@ -666,8 +679,16 @@ def _gen_one_impl(self, op_info, op_name):
666679
+ mutable_attr_name_list
667680
+ no_mutable_attr_name_list
668681
),
682+
input_out=get_input_out_str,
669683
)
670684
else:
685+
get_input_out_str = ""
686+
if (
687+
not op_name[-1:] == "_"
688+
and not op_name[-4:] == "grad"
689+
and "sparse" not in op_name
690+
):
691+
get_input_out_str = "Check_PIR_not_support_out(kwargs);"
671692
ret = NO_MUTABLE_ATTR_API_IMPL_TEMPLATE.format(
672693
api_name=op_name,
673694
check_params_count=self._gen_check_params_count(
@@ -682,6 +703,7 @@ def _gen_one_impl(self, op_info, op_name):
682703
need_check=need_check_params_count
683704
),
684705
pre_process=self._gen_pre_process(pre_process),
706+
input_out=get_input_out_str,
685707
)
686708
ret = re.sub(r' +\n', '', ret)
687709
return ret

paddle/fluid/pybind/eager_utils.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3184,4 +3184,18 @@ paddle::optional<Tensor*> GetInputOutTensorFromKwargs(PyObject* kwargs) {
31843184
return paddle::none;
31853185
}
31863186

3187+
void Check_PIR_not_support_out(PyObject* kwargs) {
3188+
if (!kwargs) {
3189+
return;
3190+
}
3191+
PyObject* obj = PyDict_GetItemString(kwargs, "out");
3192+
if (obj) {
3193+
static std::once_flag once_flag;
3194+
std::call_once(once_flag, [&] {
3195+
LOG(WARNING) << "Paddle static graph(PIR) not support input out tensor "
3196+
"for now!!!!!";
3197+
});
3198+
}
3199+
}
3200+
31873201
} // namespace paddle::pybind

paddle/fluid/pybind/eager_utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,8 @@ void EagerSetDeviceId();
540540

541541
paddle::optional<Tensor*> GetInputOutTensorFromKwargs(PyObject* kwargs);
542542

543+
void Check_PIR_not_support_out(PyObject* kwargs);
544+
543545
/*----------------------for arg parse-----------------------------*/
544546
paddle::Tensor& GetTensorFromArgsOrKWArgs(
545547
const std::string& op_type,

0 commit comments

Comments
 (0)