Skip to content

Commit bd3eac0

Browse files
committed
SDNQ fix NPU accuracy
1 parent 7d14e81 commit bd3eac0

2 files changed

Lines changed: 10 additions & 3 deletions

File tree

modules/sdnq/kernels/openvino_mm.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22
import torch
33
import openvino as ov
44
from openvino import opset16 as ov_ops
5-
5+
from openvino.properties import hint as ov_hints
66

77
core = ov.Core()
8+
9+
NPU_MUL = 32 # NPU uses FP16 x INT8 -> FP16 instead of INT8 x INT8 -> INT32 and FP16 output overflows
810
OV_DEVICE: str = os.environ.get("SDNQ_OPENVINO_DEVICE", "CPU")
911
OV_COMPILED_CACHE: dict[tuple[str, tuple[int,int] | None, tuple[int,int] | None], list[ov.InferRequest, str]] = {}
12+
core.set_property(OV_DEVICE, {ov_hints.execution_mode: ov_hints.ExecutionMode.ACCURACY})
1013

1114

1215
def ov_int_mm(A: torch.Tensor, B: torch.Tensor, infer_request: ov.InferRequest, out_name: str) -> torch.Tensor:
@@ -16,12 +19,13 @@ def ov_int_mm(A: torch.Tensor, B: torch.Tensor, infer_request: ov.InferRequest,
1619
infer_request.set_tensor(out_name, ov.Tensor(C.numpy(), shared_memory=True))
1720
infer_request.infer()
1821
C = C.to(A.device)
22+
if OV_DEVICE == "NPU":
23+
C.mul_(NPU_MUL**2)
1924
return C
2025

2126

2227
@torch.library.custom_op("sdnq::openvino_int_mm", mutates_args=())
2328
def openvino_int_mm(Tensor_A: torch.Tensor, Tensor_B: torch.Tensor) -> torch.Tensor:
24-
global OV_COMPILED_CACHE, OV_DEVICE # pylint: disable=global-variable-not-assigned
2529
if OV_DEVICE in {"NPU", "CPU"}:
2630
cache_key = (OV_DEVICE, Tensor_A.shape, Tensor_B.shape)
2731
else:
@@ -40,8 +44,12 @@ def openvino_int_mm(Tensor_A: torch.Tensor, Tensor_B: torch.Tensor) -> torch.Ten
4044
input_b = ov_ops.parameter(shape_b, ov.Type.i8, name="B")
4145
low = ov_ops.constant(-128.0, dtype=ov.Type.f32)
4246
high = ov_ops.constant(127.0, dtype=ov.Type.f32)
47+
4348
a = ov_ops.fake_quantize(ov_ops.convert(input_a, ov.Type.f32), low, high, low, high, 256)
4449
b = ov_ops.fake_quantize(ov_ops.convert(input_b, ov.Type.f32), low, high, low, high, 256)
50+
if OV_DEVICE == "NPU":
51+
a = ov_ops.divide(a, ov_ops.constant(NPU_MUL, dtype=ov.Type.f32))
52+
b = ov_ops.divide(b, ov_ops.constant(NPU_MUL, dtype=ov.Type.f32))
4553

4654
ov_model = ov.Model([ov_ops.matmul(a, b, False, False)], [input_a, input_b], "ov_int8_mm")
4755
ov_model = core.compile_model(ov_model, OV_DEVICE)

modules/sdnq/quant_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,6 @@ def build_hadamard(n: int, dtype: torch.dtype | None = None, device: torch.devic
114114
HADAMARD_MATRIX_CACHE = {}
115115
@devices.inference_context()
116116
def get_hadamard(n: int, dtype: torch.dtype | None = None, device: torch.device | None = None):
117-
global HADAMARD_MATRIX_CACHE # pylint: disable=global-variable-not-assigned
118117
device = devices.normalize_device(device)
119118
H_key = (n, device, dtype)
120119
H = HADAMARD_MATRIX_CACHE.get(H_key, None)

0 commit comments

Comments
 (0)