Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 4 additions & 21 deletions python/paddle/incubate/nn/functional/fused_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import paddle
from paddle import _C_ops
from paddle.framework import LayerHelper, in_dynamic_mode, in_pir_mode
from paddle.framework import LayerHelper, in_dynamic_or_pir_mode

if TYPE_CHECKING:
from paddle import Tensor
Expand Down Expand Up @@ -108,8 +108,7 @@ def fused_layer_norm(
>>> epsilon = 1e-6
>>> paddle_layernorm = paddle.incubate.nn.functional.fused_layer_norm(paddle_x, paddle_weight, paddle_bias, epsilon, 1)
"""

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.fused_bias_residual_layernorm(
x,
bias,
Expand All @@ -124,23 +123,7 @@ def fused_layer_norm(
quant_max_bound,
quant_min_bound,
)
elif in_pir_mode():
out, residual_out, _, _ = _C_ops.fused_bias_residual_layernorm(
x,
bias,
residual,
norm_weight,
norm_bias,
epsilon,
residual_alpha,
begin_norm_axis,
quant_scale,
quant_round_type,
quant_max_bound,
quant_min_bound,
)
return (out, residual_out) if residual is not None else out

# static mode
helper = LayerHelper('fused_layernorm', **locals())
out = None
if quant_scale <= 0:
Expand Down Expand Up @@ -183,4 +166,4 @@ def fused_layer_norm(
},
outputs=outputs_dict,
)
return (out, residual_out) if residual is not None else out
return (out, residual_out, outputs_dict['mean'], outputs_dict['variance'])
22 changes: 4 additions & 18 deletions python/paddle/incubate/nn/functional/fused_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import paddle
from paddle import _C_ops
from paddle.framework import LayerHelper, in_dynamic_mode, in_pir_mode
from paddle.framework import LayerHelper, in_dynamic_or_pir_mode

if TYPE_CHECKING:
from paddle import Tensor
Expand Down Expand Up @@ -102,7 +102,7 @@ def fused_rms_norm(
>>> epsilon = 1e-6
>>> paddle_rmsnorm = paddle.incubate.nn.functional.fused_rms_norm(paddle_x, paddle_weight, paddle_bias, epsilon, 1)
"""
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.rms_norm(
x,
bias,
Expand All @@ -116,21 +116,7 @@ def fused_rms_norm(
quant_max_bound,
quant_min_bound,
)
if in_pir_mode():
out, residual_out = _C_ops.rms_norm(
x,
bias,
residual,
norm_weight,
norm_bias,
epsilon,
begin_norm_axis,
quant_scale,
quant_round_type,
quant_max_bound,
quant_min_bound,
)
return (out, residual_out) if residual is not None else out
# static mode
helper = LayerHelper('rms_norm', **locals())
out = None
if quant_scale <= 0:
Expand Down Expand Up @@ -167,4 +153,4 @@ def fused_rms_norm(
},
outputs=outputs_dict,
)
return (out, residual_out) if residual is not None else out
return (out, residual_out, outputs_dict['inv_var'])
40 changes: 20 additions & 20 deletions test/legacy_test/test_fused_layernorm_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ def check_layernorm(self, x_np, gamma_np, beta_np, dtype):
beta_static,
self.epsilon,
begin_norm_axis=1,
)
)[0]
exe = paddle.static.Executor(self.place)
out_s = exe.run(
feed={
Expand Down Expand Up @@ -498,7 +498,7 @@ def check_layernorm_int8(self, x_np, gamma_np, beta_np, dtype):
quant_round_type=self.quant_round_type,
quant_max_bound=self.quant_max_bound,
quant_min_bound=self.quant_min_bound,
)
)[0]
exe = paddle.static.Executor(self.place)
out_s = exe.run(
feed={
Expand Down Expand Up @@ -546,7 +546,7 @@ def check_residual_bias_add(self, x_np, residual_np, bias_np, dtype):
quant_round_type=self.quant_round_type,
quant_max_bound=self.quant_max_bound,
quant_min_bound=self.quant_min_bound,
)
)[0]

exe = paddle.static.Executor(self.place)
out_s = exe.run(
Expand All @@ -556,7 +556,7 @@ def check_residual_bias_add(self, x_np, residual_np, bias_np, dtype):
"bias_static": bias_np.astype(dtype),
},
fetch_list=[
outs[0]
outs
], # NOTE: Only fetch `out`, because `residual_out` will not be initialized if both `norm_weight` and `norm_bias` are None.
)
return out_s, paddle_naive_residual_out
Expand Down Expand Up @@ -597,7 +597,7 @@ def check_residual_bias_layernorm(
beta_static = paddle.static.data(
name="beta_static", shape=[self.cols], dtype='float32'
)
outs = paddle.incubate.nn.functional.fused_layer_norm(
outs, residual = paddle.incubate.nn.functional.fused_layer_norm(
x_static,
gamma_static,
beta_static,
Expand All @@ -606,7 +606,7 @@ def check_residual_bias_layernorm(
residual_alpha=self.residual_alpha,
bias=bias_static,
residual=residual_static,
)
)[:2]

exe = paddle.static.Executor(self.place)
out_s = exe.run(
Expand All @@ -617,7 +617,7 @@ def check_residual_bias_layernorm(
"residual_static": residual_np.astype(dtype),
"bias_static": bias_np.astype(dtype),
},
fetch_list=[outs],
fetch_list=[outs, residual],
)
return out_s, paddle_naive_layernorm_out, paddle_naive_residual_out

Expand Down Expand Up @@ -667,7 +667,7 @@ def check_residual_bias_layernorm_int8(
beta_static = paddle.static.data(
name="beta_static", shape=[self.cols], dtype='float32'
)
outs = paddle.incubate.nn.functional.fused_layer_norm(
outs, residual = paddle.incubate.nn.functional.fused_layer_norm(
x_static,
gamma_static,
beta_static,
Expand All @@ -680,7 +680,7 @@ def check_residual_bias_layernorm_int8(
quant_round_type=self.quant_round_type,
quant_max_bound=self.quant_max_bound,
quant_min_bound=self.quant_min_bound,
)
)[:2]

exe = paddle.static.Executor(self.place)
out_s = exe.run(
Expand All @@ -691,7 +691,7 @@ def check_residual_bias_layernorm_int8(
"residual_static": residual_np.astype(dtype),
"bias_static": bias_np.astype(dtype),
},
fetch_list=[outs],
fetch_list=[outs, residual],
)
return out_s, paddle_naive_layernorm_out, paddle_naive_residual_out

Expand Down Expand Up @@ -847,7 +847,7 @@ def check_layernorm(self, x_np, gamma_np, beta_np, dtype):

paddle_layernorm_out = paddle.incubate.nn.functional.fused_layer_norm(
x, gamma, beta, self.epsilon, begin_norm_axis=1
)
)[0]
paddle_naive_layernorm_out = naive_layer_norm(
x, gamma, beta, self.epsilon
)
Expand All @@ -869,7 +869,7 @@ def check_residual_bias_add(self, x_np, residual_np, bias_np, dtype):
bias=bias,
residual=residual,
residual_alpha=self.residual_alpha,
)
)[0]

paddle_naive_residual_out = naive_residual_bias_add(
x, residual, bias, self.residual_alpha
Expand Down Expand Up @@ -919,7 +919,7 @@ def test_residual_bias_add(self):
self.x_np, self.residual_np, self.bias_np, 'float32'
)
np.testing.assert_allclose(
paddle_residual_bias_out[0].numpy(),
paddle_residual_bias_out.numpy(),
paddle_naive_residual_bias_out.numpy(),
rtol=1e-3,
atol=1e-3,
Expand All @@ -931,7 +931,7 @@ def test_layernorm(self):
)

np.testing.assert_allclose(
paddle_layernorm[0].numpy(),
paddle_layernorm.numpy(),
paddle_naive_layernorm.numpy(),
rtol=1e-3,
atol=1e-3,
Expand Down Expand Up @@ -1016,7 +1016,7 @@ def check_layernorm(self, x_np, gamma_np, beta_np, dtype):
beta_static,
self.epsilon,
begin_norm_axis=1,
)
)[0]
exe = paddle.static.Executor(self.place)
out_s = exe.run(
feed={
Expand Down Expand Up @@ -1060,7 +1060,7 @@ def check_residual_bias_add(self, x_np, residual_np, bias_np, dtype):
bias=bias_static,
residual=residual_static,
residual_alpha=self.residual_alpha,
)
)[0]

exe = paddle.static.Executor(self.place)
out_s = exe.run(
Expand All @@ -1070,7 +1070,7 @@ def check_residual_bias_add(self, x_np, residual_np, bias_np, dtype):
"bias_static": bias_np.astype(dtype),
},
fetch_list=[
outs[0]
outs
], # NOTE: Only fetch `out`, because `residual_out` will not be initialized if both `norm_weight` and `norm_bias` are None.
)
return out_s, paddle_naive_residual_out
Expand Down Expand Up @@ -1111,7 +1111,7 @@ def check_residual_bias_layernorm(
beta_static = paddle.static.data(
name="beta_static", shape=[self.cols], dtype='float32'
)
outs = paddle.incubate.nn.functional.fused_layer_norm(
outs, residual = paddle.incubate.nn.functional.fused_layer_norm(
x_static,
gamma_static,
beta_static,
Expand All @@ -1120,7 +1120,7 @@ def check_residual_bias_layernorm(
residual_alpha=self.residual_alpha,
bias=bias_static,
residual=residual_static,
)
)[:2]

exe = paddle.static.Executor(self.place)
out_s = exe.run(
Expand All @@ -1131,7 +1131,7 @@ def check_residual_bias_layernorm(
"residual_static": residual_np.astype(dtype),
"bias_static": bias_np.astype(dtype),
},
fetch_list=[outs],
fetch_list=[outs, residual],
)
return out_s, paddle_naive_layernorm_out, paddle_naive_residual_out

Expand Down
Loading