Skip to content

Commit 1494cf2

Browse files
committed
work
1 parent 1e9b7b9 commit 1494cf2

File tree

2 files changed

+113
-1
lines changed

2 files changed

+113
-1
lines changed

examples/webgpu_compute_reduce.html

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@
9191

9292
const rightEffectController = {
9393
algo: 'Reduce 3 (Subgroup Reduce)',
94-
currentAlgo: uniform( 0 ),
94+
currentAlgo: uniform( 3 ),
9595
highlight: uniform( 0 ),
9696
displayMode: 'Input Element 0',
9797
state: 'Run Algo',
@@ -179,6 +179,7 @@
179179
} ) );
180180

181181
const inputStorage = instancedArray( array, 'uint', size ).setPBO( true ).setName( `Current_${leftSideDisplay ? 'Left' : 'Right'}` );
182+
const inputVectorizedStorage = instancedArray( array, 'vec4' ).setPBO( true ).setName( `CurrentVectorized_${leftSideDisplay ? 'Left' : 'Right'}` );
182183
const atomicAccumulator = instancedArray( new Uint32Array( 1 ), 'uint' ).setPBO( true ).toAtomic();
183184

184185
// Reduce 3 Calculations
@@ -575,6 +576,116 @@
575576
} ).compute( 32, [ 32 ] )
576577
];
577578

579+
const createReduce4Fn = ( createReduce4FnProps ) => {
580+
581+
const { workgroupSize, workPerThread, numElements, minSubgroupSize } = createReduce4FnProps;
582+
const partitionSize = uint(workgroupSize * workPerThread);
583+
const numThreadBlocks =
584+
585+
const MAX_REDUCE_SIZE = uint(workgroupSize).div(minSubgroupSize);
586+
587+
vecSize = numElements / 4;
588+
589+
const fnDef = Fn( () => {
590+
591+
const workgroupReductionArray = createSubgroupArray('uint', maxWorkgroupSize, minSubgroupSize);
592+
593+
// Get the index of the subgroup within the workgroup
594+
const subgroupMetaRank = invocationLocalIndex.div( subgroupSize );
595+
// Offset by 4 subgroups from current subgroup since each thread will scan in values from across 4 subgroups
596+
const subgroupOffset = subgroupMetaRank.mul( subgroupSize ).mul( workPerThread );
597+
subgroupOffset.addAssign( invocationSubgroupIndex );
598+
599+
// Per workgroup, offset by number of elements scanned per workgroup
600+
const workgroupOffset = workgroupId.x.mul( partitionSize );
601+
602+
const startThread = subgroupOffset.add( workgroupOffset );
603+
604+
const subgroupReduction = uint(0);
605+
606+
If(workgroupId.x.lessThan(info.thread_blocks - 1u), () => {
607+
608+
const currentSubgroupInBlock = uint(0).toVar();
609+
610+
Loop( currentSubgroupInBlock.lessThan(workPerThread), () => {
611+
612+
// Get vectorized element from input array
613+
const val = inputVectorizedStorage.element(startThread);
614+
615+
// Sum values within vec4 together by using result of dot product
616+
subgroupReduction.addAssign(dot(val, vec4(1)));
617+
618+
// Increment so thread will scan value in next subgroup
619+
startThread.addAssign(subgroupSize);
620+
621+
// Increment to continue loop
622+
currentSubgroupInBlock.addAssign(1);
623+
624+
})
625+
})
626+
627+
If(workgroupId.x.equal(info.thread_blocks - 1u), () => {
628+
629+
const currentSubgroupInBlock = uint(0).toVar();
630+
631+
Loop( currentSubgroupInBlock.lessThan(workPerThread), () => {
632+
633+
const inputValue = inputVectorizedStorage.element(startThread);
634+
635+
const val = select(startThread.lessThan(vecSize), inputValue, vec4(0));
636+
637+
// Sum values within vec4 together by using result of dot product
638+
subgroupReduction.addAssign(dot(val, vec4(1)));
639+
640+
// Increment so thread will scan value in next subgroup
641+
startThread.addAssign(subgroupSize);
642+
643+
// Increment to continue loop
644+
currentSubgroupInBlock.addAssign(1);
645+
646+
})
647+
})
648+
649+
subgroupReduction.assign(subgroupAdd(subgroupReduction));
650+
651+
// Delegate one thread per subgroup to assign to the workgroupArray storing elements per subgroup
652+
If(invocationSubgroupIndex.equal(0), () => {
653+
654+
workgroupArray.element(subgroupMetaRank).assign()
655+
656+
})
657+
658+
// Ensure that each workgroup has populated wg_reduce with data
659+
// from each subgroup before we begin reducing down its values
660+
661+
662+
663+
{
664+
for(var k = 0u; k < VEC4_SPT; k += 1u){
665+
let t = scan_in[i];
666+
s_red += dot(t, vec4(1u, 1u, 1u, 1u));
667+
i += lane_count;
668+
}
669+
}
670+
671+
if(wgid.x == info.thread_blocks - 1u){
672+
for(var k = 0u; k < VEC4_SPT; k += 1u){
673+
let t = select(vec4<u32>(0u, 0u, 0u, 0u), scan_in[i], i < info.vec_size);
674+
s_red += dot(t, vec4(1u, 1u, 1u, 1u));
675+
i += lane_count;
676+
}
677+
}
678+
679+
680+
681+
682+
683+
} );
684+
685+
686+
687+
};
688+
578689
const incorrectBaselineCalls = [
579690
createIncorrectBaselineFn().compute( size ),
580691
];

src/renderers/webgpu/nodes/WGSLNodeBuilder.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1867,6 +1867,7 @@ ${ flowData.code }
18671867
const workgroupSize = this.object.workgroupSize;
18681868

18691869
this.computeShader = this._getWGSLComputeCode( shadersData.compute, workgroupSize );
1870+
console.log( this.computeShader );
18701871

18711872
}
18721873

0 commit comments

Comments
 (0)