Skip to content

Commit bbc8ea5

Browse files
authored
【PIR API adaptor No.127】llm_int8_linear (#58882)
* Update quantized_linear.py * Update test_llm_int8_linear.py * Update quantized_linear.py * Update test_llm_int8_linear.py
1 parent c66c543 commit bbc8ea5

File tree

2 files changed

+11
-5
lines changed

2 files changed

+11
-5
lines changed

python/paddle/nn/quant/quantized_linear.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@
1515
from paddle import _C_ops
1616
from paddle.base.data_feeder import check_dtype
1717
from paddle.base.framework import convert_np_dtype_to_dtype_
18-
from paddle.framework import LayerHelper, in_dynamic_mode
18+
from paddle.framework import (
19+
LayerHelper,
20+
in_dynamic_mode,
21+
in_dynamic_or_pir_mode,
22+
)
1923

2024

2125
def weight_quantize(x, algo="weight_only_int8"):
@@ -217,7 +221,7 @@ def llm_int8_linear(
217221
... print(out.shape)
218222
[1, 2, 32]
219223
"""
220-
if in_dynamic_mode():
224+
if in_dynamic_or_pir_mode():
221225
out = _C_ops.llm_int8_linear(x, weight, bias, weight_scale, threshold)
222226
return out
223227
else:

test/quantization/test_llm_int8_linear.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from paddle.base import core
2424
from paddle.base.framework import default_main_program
2525
from paddle.framework import set_default_dtype
26+
from paddle.pir_utils import test_with_pir_api
2627

2728
np.random.seed(123)
2829
paddle.seed(123)
@@ -86,11 +87,12 @@ def get_llm_int8_linear_out(self):
8687
)
8788
return out.numpy()
8889

90+
@test_with_pir_api
8991
def get_llm_int8_linear_out_static(self):
9092
paddle.enable_static()
91-
main = base.Program()
92-
start = base.Program()
93-
with base.program_guard(main, start):
93+
main = base.static.Program()
94+
start = base.static.Program()
95+
with base.static.program_guard(main, start):
9496
x = paddle.static.data("x", self.x.shape, dtype=self.x.dtype)
9597

9698
weight = paddle.static.data(

0 commit comments

Comments
 (0)