Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#include <metal_stdlib>
#include <simd/simd.h>
#include <metal_simdgroup_matrix>

using namespace metal;

struct SSBO
{
bfloat data[1];
};

kernel void main0(device SSBO& ssbo [[buffer(0)]])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing a test for shared memory load-store. MSL spec says it's supported.

Also, this is missing tests for load-store with different types. E.g. in SPIR-V you can load a float16_t coopmat via a uint8_t array of data. Just pointer casting and recomputing the row/col stride should work.

{
simdgroup_bfloat8x8 _21;
simdgroup_load(_21, &ssbo.data[0u], 8u);
simdgroup_bfloat8x8 _22;
simdgroup_load(_22, &ssbo.data[0u], 8u);
simdgroup_bfloat8x8 _23;
simdgroup_load(_23, &ssbo.data[0u], 8u);
simdgroup_bfloat8x8 _24;
simdgroup_multiply_accumulate(_24, _21, _22, _23);
simdgroup_store(_24, &ssbo.data[0u], 8u);
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#include <metal_stdlib>
#include <simd/simd.h>
#include <metal_simdgroup_matrix>

using namespace metal;

struct SSBO
{
uint data[1];
};

kernel void main0(device SSBO& ssbo [[buffer(0)]])
{
ssbo.data[0u] = uint(sizeof(simdgroup_float8x8::storage_type) / sizeof(float));
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#include <metal_stdlib>
#include <simd/simd.h>
#include <metal_simdgroup_matrix>

using namespace metal;

struct SSBO
{
float data[1];
};

kernel void main0(device SSBO& ssbo [[buffer(0)]])
{
simdgroup_float8x8 _20;
simdgroup_load(_20, &ssbo.data[0u], 8u);
simdgroup_store(_20, &ssbo.data[0u], 8u);
simdgroup_float8x8 _21;
simdgroup_load(_21, &ssbo.data[0u], 8u, ulong2(0), true);
simdgroup_store(_21, &ssbo.data[0u], 8u, ulong2(0), true);
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#include <metal_stdlib>
#include <simd/simd.h>
#include <metal_simdgroup_matrix>

using namespace metal;

struct SSBO32
{
float data[1];
};

struct SSBO16
{
half data[1];
};

kernel void main0(device SSBO32& ssbo32 [[buffer(0)]], device SSBO16& ssbo16 [[buffer(1)]])
{
simdgroup_float8x8 _30;
simdgroup_load(_30, &ssbo32.data[0u], 8u);
simdgroup_float8x8 _31;
simdgroup_load(_31, &ssbo32.data[0u], 8u);
simdgroup_float8x8 _32;
simdgroup_load(_32, &ssbo32.data[0u], 8u);
simdgroup_float8x8 _33;
simdgroup_multiply_accumulate(_33, _30, _31, _32);
simdgroup_store(_33, &ssbo32.data[0u], 8u);
simdgroup_half8x8 _35;
simdgroup_load(_35, &ssbo16.data[0u], 8u);
simdgroup_half8x8 _36;
simdgroup_load(_36, &ssbo16.data[0u], 8u);
simdgroup_half8x8 _37;
simdgroup_load(_37, &ssbo16.data[0u], 8u);
simdgroup_half8x8 _38;
simdgroup_multiply_accumulate(_38, _35, _36, _37);
simdgroup_store(_38, &ssbo16.data[0u], 8u);
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
; SPIR-V
; Version: 1.6
; Generator: Khronos SPIR-V Tools Assembler; 0
; Bound: 50
; Schema: 0
OpCapability Shader
OpCapability CooperativeMatrixKHR
OpCapability BFloat16TypeKHR
OpCapability BFloat16CooperativeMatrixKHR
OpCapability VulkanMemoryModel
OpExtension "SPV_KHR_cooperative_matrix"
OpExtension "SPV_KHR_bfloat16"
OpExtension "SPV_KHR_vulkan_memory_model"
OpMemoryModel Logical Vulkan
OpEntryPoint GLCompute %main "main"
OpExecutionMode %main LocalSize 32 1 1
OpName %main "main"
OpName %SSBO "SSBO"
OpMemberName %SSBO 0 "data"
OpName %ssbo "ssbo"
OpDecorate %arr_bf16 ArrayStride 2
OpMemberDecorate %SSBO 0 Offset 0
OpDecorate %SSBO Block
OpDecorate %ssbo DescriptorSet 0
OpDecorate %ssbo Binding 0
%void = OpTypeVoid
%3 = OpTypeFunction %void
%bfloat = OpTypeFloat 16 BFloat16KHR
%uint = OpTypeInt 32 0
%uint_0 = OpConstant %uint 0
%uint_1 = OpConstant %uint 1
%uint_2 = OpConstant %uint 2
%uint_3 = OpConstant %uint 3
%uint_8 = OpConstant %uint 8
%arr_bf16 = OpTypeRuntimeArray %bfloat
%SSBO = OpTypeStruct %arr_bf16
%ptr_ssbo_SSBO = OpTypePointer StorageBuffer %SSBO
%ssbo = OpVariable %ptr_ssbo_SSBO StorageBuffer
%ptr_ssbo_bf16 = OpTypePointer StorageBuffer %bfloat
%coopmat_bf16_A = OpTypeCooperativeMatrixKHR %bfloat %uint_3 %uint_8 %uint_8 %uint_0
%coopmat_bf16_B = OpTypeCooperativeMatrixKHR %bfloat %uint_3 %uint_8 %uint_8 %uint_1
%coopmat_bf16_acc = OpTypeCooperativeMatrixKHR %bfloat %uint_3 %uint_8 %uint_8 %uint_2
%main = OpFunction %void None %3
%5 = OpLabel
%p0 = OpAccessChain %ptr_ssbo_bf16 %ssbo %uint_0 %uint_0
%bf_A = OpCooperativeMatrixLoadKHR %coopmat_bf16_A %p0 %uint_0 %uint_8
%bf_B = OpCooperativeMatrixLoadKHR %coopmat_bf16_B %p0 %uint_0 %uint_8
%bf_C = OpCooperativeMatrixLoadKHR %coopmat_bf16_acc %p0 %uint_0 %uint_8
%bf_D = OpCooperativeMatrixMulAddKHR %coopmat_bf16_acc %bf_A %bf_B %bf_C
OpCooperativeMatrixStoreKHR %p0 %bf_D %uint_0 %uint_8
OpReturn
OpFunctionEnd
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
; SPIR-V
; Version: 1.6
; Generator: Khronos SPIR-V Tools Assembler; 0
; Bound: 24
; Schema: 0
OpCapability Shader
OpCapability CooperativeMatrixKHR
OpCapability VulkanMemoryModel
OpExtension "SPV_KHR_cooperative_matrix"
OpExtension "SPV_KHR_vulkan_memory_model"
OpMemoryModel Logical Vulkan
OpEntryPoint GLCompute %main "main"
OpExecutionMode %main LocalSize 32 1 1
OpName %main "main"
OpName %SSBO "SSBO"
OpMemberName %SSBO 0 "data"
OpName %ssbo "ssbo"
OpDecorate %arr_uint ArrayStride 4
OpMemberDecorate %SSBO 0 Offset 0
OpDecorate %SSBO Block
OpDecorate %ssbo DescriptorSet 0
OpDecorate %ssbo Binding 0
%void = OpTypeVoid
%3 = OpTypeFunction %void
%uint = OpTypeInt 32 0
%float = OpTypeFloat 32
%uint_0 = OpConstant %uint 0
%uint_3 = OpConstant %uint 3
%uint_8 = OpConstant %uint 8
%arr_uint = OpTypeRuntimeArray %uint
%SSBO = OpTypeStruct %arr_uint
%ptr_ssbo_SSBO = OpTypePointer StorageBuffer %SSBO
%ssbo = OpVariable %ptr_ssbo_SSBO StorageBuffer
%ptr_ssbo_uint = OpTypePointer StorageBuffer %uint
%coopmat_a = OpTypeCooperativeMatrixKHR %float %uint_3 %uint_8 %uint_8 %uint_0
%main = OpFunction %void None %3
%5 = OpLabel
%len = OpCooperativeMatrixLengthKHR %uint %coopmat_a
%p = OpAccessChain %ptr_ssbo_uint %ssbo %uint_0 %uint_0
OpStore %p %len
OpReturn
OpFunctionEnd
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
; SPIR-V
; Version: 1.6
; Generator: Khronos SPIR-V Tools Assembler; 0
; Bound: 50
; Schema: 0
OpCapability Shader
OpCapability CooperativeMatrixKHR
OpCapability VulkanMemoryModel
OpExtension "SPV_KHR_cooperative_matrix"
OpExtension "SPV_KHR_vulkan_memory_model"
OpMemoryModel Logical Vulkan
OpEntryPoint GLCompute %main "main"
OpExecutionMode %main LocalSize 32 1 1
OpName %main "main"
OpName %SSBO "SSBO"
OpMemberName %SSBO 0 "data"
OpName %ssbo "ssbo"
OpDecorate %arr_float ArrayStride 4
OpMemberDecorate %SSBO 0 Offset 0
OpDecorate %SSBO Block
OpDecorate %ssbo DescriptorSet 0
OpDecorate %ssbo Binding 0
%void = OpTypeVoid
%3 = OpTypeFunction %void
%float = OpTypeFloat 32
%uint = OpTypeInt 32 0
%uint_0 = OpConstant %uint 0
%uint_1 = OpConstant %uint 1
%uint_2 = OpConstant %uint 2
%uint_3 = OpConstant %uint 3
%uint_8 = OpConstant %uint 8
%arr_float = OpTypeRuntimeArray %float
%SSBO = OpTypeStruct %arr_float
%ptr_ssbo_SSBO = OpTypePointer StorageBuffer %SSBO
%ssbo = OpVariable %ptr_ssbo_SSBO StorageBuffer
%ptr_ssbo_float = OpTypePointer StorageBuffer %float
%coopmat_a = OpTypeCooperativeMatrixKHR %float %uint_3 %uint_8 %uint_8 %uint_0
%coopmat_acc = OpTypeCooperativeMatrixKHR %float %uint_3 %uint_8 %uint_8 %uint_2
%main = OpFunction %void None %3
%5 = OpLabel
; Row-major load from offset 0
%p0 = OpAccessChain %ptr_ssbo_float %ssbo %uint_0 %uint_0
%mat_a = OpCooperativeMatrixLoadKHR %coopmat_a %p0 %uint_0 %uint_8
; Row-major store to offset 0
OpCooperativeMatrixStoreKHR %p0 %mat_a %uint_0 %uint_8
; Column-major load from offset 0
%mat_col = OpCooperativeMatrixLoadKHR %coopmat_acc %p0 %uint_1 %uint_8
; Column-major store to offset 0
OpCooperativeMatrixStoreKHR %p0 %mat_col %uint_1 %uint_8
OpReturn
OpFunctionEnd
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
; SPIR-V
; Version: 1.6
; Generator: Khronos SPIR-V Tools Assembler; 0
; Bound: 60
; Schema: 0
OpCapability Shader
OpCapability Float16
OpCapability CooperativeMatrixKHR
OpCapability VulkanMemoryModel
OpExtension "SPV_KHR_cooperative_matrix"
OpExtension "SPV_KHR_vulkan_memory_model"
OpMemoryModel Logical Vulkan
OpEntryPoint GLCompute %main "main"
OpExecutionMode %main LocalSize 32 1 1
OpName %main "main"
OpName %SSBO32 "SSBO32"
OpMemberName %SSBO32 0 "data"
OpName %ssbo32 "ssbo32"
OpName %SSBO16 "SSBO16"
OpMemberName %SSBO16 0 "data"
OpName %ssbo16 "ssbo16"
OpDecorate %arr_float ArrayStride 4
OpMemberDecorate %SSBO32 0 Offset 0
OpDecorate %SSBO32 Block
OpDecorate %ssbo32 DescriptorSet 0
OpDecorate %ssbo32 Binding 0
OpDecorate %arr_half ArrayStride 2
OpMemberDecorate %SSBO16 0 Offset 0
OpDecorate %SSBO16 Block
OpDecorate %ssbo16 DescriptorSet 0
OpDecorate %ssbo16 Binding 1
%void = OpTypeVoid
%3 = OpTypeFunction %void
%float = OpTypeFloat 32
%half = OpTypeFloat 16
%uint = OpTypeInt 32 0
%uint_0 = OpConstant %uint 0
%uint_1 = OpConstant %uint 1
%uint_2 = OpConstant %uint 2
%uint_3 = OpConstant %uint 3
%uint_8 = OpConstant %uint 8
%arr_float = OpTypeRuntimeArray %float
%SSBO32 = OpTypeStruct %arr_float
%ptr_ssbo_SSBO32 = OpTypePointer StorageBuffer %SSBO32
%ssbo32 = OpVariable %ptr_ssbo_SSBO32 StorageBuffer
%arr_half = OpTypeRuntimeArray %half
%SSBO16 = OpTypeStruct %arr_half
%ptr_ssbo_SSBO16 = OpTypePointer StorageBuffer %SSBO16
%ssbo16 = OpVariable %ptr_ssbo_SSBO16 StorageBuffer
%ptr_ssbo_float = OpTypePointer StorageBuffer %float
%ptr_ssbo_half = OpTypePointer StorageBuffer %half
; float32 cooperative matrix types
%coopmat_f32_A = OpTypeCooperativeMatrixKHR %float %uint_3 %uint_8 %uint_8 %uint_0
%coopmat_f32_B = OpTypeCooperativeMatrixKHR %float %uint_3 %uint_8 %uint_8 %uint_1
%coopmat_f32_acc = OpTypeCooperativeMatrixKHR %float %uint_3 %uint_8 %uint_8 %uint_2
; half cooperative matrix types
%coopmat_f16_A = OpTypeCooperativeMatrixKHR %half %uint_3 %uint_8 %uint_8 %uint_0
%coopmat_f16_B = OpTypeCooperativeMatrixKHR %half %uint_3 %uint_8 %uint_8 %uint_1
%coopmat_f16_acc = OpTypeCooperativeMatrixKHR %half %uint_3 %uint_8 %uint_8 %uint_2
%main = OpFunction %void None %3
%5 = OpLabel
; float32 muladd: D = A * B + C
%p_f32 = OpAccessChain %ptr_ssbo_float %ssbo32 %uint_0 %uint_0
%f_A = OpCooperativeMatrixLoadKHR %coopmat_f32_A %p_f32 %uint_0 %uint_8
%f_B = OpCooperativeMatrixLoadKHR %coopmat_f32_B %p_f32 %uint_0 %uint_8
%f_C = OpCooperativeMatrixLoadKHR %coopmat_f32_acc %p_f32 %uint_0 %uint_8
%f_D = OpCooperativeMatrixMulAddKHR %coopmat_f32_acc %f_A %f_B %f_C
OpCooperativeMatrixStoreKHR %p_f32 %f_D %uint_0 %uint_8
; half muladd: D = A * B + C
%p_f16 = OpAccessChain %ptr_ssbo_half %ssbo16 %uint_0 %uint_0
%h_A = OpCooperativeMatrixLoadKHR %coopmat_f16_A %p_f16 %uint_0 %uint_8
%h_B = OpCooperativeMatrixLoadKHR %coopmat_f16_B %p_f16 %uint_0 %uint_8
%h_C = OpCooperativeMatrixLoadKHR %coopmat_f16_acc %p_f16 %uint_0 %uint_8
%h_D = OpCooperativeMatrixMulAddKHR %coopmat_f16_acc %h_A %h_B %h_C
OpCooperativeMatrixStoreKHR %p_f16 %h_D %uint_0 %uint_8
OpReturn
OpFunctionEnd
Loading
Loading