22import torch
33import openvino as ov
44from openvino import opset16 as ov_ops
5-
5+ from openvino . properties import hint as ov_hints
66
77core = ov .Core ()
8+
9+ NPU_MUL = 32 # NPU uses FP16 x INT8 -> FP16 instead of INT8 x INT8 -> INT32 and FP16 output overflows
810OV_DEVICE : str = os .environ .get ("SDNQ_OPENVINO_DEVICE" , "CPU" )
911OV_COMPILED_CACHE : dict [tuple [str , tuple [int ,int ] | None , tuple [int ,int ] | None ], list [ov .InferRequest , str ]] = {}
12+ core .set_property (OV_DEVICE , {ov_hints .execution_mode : ov_hints .ExecutionMode .ACCURACY })
1013
1114
1215def ov_int_mm (A : torch .Tensor , B : torch .Tensor , infer_request : ov .InferRequest , out_name : str ) -> torch .Tensor :
@@ -16,12 +19,13 @@ def ov_int_mm(A: torch.Tensor, B: torch.Tensor, infer_request: ov.InferRequest,
1619 infer_request .set_tensor (out_name , ov .Tensor (C .numpy (), shared_memory = True ))
1720 infer_request .infer ()
1821 C = C .to (A .device )
22+ if OV_DEVICE == "NPU" :
23+ C .mul_ (NPU_MUL ** 2 )
1924 return C
2025
2126
2227@torch .library .custom_op ("sdnq::openvino_int_mm" , mutates_args = ())
2328def openvino_int_mm (Tensor_A : torch .Tensor , Tensor_B : torch .Tensor ) -> torch .Tensor :
24- global OV_COMPILED_CACHE , OV_DEVICE # pylint: disable=global-variable-not-assigned
2529 if OV_DEVICE in {"NPU" , "CPU" }:
2630 cache_key = (OV_DEVICE , Tensor_A .shape , Tensor_B .shape )
2731 else :
@@ -40,8 +44,12 @@ def openvino_int_mm(Tensor_A: torch.Tensor, Tensor_B: torch.Tensor) -> torch.Ten
4044 input_b = ov_ops .parameter (shape_b , ov .Type .i8 , name = "B" )
4145 low = ov_ops .constant (- 128.0 , dtype = ov .Type .f32 )
4246 high = ov_ops .constant (127.0 , dtype = ov .Type .f32 )
47+
4348 a = ov_ops .fake_quantize (ov_ops .convert (input_a , ov .Type .f32 ), low , high , low , high , 256 )
4449 b = ov_ops .fake_quantize (ov_ops .convert (input_b , ov .Type .f32 ), low , high , low , high , 256 )
50+ if OV_DEVICE == "NPU" :
51+ a = ov_ops .divide (a , ov_ops .constant (NPU_MUL , dtype = ov .Type .f32 ))
52+ b = ov_ops .divide (b , ov_ops .constant (NPU_MUL , dtype = ov .Type .f32 ))
4553
4654 ov_model = ov .Model ([ov_ops .matmul (a , b , False , False )], [input_a , input_b ], "ov_int8_mm" )
4755 ov_model = core .compile_model (ov_model , OV_DEVICE )
0 commit comments