Skip to content

Commit 9247ade

Browse files
authored
fix eb4 (#62032)
1 parent 194ef8b commit 9247ade

3 files changed

Lines changed: 29 additions & 10 deletions

File tree

python/paddle/distributed/fleet/layers/mpu/mp_layers.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import paddle
1818
from paddle.autograd import PyLayer
19+
from paddle.base import core
1920
from paddle.distributed import fleet
2021
from paddle.nn import functional as F
2122

@@ -33,7 +34,7 @@
3334

3435

3536
def is_fused_matmul_bias_supported():
36-
return hasattr(paddle._C_ops, 'fused_gemm_epilogue')
37+
return hasattr(core.eager.ops.legacy, 'fused_gemm_epilogue')
3738

3839

3940
def is_fused_linear_param_grad_add_supported():
@@ -213,10 +214,7 @@ def forward(
213214
if not fuse_matmul_bias:
214215
return paddle._C_ops.linear(x, weight, bias)
215216
else:
216-
result, _ = paddle._C_ops.fused_gemm_epilogue(
217-
x, weight, bias, False, False, "none"
218-
)
219-
return result
217+
return paddle._legacy_C_ops.fused_gemm_epilogue(x, weight, bias)
220218

221219
@staticmethod
222220
def backward(ctx, dy):

python/paddle/distributed/fleet/utils/sequence_parallel_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import paddle
1818
from paddle import distributed as dist
1919
from paddle.autograd import PyLayer
20+
from paddle.base import core
2021
from paddle.distributed import fleet
2122
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
2223
from paddle.distributed.fleet.utils.hybrid_parallel_util import (
@@ -221,7 +222,7 @@ def is_fused_matmul_bias_supported():
221222
and not paddle.is_compiled_with_rocm()
222223
or paddle.is_compiled_with_xpu()
223224
):
224-
return hasattr(paddle._C_ops, "fused_gemm_epilogue")
225+
return hasattr(core.eager.ops.legacy, "fused_gemm_epilogue")
225226
else:
226227
return False
227228

python/paddle/incubate/nn/functional/fused_matmul_bias.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from paddle import _C_ops
15+
from paddle import _C_ops, _legacy_C_ops
1616
from paddle.base.layer_helper import LayerHelper
17-
from paddle.framework import in_dynamic_or_pir_mode
17+
from paddle.framework import (
18+
in_dynamic_mode,
19+
in_pir_mode,
20+
)
1821
from paddle.tensor.linalg import matmul
1922

2023

@@ -56,7 +59,11 @@ def fused_matmul_bias(
5659
"""
5760
if bias is None:
5861
return matmul(x, y, transpose_x, transpose_y, name)
59-
if in_dynamic_or_pir_mode():
62+
if in_dynamic_mode():
63+
return _legacy_C_ops.fused_gemm_epilogue(
64+
x, y, bias, 'trans_x', transpose_x, 'trans_y', transpose_y
65+
)
66+
if in_pir_mode():
6067
out, _ = _C_ops.fused_gemm_epilogue(
6168
x, y, bias, transpose_x, transpose_y, "none"
6269
)
@@ -146,7 +153,20 @@ def fused_linear_activation(
146153
if activation is None:
147154
activation = "none"
148155

149-
if in_dynamic_or_pir_mode():
156+
if in_dynamic_mode():
157+
return _legacy_C_ops.fused_gemm_epilogue(
158+
x,
159+
y,
160+
bias,
161+
'trans_x',
162+
trans_x,
163+
'trans_y',
164+
trans_y,
165+
'activation',
166+
activation,
167+
)
168+
169+
if in_pir_mode():
150170
out, _ = _C_ops.fused_gemm_epilogue(
151171
x,
152172
y,

0 commit comments

Comments
 (0)