Skip to content

Commit ba5438f

Browse files
committed
work
1 parent 1494cf2 commit ba5438f

File tree

4 files changed

+115
-54
lines changed

4 files changed

+115
-54
lines changed

examples/webgpu_compute_reduce.html

Lines changed: 82 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
<script type="module">
3232

3333
import * as THREE from 'three/webgpu';
34-
import { instancedArray, Loop, If, vec3, nativeSelect, clamp, atomicStore, subgroupAdd, uniform, uv, uint, float, Fn, vec2, invocationLocalIndex, invocationSubgroupIndex, uvec2, floor, instanceIndex, workgroupId, workgroupBarrier, workgroupArray, subgroupSize } from 'three/tsl';
34+
import { instancedArray, Loop, If, vec3, nativeSelect, clamp, atomicStore, subgroupAdd, uniform, uv, uint, float, Fn, vec2, invocationLocalIndex, invocationSubgroupIndex, uvec2, floor, instanceIndex, workgroupId, workgroupBarrier, workgroupArray, subgroupSize, vec4 } from 'three/tsl';
3535

3636
import WebGPU from 'three/addons/capabilities/WebGPU.js';
3737

@@ -578,17 +578,18 @@
578578

579579
const createReduce4Fn = ( createReduce4FnProps ) => {
580580

581-
const { workgroupSize, workPerThread, numElements, minSubgroupSize } = createReduce4FnProps;
582-
const partitionSize = uint(workgroupSize * workPerThread);
583-
const numThreadBlocks =
581+
const { workgroupSize, workPerThread, numElements, minSubgroupSize, inputBuffer } = createReduce4FnProps;
582+
// Number of elements handled by each thread is equal to the workgroupSize * the number of
583+
// elements scanned per thread * the number of elements packed into a vec4
584+
const partitionSize = uint( workgroupSize * workPerThread * 4 );
585+
const vecSize = numElements / 4;
586+
const NUM_WORKGROUPS = uint( divRoundUp( numElements, partitionSize ) );
584587

585-
const MAX_REDUCE_SIZE = uint(workgroupSize).div(minSubgroupSize);
586-
587-
vecSize = numElements / 4;
588+
const MAX_REDUCE_SIZE = uint( workgroupSize ).div( minSubgroupSize );
588589

589590
const fnDef = Fn( () => {
590591

591-
const workgroupReductionArray = createSubgroupArray('uint', maxWorkgroupSize, minSubgroupSize);
592+
const workgroupReductionArray = createSubgroupArray( 'uint', maxWorkgroupSize, minSubgroupSize );
592593

593594
// Get the index of the subgroup within the workgroup
594595
const subgroupMetaRank = invocationLocalIndex.div( subgroupSize );
@@ -601,89 +602,117 @@
601602

602603
const startThread = subgroupOffset.add( workgroupOffset );
603604

604-
const subgroupReduction = uint(0);
605+
const subgroupReduction = uint( 0 );
605606

606-
If(workgroupId.x.lessThan(info.thread_blocks - 1u), () => {
607+
If( workgroupId.x.lessThan( NUM_WORKGROUPS.sub( 1 ) ), () => {
607608

608-
const currentSubgroupInBlock = uint(0).toVar();
609+
const currentSubgroupInBlock = uint( 0 ).toVar();
609610

610-
Loop( currentSubgroupInBlock.lessThan(workPerThread), () => {
611+
Loop( currentSubgroupInBlock.lessThan( workPerThread ), () => {
611612

612613
// Get vectorized element from input array
613-
const val = inputVectorizedStorage.element(startThread);
614+
const val = inputVectorizedStorage.element( startThread );
614615

615616
// Sum values within vec4 together by using result of dot product
616-
subgroupReduction.addAssign(dot(val, vec4(1)));
617+
subgroupReduction.addAssign( dot( val, vec4( 1 ) ) );
617618

618619
// Increment so thread will scan value in next subgroup
619-
startThread.addAssign(subgroupSize);
620+
startThread.addAssign( subgroupSize );
620621

621622
// Increment to continue loop
622-
currentSubgroupInBlock.addAssign(1);
623-
624-
})
625-
})
623+
currentSubgroupInBlock.addAssign( 1 );
624+
625+
} );
626+
627+
} );
626628

627-
If(workgroupId.x.equal(info.thread_blocks - 1u), () => {
629+
If( workgroupId.x.equal( NUM_WORKGROUPS.sub( 1 ) ), () => {
628630

629-
const currentSubgroupInBlock = uint(0).toVar();
631+
const currentSubgroupInBlock = uint( 0 ).toVar();
630632

631-
Loop( currentSubgroupInBlock.lessThan(workPerThread), () => {
633+
Loop( currentSubgroupInBlock.lessThan( workPerThread ), () => {
632634

633-
const inputValue = inputVectorizedStorage.element(startThread);
635+
const inputValue = inputVectorizedStorage.element( startThread );
634636

635-
const val = select(startThread.lessThan(vecSize), inputValue, vec4(0));
637+
const val = nativeSelect( startThread.lessThan( vecSize ), inputValue, vec4( 0 ) );
636638

637639
// Sum values within vec4 together by using result of dot product
638-
subgroupReduction.addAssign(dot(val, vec4(1)));
640+
subgroupReduction.addAssign( dot( val, vec4( 1 ) ) );
639641

640642
// Increment so thread will scan value in next subgroup
641-
startThread.addAssign(subgroupSize);
643+
startThread.addAssign( subgroupSize );
642644

643645
// Increment to continue loop
644-
currentSubgroupInBlock.addAssign(1);
645-
646-
})
647-
})
646+
currentSubgroupInBlock.addAssign( 1 );
647+
648+
} );
649+
650+
} );
648651

649-
subgroupReduction.assign(subgroupAdd(subgroupReduction));
652+
subgroupReduction.assign( subgroupAdd( subgroupReduction ) );
650653

651-
// Delegate one thread per subgroup to assign to the workgroupArray storing elements per subgroup
652-
If(invocationSubgroupIndex.equal(0), () => {
654+
// Delegate one thread per subgroup to assign each subgroup's reduction to the workgroup array
655+
If( invocationSubgroupIndex.equal( 0 ), () => {
653656

654-
workgroupArray.element(subgroupMetaRank).assign()
657+
workgroupReductionArray.element( subgroupMetaRank ).assign( subgroupReduction );
655658

656-
})
659+
} );
657660

658661
// Ensure that each workgroup has populated wg_reduce with data
659662
// from each subgroup before we begin reducing down its values
663+
workgroupBarrier();
660664

661-
662-
663-
{
664-
for(var k = 0u; k < VEC4_SPT; k += 1u){
665-
let t = scan_in[i];
666-
s_red += dot(t, vec4(1u, 1u, 1u, 1u));
667-
i += lane_count;
668-
}
669-
}
665+
// WORKGROUP LEVEL REDUCE
670666

671-
if(wgid.x == info.thread_blocks - 1u){
672-
for(var k = 0u; k < VEC4_SPT; k += 1u){
673-
let t = select(vec4<u32>(0u, 0u, 0u, 0u), scan_in[i], i < info.vec_size);
674-
s_red += dot(t, vec4(1u, 1u, 1u, 1u));
675-
i += lane_count;
676-
}
677-
}
667+
const subgroupSizeLog = log2( subgroupSize );
668+
// Effectively equal to number of subgroups in the workgroup
669+
// also 'spine_size'
670+
const numSubgroupsInWorkgroup = uint( workgroupSize ).shiftRight( subgroupSizeLog );
671+
const spineLog = log2( spineSize );
678672

673+
const alignedSize = ( spineLog.add( subgroupSizeLog ).sub( 1 ) ).div( laneLog );
674+
alignedSize.assign( uint( 1 ).shiftLeft( alignedSize ) );
679675

676+
const offset = uint( 0 );
680677

678+
const j = subgroupSize.toVar();
681679

682-
683-
} );
680+
Loop( j.lessThanEqual( alignedSize ), () => {
681+
682+
const subgroupIndex = ( ( invocationLocalIndex.add( 1 ) ).shiftLeft( offset ) ).sub( 1 );
683+
684+
const isValidSubgroupIndex = subgroupIndex.lessThan( numSubgroupsInWorkgroup );
685+
686+
// Reduce values within the local workgroup memory
687+
const t = subgroupAdd( select(
688+
isValidSubgroupIndex,
689+
workgroupReductionArray.element( subgroupIndex ),
690+
0
691+
) );
692+
693+
// Can assign back to workgroupArray since all
694+
// subgroup threads work in lockstop for subgroupAdd
695+
If( isValidSubgroupIndex, () => {
696+
697+
workgroupReductionArray.element( subgroupIndex ).assign( t );
698+
699+
} );
700+
701+
// Ensure all threads have completed work
702+
703+
workgroupBarrier();
704+
705+
offset.addAssign( subgroupSizeLog );
706+
j.shiftLeftAssign( subgroupSizeLog );
707+
708+
709+
} );
684710

685711

686712

713+
714+
} );
715+
687716
};
688717

689718
const incorrectBaselineCalls = [

src/nodes/core/ContextNode.js

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,11 @@ export const context = /*@__PURE__*/ nodeProxy( ContextNode ).setParameterLength
142142
export const uniformFlow = ( node ) => context( node, { uniformFlow: true } );
143143

144144
/**
145+
<<<<<<< HEAD
145146
* TSL function for defining a name for the context value for a given node.
147+
=======
148+
* TSL function for defining a label context value for a given node.
149+
>>>>>>> d15ca48302 (work)
146150
*
147151
* @tsl
148152
* @function

src/nodes/math/ConditionalNode.js

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,33 @@ class ConditionalNode extends Node {
135135

136136
const { condNode, ifNode, elseNode } = builder.getNodeProperties( this );
137137

138+
const isUniformControlFlow = builder.context.uniformFlow;
139+
140+
// Build node using ternary operator or select
141+
if ( isUniformControlFlow ) {
142+
143+
const condSnippet = condNode.build( builder, 'bool' );
144+
const ifSnippet = ifNode.build( builder, type );
145+
const elseSnippet = elseNode.build( builder, type );
146+
147+
let codeSnippet = '';
148+
149+
if ( builder.renderer.backend.isWebGLBackend ) {
150+
151+
codeSnippet = `${condSnippet} ? ${ifSnippet} : ${elseSnippet}`;
152+
153+
} else {
154+
155+
codeSnippet = `select(${elseSnippet}, ${ifSnippet}, )`;
156+
157+
}
158+
159+
builder.addFlowCode( codeSnippet );
160+
161+
return builder.format( nodeProperty, type, output );
162+
163+
}
164+
138165
const functionNode = builder.currentFunctionNode;
139166
const needsOutput = output !== 'void';
140167
const nodeProperty = needsOutput ? property( type ).build( builder ) : '';

src/renderers/webgl-fallback/nodes/GLSLNodeBuilder.js

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ import { DataTexture } from '../../../textures/DataTexture.js';
1010

1111
const glslMethods = {
1212
textureDimensions: 'textureSize',
13-
equals: 'equal'
13+
equals: 'equal',
14+
countTrailingZeros: 'findLSB'
1415
};
1516

1617
const precisionLib = {

0 commit comments

Comments
 (0)