-
Notifications
You must be signed in to change notification settings - Fork 808
[SYCL] Represent JointMatrixINTEL type as extension type #8343
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
b9529a6
366c3e3
5d16304
3edb1bd
8d29f3c
055c129
3e8d12a
a6e4e23
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -51,65 +51,6 @@ void CodeGenTypes::addRecordTypeName(const RecordDecl *RD, | |
| StringRef suffix) { | ||
| SmallString<256> TypeName; | ||
| llvm::raw_svector_ostream OS(TypeName); | ||
| // If RD is spirv_JointMatrixINTEL type, mangle differently. | ||
| if (CGM.getTriple().isSPIRV() || CGM.getTriple().isSPIR()) { | ||
| if (RD->getQualifiedNameAsString() == "__spv::__spirv_JointMatrixINTEL") { | ||
| if (auto TemplateDecl = dyn_cast<ClassTemplateSpecializationDecl>(RD)) { | ||
| ArrayRef<TemplateArgument> TemplateArgs = | ||
| TemplateDecl->getTemplateArgs().asArray(); | ||
| OS << "spirv.JointMatrixINTEL."; | ||
| for (auto &TemplateArg : TemplateArgs) { | ||
| OS << "_"; | ||
| if (TemplateArg.getKind() == TemplateArgument::Type) { | ||
| llvm::Type *TTy = ConvertType(TemplateArg.getAsType()); | ||
| if (TTy->isIntegerTy()) { | ||
| switch (TTy->getIntegerBitWidth()) { | ||
| case 8: | ||
| OS << "char"; | ||
| break; | ||
| case 16: | ||
| OS << "short"; | ||
| break; | ||
| case 32: | ||
| OS << "int"; | ||
| break; | ||
| case 64: | ||
| OS << "long"; | ||
| break; | ||
| default: | ||
| OS << "i" << TTy->getIntegerBitWidth(); | ||
| break; | ||
| } | ||
| } else if (TTy->isHalfTy()) { | ||
| OS << "half"; | ||
| } else if (TTy->isFloatTy()) { | ||
| OS << "float"; | ||
| } else if (TTy->isDoubleTy()) { | ||
| OS << "double"; | ||
| } else if (TTy->isBFloatTy()) { | ||
| OS << "bfloat16"; | ||
| } else if (TTy->isStructTy()) { | ||
| StringRef LlvmTyName = TTy->getStructName(); | ||
| // Emit half/bfloat16/tf32 for sycl[::*]::{half,bfloat16,tf32} | ||
| if (LlvmTyName.startswith("class.sycl::") || | ||
| LlvmTyName.startswith("class.__sycl_internal::")) | ||
| LlvmTyName = LlvmTyName.rsplit("::").second; | ||
| if (LlvmTyName != "half" && LlvmTyName != "bfloat16" && | ||
| LlvmTyName != "tf32") | ||
| llvm_unreachable("Wrong matrix base type!"); | ||
| OS << LlvmTyName; | ||
| } else { | ||
| llvm_unreachable("Wrong matrix base type!"); | ||
| } | ||
| } else if (TemplateArg.getKind() == TemplateArgument::Integral) { | ||
| OS << TemplateArg.getAsIntegral(); | ||
| } | ||
| } | ||
| Ty->setName(OS.str()); | ||
| return; | ||
| } | ||
| } | ||
| } | ||
| OS << RD->getKindName() << '.'; | ||
|
|
||
| // FIXME: We probably want to make more tweaks to the printing policy. For | ||
|
|
@@ -460,6 +401,78 @@ llvm::Type *CodeGenTypes::ConvertFunctionTypeInternal(QualType QFT) { | |
| return ResultType; | ||
| } | ||
|
|
||
| template <bool NeedTypeInterpret = false> | ||
| llvm::Type *getJointMatrixINTELExtType(llvm::Type *CompTy, | ||
| ArrayRef<TemplateArgument> TemplateArgs, | ||
| const unsigned Val = 0) { | ||
| // TODO: we should actually have exactly 5 template parameters: 1 for | ||
| // type and 4 for type parameters. But in previous version of the SPIR-V | ||
| // spec we have Layout matrix type parameter, that was later removed. | ||
| // Once we update to the newest version of the spec - this should be updated. | ||
| assert((TemplateArgs.size() == 5 || TemplateArgs.size() == 6) && | ||
| "Wrong JointMatrixINTEL template parameters number"); | ||
| // This is required to represent optional Optional | ||
| // 'Component Type Interpretation' parameter | ||
| using ParamsType = | ||
| typename std::conditional<NeedTypeInterpret, SmallVector<unsigned, 6>, | ||
| SmallVector<unsigned, 5>>::type; | ||
| ParamsType Params; | ||
| if constexpr (NeedTypeInterpret) | ||
| Params = {0, 0, 0, 0, 0, Val}; | ||
| else | ||
| Params = {0, 0, 0, 0, 0}; | ||
| for (size_t I = 1; I != TemplateArgs.size(); ++I) { | ||
| assert(TemplateArgs[I].getKind() == TemplateArgument::Integral && | ||
| "Wrong JointMatrixINTEL template parameter"); | ||
| Params[I - 1] = TemplateArgs[I].getAsIntegral().getExtValue(); | ||
| } | ||
| return llvm::TargetExtType::get(CompTy->getContext(), | ||
| "spirv.JointMatrixINTEL", {CompTy}, Params); | ||
| } | ||
|
|
||
| /// ConvertSYCLJointMatrixINTELType - Convert SYCL joint_matrix type | ||
| /// which is represented as a pointer to a structure to LLVM extension type | ||
| /// with the parameters that follow SPIR-V JointMatrixINTEL type. | ||
| /// The expected representation is: | ||
| /// target("spirv.JointMatrixINTEL", %element_type, %rows%, %cols%, %scope%, | ||
| /// %use%, (optional) %element_type_interpretation%) | ||
| llvm::Type *CodeGenTypes::ConvertSYCLJointMatrixINTELType(RecordDecl *RD) { | ||
| auto *TemplateDecl = cast<ClassTemplateSpecializationDecl>(RD); | ||
| ArrayRef<TemplateArgument> TemplateArgs = | ||
| TemplateDecl->getTemplateArgs().asArray(); | ||
| assert(TemplateArgs[0].getKind() == TemplateArgument::Type && | ||
| "1st JointMatrixINTEL template parameter must be type"); | ||
| llvm::Type *CompTy = ConvertType(TemplateArgs[0].getAsType()); | ||
|
|
||
| // Per JointMatrixINTEL spec the type can have an Optional | ||
MrSidims marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| // 'Component Type Interpretation' parameter. We should emit it in case | ||
| // if on SYCL level joint matrix accepts 'bfloat16' or 'tf32' objects as | ||
| // matrix's components. Yet bfloat16 should be represented as 'int16' and | ||
MrSidims marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| // 'tf32' as 'float' types. | ||
| if (CompTy->isStructTy()) { | ||
| StringRef LlvmTyName = CompTy->getStructName(); | ||
| // Emit half/int16/float for sycl[::*]::{half,bfloat16,tf32} | ||
| if (LlvmTyName.startswith("class.sycl::") || | ||
| LlvmTyName.startswith("class.__sycl_internal::")) | ||
| LlvmTyName = LlvmTyName.rsplit("::").second; | ||
| if (LlvmTyName == "half") { | ||
| CompTy = llvm::Type::getHalfTy(getLLVMContext()); | ||
| return getJointMatrixINTELExtType(CompTy, TemplateArgs); | ||
| } else if (LlvmTyName == "tf32") { | ||
| CompTy = llvm::Type::getFloatTy(getLLVMContext()); | ||
| // 'tf32' interpretation is mapped to '0' | ||
| return getJointMatrixINTELExtType<true>(CompTy, TemplateArgs, 0); | ||
| } else if (LlvmTyName == "bfloat16") { | ||
| CompTy = llvm::Type::getInt16Ty(getLLVMContext()); | ||
|
||
| // 'bfloat16' interpretation is mapped to '1' | ||
| return getJointMatrixINTELExtType<true>(CompTy, TemplateArgs, 1); | ||
| } else { | ||
| llvm_unreachable("Wrong matrix base type!"); | ||
| } | ||
| } | ||
| return getJointMatrixINTELExtType(CompTy, TemplateArgs); | ||
| } | ||
|
|
||
| /// ConvertType - Convert the specified type to its LLVM form. | ||
| llvm::Type *CodeGenTypes::ConvertType(QualType T) { | ||
| T = Context.getCanonicalType(T); | ||
|
|
@@ -745,6 +758,18 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) { | |
| llvm::Type *PointeeType = ConvertTypeForMem(ETy); | ||
| if (PointeeType->isVoidTy()) | ||
| PointeeType = llvm::Type::getInt8Ty(getLLVMContext()); | ||
| if (CGM.getTriple().isSPIRV() || CGM.getTriple().isSPIR()) { | ||
| const Type *ClangETy = ETy.getTypePtrOrNull(); | ||
| if (ClangETy && ClangETy->isStructureOrClassType()) { | ||
| RecordDecl *RD = ClangETy->getAsCXXRecordDecl(); | ||
| if (RD && | ||
| RD->getQualifiedNameAsString() == "__spv::__spirv_JointMatrixINTEL") { | ||
| ResultType = ConvertSYCLJointMatrixINTELType(RD); | ||
| break; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| unsigned AS = getTargetAddressSpace(ETy); | ||
| ResultType = llvm::PointerType::get(PointeeType, AS); | ||
| break; | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,18 +5,18 @@ | |
| #include <stdint.h> | ||
|
|
||
| namespace __spv { | ||
| template <typename T, size_t R, size_t C, uint32_t U, uint32_t S> | ||
| template <typename T, size_t R, size_t C, uint32_t L, uint32_t S, uint32_t U> | ||
| struct __spirv_JointMatrixINTEL; | ||
| } | ||
|
|
||
| // CHECK: @_Z2f1{{.*}}(%spirv.JointMatrixINTEL._float_5_10_0_1 | ||
| void f1(__spv::__spirv_JointMatrixINTEL<float, 5, 10, 0, 1> *matrix) {} | ||
| // CHECK: @_Z2f1{{.*}}(target("spirv.JointMatrixINTEL", float, 5, 10, 0, 1, 0) | ||
| void f1(__spv::__spirv_JointMatrixINTEL<float, 5, 10, 0, 1, 0> *matrix) {} | ||
|
|
||
| // CHECK: @_Z2f2{{.*}}(%spirv.JointMatrixINTEL._long_10_2_0_0 | ||
| void f2(__spv::__spirv_JointMatrixINTEL<uint64_t, 10, 2, 0, 0> *matrix) {} | ||
| // CHECK: @_Z2f2{{.*}}(target("spirv.JointMatrixINTEL", i64, 10, 2, 0, 0, 0) | ||
| void f2(__spv::__spirv_JointMatrixINTEL<uint64_t, 10, 2, 0, 0, 0> *matrix) {} | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @aelovikov-intel here is the test for unsigned. Would you mind if I won't duplicate it in sycl headers?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, I'm not sure what's the purpose of the SYCL RT tests then, but yes, no need for unsigned there.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, matrix API is changing overtime, so having several compilation tests using real headers is good to have. One the extension is stable we will remove them (note, the E2E tests were in different repo and require specific hardware) |
||
|
|
||
| // CHECK: @_Z2f3{{.*}}(%spirv.JointMatrixINTEL._char_10_2_0_0 | ||
| void f3(__spv::__spirv_JointMatrixINTEL<char, 10, 2, 0, 0> *matrix) {} | ||
| // CHECK: @_Z2f3{{.*}}(target("spirv.JointMatrixINTEL", i8, 10, 2, 0, 0, 0) | ||
| void f3(__spv::__spirv_JointMatrixINTEL<char, 10, 2, 0, 0, 0> *matrix) {} | ||
|
|
||
| namespace sycl { | ||
| class half {}; | ||
|
|
@@ -25,17 +25,17 @@ namespace sycl { | |
| } | ||
| typedef sycl::half my_half; | ||
|
|
||
| // CHECK: @_Z2f4{{.*}}(%spirv.JointMatrixINTEL._half_10_2_0_0 | ||
| void f4(__spv::__spirv_JointMatrixINTEL<my_half, 10, 2, 0, 0> *matrix) {} | ||
| // CHECK: @_Z2f4{{.*}}(target("spirv.JointMatrixINTEL", half, 10, 2, 0, 0, 0) | ||
| void f4(__spv::__spirv_JointMatrixINTEL<my_half, 10, 2, 0, 0, 0> *matrix) {} | ||
|
|
||
| // CHECK: @_Z2f5{{.*}}(%spirv.JointMatrixINTEL._bfloat16_10_2_0_0 | ||
| void f5(__spv::__spirv_JointMatrixINTEL<sycl::bfloat16, 10, 2, 0, 0> *matrix) {} | ||
| // CHECK: @_Z2f5{{.*}}(target("spirv.JointMatrixINTEL", i16, 10, 2, 0, 0, 0, 1) | ||
| void f5(__spv::__spirv_JointMatrixINTEL<sycl::bfloat16, 10, 2, 0, 0, 0> *matrix) {} | ||
|
|
||
| // CHECK: @_Z2f6{{.*}}(%spirv.JointMatrixINTEL._i128_10_2_0_0 | ||
| void f6(__spv::__spirv_JointMatrixINTEL<_BitInt(128), 10, 2, 0, 0> *matrix) {} | ||
| // CHECK: @_Z2f6{{.*}}(target("spirv.JointMatrixINTEL", i128, 10, 2, 0, 0, 0) | ||
| void f6(__spv::__spirv_JointMatrixINTEL<_BitInt(128), 10, 2, 0, 0, 0> *matrix) {} | ||
|
|
||
| // CHECK: @_Z2f7{{.*}}(%spirv.JointMatrixINTEL._tf32_10_2_0_0 | ||
| void f7(__spv::__spirv_JointMatrixINTEL<sycl::tf32, 10, 2, 0, 0> *matrix) {} | ||
| // CHECK: @_Z2f7{{.*}}(target("spirv.JointMatrixINTEL", float, 10, 2, 0, 0, 0, 0) | ||
| void f7(__spv::__spirv_JointMatrixINTEL<sycl::tf32, 10, 2, 0, 0, 0> *matrix) {} | ||
|
|
||
| // CHECK: @_Z2f8{{.*}}(%spirv.JointMatrixINTEL._double_5_10_0_1 | ||
| void f8(__spv::__spirv_JointMatrixINTEL<double, 5, 10, 0, 1> *matrix) {} | ||
| // CHECK: @_Z2f8{{.*}}(target("spirv.JointMatrixINTEL", double, 5, 10, 0, 1, 0) | ||
| void f8(__spv::__spirv_JointMatrixINTEL<double, 5, 10, 0, 1, 0> *matrix) {} | ||
Uh oh!
There was an error while loading. Please reload this page.