-
Notifications
You must be signed in to change notification settings - Fork 644
MSL: add initial cooperative matrix support #2596
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
2047b7b
MSL: add initial cooperative matrix support
kabu1204 6fc910e
MSL: reject cooperative matrix muladd operand flags
kabu1204 188e392
Workaround some MSVC shenanigans
HansKristian-Work cda74fe
MSL: Fix ptr-cast prepass and coopmat typed load/store
kabu1204 0e67b32
Update some stray references.
HansKristian-Work 7c69662
Indentation fixes.
HansKristian-Work 8f37924
Simplify unsupported coopmat check.
HansKristian-Work 7594d2b
Revert questionable change to bitcast.
HansKristian-Work File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
24 changes: 24 additions & 0 deletions
24
reference/shaders-msl-no-opt/asm/comp/cooperative-matrix-bfloat.asm.msl31.comp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| #include <metal_stdlib> | ||
| #include <simd/simd.h> | ||
| #include <metal_simdgroup_matrix> | ||
|
|
||
| using namespace metal; | ||
|
|
||
| struct SSBO | ||
| { | ||
| bfloat data[1]; | ||
| }; | ||
|
|
||
| kernel void main0(device SSBO& ssbo [[buffer(0)]]) | ||
| { | ||
| simdgroup_bfloat8x8 _21; | ||
| simdgroup_load(_21, &ssbo.data[0u], 8u); | ||
| simdgroup_bfloat8x8 _22; | ||
| simdgroup_load(_22, &ssbo.data[0u], 8u); | ||
| simdgroup_bfloat8x8 _23; | ||
| simdgroup_load(_23, &ssbo.data[0u], 8u); | ||
| simdgroup_bfloat8x8 _24; | ||
| simdgroup_multiply_accumulate(_24, _21, _22, _23); | ||
| simdgroup_store(_24, &ssbo.data[0u], 8u); | ||
| } | ||
|
|
||
16 changes: 16 additions & 0 deletions
16
reference/shaders-msl-no-opt/asm/comp/cooperative-matrix-length.asm.msl31.comp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| #include <metal_stdlib> | ||
| #include <simd/simd.h> | ||
| #include <metal_simdgroup_matrix> | ||
|
|
||
| using namespace metal; | ||
|
|
||
| struct SSBO | ||
| { | ||
| uint data[1]; | ||
| }; | ||
|
|
||
| kernel void main0(device SSBO& ssbo [[buffer(0)]]) | ||
| { | ||
| ssbo.data[0u] = uint(sizeof(simdgroup_float8x8::storage_type) / sizeof(float)); | ||
| } | ||
|
|
21 changes: 21 additions & 0 deletions
21
reference/shaders-msl-no-opt/asm/comp/cooperative-matrix-load-store.asm.msl31.comp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| #include <metal_stdlib> | ||
| #include <simd/simd.h> | ||
| #include <metal_simdgroup_matrix> | ||
|
|
||
| using namespace metal; | ||
|
|
||
| struct SSBO | ||
| { | ||
| float data[1]; | ||
| }; | ||
|
|
||
| kernel void main0(device SSBO& ssbo [[buffer(0)]]) | ||
| { | ||
| simdgroup_float8x8 _20; | ||
| simdgroup_load(_20, &ssbo.data[0u], 8u); | ||
| simdgroup_store(_20, &ssbo.data[0u], 8u); | ||
| simdgroup_float8x8 _21; | ||
| simdgroup_load(_21, &ssbo.data[0u], 8u, ulong2(0), true); | ||
| simdgroup_store(_21, &ssbo.data[0u], 8u, ulong2(0), true); | ||
| } | ||
|
|
38 changes: 38 additions & 0 deletions
38
reference/shaders-msl-no-opt/asm/comp/cooperative-matrix-muladd.asm.msl31.comp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,38 @@ | ||
| #include <metal_stdlib> | ||
| #include <simd/simd.h> | ||
| #include <metal_simdgroup_matrix> | ||
|
|
||
| using namespace metal; | ||
|
|
||
| struct SSBO32 | ||
| { | ||
| float data[1]; | ||
| }; | ||
|
|
||
| struct SSBO16 | ||
| { | ||
| half data[1]; | ||
| }; | ||
|
|
||
| kernel void main0(device SSBO32& ssbo32 [[buffer(0)]], device SSBO16& ssbo16 [[buffer(1)]]) | ||
| { | ||
| simdgroup_float8x8 _30; | ||
| simdgroup_load(_30, &ssbo32.data[0u], 8u); | ||
| simdgroup_float8x8 _31; | ||
| simdgroup_load(_31, &ssbo32.data[0u], 8u); | ||
| simdgroup_float8x8 _32; | ||
| simdgroup_load(_32, &ssbo32.data[0u], 8u); | ||
| simdgroup_float8x8 _33; | ||
| simdgroup_multiply_accumulate(_33, _30, _31, _32); | ||
| simdgroup_store(_33, &ssbo32.data[0u], 8u); | ||
| simdgroup_half8x8 _35; | ||
| simdgroup_load(_35, &ssbo16.data[0u], 8u); | ||
| simdgroup_half8x8 _36; | ||
| simdgroup_load(_36, &ssbo16.data[0u], 8u); | ||
| simdgroup_half8x8 _37; | ||
| simdgroup_load(_37, &ssbo16.data[0u], 8u); | ||
| simdgroup_half8x8 _38; | ||
| simdgroup_multiply_accumulate(_38, _35, _36, _37); | ||
| simdgroup_store(_38, &ssbo16.data[0u], 8u); | ||
| } | ||
|
|
56 changes: 56 additions & 0 deletions
56
...e/shaders-msl-no-opt/asm/comp/cooperative-matrix-workgroup-cast-load-store.asm.msl31.comp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,56 @@ | ||
| #pragma clang diagnostic ignored "-Wmissing-prototypes" | ||
| #pragma clang diagnostic ignored "-Wmissing-braces" | ||
|
|
||
| #include <metal_stdlib> | ||
| #include <simd/simd.h> | ||
| #include <metal_simdgroup_matrix> | ||
|
|
||
| using namespace metal; | ||
|
|
||
| template<typename T, size_t Num> | ||
| struct spvUnsafeArray | ||
| { | ||
| T elements[Num ? Num : 1]; | ||
|
|
||
| thread T& operator [] (size_t pos) thread | ||
| { | ||
| return elements[pos]; | ||
| } | ||
| constexpr const thread T& operator [] (size_t pos) const thread | ||
| { | ||
| return elements[pos]; | ||
| } | ||
|
|
||
| device T& operator [] (size_t pos) device | ||
| { | ||
| return elements[pos]; | ||
| } | ||
| constexpr const device T& operator [] (size_t pos) const device | ||
| { | ||
| return elements[pos]; | ||
| } | ||
|
|
||
| constexpr const constant T& operator [] (size_t pos) const constant | ||
| { | ||
| return elements[pos]; | ||
| } | ||
|
|
||
| threadgroup T& operator [] (size_t pos) threadgroup | ||
| { | ||
| return elements[pos]; | ||
| } | ||
| constexpr const threadgroup T& operator [] (size_t pos) const threadgroup | ||
| { | ||
| return elements[pos]; | ||
| } | ||
| }; | ||
|
|
||
| kernel void main0() | ||
| { | ||
| threadgroup spvUnsafeArray<uchar, 128> _15; | ||
| _15[0u] = uchar(0); | ||
| simdgroup_half8x8 _20; | ||
| simdgroup_load(_20, reinterpret_cast<threadgroup half*>(&_15[0u]), (16u) / 2u); | ||
| simdgroup_store(_20, reinterpret_cast<threadgroup half*>(&_15[0u]), (16u) / 2u); | ||
| } | ||
|
|
58 changes: 58 additions & 0 deletions
58
reference/shaders-msl-no-opt/asm/comp/cooperative-matrix-workgroup-load-store.asm.msl31.comp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,58 @@ | ||
| #pragma clang diagnostic ignored "-Wmissing-prototypes" | ||
| #pragma clang diagnostic ignored "-Wmissing-braces" | ||
|
|
||
| #include <metal_stdlib> | ||
| #include <simd/simd.h> | ||
| #include <metal_simdgroup_matrix> | ||
|
|
||
| using namespace metal; | ||
|
|
||
| template<typename T, size_t Num> | ||
| struct spvUnsafeArray | ||
| { | ||
| T elements[Num ? Num : 1]; | ||
|
|
||
| thread T& operator [] (size_t pos) thread | ||
| { | ||
| return elements[pos]; | ||
| } | ||
| constexpr const thread T& operator [] (size_t pos) const thread | ||
| { | ||
| return elements[pos]; | ||
| } | ||
|
|
||
| device T& operator [] (size_t pos) device | ||
| { | ||
| return elements[pos]; | ||
| } | ||
| constexpr const device T& operator [] (size_t pos) const device | ||
| { | ||
| return elements[pos]; | ||
| } | ||
|
|
||
| constexpr const constant T& operator [] (size_t pos) const constant | ||
| { | ||
| return elements[pos]; | ||
| } | ||
|
|
||
| threadgroup T& operator [] (size_t pos) threadgroup | ||
| { | ||
| return elements[pos]; | ||
| } | ||
| constexpr const threadgroup T& operator [] (size_t pos) const threadgroup | ||
| { | ||
| return elements[pos]; | ||
| } | ||
| }; | ||
|
|
||
| kernel void main0() | ||
| { | ||
| threadgroup spvUnsafeArray<float, 64> _14; | ||
| simdgroup_float8x8 _18; | ||
| simdgroup_load(_18, &_14[0u], 8u); | ||
| simdgroup_store(_18, &_14[0u], 8u); | ||
| simdgroup_float8x8 _19; | ||
| simdgroup_load(_19, &_14[0u], 8u, ulong2(0), true); | ||
| simdgroup_store(_19, &_14[0u], 8u, ulong2(0), true); | ||
| } | ||
|
|
52 changes: 52 additions & 0 deletions
52
shaders-msl-no-opt/asm/comp/cooperative-matrix-bfloat.asm.msl31.comp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,52 @@ | ||
| ; SPIR-V | ||
| ; Version: 1.6 | ||
| ; Generator: Khronos SPIR-V Tools Assembler; 0 | ||
| ; Bound: 50 | ||
| ; Schema: 0 | ||
| OpCapability Shader | ||
| OpCapability CooperativeMatrixKHR | ||
| OpCapability BFloat16TypeKHR | ||
| OpCapability BFloat16CooperativeMatrixKHR | ||
| OpCapability VulkanMemoryModel | ||
| OpExtension "SPV_KHR_cooperative_matrix" | ||
| OpExtension "SPV_KHR_bfloat16" | ||
| OpExtension "SPV_KHR_vulkan_memory_model" | ||
| OpMemoryModel Logical Vulkan | ||
| OpEntryPoint GLCompute %main "main" | ||
| OpExecutionMode %main LocalSize 32 1 1 | ||
| OpName %main "main" | ||
| OpName %SSBO "SSBO" | ||
| OpMemberName %SSBO 0 "data" | ||
| OpName %ssbo "ssbo" | ||
| OpDecorate %arr_bf16 ArrayStride 2 | ||
| OpMemberDecorate %SSBO 0 Offset 0 | ||
| OpDecorate %SSBO Block | ||
| OpDecorate %ssbo DescriptorSet 0 | ||
| OpDecorate %ssbo Binding 0 | ||
| %void = OpTypeVoid | ||
| %3 = OpTypeFunction %void | ||
| %bfloat = OpTypeFloat 16 BFloat16KHR | ||
| %uint = OpTypeInt 32 0 | ||
| %uint_0 = OpConstant %uint 0 | ||
| %uint_1 = OpConstant %uint 1 | ||
| %uint_2 = OpConstant %uint 2 | ||
| %uint_3 = OpConstant %uint 3 | ||
| %uint_8 = OpConstant %uint 8 | ||
| %arr_bf16 = OpTypeRuntimeArray %bfloat | ||
| %SSBO = OpTypeStruct %arr_bf16 | ||
| %ptr_ssbo_SSBO = OpTypePointer StorageBuffer %SSBO | ||
| %ssbo = OpVariable %ptr_ssbo_SSBO StorageBuffer | ||
| %ptr_ssbo_bf16 = OpTypePointer StorageBuffer %bfloat | ||
| %coopmat_bf16_A = OpTypeCooperativeMatrixKHR %bfloat %uint_3 %uint_8 %uint_8 %uint_0 | ||
| %coopmat_bf16_B = OpTypeCooperativeMatrixKHR %bfloat %uint_3 %uint_8 %uint_8 %uint_1 | ||
| %coopmat_bf16_acc = OpTypeCooperativeMatrixKHR %bfloat %uint_3 %uint_8 %uint_8 %uint_2 | ||
| %main = OpFunction %void None %3 | ||
| %5 = OpLabel | ||
| %p0 = OpAccessChain %ptr_ssbo_bf16 %ssbo %uint_0 %uint_0 | ||
| %bf_A = OpCooperativeMatrixLoadKHR %coopmat_bf16_A %p0 %uint_0 %uint_8 | ||
| %bf_B = OpCooperativeMatrixLoadKHR %coopmat_bf16_B %p0 %uint_0 %uint_8 | ||
| %bf_C = OpCooperativeMatrixLoadKHR %coopmat_bf16_acc %p0 %uint_0 %uint_8 | ||
| %bf_D = OpCooperativeMatrixMulAddKHR %coopmat_bf16_acc %bf_A %bf_B %bf_C | ||
| OpCooperativeMatrixStoreKHR %p0 %bf_D %uint_0 %uint_8 | ||
| OpReturn | ||
| OpFunctionEnd |
42 changes: 42 additions & 0 deletions
42
shaders-msl-no-opt/asm/comp/cooperative-matrix-length.asm.msl31.comp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,42 @@ | ||
| ; SPIR-V | ||
| ; Version: 1.6 | ||
| ; Generator: Khronos SPIR-V Tools Assembler; 0 | ||
| ; Bound: 24 | ||
| ; Schema: 0 | ||
| OpCapability Shader | ||
| OpCapability CooperativeMatrixKHR | ||
| OpCapability VulkanMemoryModel | ||
| OpExtension "SPV_KHR_cooperative_matrix" | ||
| OpExtension "SPV_KHR_vulkan_memory_model" | ||
| OpMemoryModel Logical Vulkan | ||
| OpEntryPoint GLCompute %main "main" | ||
| OpExecutionMode %main LocalSize 32 1 1 | ||
| OpName %main "main" | ||
| OpName %SSBO "SSBO" | ||
| OpMemberName %SSBO 0 "data" | ||
| OpName %ssbo "ssbo" | ||
| OpDecorate %arr_uint ArrayStride 4 | ||
| OpMemberDecorate %SSBO 0 Offset 0 | ||
| OpDecorate %SSBO Block | ||
| OpDecorate %ssbo DescriptorSet 0 | ||
| OpDecorate %ssbo Binding 0 | ||
| %void = OpTypeVoid | ||
| %3 = OpTypeFunction %void | ||
| %uint = OpTypeInt 32 0 | ||
| %float = OpTypeFloat 32 | ||
| %uint_0 = OpConstant %uint 0 | ||
| %uint_3 = OpConstant %uint 3 | ||
| %uint_8 = OpConstant %uint 8 | ||
| %arr_uint = OpTypeRuntimeArray %uint | ||
| %SSBO = OpTypeStruct %arr_uint | ||
| %ptr_ssbo_SSBO = OpTypePointer StorageBuffer %SSBO | ||
| %ssbo = OpVariable %ptr_ssbo_SSBO StorageBuffer | ||
| %ptr_ssbo_uint = OpTypePointer StorageBuffer %uint | ||
| %coopmat_a = OpTypeCooperativeMatrixKHR %float %uint_3 %uint_8 %uint_8 %uint_0 | ||
| %main = OpFunction %void None %3 | ||
| %5 = OpLabel | ||
| %len = OpCooperativeMatrixLengthKHR %uint %coopmat_a | ||
| %p = OpAccessChain %ptr_ssbo_uint %ssbo %uint_0 %uint_0 | ||
| OpStore %p %len | ||
| OpReturn | ||
| OpFunctionEnd |
51 changes: 51 additions & 0 deletions
51
shaders-msl-no-opt/asm/comp/cooperative-matrix-load-store.asm.msl31.comp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,51 @@ | ||
| ; SPIR-V | ||
| ; Version: 1.6 | ||
| ; Generator: Khronos SPIR-V Tools Assembler; 0 | ||
| ; Bound: 50 | ||
| ; Schema: 0 | ||
| OpCapability Shader | ||
| OpCapability CooperativeMatrixKHR | ||
| OpCapability VulkanMemoryModel | ||
| OpExtension "SPV_KHR_cooperative_matrix" | ||
| OpExtension "SPV_KHR_vulkan_memory_model" | ||
| OpMemoryModel Logical Vulkan | ||
| OpEntryPoint GLCompute %main "main" | ||
| OpExecutionMode %main LocalSize 32 1 1 | ||
| OpName %main "main" | ||
| OpName %SSBO "SSBO" | ||
| OpMemberName %SSBO 0 "data" | ||
| OpName %ssbo "ssbo" | ||
| OpDecorate %arr_float ArrayStride 4 | ||
| OpMemberDecorate %SSBO 0 Offset 0 | ||
| OpDecorate %SSBO Block | ||
| OpDecorate %ssbo DescriptorSet 0 | ||
| OpDecorate %ssbo Binding 0 | ||
| %void = OpTypeVoid | ||
| %3 = OpTypeFunction %void | ||
| %float = OpTypeFloat 32 | ||
| %uint = OpTypeInt 32 0 | ||
| %uint_0 = OpConstant %uint 0 | ||
| %uint_1 = OpConstant %uint 1 | ||
| %uint_2 = OpConstant %uint 2 | ||
| %uint_3 = OpConstant %uint 3 | ||
| %uint_8 = OpConstant %uint 8 | ||
| %arr_float = OpTypeRuntimeArray %float | ||
| %SSBO = OpTypeStruct %arr_float | ||
| %ptr_ssbo_SSBO = OpTypePointer StorageBuffer %SSBO | ||
| %ssbo = OpVariable %ptr_ssbo_SSBO StorageBuffer | ||
| %ptr_ssbo_float = OpTypePointer StorageBuffer %float | ||
| %coopmat_a = OpTypeCooperativeMatrixKHR %float %uint_3 %uint_8 %uint_8 %uint_0 | ||
| %coopmat_acc = OpTypeCooperativeMatrixKHR %float %uint_3 %uint_8 %uint_8 %uint_2 | ||
| %main = OpFunction %void None %3 | ||
| %5 = OpLabel | ||
| ; Row-major load from offset 0 | ||
| %p0 = OpAccessChain %ptr_ssbo_float %ssbo %uint_0 %uint_0 | ||
| %mat_a = OpCooperativeMatrixLoadKHR %coopmat_a %p0 %uint_0 %uint_8 | ||
| ; Row-major store to offset 0 | ||
| OpCooperativeMatrixStoreKHR %p0 %mat_a %uint_0 %uint_8 | ||
| ; Column-major load from offset 0 | ||
| %mat_col = OpCooperativeMatrixLoadKHR %coopmat_acc %p0 %uint_1 %uint_8 | ||
| ; Column-major store to offset 0 | ||
| OpCooperativeMatrixStoreKHR %p0 %mat_col %uint_1 %uint_8 | ||
| OpReturn | ||
| OpFunctionEnd |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing a test for shared memory load-store. MSL spec says it's supported.
Also, this is missing tests for load-store with different types. E.g. in SPIR-V you can load a float16_t coopmat via a uint8_t array of data. Just pointer casting and recomputing the row/col stride should work.