Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#include <metal_stdlib>
#include <simd/simd.h>
#include <metal_simdgroup_matrix>

using namespace metal;

struct SSBO
{
bfloat data[1];
};

kernel void main0(device SSBO& ssbo [[buffer(0)]])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing a test for shared memory load-store. MSL spec says it's supported.

Also, this is missing tests for load-store with different types. E.g. in SPIR-V you can load a float16_t coopmat via a uint8_t array of data. Just pointer casting and recomputing the row/col stride should work.

{
simdgroup_bfloat8x8 _21;
simdgroup_load(_21, &ssbo.data[0u], 8u);
simdgroup_bfloat8x8 _22;
simdgroup_load(_22, &ssbo.data[0u], 8u);
simdgroup_bfloat8x8 _23;
simdgroup_load(_23, &ssbo.data[0u], 8u);
simdgroup_bfloat8x8 _24;
simdgroup_multiply_accumulate(_24, _21, _22, _23);
simdgroup_store(_24, &ssbo.data[0u], 8u);
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#include <metal_stdlib>
#include <simd/simd.h>
#include <metal_simdgroup_matrix>

using namespace metal;

struct SSBO
{
uint data[1];
};

kernel void main0(device SSBO& ssbo [[buffer(0)]])
{
ssbo.data[0u] = uint(sizeof(simdgroup_float8x8::storage_type) / sizeof(float));
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#include <metal_stdlib>
#include <simd/simd.h>
#include <metal_simdgroup_matrix>

using namespace metal;

struct SSBO
{
float data[1];
};

kernel void main0(device SSBO& ssbo [[buffer(0)]])
{
simdgroup_float8x8 _20;
simdgroup_load(_20, &ssbo.data[0u], 8u);
simdgroup_store(_20, &ssbo.data[0u], 8u);
simdgroup_float8x8 _21;
simdgroup_load(_21, &ssbo.data[0u], 8u, ulong2(0), true);
simdgroup_store(_21, &ssbo.data[0u], 8u, ulong2(0), true);
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#include <metal_stdlib>
#include <simd/simd.h>
#include <metal_simdgroup_matrix>

using namespace metal;

struct SSBO32
{
float data[1];
};

struct SSBO16
{
half data[1];
};

kernel void main0(device SSBO32& ssbo32 [[buffer(0)]], device SSBO16& ssbo16 [[buffer(1)]])
{
simdgroup_float8x8 _30;
simdgroup_load(_30, &ssbo32.data[0u], 8u);
simdgroup_float8x8 _31;
simdgroup_load(_31, &ssbo32.data[0u], 8u);
simdgroup_float8x8 _32;
simdgroup_load(_32, &ssbo32.data[0u], 8u);
simdgroup_float8x8 _33;
simdgroup_multiply_accumulate(_33, _30, _31, _32);
simdgroup_store(_33, &ssbo32.data[0u], 8u);
simdgroup_half8x8 _35;
simdgroup_load(_35, &ssbo16.data[0u], 8u);
simdgroup_half8x8 _36;
simdgroup_load(_36, &ssbo16.data[0u], 8u);
simdgroup_half8x8 _37;
simdgroup_load(_37, &ssbo16.data[0u], 8u);
simdgroup_half8x8 _38;
simdgroup_multiply_accumulate(_38, _35, _36, _37);
simdgroup_store(_38, &ssbo16.data[0u], 8u);
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#pragma clang diagnostic ignored "-Wmissing-prototypes"
#pragma clang diagnostic ignored "-Wmissing-braces"

#include <metal_stdlib>
#include <simd/simd.h>
#include <metal_simdgroup_matrix>

using namespace metal;

template<typename T, size_t Num>
struct spvUnsafeArray
{
T elements[Num ? Num : 1];

thread T& operator [] (size_t pos) thread
{
return elements[pos];
}
constexpr const thread T& operator [] (size_t pos) const thread
{
return elements[pos];
}

device T& operator [] (size_t pos) device
{
return elements[pos];
}
constexpr const device T& operator [] (size_t pos) const device
{
return elements[pos];
}

constexpr const constant T& operator [] (size_t pos) const constant
{
return elements[pos];
}

threadgroup T& operator [] (size_t pos) threadgroup
{
return elements[pos];
}
constexpr const threadgroup T& operator [] (size_t pos) const threadgroup
{
return elements[pos];
}
};

kernel void main0()
{
threadgroup spvUnsafeArray<uchar, 128> _15;
_15[0u] = uchar(0);
simdgroup_half8x8 _20;
simdgroup_load(_20, reinterpret_cast<threadgroup half*>(&_15[0u]), (16u) / 2u);
simdgroup_store(_20, reinterpret_cast<threadgroup half*>(&_15[0u]), (16u) / 2u);
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#pragma clang diagnostic ignored "-Wmissing-prototypes"
#pragma clang diagnostic ignored "-Wmissing-braces"

#include <metal_stdlib>
#include <simd/simd.h>
#include <metal_simdgroup_matrix>

using namespace metal;

template<typename T, size_t Num>
struct spvUnsafeArray
{
T elements[Num ? Num : 1];

thread T& operator [] (size_t pos) thread
{
return elements[pos];
}
constexpr const thread T& operator [] (size_t pos) const thread
{
return elements[pos];
}

device T& operator [] (size_t pos) device
{
return elements[pos];
}
constexpr const device T& operator [] (size_t pos) const device
{
return elements[pos];
}

constexpr const constant T& operator [] (size_t pos) const constant
{
return elements[pos];
}

threadgroup T& operator [] (size_t pos) threadgroup
{
return elements[pos];
}
constexpr const threadgroup T& operator [] (size_t pos) const threadgroup
{
return elements[pos];
}
};

kernel void main0()
{
threadgroup spvUnsafeArray<float, 64> _14;
simdgroup_float8x8 _18;
simdgroup_load(_18, &_14[0u], 8u);
simdgroup_store(_18, &_14[0u], 8u);
simdgroup_float8x8 _19;
simdgroup_load(_19, &_14[0u], 8u, ulong2(0), true);
simdgroup_store(_19, &_14[0u], 8u, ulong2(0), true);
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
; SPIR-V
; Version: 1.6
; Generator: Khronos SPIR-V Tools Assembler; 0
; Bound: 50
; Schema: 0
OpCapability Shader
OpCapability CooperativeMatrixKHR
OpCapability BFloat16TypeKHR
OpCapability BFloat16CooperativeMatrixKHR
OpCapability VulkanMemoryModel
OpExtension "SPV_KHR_cooperative_matrix"
OpExtension "SPV_KHR_bfloat16"
OpExtension "SPV_KHR_vulkan_memory_model"
OpMemoryModel Logical Vulkan
OpEntryPoint GLCompute %main "main"
OpExecutionMode %main LocalSize 32 1 1
OpName %main "main"
OpName %SSBO "SSBO"
OpMemberName %SSBO 0 "data"
OpName %ssbo "ssbo"
OpDecorate %arr_bf16 ArrayStride 2
OpMemberDecorate %SSBO 0 Offset 0
OpDecorate %SSBO Block
OpDecorate %ssbo DescriptorSet 0
OpDecorate %ssbo Binding 0
%void = OpTypeVoid
%3 = OpTypeFunction %void
%bfloat = OpTypeFloat 16 BFloat16KHR
%uint = OpTypeInt 32 0
%uint_0 = OpConstant %uint 0
%uint_1 = OpConstant %uint 1
%uint_2 = OpConstant %uint 2
%uint_3 = OpConstant %uint 3
%uint_8 = OpConstant %uint 8
%arr_bf16 = OpTypeRuntimeArray %bfloat
%SSBO = OpTypeStruct %arr_bf16
%ptr_ssbo_SSBO = OpTypePointer StorageBuffer %SSBO
%ssbo = OpVariable %ptr_ssbo_SSBO StorageBuffer
%ptr_ssbo_bf16 = OpTypePointer StorageBuffer %bfloat
%coopmat_bf16_A = OpTypeCooperativeMatrixKHR %bfloat %uint_3 %uint_8 %uint_8 %uint_0
%coopmat_bf16_B = OpTypeCooperativeMatrixKHR %bfloat %uint_3 %uint_8 %uint_8 %uint_1
%coopmat_bf16_acc = OpTypeCooperativeMatrixKHR %bfloat %uint_3 %uint_8 %uint_8 %uint_2
%main = OpFunction %void None %3
%5 = OpLabel
%p0 = OpAccessChain %ptr_ssbo_bf16 %ssbo %uint_0 %uint_0
%bf_A = OpCooperativeMatrixLoadKHR %coopmat_bf16_A %p0 %uint_0 %uint_8
%bf_B = OpCooperativeMatrixLoadKHR %coopmat_bf16_B %p0 %uint_0 %uint_8
%bf_C = OpCooperativeMatrixLoadKHR %coopmat_bf16_acc %p0 %uint_0 %uint_8
%bf_D = OpCooperativeMatrixMulAddKHR %coopmat_bf16_acc %bf_A %bf_B %bf_C
OpCooperativeMatrixStoreKHR %p0 %bf_D %uint_0 %uint_8
OpReturn
OpFunctionEnd
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
; SPIR-V
; Version: 1.6
; Generator: Khronos SPIR-V Tools Assembler; 0
; Bound: 24
; Schema: 0
OpCapability Shader
OpCapability CooperativeMatrixKHR
OpCapability VulkanMemoryModel
OpExtension "SPV_KHR_cooperative_matrix"
OpExtension "SPV_KHR_vulkan_memory_model"
OpMemoryModel Logical Vulkan
OpEntryPoint GLCompute %main "main"
OpExecutionMode %main LocalSize 32 1 1
OpName %main "main"
OpName %SSBO "SSBO"
OpMemberName %SSBO 0 "data"
OpName %ssbo "ssbo"
OpDecorate %arr_uint ArrayStride 4
OpMemberDecorate %SSBO 0 Offset 0
OpDecorate %SSBO Block
OpDecorate %ssbo DescriptorSet 0
OpDecorate %ssbo Binding 0
%void = OpTypeVoid
%3 = OpTypeFunction %void
%uint = OpTypeInt 32 0
%float = OpTypeFloat 32
%uint_0 = OpConstant %uint 0
%uint_3 = OpConstant %uint 3
%uint_8 = OpConstant %uint 8
%arr_uint = OpTypeRuntimeArray %uint
%SSBO = OpTypeStruct %arr_uint
%ptr_ssbo_SSBO = OpTypePointer StorageBuffer %SSBO
%ssbo = OpVariable %ptr_ssbo_SSBO StorageBuffer
%ptr_ssbo_uint = OpTypePointer StorageBuffer %uint
%coopmat_a = OpTypeCooperativeMatrixKHR %float %uint_3 %uint_8 %uint_8 %uint_0
%main = OpFunction %void None %3
%5 = OpLabel
%len = OpCooperativeMatrixLengthKHR %uint %coopmat_a
%p = OpAccessChain %ptr_ssbo_uint %ssbo %uint_0 %uint_0
OpStore %p %len
OpReturn
OpFunctionEnd
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
; SPIR-V
; Version: 1.6
; Generator: Khronos SPIR-V Tools Assembler; 0
; Bound: 50
; Schema: 0
OpCapability Shader
OpCapability CooperativeMatrixKHR
OpCapability VulkanMemoryModel
OpExtension "SPV_KHR_cooperative_matrix"
OpExtension "SPV_KHR_vulkan_memory_model"
OpMemoryModel Logical Vulkan
OpEntryPoint GLCompute %main "main"
OpExecutionMode %main LocalSize 32 1 1
OpName %main "main"
OpName %SSBO "SSBO"
OpMemberName %SSBO 0 "data"
OpName %ssbo "ssbo"
OpDecorate %arr_float ArrayStride 4
OpMemberDecorate %SSBO 0 Offset 0
OpDecorate %SSBO Block
OpDecorate %ssbo DescriptorSet 0
OpDecorate %ssbo Binding 0
%void = OpTypeVoid
%3 = OpTypeFunction %void
%float = OpTypeFloat 32
%uint = OpTypeInt 32 0
%uint_0 = OpConstant %uint 0
%uint_1 = OpConstant %uint 1
%uint_2 = OpConstant %uint 2
%uint_3 = OpConstant %uint 3
%uint_8 = OpConstant %uint 8
%arr_float = OpTypeRuntimeArray %float
%SSBO = OpTypeStruct %arr_float
%ptr_ssbo_SSBO = OpTypePointer StorageBuffer %SSBO
%ssbo = OpVariable %ptr_ssbo_SSBO StorageBuffer
%ptr_ssbo_float = OpTypePointer StorageBuffer %float
%coopmat_a = OpTypeCooperativeMatrixKHR %float %uint_3 %uint_8 %uint_8 %uint_0
%coopmat_acc = OpTypeCooperativeMatrixKHR %float %uint_3 %uint_8 %uint_8 %uint_2
%main = OpFunction %void None %3
%5 = OpLabel
; Row-major load from offset 0
%p0 = OpAccessChain %ptr_ssbo_float %ssbo %uint_0 %uint_0
%mat_a = OpCooperativeMatrixLoadKHR %coopmat_a %p0 %uint_0 %uint_8
; Row-major store to offset 0
OpCooperativeMatrixStoreKHR %p0 %mat_a %uint_0 %uint_8
; Column-major load from offset 0
%mat_col = OpCooperativeMatrixLoadKHR %coopmat_acc %p0 %uint_1 %uint_8
; Column-major store to offset 0
OpCooperativeMatrixStoreKHR %p0 %mat_col %uint_1 %uint_8
OpReturn
OpFunctionEnd
Loading
Loading