Skip to content
Open
15 changes: 15 additions & 0 deletions deps/onemkl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,21 @@ extern "C" int onemklZgemm(syclQueue_t device_queue, onemklTranspose transA,
return 0;
}

extern "C" void onemklSaxpy(syclQueue_t device_queue, int64_t n, float alpha, const float *x, std::int64_t incx, float *y, int64_t incy) {
oneapi::mkl::blas::column_major::axpy(device_queue->val, n, alpha, x, incx, y, incy);
}

extern "C" void onemklDaxpy(syclQueue_t device_queue, int64_t n, double alpha, const double *x, std::int64_t incx, double *y, int64_t incy) {
oneapi::mkl::blas::column_major::axpy(device_queue->val, n, alpha, x, incx, y, incy);
}

extern "C" void onemklCaxpy(syclQueue_t device_queue, int64_t n, float _Complex alpha, const float _Complex *x, std::int64_t incx, float _Complex *y, int64_t incy) {
oneapi::mkl::blas::column_major::axpy(device_queue->val, n, alpha, reinterpret_cast<const std::complex<float> *>(x), incx, reinterpret_cast<std::complex<float> *>(y), incy);
}

extern "C" void onemklZaxpy(syclQueue_t device_queue, int64_t n, double _Complex alpha, const double _Complex *x, std::int64_t incx, double _Complex *y, int64_t incy) {
oneapi::mkl::blas::column_major::axpy(device_queue->val, n, alpha, reinterpret_cast<const std::complex<double> *>(x), incx, reinterpret_cast<std::complex<double> *>(y), incy);
}

// other

Expand Down
5 changes: 5 additions & 0 deletions deps/onemkl.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ int onemklZgemm(syclQueue_t device_queue, onemklTranspose transA,
const double _Complex *B, int64_t ldb, double _Complex beta,
double _Complex *C, int64_t ldc);

void onemklSaxpy(syclQueue_t device_queue, int64_t n, float alpha, const float *x, int64_t incx, float *y, int64_t incy);
void onemklDaxpy(syclQueue_t device_queue, int64_t n, double alpha, const double *x, int64_t incx, double *y, int64_t incy);
void onemklCaxpy(syclQueue_t device_queue, int64_t n, float _Complex alpha, const float _Complex *x, int64_t incx, float _Complex *y, int64_t incy);
void onemklZaxpy(syclQueue_t device_queue, int64_t n, double _Complex alpha, const double _Complex *x, int64_t incx, double _Complex *y, int64_t incy);

void onemklDestroy();
#ifdef __cplusplus
}
Expand Down
16 changes: 16 additions & 0 deletions lib/mkl/libonemkl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,19 @@ function onemklZgemm(device_queue, transA, transB, m, n, k, alpha, A, lda, B, ld
B::ZePtr{ComplexF64}, ldb::Int64, beta::ComplexF64,
C::ZePtr{ComplexF64}, ldc::Int64)::Cint
end

function onemklSaxpy(device_queue, n, alpha, x, incx, y, incy)
@ccall liboneapi_support.onemklSaxpy(device_queue::syclQueue_t, n::Int64, alpha::Cfloat, x::ZePtr{Cfloat}, incx::Int64, y::ZePtr{Cfloat}, incy::Int64)::Cvoid
end

function onemklDaxpy(device_queue, n, alpha, x, incx, y, incy)
@ccall liboneapi_support.onemklDaxpy(device_queue::syclQueue_t, n::Int64, alpha::Cdouble, x::ZePtr{Cdouble}, incx::Int64, y::ZePtr{Cdouble}, incy::Int64)::Cvoid
end

function onemklCaxpy(device_queue, n, alpha, x, incx, y, incy)
@ccall liboneapi_support.onemklCaxpy(device_queue::syclQueue_t, n::Int64, alpha::ComplexF32, x::ZePtr{ComplexF32}, incx::Int64, y::ZePtr{ComplexF32}, incy::Int64)::Cvoid
end

function onemklZaxpy(device_queue, n, alpha, x, incx, y, incy)
@ccall liboneapi_support.onemklZaxpy(device_queue::syclQueue_t, n::Int64, alpha::ComplexF64, x::ZePtr{ComplexF64}, incx::Int64, y::ZePtr{ComplexF64}, incy::Int64)::Cvoid
end
5 changes: 5 additions & 0 deletions lib/mkl/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ function gemm_dispatch!(C::oneStridedVecOrMat, A, B, alpha::Number=true, beta::N
end
end

function LinearAlgebra.axpy!(alpha::Number, x::oneStridedVecOrMat{<:onemklFloat}, y::oneStridedVecOrMat{<:onemklFloat}) where T<:Union{Float16, ComplexF16, onemklFloat}
length(x)==length(y) || throw(DimensionMismatch("axpy arguments have lengths $(length(x)) and $(length(y))"))
oneMKL.axpy!(length(x), alpha, x, y)
end

for NT in (Number, Real)
# NOTE: alpha/beta also ::Real to avoid ambiguities with certain Base methods
@eval begin
Expand Down
2 changes: 1 addition & 1 deletion lib/mkl/oneMKL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ using GPUArrays

include("libonemkl.jl")

const onemklFloat = Union{Float64,Float32,Float16,ComplexF64,ComplexF32}
const onemklFloat = Union{Float64,Float32,ComplexF64,ComplexF32}

include("wrappers.jl")
include("linalg.jl")
Expand Down
22 changes: 20 additions & 2 deletions lib/mkl/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,26 @@ function Base.convert(::Type{onemklTranspose}, trans::Char)
end
end



# level 1
## axpy
for (fname, elty) in
((:onemklDaxpy,:Float64),
(:onemklSaxpy,:Float32),
(:onemklZaxpy,:ComplexF64),
(:onemklCaxpy,:ComplexF32))
@eval begin
function axpy!(n::Integer,
alpha::Number,
x::StridedArray{$elty},
y::StridedArray{$elty}
)
queue = global_queue(context(x), device(x))
alpha = $elty(alpha)
$fname(sycl_queue(queue), n, alpha, x, stride(x,1), y, stride(y,1))
y
end
end
end
#
# BLAS
#
Expand Down
19 changes: 19 additions & 0 deletions test/onemkl.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using oneAPI
using oneAPI.oneMKL
using LinearAlgebra

m = 20
n = 35
k = 13

#####
@testset "level 1" begin
@testset for T in eltypes
if T <:oneMKL.onemklFloat
A = rand(T,m)
B = rand(T, m)
alpha = rand()
@test testf(axpy!, alpha, A, B)
end
end
end