@@ -137,9 +137,8 @@ DNNLMatMulPrimitiveHandler::DNNLMatMulPrimitiveHandler(
137137}
138138
139139void DNNLMatMulPrimitiveHandler::prepack_weight (
140- void * original_b_ptr, dnnl::memory::desc b_target_mem_desc) {
141- dnnl::memory::desc original_b_md ({b_k_size_, b_n_size_}, b_type_,
142- {b_k_stride_, b_n_stride_});
140+ void * original_b_ptr, dnnl::memory::desc original_b_md,
141+ dnnl::memory::desc b_target_mem_desc) {
143142 dnnl::memory original_weight (original_b_md, default_engine (), original_b_ptr);
144143 dnnl::memory packed_weight (b_target_mem_desc, default_engine ());
145144 {
@@ -250,7 +249,9 @@ W8A8MatMulPrimitiveHandler::W8A8MatMulPrimitiveHandler(const Args& args)
250249 if (a_qs_ == QuantizationStrategy::PER_TOKEN) {
251250 assert (!use_azp_);
252251 };
253- prepack_weight (args.b_ptr ,
252+ dnnl::memory::desc original_b_md ({b_k_size_, b_n_size_}, b_type_,
253+ {b_k_stride_, b_n_stride_});
254+ prepack_weight (args.b_ptr , original_b_md,
254255 create_primitive_desc (
255256 MSizeCacheKey{.a_m_size = DNNL_RUNTIME_DIM_VAL,
256257 .use_bias = false ,
@@ -412,12 +413,25 @@ MatMulPrimitiveHandler::MatMulPrimitiveHandler(const Args& args)
412413 assert (ab_type_ == dnnl::memory::data_type::f32 ||
413414 ab_type_ == dnnl::memory::data_type::bf16 ||
414415 ab_type_ == dnnl::memory::data_type::f16 );
415- prepack_weight (args.b_ptr ,
416+
417+ dnnl::memory::desc original_b_md ({b_k_size_, b_n_size_}, b_type_,
418+ {b_k_stride_, b_n_stride_});
419+
420+ prepack_weight (args.b_ptr , original_b_md,
416421 create_primitive_desc (
417- MSizeCacheKey{.a_m_size = DNNL_RUNTIME_DIM_VAL,
418- .a_m_stride = DNNL_RUNTIME_DIM_VAL,
419- .use_bias = false ,
420- .bias_type = dnnl::memory::data_type::undef},
422+ MSizeCacheKey{
423+ #ifdef VLLM_USE_ACL
424+ // Arm Compute Library (ACL) backend for oneDNN does
425+ // not support runtime
426+ // dimensions, so we set M to a default value
427+ .a_m_size = 128 ,
428+ .a_m_stride = b_k_size_,
429+ #else
430+ .a_m_size = DNNL_RUNTIME_DIM_VAL,
431+ .a_m_stride = DNNL_RUNTIME_DIM_VAL,
432+ #endif
433+ .use_bias = false ,
434+ .bias_type = dnnl::memory::data_type::undef},
421435 true )
422436 .weights_desc ());
423437 init_runtime_memory_cache (args);
@@ -443,13 +457,31 @@ void MatMulPrimitiveHandler::execute(ExecArgs& args) {
443457 c_storage->set_data_handle ((void *)args.c_ptr );
444458 c_mem_desc->dims [0 ] = args.a_m_size ;
445459
460+ #ifndef VLLM_USE_ACL
461+ // We do not support in ACL backend of oneDNN, we handle bias by:
462+ // 1. copying it into the result tensor
463+ // 2. attaching a fused-sum post-op to the matmul primitive
446464 if (args.use_bias ) {
447465 auto && [bias_storage, bias_mem_desc] = get_runtime_memory_ptr (2 );
448466 bias_storage->set_data_handle ((void *)args.bias_ptr );
449467 }
450-
468+ # endif
451469 dnnl::matmul matmul = get_matmul_cache (args);
452470
471+ // With ACL backend of oneDNN, the required memory format might change when the
472+ // source tensor dims change. This does not really happen in practice, so isn't
473+ // a performance hit, but we need to support it because the API allows for it.
474+ #ifdef VLLM_USE_ACL
475+ auto new_expected_wei_desc =
476+ dnnl::matmul::primitive_desc (
477+ const_cast <dnnl_primitive_desc_t >(matmul.get_primitive_desc ()))
478+ .weights_desc ();
479+ if (new_expected_wei_desc != b_target_mem_desc_) {
480+ prepack_weight (memory_cache_[DNNL_ARG_WEIGHTS].get_data_handle (),
481+ b_target_mem_desc_, new_expected_wei_desc);
482+ }
483+ #endif
484+
453485 auto && [scratchpad_storage, scratchpad_mem_desc] = get_runtime_memory_ptr (3 );
454486 scratchpad_storage->set_data_handle (
455487 DNNLScratchPadManager::get_dnnl_scratchpad_manager ()->get_data <void >());
@@ -484,7 +516,13 @@ dnnl::matmul::primitive_desc MatMulPrimitiveHandler::create_primitive_desc(
484516 } else {
485517 a_md = dnnl::memory::desc ({key.a_m_size , b_k_size_}, b_type_,
486518 {key.a_m_stride , 1 });
519+ #ifdef VLLM_USE_ACL
520+ // ACL's backend of oneDNN always expects the weight format to be "any"
521+ b_md = dnnl::memory::desc ({b_k_size_, b_n_size_}, b_type_,
522+ dnnl::memory::format_tag::any);
523+ #else
487524 b_md = b_target_mem_desc_;
525+ #endif
488526 }
489527 dnnl::memory::desc c_md ({key.a_m_size , b_n_size_}, c_type_,
490528 dnnl::memory::format_tag::ab);
@@ -494,8 +532,18 @@ dnnl::matmul::primitive_desc MatMulPrimitiveHandler::create_primitive_desc(
494532
495533 if (key.use_bias ) {
496534 dnnl::memory::desc bias_md ({1 , b_n_size_}, key.bias_type , {b_n_size_, 1 });
535+ // Since ACL's matmuls don't support passing a bias_md, we apply the bias
536+ // through a fused-sum post-op
537+ #ifdef VLLM_USE_ACL
538+ dnnl::post_ops post_ops;
539+ post_ops.append_sum ();
540+ attr.set_post_ops (post_ops);
541+ return dnnl::matmul::primitive_desc (default_engine (), a_md, b_md, c_md,
542+ attr);
543+ #else
497544 return dnnl::matmul::primitive_desc (default_engine (), a_md, b_md, bias_md,
498545 c_md, attr);
546+ #endif
499547 } else {
500548 return dnnl::matmul::primitive_desc (default_engine (), a_md, b_md, c_md,
501549 attr);
@@ -511,13 +559,23 @@ void MatMulPrimitiveHandler::init_runtime_memory_cache(const Args& args) {
511559 default_engine (), nullptr );
512560 set_runtime_memory_ptr (1 , memory_cache_[DNNL_ARG_DST].get ());
513561
562+ // ACL matmuls don't support bias_md, so we don't need these
563+ #ifndef VLLM_USE_ACL
514564 memory_cache_[DNNL_ARG_BIAS] =
515565 dnnl::memory ({{b_n_size_}, dnnl::memory::data_type::f32 , {1 }},
516566 default_engine (), nullptr );
517567 set_runtime_memory_ptr (2 , memory_cache_[DNNL_ARG_BIAS].get ());
518-
568+ # endif
519569 memory_cache_[DNNL_ARG_SCRATCHPAD] =
520570 dnnl::memory ({{b_n_size_}, dnnl::memory::data_type::f32 , {1 }},
521571 default_engine (), nullptr );
522572 set_runtime_memory_ptr (3 , memory_cache_[DNNL_ARG_SCRATCHPAD].get ());
523573}
574+
575+ bool is_onednn_acl_supported () {
576+ #ifdef VLLM_USE_ACL
577+ return true ;
578+ #else
579+ return false ;
580+ #endif
581+ }
0 commit comments