|
31 | 31 | <script type="module"> |
32 | 32 |
|
33 | 33 | import * as THREE from 'three/webgpu'; |
34 | | - import { instancedArray, Loop, If, vec3, nativeSelect, clamp, atomicStore, subgroupAdd, uniform, uv, uint, float, Fn, vec2, invocationLocalIndex, invocationSubgroupIndex, uvec2, floor, instanceIndex, workgroupId, workgroupBarrier, workgroupArray, subgroupSize } from 'three/tsl'; |
| 34 | + import { instancedArray, Loop, If, vec3, nativeSelect, clamp, atomicStore, subgroupAdd, uniform, uv, uint, float, Fn, vec2, invocationLocalIndex, invocationSubgroupIndex, uvec2, floor, instanceIndex, workgroupId, workgroupBarrier, workgroupArray, subgroupSize, vec4 } from 'three/tsl'; |
35 | 35 |
|
36 | 36 | import WebGPU from 'three/addons/capabilities/WebGPU.js'; |
37 | 37 |
|
|
578 | 578 |
|
579 | 579 | const createReduce4Fn = ( createReduce4FnProps ) => { |
580 | 580 |
|
581 | | - const { workgroupSize, workPerThread, numElements, minSubgroupSize } = createReduce4FnProps; |
582 | | - const partitionSize = uint(workgroupSize * workPerThread); |
583 | | - const numThreadBlocks = |
| 581 | + const { workgroupSize, workPerThread, numElements, minSubgroupSize, inputBuffer } = createReduce4FnProps; |
| 582 | + // Number of elements handled by each thread is equal to the workgroupSize * the number of |
| 583 | + // elements scanned per thread * the number of elements packed into a vec4 |
| 584 | + const partitionSize = uint( workgroupSize * workPerThread * 4 ); |
| 585 | + const vecSize = numElements / 4; |
| 586 | + const NUM_WORKGROUPS = uint( divRoundUp( numElements, partitionSize ) ); |
584 | 587 |
|
585 | | - const MAX_REDUCE_SIZE = uint(workgroupSize).div(minSubgroupSize); |
586 | | - |
587 | | - vecSize = numElements / 4; |
| 588 | + const MAX_REDUCE_SIZE = uint( workgroupSize ).div( minSubgroupSize ); |
588 | 589 |
|
589 | 590 | const fnDef = Fn( () => { |
590 | 591 |
|
591 | | - const workgroupReductionArray = createSubgroupArray('uint', maxWorkgroupSize, minSubgroupSize); |
| 592 | + const workgroupReductionArray = createSubgroupArray( 'uint', maxWorkgroupSize, minSubgroupSize ); |
592 | 593 |
|
593 | 594 | // Get the index of the subgroup within the workgroup |
594 | 595 | const subgroupMetaRank = invocationLocalIndex.div( subgroupSize ); |
|
601 | 602 |
|
602 | 603 | const startThread = subgroupOffset.add( workgroupOffset ); |
603 | 604 |
|
604 | | - const subgroupReduction = uint(0); |
| 605 | + const subgroupReduction = uint( 0 ); |
605 | 606 |
|
606 | | - If(workgroupId.x.lessThan(info.thread_blocks - 1u), () => { |
| 607 | + If( workgroupId.x.lessThan( NUM_WORKGROUPS.sub( 1 ) ), () => { |
607 | 608 |
|
608 | | - const currentSubgroupInBlock = uint(0).toVar(); |
| 609 | + const currentSubgroupInBlock = uint( 0 ).toVar(); |
609 | 610 |
|
610 | | - Loop( currentSubgroupInBlock.lessThan(workPerThread), () => { |
| 611 | + Loop( currentSubgroupInBlock.lessThan( workPerThread ), () => { |
611 | 612 |
|
612 | 613 | // Get vectorized element from input array |
613 | | - const val = inputVectorizedStorage.element(startThread); |
| 614 | + const val = inputVectorizedStorage.element( startThread ); |
614 | 615 |
|
615 | 616 | // Sum values within vec4 together by using result of dot product |
616 | | - subgroupReduction.addAssign(dot(val, vec4(1))); |
| 617 | + subgroupReduction.addAssign( dot( val, vec4( 1 ) ) ); |
617 | 618 |
|
618 | 619 | // Increment so thread will scan value in next subgroup |
619 | | - startThread.addAssign(subgroupSize); |
| 620 | + startThread.addAssign( subgroupSize ); |
620 | 621 |
|
621 | 622 | // Increment to continue loop |
622 | | - currentSubgroupInBlock.addAssign(1); |
623 | | - |
624 | | - }) |
625 | | - }) |
| 623 | + currentSubgroupInBlock.addAssign( 1 ); |
| 624 | + |
| 625 | + } ); |
| 626 | + |
| 627 | + } ); |
626 | 628 |
|
627 | | - If(workgroupId.x.equal(info.thread_blocks - 1u), () => { |
| 629 | + If( workgroupId.x.equal( NUM_WORKGROUPS.sub( 1 ) ), () => { |
628 | 630 |
|
629 | | - const currentSubgroupInBlock = uint(0).toVar(); |
| 631 | + const currentSubgroupInBlock = uint( 0 ).toVar(); |
630 | 632 |
|
631 | | - Loop( currentSubgroupInBlock.lessThan(workPerThread), () => { |
| 633 | + Loop( currentSubgroupInBlock.lessThan( workPerThread ), () => { |
632 | 634 |
|
633 | | - const inputValue = inputVectorizedStorage.element(startThread); |
| 635 | + const inputValue = inputVectorizedStorage.element( startThread ); |
634 | 636 |
|
635 | | - const val = select(startThread.lessThan(vecSize), inputValue, vec4(0)); |
| 637 | + const val = nativeSelect( startThread.lessThan( vecSize ), inputValue, vec4( 0 ) ); |
636 | 638 |
|
637 | 639 | // Sum values within vec4 together by using result of dot product |
638 | | - subgroupReduction.addAssign(dot(val, vec4(1))); |
| 640 | + subgroupReduction.addAssign( dot( val, vec4( 1 ) ) ); |
639 | 641 |
|
640 | 642 | // Increment so thread will scan value in next subgroup |
641 | | - startThread.addAssign(subgroupSize); |
| 643 | + startThread.addAssign( subgroupSize ); |
642 | 644 |
|
643 | 645 | // Increment to continue loop |
644 | | - currentSubgroupInBlock.addAssign(1); |
645 | | - |
646 | | - }) |
647 | | - }) |
| 646 | + currentSubgroupInBlock.addAssign( 1 ); |
| 647 | + |
| 648 | + } ); |
| 649 | + |
| 650 | + } ); |
648 | 651 |
|
649 | | - subgroupReduction.assign(subgroupAdd(subgroupReduction)); |
| 652 | + subgroupReduction.assign( subgroupAdd( subgroupReduction ) ); |
650 | 653 |
|
651 | | - // Delegate one thread per subgroup to assign to the workgroupArray storing elements per subgroup |
652 | | - If(invocationSubgroupIndex.equal(0), () => { |
| 654 | + // Delegate one thread per subgroup to assign each subgroup's reduction to the workgroup array |
| 655 | + If( invocationSubgroupIndex.equal( 0 ), () => { |
653 | 656 |
|
654 | | - workgroupArray.element(subgroupMetaRank).assign() |
| 657 | + workgroupReductionArray.element( subgroupMetaRank ).assign( subgroupReduction ); |
655 | 658 |
|
656 | | - }) |
| 659 | + } ); |
657 | 660 |
|
658 | 661 | // Ensure that each workgroup has populated wg_reduce with data |
659 | 662 | // from each subgroup before we begin reducing down its values |
| 663 | + workgroupBarrier(); |
660 | 664 |
|
661 | | - |
662 | | - |
663 | | - { |
664 | | - for(var k = 0u; k < VEC4_SPT; k += 1u){ |
665 | | - let t = scan_in[i]; |
666 | | - s_red += dot(t, vec4(1u, 1u, 1u, 1u)); |
667 | | - i += lane_count; |
668 | | - } |
669 | | - } |
| 665 | + // WORKGROUP LEVEL REDUCE |
670 | 666 |
|
671 | | - if(wgid.x == info.thread_blocks - 1u){ |
672 | | - for(var k = 0u; k < VEC4_SPT; k += 1u){ |
673 | | - let t = select(vec4<u32>(0u, 0u, 0u, 0u), scan_in[i], i < info.vec_size); |
674 | | - s_red += dot(t, vec4(1u, 1u, 1u, 1u)); |
675 | | - i += lane_count; |
676 | | - } |
677 | | - } |
| 667 | + const subgroupSizeLog = log2( subgroupSize ); |
| 668 | + // Effectively equal to number of subgroups in the workgroup |
| 669 | + // also 'spine_size' |
| 670 | + const numSubgroupsInWorkgroup = uint( workgroupSize ).shiftRight( subgroupSizeLog ); |
| 671 | + const spineLog = log2( spineSize ); |
678 | 672 |
|
| 673 | + const alignedSize = ( spineLog.add( subgroupSizeLog ).sub( 1 ) ).div( laneLog ); |
| 674 | + alignedSize.assign( uint( 1 ).shiftLeft( alignedSize ) ); |
679 | 675 |
|
| 676 | + const offset = uint( 0 ); |
680 | 677 |
|
| 678 | + const j = subgroupSize.toVar(); |
681 | 679 |
|
682 | | - |
683 | | - } ); |
| 680 | + Loop( j.lessThanEqual( alignedSize ), () => { |
| 681 | + |
| 682 | + const subgroupIndex = ( ( invocationLocalIndex.add( 1 ) ).shiftLeft( offset ) ).sub( 1 ); |
| 683 | + |
| 684 | + const isValidSubgroupIndex = subgroupIndex.lessThan( numSubgroupsInWorkgroup ); |
| 685 | + |
| 686 | + // Reduce values within the local workgroup memory |
| 687 | + const t = subgroupAdd( select( |
| 688 | + isValidSubgroupIndex, |
| 689 | + workgroupReductionArray.element( subgroupIndex ), |
| 690 | + 0 |
| 691 | + ) ); |
| 692 | + |
| 693 | + // Can assign back to workgroupArray since all |
| 694 | + // subgroup threads work in lockstop for subgroupAdd |
| 695 | + If( isValidSubgroupIndex, () => { |
| 696 | + |
| 697 | + workgroupReductionArray.element( subgroupIndex ).assign( t ); |
| 698 | + |
| 699 | + } ); |
| 700 | + |
| 701 | + // Ensure all threads have completed work |
| 702 | + |
| 703 | + workgroupBarrier(); |
| 704 | + |
| 705 | + offset.addAssign( subgroupSizeLog ); |
| 706 | + j.shiftLeftAssign( subgroupSizeLog ); |
| 707 | + |
| 708 | + |
| 709 | + } ); |
684 | 710 |
|
685 | 711 |
|
686 | 712 |
|
| 713 | + |
| 714 | + } ); |
| 715 | + |
687 | 716 | }; |
688 | 717 |
|
689 | 718 | const incorrectBaselineCalls = [ |
|
0 commit comments