diff --git a/python/paddle/base/layer_helper.py b/python/paddle/base/layer_helper.py index add33e0a4edeed..00a7a4b729a58e 100644 --- a/python/paddle/base/layer_helper.py +++ b/python/paddle/base/layer_helper.py @@ -15,6 +15,7 @@ import copy import paddle +from paddle import _C_ops from . import unique_name from .dygraph_utils import _append_activation_in_dygraph @@ -162,6 +163,20 @@ def append_activation(self, input_var): if in_dygraph_mode(): res = _append_activation_in_dygraph(input_var, act_type, use_cudnn) return res + elif in_pir_mode(): + + def _append_activation_in_pir(input, act=None, use_cudnn=None): + if act is None: + return input + + attrs = () + if use_cudnn: + attrs = ('use_cudnn', use_cudnn) + + act_op = getattr(_C_ops, act) + return act_op(input, *attrs) + + return _append_activation_in_pir(input_var, act_type, use_cudnn) else: tmp = self.create_variable_for_type_inference(dtype=input_var.dtype) self.append_op( diff --git a/python/paddle/static/input.py b/python/paddle/static/input.py index 554aaf28a1cbd5..ee1b1e5b2d3dc0 100644 --- a/python/paddle/static/input.py +++ b/python/paddle/static/input.py @@ -26,6 +26,10 @@ ) from paddle.base.layer_helper import LayerHelper from paddle.base.libpaddle import DataType +from paddle.base.libpaddle.pir import ( + get_current_insertion_point, + set_insertion_point, +) from ..base.variable_index import _setitem_static @@ -132,10 +136,11 @@ def _reset_data_op_insertion_point(): ir_dtype = dtype if not isinstance(ir_dtype, DataType): ir_dtype = paddle.pir.core.convert_np_dtype_to_dtype_(dtype) + prev_insertion_point = get_current_insertion_point() _reset_data_op_insertion_point() out = paddle._pir_ops.data(name, shape, ir_dtype, core.Place()) out.lod_level = lod_level - paddle.pir.reset_insertion_point_to_end() + set_insertion_point(prev_insertion_point) return out out = helper.create_global_variable( diff --git a/python/paddle/static/nn/common.py b/python/paddle/static/nn/common.py index 88cd6ab9e0b5a8..1235859c40a20a 100644 --- a/python/paddle/static/nn/common.py +++ b/python/paddle/static/nn/common.py @@ -219,6 +219,12 @@ def fc_base( attr=param_attr, shape=param_shape, dtype=dtype, is_bias=False ) if in_pir_mode(): + if len(input_var.shape) > 2: + new_shape = ( + input_var.shape[0], + np.prod(input_var.shape[1:]), + ) + input_var = paddle.reshape(input_var, new_shape) tmp = paddle.matmul(input_var, w) else: tmp = helper.create_variable_for_type_inference(dtype) diff --git a/test/ir/test_ir_subgraph_python_interface.py b/test/ir/test_ir_subgraph_python_interface.py index 7c1258dc8f8379..af0feb5de78847 100644 --- a/test/ir/test_ir_subgraph_python_interface.py +++ b/test/ir/test_ir_subgraph_python_interface.py @@ -18,8 +18,9 @@ import paddle from paddle import base -from paddle.base import core +from paddle.base import core, in_pir_mode from paddle.base.framework import IrGraph, Program, program_guard +from paddle.pir_utils import test_with_pir_api from paddle.static.quantization import QuantizationTransformPass paddle.enable_static() @@ -76,6 +77,7 @@ def false_func(): # be destructed and the sub_graphs will be empty. return graph, all_sub_graphs + @test_with_pir_api def test_quant_sub_graphs(self, use_cuda=False): graph, sub_graphs = self.build_graph_with_sub_graph() place = base.CUDAPlace(0) if use_cuda else base.CPUPlace() @@ -86,12 +88,13 @@ def test_quant_sub_graphs(self, use_cuda=False): weight_quantize_type='range_abs_max', ) Find_inserted_quant_op = False - for sub_graph in sub_graphs: - transform_pass.apply(sub_graph) - for op in sub_graph.all_op_nodes(): - if 'quantize' in op.name(): - Find_inserted_quant_op = True - self.assertTrue(Find_inserted_quant_op) + if not in_pir_mode(): + for sub_graph in sub_graphs: + transform_pass.apply(sub_graph) + for op in sub_graph.all_op_nodes(): + if 'quantize' in op.name(): + Find_inserted_quant_op = True + self.assertTrue(Find_inserted_quant_op) def test_quant_sub_graphs_cpu(self): self.test_quant_sub_graphs(use_cuda=False)