Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sycl/include/CL/sycl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <CL/sycl/item.hpp>
#include <CL/sycl/kernel.hpp>
#include <CL/sycl/kernel_bundle.hpp>
#include <CL/sycl/kernel_handler.hpp>
#include <CL/sycl/marray.hpp>
#include <CL/sycl/multi_ptr.hpp>
#include <CL/sycl/nd_item.hpp>
Expand All @@ -47,6 +48,7 @@
#include <CL/sycl/range.hpp>
#include <CL/sycl/reduction.hpp>
#include <CL/sycl/sampler.hpp>
#include <CL/sycl/specialization_id.hpp>
#include <CL/sycl/stream.hpp>
#include <CL/sycl/types.hpp>
#include <CL/sycl/usm.hpp>
Expand Down
74 changes: 68 additions & 6 deletions sycl/include/CL/sycl/detail/cg_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <CL/sycl/interop_handle.hpp>
#include <CL/sycl/interop_handler.hpp>
#include <CL/sycl/kernel.hpp>
#include <CL/sycl/kernel_handler.hpp>
#include <CL/sycl/nd_item.hpp>
#include <CL/sycl/range.hpp>

Expand Down Expand Up @@ -122,6 +123,32 @@ class NDRDescT {
size_t Dims;
};

template <typename Func>
struct KernelArgsInfo : KernelArgsInfo<decltype(&Func::operator())> {};

template <typename RetType, typename Func, class... Args>
struct KernelArgsInfo<RetType (Func::*)(Args...) const> {
constexpr static size_t argsCount() { return sizeof...(Args); };
constexpr static bool hasArgs() { return sizeof...(Args) > 0; };
template <std::size_t ArgNum> struct ArgType {
typedef typename std::tuple_element<ArgNum, std::tuple<Args...>>::type type;
};
};

template <typename Func>
std::enable_if_t<KernelArgsInfo<Func>::hasArgs(),
bool> constexpr kernelHandlerIsLastElementTypeOfKernel() {
return std::is_same<typename KernelArgsInfo<Func>::template ArgType<
KernelArgsInfo<Func>::argsCount() - 1>::type,
kernel_handler>::value;
}

template <typename Func>
std::enable_if_t<!KernelArgsInfo<Func>::hasArgs(),
bool> constexpr kernelHandlerIsLastElementTypeOfKernel() {
return false;
}

// The pure virtual class aimed to store lambda/functors of any type.
class HostKernelBase {
public:
Expand Down Expand Up @@ -197,7 +224,7 @@ class HostKernel : public HostKernelBase {
template <class ArgT = KernelArgType>
typename detail::enable_if_t<std::is_same<ArgT, void>::value>
runOnHost(const NDRDescT &) {
MKernel();
runKernelWithoutArg<decltype(MKernel)>();
}

template <class ArgT = KernelArgType>
Expand Down Expand Up @@ -228,7 +255,8 @@ class HostKernel : public HostKernelBase {
store_id(&ID);
store_item(&Item);
}
MKernel(ID);
runKernelWithArg<const sycl::id<Dims> &,
decltype(MKernel)>(ID);
});
}

Expand All @@ -253,7 +281,8 @@ class HostKernel : public HostKernelBase {
store_id(&ID);
store_item(&ItemWithOffset);
}
MKernel(Item);
runKernelWithArg<sycl::item<Dims, /*Offset=*/false>, decltype(MKernel)>(
Item);
});
}

Expand Down Expand Up @@ -286,7 +315,9 @@ class HostKernel : public HostKernelBase {
store_id(&ID);
store_item(&Item);
}
MKernel(Item);
runKernelWithArg<
sycl::item<Dims, /*Offset=*/true>,
decltype(MKernel)>(Item);
});
}

Expand Down Expand Up @@ -336,7 +367,7 @@ class HostKernel : public HostKernelBase {
auto g = NDItem.get_group();
store_group(&g);
}
MKernel(NDItem);
runKernelWithArg<const sycl::nd_item<Dims>, decltype(MKernel)>(NDItem);
});
});
}
Expand Down Expand Up @@ -364,11 +395,42 @@ class HostKernel : public HostKernelBase {
detail::NDLoop<Dims>::iterate(NGroups, [&](const id<Dims> &GroupID) {
sycl::group<Dims> Group =
IDBuilder::createGroup<Dims>(GlobalSize, LocalSize, NGroups, GroupID);
MKernel(Group);
runKernelWithArg<sycl::group<Dims>, decltype(MKernel)>(Group);
});
}

~HostKernel() = default;

private:
template <class KernelT>
std::enable_if_t<detail::kernelHandlerIsLastElementTypeOfKernel<KernelT>(),
void>
runKernelWithoutArg() {
kernel_handler KH;
MKernel(KH);
}

template <class KernelT>
std::enable_if_t<!detail::kernelHandlerIsLastElementTypeOfKernel<KernelT>(),
void>
runKernelWithoutArg() {
MKernel();
}

template <typename ArgType, class KernelT>
std::enable_if_t<detail::kernelHandlerIsLastElementTypeOfKernel<KernelT>(),
void>
runKernelWithArg(ArgType Arg) {
kernel_handler KH;
MKernel(Arg, KH);
}

template <typename ArgType, class KernelT>
std::enable_if_t<!detail::kernelHandlerIsLastElementTypeOfKernel<KernelT>(),
void>
runKernelWithArg(ArgType Arg) {
MKernel(Arg);
}
};

} // namespace detail
Expand Down
Loading