1212#pragma OPENCL EXTENSION cl_khr_fp16 : enable
1313#pragma OPENCL EXTENSION cl_khr_fp64 : enable
1414
15+ int __nvvm_reflect (const char __constant * );
16+
1517// CLC helpers
1618__local bool *
1719__clc__get_group_scratch_bool () __asm("__clc__get_group_scratch_bool" );
@@ -150,43 +152,58 @@ __clc__SubgroupBitwiseAny(uint op, bool predicate, bool *carry) {
150152#define __CLC_OR (x , y ) (x | y)
151153#define __CLC_AND (x , y ) (x & y)
152154
155+ #define __CLC_SUBGROUP_COLLECTIVE_BODY (OP , TYPE , IDENTITY ) \
156+ uint sg_lid = __spirv_SubgroupLocalInvocationId(); \
157+ /* Can't use XOR/butterfly shuffles; some lanes may be inactive */ \
158+ for (int o = 1 ; o < __spirv_SubgroupMaxSize (); o *= 2 ) { \
159+ TYPE contribution = __clc__SubgroupShuffleUp (x , o ); \
160+ bool inactive = (sg_lid < o ); \
161+ contribution = (inactive ) ? IDENTITY : contribution ; \
162+ x = OP (x , contribution ); \
163+ } \
164+ /* For Reduce, broadcast result from highest active lane */ \
165+ TYPE result ; \
166+ if (op == Reduce ) { \
167+ result = __clc__SubgroupShuffle (x , __spirv_SubgroupSize () - 1 ); \
168+ * carry = result ; \
169+ } /* For InclusiveScan, use results as computed */ \
170+ else if (op == InclusiveScan ) { \
171+ result = x ; \
172+ * carry = result ; \
173+ } /* For ExclusiveScan, shift and prepend identity */ \
174+ else if (op == ExclusiveScan ) { \
175+ * carry = x ; \
176+ result = __clc__SubgroupShuffleUp (x , 1 ); \
177+ if (sg_lid == 0 ) { \
178+ result = IDENTITY ; \
179+ } \
180+ } \
181+ return result ;
182+
153183#define __CLC_SUBGROUP_COLLECTIVE (NAME , OP , TYPE , IDENTITY ) \
154184 _CLC_DEF _CLC_OVERLOAD _CLC_CONVERGENT TYPE __CLC_APPEND( \
155185 __clc__Subgroup, NAME)(uint op, TYPE x, TYPE * carry) { \
156- uint sg_lid = __spirv_SubgroupLocalInvocationId(); \
157- /* Can't use XOR/butterfly shuffles; some lanes may be inactive */ \
158- for (int o = 1 ; o < __spirv_SubgroupMaxSize (); o *= 2 ) { \
159- TYPE contribution = __clc__SubgroupShuffleUp (x , o ); \
160- bool inactive = (sg_lid < o ); \
161- contribution = (inactive ) ? IDENTITY : contribution ; \
162- x = OP (x , contribution ); \
163- } \
164- /* For Reduce, broadcast result from highest active lane */ \
165- TYPE result ; \
166- if (op == Reduce ) { \
167- result = __clc__SubgroupShuffle (x , __spirv_SubgroupSize () - 1 ); \
168- * carry = result ; \
169- } /* For InclusiveScan, use results as computed */ \
170- else if (op == InclusiveScan ) { \
171- result = x ; \
186+ __CLC_SUBGROUP_COLLECTIVE_BODY(OP, TYPE, IDENTITY) \
187+ }
188+
189+ #define __CLC_SUBGROUP_COLLECTIVE_REDUX (NAME , OP , REDUX_OP , TYPE , IDENTITY ) \
190+ _CLC_DEF _CLC_OVERLOAD _CLC_CONVERGENT TYPE __CLC_APPEND( \
191+ __clc__Subgroup, NAME)(uint op, TYPE x, TYPE * carry) { \
192+ /* Fast path for warp reductions for sm_80+ */ \
193+ if (__nvvm_reflect ("__CUDA_ARCH" ) >= 800 && op == Reduce ) { \
194+ TYPE result = __nvvm_redux_sync_ ##REDUX_OP (x, __clc__membermask()); \
172195 *carry = result; \
173- } /* For ExclusiveScan, shift and prepend identity */ \
174- else if (op == ExclusiveScan ) { \
175- * carry = x ; \
176- result = __clc__SubgroupShuffleUp (x , 1 ); \
177- if (sg_lid == 0 ) { \
178- result = IDENTITY ; \
179- } \
196+ return result; \
180197 } \
181- return result ; \
198+ __CLC_SUBGROUP_COLLECTIVE_BODY(OP, TYPE, IDENTITY) \
182199 }
183200
184201__CLC_SUBGROUP_COLLECTIVE (IAdd , __CLC_ADD , char , 0 )
185202__CLC_SUBGROUP_COLLECTIVE (IAdd , __CLC_ADD , uchar , 0 )
186203__CLC_SUBGROUP_COLLECTIVE (IAdd , __CLC_ADD , short , 0 )
187204__CLC_SUBGROUP_COLLECTIVE (IAdd , __CLC_ADD , ushort , 0 )
188- __CLC_SUBGROUP_COLLECTIVE (IAdd , __CLC_ADD , int , 0 )
189- __CLC_SUBGROUP_COLLECTIVE (IAdd , __CLC_ADD , uint , 0 )
205+ __CLC_SUBGROUP_COLLECTIVE_REDUX (IAdd , __CLC_ADD , add , int , 0 )
206+ __CLC_SUBGROUP_COLLECTIVE_REDUX (IAdd , __CLC_ADD , add , uint , 0 )
190207__CLC_SUBGROUP_COLLECTIVE (IAdd , __CLC_ADD , long , 0 )
191208__CLC_SUBGROUP_COLLECTIVE (IAdd , __CLC_ADD , ulong , 0 )
192209__CLC_SUBGROUP_COLLECTIVE (FAdd , __CLC_ADD , half , 0 )
@@ -197,8 +214,8 @@ __CLC_SUBGROUP_COLLECTIVE(SMin, __CLC_MIN, char, CHAR_MAX)
197214__CLC_SUBGROUP_COLLECTIVE (UMin , __CLC_MIN , uchar , UCHAR_MAX )
198215__CLC_SUBGROUP_COLLECTIVE (SMin , __CLC_MIN , short , SHRT_MAX )
199216__CLC_SUBGROUP_COLLECTIVE (UMin , __CLC_MIN , ushort , USHRT_MAX )
200- __CLC_SUBGROUP_COLLECTIVE (SMin , __CLC_MIN , int , INT_MAX )
201- __CLC_SUBGROUP_COLLECTIVE (UMin , __CLC_MIN , uint , UINT_MAX )
217+ __CLC_SUBGROUP_COLLECTIVE_REDUX (SMin , __CLC_MIN , min , int , INT_MAX )
218+ __CLC_SUBGROUP_COLLECTIVE_REDUX (UMin , __CLC_MIN , umin , uint , UINT_MAX )
202219__CLC_SUBGROUP_COLLECTIVE (SMin , __CLC_MIN , long , LONG_MAX )
203220__CLC_SUBGROUP_COLLECTIVE (UMin , __CLC_MIN , ulong , ULONG_MAX )
204221__CLC_SUBGROUP_COLLECTIVE (FMin , __CLC_MIN , half , HALF_MAX )
@@ -209,15 +226,17 @@ __CLC_SUBGROUP_COLLECTIVE(SMax, __CLC_MAX, char, CHAR_MIN)
209226__CLC_SUBGROUP_COLLECTIVE (UMax , __CLC_MAX , uchar , 0 )
210227__CLC_SUBGROUP_COLLECTIVE (SMax , __CLC_MAX , short , SHRT_MIN )
211228__CLC_SUBGROUP_COLLECTIVE (UMax , __CLC_MAX , ushort , 0 )
212- __CLC_SUBGROUP_COLLECTIVE (SMax , __CLC_MAX , int , INT_MIN )
213- __CLC_SUBGROUP_COLLECTIVE (UMax , __CLC_MAX , uint , 0 )
229+ __CLC_SUBGROUP_COLLECTIVE_REDUX (SMax , __CLC_MAX , max , int , INT_MIN )
230+ __CLC_SUBGROUP_COLLECTIVE_REDUX (UMax , __CLC_MAX , umax , uint , 0 )
214231__CLC_SUBGROUP_COLLECTIVE (SMax , __CLC_MAX , long , LONG_MIN )
215232__CLC_SUBGROUP_COLLECTIVE (UMax , __CLC_MAX , ulong , 0 )
216233__CLC_SUBGROUP_COLLECTIVE (FMax , __CLC_MAX , half , - HALF_MAX )
217234__CLC_SUBGROUP_COLLECTIVE (FMax , __CLC_MAX , float , - FLT_MAX )
218235__CLC_SUBGROUP_COLLECTIVE (FMax , __CLC_MAX , double , - DBL_MAX )
219236
237+ #undef __CLC_SUBGROUP_COLLECTIVE_BODY
220238#undef __CLC_SUBGROUP_COLLECTIVE
239+ #undef __CLC_SUBGROUP_COLLECTIVE_REDUX
221240
222241#define __CLC_GROUP_COLLECTIVE (NAME , OP , TYPE , IDENTITY ) \
223242 _CLC_DEF _CLC_OVERLOAD _CLC_CONVERGENT TYPE __CLC_APPEND( \
0 commit comments