Skip to content

Commit ccc16ad

Browse files
fadara01bigPYJ1151
authored andcommitted
[cpu][perf] Accelerate unquantized-linear for AArch64 through oneDNN/ACL and weight prepack (vllm-project#25948)
Signed-off-by: Fadi Arafeh <[email protected]> Co-authored-by: Li, Jiang <[email protected]>
1 parent 358b9f7 commit ccc16ad

File tree

8 files changed

+111
-16
lines changed

8 files changed

+111
-16
lines changed

cmake/cpu_extension.cmake

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON
213213
endif()
214214
set(ONEDNN_AARCH64_USE_ACL "ON")
215215
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ENV{ACL_ROOT_DIR}/build/")
216+
add_compile_definitions(VLLM_USE_ACL)
216217
endif()
217218

218219
set(ONEDNN_LIBRARY_TYPE "STATIC")
@@ -226,7 +227,7 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON
226227
set(ONEDNN_ENABLE_ITT_TASKS "OFF")
227228
set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF")
228229
set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF")
229-
set(ONEDNN_VERBOSE "OFF")
230+
set(ONEDNN_VERBOSE "ON")
230231
set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
231232

232233
FetchContent_MakeAvailable(oneDNN)

csrc/cpu/dnnl_helper.cpp

Lines changed: 69 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,8 @@ DNNLMatMulPrimitiveHandler::DNNLMatMulPrimitiveHandler(
137137
}
138138

139139
void DNNLMatMulPrimitiveHandler::prepack_weight(
140-
void* original_b_ptr, dnnl::memory::desc b_target_mem_desc) {
141-
dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_,
142-
{b_k_stride_, b_n_stride_});
140+
void* original_b_ptr, dnnl::memory::desc original_b_md,
141+
dnnl::memory::desc b_target_mem_desc) {
143142
dnnl::memory original_weight(original_b_md, default_engine(), original_b_ptr);
144143
dnnl::memory packed_weight(b_target_mem_desc, default_engine());
145144
{
@@ -250,7 +249,9 @@ W8A8MatMulPrimitiveHandler::W8A8MatMulPrimitiveHandler(const Args& args)
250249
if (a_qs_ == QuantizationStrategy::PER_TOKEN) {
251250
assert(!use_azp_);
252251
};
253-
prepack_weight(args.b_ptr,
252+
dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_,
253+
{b_k_stride_, b_n_stride_});
254+
prepack_weight(args.b_ptr, original_b_md,
254255
create_primitive_desc(
255256
MSizeCacheKey{.a_m_size = DNNL_RUNTIME_DIM_VAL,
256257
.use_bias = false,
@@ -412,12 +413,25 @@ MatMulPrimitiveHandler::MatMulPrimitiveHandler(const Args& args)
412413
assert(ab_type_ == dnnl::memory::data_type::f32 ||
413414
ab_type_ == dnnl::memory::data_type::bf16 ||
414415
ab_type_ == dnnl::memory::data_type::f16);
415-
prepack_weight(args.b_ptr,
416+
417+
dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_,
418+
{b_k_stride_, b_n_stride_});
419+
420+
prepack_weight(args.b_ptr, original_b_md,
416421
create_primitive_desc(
417-
MSizeCacheKey{.a_m_size = DNNL_RUNTIME_DIM_VAL,
418-
.a_m_stride = DNNL_RUNTIME_DIM_VAL,
419-
.use_bias = false,
420-
.bias_type = dnnl::memory::data_type::undef},
422+
MSizeCacheKey{
423+
#ifdef VLLM_USE_ACL
424+
// Arm Compute Library (ACL) backend for oneDNN does
425+
// not support runtime
426+
// dimensions, so we set M to a default value
427+
.a_m_size = 128,
428+
.a_m_stride = b_k_size_,
429+
#else
430+
.a_m_size = DNNL_RUNTIME_DIM_VAL,
431+
.a_m_stride = DNNL_RUNTIME_DIM_VAL,
432+
#endif
433+
.use_bias = false,
434+
.bias_type = dnnl::memory::data_type::undef},
421435
true)
422436
.weights_desc());
423437
init_runtime_memory_cache(args);
@@ -443,13 +457,31 @@ void MatMulPrimitiveHandler::execute(ExecArgs& args) {
443457
c_storage->set_data_handle((void*)args.c_ptr);
444458
c_mem_desc->dims[0] = args.a_m_size;
445459

460+
#ifndef VLLM_USE_ACL
461+
// We do not support in ACL backend of oneDNN, we handle bias by:
462+
// 1. copying it into the result tensor
463+
// 2. attaching a fused-sum post-op to the matmul primitive
446464
if (args.use_bias) {
447465
auto&& [bias_storage, bias_mem_desc] = get_runtime_memory_ptr(2);
448466
bias_storage->set_data_handle((void*)args.bias_ptr);
449467
}
450-
468+
#endif
451469
dnnl::matmul matmul = get_matmul_cache(args);
452470

471+
// With ACL backend of oneDNN, the required memory format might change when the
472+
// source tensor dims change. This does not really happen in practice, so isn't
473+
// a performance hit, but we need to support it because the API allows for it.
474+
#ifdef VLLM_USE_ACL
475+
auto new_expected_wei_desc =
476+
dnnl::matmul::primitive_desc(
477+
const_cast<dnnl_primitive_desc_t>(matmul.get_primitive_desc()))
478+
.weights_desc();
479+
if (new_expected_wei_desc != b_target_mem_desc_) {
480+
prepack_weight(memory_cache_[DNNL_ARG_WEIGHTS].get_data_handle(),
481+
b_target_mem_desc_, new_expected_wei_desc);
482+
}
483+
#endif
484+
453485
auto&& [scratchpad_storage, scratchpad_mem_desc] = get_runtime_memory_ptr(3);
454486
scratchpad_storage->set_data_handle(
455487
DNNLScratchPadManager::get_dnnl_scratchpad_manager()->get_data<void>());
@@ -484,7 +516,13 @@ dnnl::matmul::primitive_desc MatMulPrimitiveHandler::create_primitive_desc(
484516
} else {
485517
a_md = dnnl::memory::desc({key.a_m_size, b_k_size_}, b_type_,
486518
{key.a_m_stride, 1});
519+
#ifdef VLLM_USE_ACL
520+
// ACL's backend of oneDNN always expects the weight format to be "any"
521+
b_md = dnnl::memory::desc({b_k_size_, b_n_size_}, b_type_,
522+
dnnl::memory::format_tag::any);
523+
#else
487524
b_md = b_target_mem_desc_;
525+
#endif
488526
}
489527
dnnl::memory::desc c_md({key.a_m_size, b_n_size_}, c_type_,
490528
dnnl::memory::format_tag::ab);
@@ -494,8 +532,18 @@ dnnl::matmul::primitive_desc MatMulPrimitiveHandler::create_primitive_desc(
494532

495533
if (key.use_bias) {
496534
dnnl::memory::desc bias_md({1, b_n_size_}, key.bias_type, {b_n_size_, 1});
535+
// Since ACL's matmuls don't support passing a bias_md, we apply the bias
536+
// through a fused-sum post-op
537+
#ifdef VLLM_USE_ACL
538+
dnnl::post_ops post_ops;
539+
post_ops.append_sum();
540+
attr.set_post_ops(post_ops);
541+
return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, c_md,
542+
attr);
543+
#else
497544
return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, bias_md,
498545
c_md, attr);
546+
#endif
499547
} else {
500548
return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, c_md,
501549
attr);
@@ -511,13 +559,23 @@ void MatMulPrimitiveHandler::init_runtime_memory_cache(const Args& args) {
511559
default_engine(), nullptr);
512560
set_runtime_memory_ptr(1, memory_cache_[DNNL_ARG_DST].get());
513561

562+
// ACL matmuls don't support bias_md, so we don't need these
563+
#ifndef VLLM_USE_ACL
514564
memory_cache_[DNNL_ARG_BIAS] =
515565
dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
516566
default_engine(), nullptr);
517567
set_runtime_memory_ptr(2, memory_cache_[DNNL_ARG_BIAS].get());
518-
568+
#endif
519569
memory_cache_[DNNL_ARG_SCRATCHPAD] =
520570
dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
521571
default_engine(), nullptr);
522572
set_runtime_memory_ptr(3, memory_cache_[DNNL_ARG_SCRATCHPAD].get());
523573
}
574+
575+
bool is_onednn_acl_supported() {
576+
#ifdef VLLM_USE_ACL
577+
return true;
578+
#else
579+
return false;
580+
#endif
581+
}

csrc/cpu/dnnl_helper.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ class DNNLMatMulPrimitiveHandler {
101101
protected:
102102
DNNLMatMulPrimitiveHandler(const Args& args, dnnl::memory::data_type b_type);
103103

104-
void prepack_weight(void* original_b_ptr,
104+
void prepack_weight(void* original_b_ptr, dnnl::memory::desc original_b_md,
105105
dnnl::memory::desc b_target_mem_desc);
106106

107107
void set_runtime_memory_ptr(size_t index, dnnl_memory* memory_ptr);

csrc/cpu/dnnl_kernels.cpp

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -527,21 +527,42 @@ void onednn_mm(torch::Tensor& c, // [M, OC], row-major
527527
MatMulPrimitiveHandler* ptr =
528528
reinterpret_cast<MatMulPrimitiveHandler*>(handler);
529529

530+
// ACL matmuls expect contiguous source tensors
531+
#ifdef VLLM_USE_ACL
532+
torch::Tensor a_contig = a.contiguous();
533+
#endif
534+
530535
MatMulPrimitiveHandler::ExecArgs exec_args;
536+
537+
#ifdef VLLM_USE_ACL
538+
exec_args.a_m_size = a_contig.size(0);
539+
exec_args.a_m_stride = a_contig.stride(0);
540+
#else
531541
exec_args.a_m_size = a.size(0);
532542
exec_args.a_m_stride = a.stride(0);
533-
543+
#endif
534544
VLLM_DISPATCH_FLOATING_TYPES(a.scalar_type(), "onednn_mm", [&] {
535545
if (bias.has_value()) {
536546
exec_args.use_bias = true;
537547
exec_args.bias_type = get_dnnl_type<scalar_t>();
548+
#ifdef VLLM_USE_ACL
549+
// ACL matmuls in oneDNN do not support a bias.
550+
// We handle a matmul with bias by doing: c = bias; c += matmul(a, b)
551+
c.copy_(bias.value());
552+
#else
538553
exec_args.bias_ptr = bias->data_ptr<scalar_t>();
554+
#endif
539555
} else {
540556
exec_args.use_bias = false;
541557
exec_args.bias_type = get_dnnl_type<void>();
542558
exec_args.bias_ptr = nullptr;
543559
}
560+
#ifdef VLLM_USE_ACL
561+
exec_args.a_ptr = a_contig.data_ptr<scalar_t>();
562+
#else
544563
exec_args.a_ptr = a.data_ptr<scalar_t>();
564+
565+
#endif
545566
exec_args.c_ptr = c.data_ptr<scalar_t>();
546567

547568
ptr->execute(exec_args);

csrc/cpu/torch_bindings.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ int64_t create_onednn_mm_handler(const torch::Tensor& b,
2727
void onednn_mm(torch::Tensor& c, const torch::Tensor& a,
2828
const std::optional<torch::Tensor>& bias, int64_t handler);
2929

30+
bool is_onednn_acl_supported();
31+
3032
void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query,
3133
torch::Tensor& kv_cache, double scale,
3234
torch::Tensor& block_tables, torch::Tensor& seq_lens);
@@ -181,6 +183,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
181183
"int handler) -> ()");
182184
ops.impl("onednn_mm", torch::kCPU, &onednn_mm);
183185

186+
// Check if oneDNN was built with ACL backend
187+
ops.def("is_onednn_acl_supported() -> bool", &is_onednn_acl_supported);
188+
184189
// Create oneDNN W8A8 handler
185190
ops.def(
186191
"create_onednn_scaled_mm_handler(Tensor b, Tensor b_scales, ScalarType "

setup.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,11 @@ def configure(self, ext: CMakeExtension) -> None:
205205
# Make sure we use the nvcc from CUDA_HOME
206206
if _is_cuda():
207207
cmake_args += [f'-DCMAKE_CUDA_COMPILER={CUDA_HOME}/bin/nvcc']
208+
209+
other_cmake_args = os.environ.get("CMAKE_ARGS")
210+
if other_cmake_args:
211+
cmake_args += other_cmake_args.split()
212+
208213
subprocess.check_call(
209214
['cmake', ext.cmake_lists_dir, *build_tool, *cmake_args],
210215
cwd=self.build_temp)

vllm/_custom_ops.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1926,6 +1926,10 @@ def __del__(self):
19261926
_supports_onednn = False
19271927

19281928

1929+
def is_onednn_acl_supported():
1930+
return torch.ops._C.is_onednn_acl_supported()
1931+
1932+
19291933
def create_onednn_mm(
19301934
weight: torch.Tensor, # [K, N]
19311935
primitive_cache_size: int = 128,

vllm/model_executor/layers/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,9 @@ def dispatch_cpu_unquantized_gemm(
165165
if remove_weight:
166166
layer.weight = torch.nn.Parameter(torch.empty(0),
167167
requires_grad=False)
168-
elif (ops._supports_onednn
169-
and current_platform.get_cpu_architecture() == CpuArchEnum.X86):
168+
elif ops._supports_onednn and (current_platform.get_cpu_architecture()
169+
== CpuArchEnum.X86
170+
or ops.is_onednn_acl_supported()):
170171
origin_weight = layer.weight
171172
if remove_weight:
172173
layer.weight = torch.nn.Parameter(torch.empty(0),

0 commit comments

Comments
 (0)