Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cmake/cpu_extension.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON
endif()
set(ONEDNN_AARCH64_USE_ACL "ON")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ENV{ACL_ROOT_DIR}/build/")
add_compile_definitions(VLLM_USE_ACL)
endif()

set(ONEDNN_LIBRARY_TYPE "STATIC")
Expand All @@ -226,7 +227,7 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON
set(ONEDNN_ENABLE_ITT_TASKS "OFF")
set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF")
set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF")
set(ONEDNN_VERBOSE "OFF")
set(ONEDNN_VERBOSE "ON")
set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)

FetchContent_MakeAvailable(oneDNN)
Expand Down
80 changes: 69 additions & 11 deletions csrc/cpu/dnnl_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,8 @@ DNNLMatMulPrimitiveHandler::DNNLMatMulPrimitiveHandler(
}

void DNNLMatMulPrimitiveHandler::prepack_weight(
void* original_b_ptr, dnnl::memory::desc b_target_mem_desc) {
dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_,
{b_k_stride_, b_n_stride_});
void* original_b_ptr, dnnl::memory::desc original_b_md,
dnnl::memory::desc b_target_mem_desc) {
dnnl::memory original_weight(original_b_md, default_engine(), original_b_ptr);
dnnl::memory packed_weight(b_target_mem_desc, default_engine());
{
Expand Down Expand Up @@ -250,7 +249,9 @@ W8A8MatMulPrimitiveHandler::W8A8MatMulPrimitiveHandler(const Args& args)
if (a_qs_ == QuantizationStrategy::PER_TOKEN) {
assert(!use_azp_);
};
prepack_weight(args.b_ptr,
dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_,
{b_k_stride_, b_n_stride_});
prepack_weight(args.b_ptr, original_b_md,
create_primitive_desc(
MSizeCacheKey{.a_m_size = DNNL_RUNTIME_DIM_VAL,
.use_bias = false,
Expand Down Expand Up @@ -412,12 +413,25 @@ MatMulPrimitiveHandler::MatMulPrimitiveHandler(const Args& args)
assert(ab_type_ == dnnl::memory::data_type::f32 ||
ab_type_ == dnnl::memory::data_type::bf16 ||
ab_type_ == dnnl::memory::data_type::f16);
prepack_weight(args.b_ptr,

dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_,
{b_k_stride_, b_n_stride_});

prepack_weight(args.b_ptr, original_b_md,
create_primitive_desc(
MSizeCacheKey{.a_m_size = DNNL_RUNTIME_DIM_VAL,
.a_m_stride = DNNL_RUNTIME_DIM_VAL,
.use_bias = false,
.bias_type = dnnl::memory::data_type::undef},
MSizeCacheKey{
#ifdef VLLM_USE_ACL
// Arm Compute Library (ACL) backend for oneDNN does
// not support runtime
// dimensions, so we set M to a default value
.a_m_size = 128,
.a_m_stride = b_k_size_,
#else
.a_m_size = DNNL_RUNTIME_DIM_VAL,
.a_m_stride = DNNL_RUNTIME_DIM_VAL,
#endif
.use_bias = false,
.bias_type = dnnl::memory::data_type::undef},
true)
.weights_desc());
init_runtime_memory_cache(args);
Expand All @@ -443,13 +457,31 @@ void MatMulPrimitiveHandler::execute(ExecArgs& args) {
c_storage->set_data_handle((void*)args.c_ptr);
c_mem_desc->dims[0] = args.a_m_size;

#ifndef VLLM_USE_ACL
// We do not support in ACL backend of oneDNN, we handle bias by:
// 1. copying it into the result tensor
// 2. attaching a fused-sum post-op to the matmul primitive
if (args.use_bias) {
auto&& [bias_storage, bias_mem_desc] = get_runtime_memory_ptr(2);
bias_storage->set_data_handle((void*)args.bias_ptr);
}

#endif
dnnl::matmul matmul = get_matmul_cache(args);

// With ACL backend of oneDNN, the required memory format might change when the
// source tensor dims change. This does not really happen in practice, so isn't
// a performance hit, but we need to support it because the API allows for it.
#ifdef VLLM_USE_ACL
auto new_expected_wei_desc =
dnnl::matmul::primitive_desc(
const_cast<dnnl_primitive_desc_t>(matmul.get_primitive_desc()))
.weights_desc();
if (new_expected_wei_desc != b_target_mem_desc_) {
prepack_weight(memory_cache_[DNNL_ARG_WEIGHTS].get_data_handle(),
b_target_mem_desc_, new_expected_wei_desc);
}
#endif

auto&& [scratchpad_storage, scratchpad_mem_desc] = get_runtime_memory_ptr(3);
scratchpad_storage->set_data_handle(
DNNLScratchPadManager::get_dnnl_scratchpad_manager()->get_data<void>());
Expand Down Expand Up @@ -484,7 +516,13 @@ dnnl::matmul::primitive_desc MatMulPrimitiveHandler::create_primitive_desc(
} else {
a_md = dnnl::memory::desc({key.a_m_size, b_k_size_}, b_type_,
{key.a_m_stride, 1});
#ifdef VLLM_USE_ACL
// ACL's backend of oneDNN always expects the weight format to be "any"
b_md = dnnl::memory::desc({b_k_size_, b_n_size_}, b_type_,
dnnl::memory::format_tag::any);
#else
b_md = b_target_mem_desc_;
#endif
}
dnnl::memory::desc c_md({key.a_m_size, b_n_size_}, c_type_,
dnnl::memory::format_tag::ab);
Expand All @@ -494,8 +532,18 @@ dnnl::matmul::primitive_desc MatMulPrimitiveHandler::create_primitive_desc(

if (key.use_bias) {
dnnl::memory::desc bias_md({1, b_n_size_}, key.bias_type, {b_n_size_, 1});
// Since ACL's matmuls don't support passing a bias_md, we apply the bias
// through a fused-sum post-op
#ifdef VLLM_USE_ACL
dnnl::post_ops post_ops;
post_ops.append_sum();
attr.set_post_ops(post_ops);
return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, c_md,
attr);
#else
return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, bias_md,
c_md, attr);
#endif
} else {
return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, c_md,
attr);
Expand All @@ -511,13 +559,23 @@ void MatMulPrimitiveHandler::init_runtime_memory_cache(const Args& args) {
default_engine(), nullptr);
set_runtime_memory_ptr(1, memory_cache_[DNNL_ARG_DST].get());

// ACL matmuls don't support bias_md, so we don't need these
#ifndef VLLM_USE_ACL
memory_cache_[DNNL_ARG_BIAS] =
dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
default_engine(), nullptr);
set_runtime_memory_ptr(2, memory_cache_[DNNL_ARG_BIAS].get());

#endif
memory_cache_[DNNL_ARG_SCRATCHPAD] =
dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
default_engine(), nullptr);
set_runtime_memory_ptr(3, memory_cache_[DNNL_ARG_SCRATCHPAD].get());
}

bool is_onednn_acl_supported() {
#ifdef VLLM_USE_ACL
return true;
#else
return false;
#endif
}
2 changes: 1 addition & 1 deletion csrc/cpu/dnnl_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class DNNLMatMulPrimitiveHandler {
protected:
DNNLMatMulPrimitiveHandler(const Args& args, dnnl::memory::data_type b_type);

void prepack_weight(void* original_b_ptr,
void prepack_weight(void* original_b_ptr, dnnl::memory::desc original_b_md,
dnnl::memory::desc b_target_mem_desc);

void set_runtime_memory_ptr(size_t index, dnnl_memory* memory_ptr);
Expand Down
23 changes: 22 additions & 1 deletion csrc/cpu/dnnl_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -527,21 +527,42 @@ void onednn_mm(torch::Tensor& c, // [M, OC], row-major
MatMulPrimitiveHandler* ptr =
reinterpret_cast<MatMulPrimitiveHandler*>(handler);

// ACL matmuls expect contiguous source tensors
#ifdef VLLM_USE_ACL
torch::Tensor a_contig = a.contiguous();
#endif

MatMulPrimitiveHandler::ExecArgs exec_args;

#ifdef VLLM_USE_ACL
exec_args.a_m_size = a_contig.size(0);
exec_args.a_m_stride = a_contig.stride(0);
#else
exec_args.a_m_size = a.size(0);
exec_args.a_m_stride = a.stride(0);

#endif
VLLM_DISPATCH_FLOATING_TYPES(a.scalar_type(), "onednn_mm", [&] {
if (bias.has_value()) {
exec_args.use_bias = true;
exec_args.bias_type = get_dnnl_type<scalar_t>();
#ifdef VLLM_USE_ACL
// ACL matmuls in oneDNN do not support a bias.
// We handle a matmul with bias by doing: c = bias; c += matmul(a, b)
c.copy_(bias.value());
#else
exec_args.bias_ptr = bias->data_ptr<scalar_t>();
#endif
} else {
exec_args.use_bias = false;
exec_args.bias_type = get_dnnl_type<void>();
exec_args.bias_ptr = nullptr;
}
#ifdef VLLM_USE_ACL
exec_args.a_ptr = a_contig.data_ptr<scalar_t>();
#else
exec_args.a_ptr = a.data_ptr<scalar_t>();

#endif
exec_args.c_ptr = c.data_ptr<scalar_t>();

ptr->execute(exec_args);
Expand Down
5 changes: 5 additions & 0 deletions csrc/cpu/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ int64_t create_onednn_mm_handler(const torch::Tensor& b,
void onednn_mm(torch::Tensor& c, const torch::Tensor& a,
const std::optional<torch::Tensor>& bias, int64_t handler);

bool is_onednn_acl_supported();

void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query,
torch::Tensor& kv_cache, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens);
Expand Down Expand Up @@ -181,6 +183,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"int handler) -> ()");
ops.impl("onednn_mm", torch::kCPU, &onednn_mm);

// Check if oneDNN was built with ACL backend
ops.def("is_onednn_acl_supported() -> bool", &is_onednn_acl_supported);

// Create oneDNN W8A8 handler
ops.def(
"create_onednn_scaled_mm_handler(Tensor b, Tensor b_scales, ScalarType "
Expand Down
5 changes: 5 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,11 @@ def configure(self, ext: CMakeExtension) -> None:
# Make sure we use the nvcc from CUDA_HOME
if _is_cuda():
cmake_args += [f'-DCMAKE_CUDA_COMPILER={CUDA_HOME}/bin/nvcc']

other_cmake_args = os.environ.get("CMAKE_ARGS")
if other_cmake_args:
cmake_args += other_cmake_args.split()
Comment on lines +209 to +211
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this change is required? Looks like no arg is added to CMAKE_ARGS in this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To build oneDNN with ACL you currently need to set CMAKE_ARGS="-DVLLM_BUILD_ACL=ON "
I agree that it is not well documented on how one would build vllm's oneDNN with ACL.
A new PR is coming soon to build oneDNN with ACL by default.


subprocess.check_call(
['cmake', ext.cmake_lists_dir, *build_tool, *cmake_args],
cwd=self.build_temp)
Expand Down
4 changes: 4 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1926,6 +1926,10 @@ def __del__(self):
_supports_onednn = False


def is_onednn_acl_supported():
return torch.ops._C.is_onednn_acl_supported()


def create_onednn_mm(
weight: torch.Tensor, # [K, N]
primitive_cache_size: int = 128,
Expand Down
5 changes: 3 additions & 2 deletions vllm/model_executor/layers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,9 @@ def dispatch_cpu_unquantized_gemm(
if remove_weight:
layer.weight = torch.nn.Parameter(torch.empty(0),
requires_grad=False)
elif (ops._supports_onednn
and current_platform.get_cpu_architecture() == CpuArchEnum.X86):
elif ops._supports_onednn and (current_platform.get_cpu_architecture()
== CpuArchEnum.X86
or ops.is_onednn_acl_supported()):
origin_weight = layer.weight
if remove_weight:
layer.weight = torch.nn.Parameter(torch.empty(0),
Expand Down