diff --git a/clang/lib/CodeGen/CodeGenTypes.cpp b/clang/lib/CodeGen/CodeGenTypes.cpp index bf8e4006e1d69..ed063a36e5705 100644 --- a/clang/lib/CodeGen/CodeGenTypes.cpp +++ b/clang/lib/CodeGen/CodeGenTypes.cpp @@ -51,65 +51,6 @@ void CodeGenTypes::addRecordTypeName(const RecordDecl *RD, StringRef suffix) { SmallString<256> TypeName; llvm::raw_svector_ostream OS(TypeName); - // If RD is spirv_JointMatrixINTEL type, mangle differently. - if (CGM.getTriple().isSPIRV() || CGM.getTriple().isSPIR()) { - if (RD->getQualifiedNameAsString() == "__spv::__spirv_JointMatrixINTEL") { - if (auto TemplateDecl = dyn_cast(RD)) { - ArrayRef TemplateArgs = - TemplateDecl->getTemplateArgs().asArray(); - OS << "spirv.JointMatrixINTEL."; - for (auto &TemplateArg : TemplateArgs) { - OS << "_"; - if (TemplateArg.getKind() == TemplateArgument::Type) { - llvm::Type *TTy = ConvertType(TemplateArg.getAsType()); - if (TTy->isIntegerTy()) { - switch (TTy->getIntegerBitWidth()) { - case 8: - OS << "char"; - break; - case 16: - OS << "short"; - break; - case 32: - OS << "int"; - break; - case 64: - OS << "long"; - break; - default: - OS << "i" << TTy->getIntegerBitWidth(); - break; - } - } else if (TTy->isHalfTy()) { - OS << "half"; - } else if (TTy->isFloatTy()) { - OS << "float"; - } else if (TTy->isDoubleTy()) { - OS << "double"; - } else if (TTy->isBFloatTy()) { - OS << "bfloat16"; - } else if (TTy->isStructTy()) { - StringRef LlvmTyName = TTy->getStructName(); - // Emit half/bfloat16/tf32 for sycl[::*]::{half,bfloat16,tf32} - if (LlvmTyName.startswith("class.sycl::") || - LlvmTyName.startswith("class.__sycl_internal::")) - LlvmTyName = LlvmTyName.rsplit("::").second; - if (LlvmTyName != "half" && LlvmTyName != "bfloat16" && - LlvmTyName != "tf32") - llvm_unreachable("Wrong matrix base type!"); - OS << LlvmTyName; - } else { - llvm_unreachable("Wrong matrix base type!"); - } - } else if (TemplateArg.getKind() == TemplateArgument::Integral) { - OS << TemplateArg.getAsIntegral(); - } - } - Ty->setName(OS.str()); - return; - } - } - } OS << RD->getKindName() << '.'; // FIXME: We probably want to make more tweaks to the printing policy. For @@ -460,6 +401,78 @@ llvm::Type *CodeGenTypes::ConvertFunctionTypeInternal(QualType QFT) { return ResultType; } +template +llvm::Type *getJointMatrixINTELExtType(llvm::Type *CompTy, + ArrayRef TemplateArgs, + const unsigned Val = 0) { + // TODO: we should actually have exactly 5 template parameters: 1 for + // type and 4 for type parameters. But in previous version of the SPIR-V + // spec we have Layout matrix type parameter, that was later removed. + // Once we update to the newest version of the spec - this should be updated. + assert((TemplateArgs.size() == 5 || TemplateArgs.size() == 6) && + "Wrong JointMatrixINTEL template parameters number"); + // This is required to represent optional 'Component Type Interpretation' + // parameter + using ParamsType = + typename std::conditional, + SmallVector>::type; + ParamsType Params; + if constexpr (NeedTypeInterpret) + Params = {0, 0, 0, 0, 0, Val}; + else + Params = {0, 0, 0, 0, 0}; + for (size_t I = 1; I != TemplateArgs.size(); ++I) { + assert(TemplateArgs[I].getKind() == TemplateArgument::Integral && + "Wrong JointMatrixINTEL template parameter"); + Params[I - 1] = TemplateArgs[I].getAsIntegral().getExtValue(); + } + return llvm::TargetExtType::get(CompTy->getContext(), + "spirv.JointMatrixINTEL", {CompTy}, Params); +} + +/// ConvertSYCLJointMatrixINTELType - Convert SYCL joint_matrix type +/// which is represented as a pointer to a structure to LLVM extension type +/// with the parameters that follow SPIR-V JointMatrixINTEL type. +/// The expected representation is: +/// target("spirv.JointMatrixINTEL", %element_type, %rows%, %cols%, %scope%, +/// %use%, (optional) %element_type_interpretation%) +llvm::Type *CodeGenTypes::ConvertSYCLJointMatrixINTELType(RecordDecl *RD) { + auto *TemplateDecl = cast(RD); + ArrayRef TemplateArgs = + TemplateDecl->getTemplateArgs().asArray(); + assert(TemplateArgs[0].getKind() == TemplateArgument::Type && + "1st JointMatrixINTEL template parameter must be type"); + llvm::Type *CompTy = ConvertType(TemplateArgs[0].getAsType()); + + // Per JointMatrixINTEL spec the type can have an optional + // 'Component Type Interpretation' parameter. We should emit it in case + // if on SYCL level joint matrix accepts 'bfloat16' or 'tf32' objects as + // matrix's components. Yet 'bfloat16' should be represented as 'int16' and + // 'tf32' as 'float' types. + if (CompTy->isStructTy()) { + StringRef LlvmTyName = CompTy->getStructName(); + // Emit half/int16/float for sycl[::*]::{half,bfloat16,tf32} + if (LlvmTyName.startswith("class.sycl::") || + LlvmTyName.startswith("class.__sycl_internal::")) + LlvmTyName = LlvmTyName.rsplit("::").second; + if (LlvmTyName == "half") { + CompTy = llvm::Type::getHalfTy(getLLVMContext()); + return getJointMatrixINTELExtType(CompTy, TemplateArgs); + } else if (LlvmTyName == "tf32") { + CompTy = llvm::Type::getFloatTy(getLLVMContext()); + // 'tf32' interpretation is mapped to '0' + return getJointMatrixINTELExtType(CompTy, TemplateArgs, 0); + } else if (LlvmTyName == "bfloat16") { + CompTy = llvm::Type::getInt16Ty(getLLVMContext()); + // 'bfloat16' interpretation is mapped to '1' + return getJointMatrixINTELExtType(CompTy, TemplateArgs, 1); + } else { + llvm_unreachable("Wrong matrix base type!"); + } + } + return getJointMatrixINTELExtType(CompTy, TemplateArgs); +} + /// ConvertType - Convert the specified type to its LLVM form. llvm::Type *CodeGenTypes::ConvertType(QualType T) { T = Context.getCanonicalType(T); @@ -745,6 +758,18 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) { llvm::Type *PointeeType = ConvertTypeForMem(ETy); if (PointeeType->isVoidTy()) PointeeType = llvm::Type::getInt8Ty(getLLVMContext()); + if (CGM.getTriple().isSPIRV() || CGM.getTriple().isSPIR()) { + const Type *ClangETy = ETy.getTypePtrOrNull(); + if (ClangETy && ClangETy->isStructureOrClassType()) { + RecordDecl *RD = ClangETy->getAsCXXRecordDecl(); + if (RD && + RD->getQualifiedNameAsString() == "__spv::__spirv_JointMatrixINTEL") { + ResultType = ConvertSYCLJointMatrixINTELType(RD); + break; + } + } + } + unsigned AS = getTargetAddressSpace(ETy); ResultType = llvm::PointerType::get(PointeeType, AS); break; diff --git a/clang/lib/CodeGen/CodeGenTypes.h b/clang/lib/CodeGen/CodeGenTypes.h index e76fda95513f6..3f198b2a3de1a 100644 --- a/clang/lib/CodeGen/CodeGenTypes.h +++ b/clang/lib/CodeGen/CodeGenTypes.h @@ -133,6 +133,14 @@ class CodeGenTypes { /// memory representation is usually i8 or i32, depending on the target. llvm::Type *ConvertTypeForMem(QualType T, bool ForBitField = false); + /// ConvertSYCLJointMatrixINTELType - Convert SYCL joint_matrix type + /// which is represented as a pointer to a structure to LLVM extension type + /// with the parameters that follow SPIR-V JointMatrixINTEL type. + /// The expected representation is: + /// target("spirv.JointMatrixINTEL", %element_type, %rows%, %cols%, %scope%, + /// %use%, (optional) %element_type_interpretation%) + llvm::Type *ConvertSYCLJointMatrixINTELType(RecordDecl *RD); + /// GetFunctionType - Get the LLVM function type for \arg Info. llvm::FunctionType *GetFunctionType(const CGFunctionInfo &Info); diff --git a/clang/test/CodeGenSYCL/matrix.cpp b/clang/test/CodeGenSYCL/matrix.cpp index 69469811047fd..b2c0c51adba6e 100644 --- a/clang/test/CodeGenSYCL/matrix.cpp +++ b/clang/test/CodeGenSYCL/matrix.cpp @@ -5,18 +5,18 @@ #include namespace __spv { - template + template struct __spirv_JointMatrixINTEL; } -// CHECK: @_Z2f1{{.*}}(%spirv.JointMatrixINTEL._float_5_10_0_1 -void f1(__spv::__spirv_JointMatrixINTEL *matrix) {} +// CHECK: @_Z2f1{{.*}}(target("spirv.JointMatrixINTEL", float, 5, 10, 0, 1, 0) +void f1(__spv::__spirv_JointMatrixINTEL *matrix) {} -// CHECK: @_Z2f2{{.*}}(%spirv.JointMatrixINTEL._long_10_2_0_0 -void f2(__spv::__spirv_JointMatrixINTEL *matrix) {} +// CHECK: @_Z2f2{{.*}}(target("spirv.JointMatrixINTEL", i64, 10, 2, 0, 0, 0) +void f2(__spv::__spirv_JointMatrixINTEL *matrix) {} -// CHECK: @_Z2f3{{.*}}(%spirv.JointMatrixINTEL._char_10_2_0_0 -void f3(__spv::__spirv_JointMatrixINTEL *matrix) {} +// CHECK: @_Z2f3{{.*}}(target("spirv.JointMatrixINTEL", i8, 10, 2, 0, 0, 0) +void f3(__spv::__spirv_JointMatrixINTEL *matrix) {} namespace sycl { class half {}; @@ -25,17 +25,17 @@ namespace sycl { } typedef sycl::half my_half; -// CHECK: @_Z2f4{{.*}}(%spirv.JointMatrixINTEL._half_10_2_0_0 -void f4(__spv::__spirv_JointMatrixINTEL *matrix) {} +// CHECK: @_Z2f4{{.*}}(target("spirv.JointMatrixINTEL", half, 10, 2, 0, 0, 0) +void f4(__spv::__spirv_JointMatrixINTEL *matrix) {} -// CHECK: @_Z2f5{{.*}}(%spirv.JointMatrixINTEL._bfloat16_10_2_0_0 -void f5(__spv::__spirv_JointMatrixINTEL *matrix) {} +// CHECK: @_Z2f5{{.*}}(target("spirv.JointMatrixINTEL", i16, 10, 2, 0, 0, 0, 1) +void f5(__spv::__spirv_JointMatrixINTEL *matrix) {} -// CHECK: @_Z2f6{{.*}}(%spirv.JointMatrixINTEL._i128_10_2_0_0 -void f6(__spv::__spirv_JointMatrixINTEL<_BitInt(128), 10, 2, 0, 0> *matrix) {} +// CHECK: @_Z2f6{{.*}}(target("spirv.JointMatrixINTEL", i128, 10, 2, 0, 0, 0) +void f6(__spv::__spirv_JointMatrixINTEL<_BitInt(128), 10, 2, 0, 0, 0> *matrix) {} -// CHECK: @_Z2f7{{.*}}(%spirv.JointMatrixINTEL._tf32_10_2_0_0 -void f7(__spv::__spirv_JointMatrixINTEL *matrix) {} +// CHECK: @_Z2f7{{.*}}(target("spirv.JointMatrixINTEL", float, 10, 2, 0, 0, 0, 0) +void f7(__spv::__spirv_JointMatrixINTEL *matrix) {} -// CHECK: @_Z2f8{{.*}}(%spirv.JointMatrixINTEL._double_5_10_0_1 -void f8(__spv::__spirv_JointMatrixINTEL *matrix) {} +// CHECK: @_Z2f8{{.*}}(target("spirv.JointMatrixINTEL", double, 5, 10, 0, 1, 0) +void f8(__spv::__spirv_JointMatrixINTEL *matrix) {} diff --git a/sycl/test/matrix/legacy/matrix-int8-test.cpp b/sycl/test/matrix/legacy/matrix-int8-test.cpp index a0c2edb62c2f1..852c877b46fc4 100644 --- a/sycl/test/matrix/legacy/matrix-int8-test.cpp +++ b/sycl/test/matrix/legacy/matrix-int8-test.cpp @@ -1,8 +1,8 @@ // RUN: %clangxx -fsycl -fsycl-device-only -O2 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=1 -S -emit-llvm -o - %s | FileCheck %s -// CHECK-DAG: %spirv.JointMatrixINTEL._char_12_48_0_3 = type opaque -// CHECK-DAG: %spirv.JointMatrixINTEL._int_12_12_0_3 = type opaque -// CHECK-DAG: %spirv.JointMatrixINTEL._char_48_12_3_3 = type opaque +// CHECK-DAG: target("spirv.JointMatrixINTEL", i8, 12, 48, 0, 3, 0) +// CHECK-DAG: target("spirv.JointMatrixINTEL", i32, 12, 12, 0, 3, 0) +// CHECK-DAG: target("spirv.JointMatrixINTEL", i8, 48, 12, 3, 3, 0) #include #include diff --git a/sycl/test/matrix/matrix-int8-test.cpp b/sycl/test/matrix/matrix-int8-test.cpp index de8721bca3b09..99f60423ca212 100644 --- a/sycl/test/matrix/matrix-int8-test.cpp +++ b/sycl/test/matrix/matrix-int8-test.cpp @@ -1,8 +1,8 @@ // RUN: %clangxx -fsycl -fsycl-device-only -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 -O2 -S -emit-llvm -o - %s | FileCheck %s -// CHECK-DAG: %spirv.JointMatrixINTEL._char_12_48_0_3_0 = type opaque -// CHECK-DAG: %spirv.JointMatrixINTEL._int_12_12_3_3_2 = type opaque -// CHECK-DAG: %spirv.JointMatrixINTEL._char_48_12_2_3_1 = type opaque +// CHECK-DAG: target("spirv.JointMatrixINTEL", i8, 12, 48, 0, 3, 0) +// CHECK-DAG: target("spirv.JointMatrixINTEL", i32, 12, 12, 3, 3, 2) +// CHECK-DAG: target("spirv.JointMatrixINTEL", i8, 48, 12, 2, 3, 1) #include #include