@@ -2,21 +2,47 @@ import { Quat } from '../../core/math/quat.js';
22import { Vec3 } from '../../core/math/vec3.js' ;
33import { Vec4 } from '../../core/math/vec4.js' ;
44import { GSplatData } from './gsplat-data.js' ;
5+ import { BlendState } from '../../platform/graphics/blend-state.js' ;
6+ import { DepthState } from '../../platform/graphics/depth-state.js' ;
7+ import { RenderTarget } from '../../platform/graphics/render-target.js' ;
8+ import { Texture } from '../../platform/graphics/texture.js' ;
9+ import { CULLFACE_NONE , PIXELFORMAT_R32U , PIXELFORMAT_RGBA8 , SEMANTIC_POSITION } from '../../platform/graphics/constants.js' ;
10+ import { drawQuadWithShader } from '../../scene/graphics/quad-render-utils.js' ;
11+ import { createShaderFromCode } from '../shader-lib/shader-utils.js' ;
512
6- let offscreen = null ;
7- let ctx = null ;
13+ const SH_C0 = 0.28209479177387814 ;
814
9- const readImageData = ( imageBitmap ) => {
10- if ( ! offscreen || offscreen . width !== imageBitmap . width || offscreen . height !== imageBitmap . height ) {
11- offscreen = new OffscreenCanvas ( imageBitmap . width , imageBitmap . height ) ;
12- ctx = offscreen . getContext ( '2d' ) ;
13- ctx . globalCompositeOperation = 'copy' ;
15+ const reorderVS = /* glsl */ `
16+ attribute vec2 vertex_position;
17+ void main(void) {
18+ gl_Position = vec4(vertex_position, 0.0, 1.0);
1419 }
15- ctx . drawImage ( imageBitmap , 0 , 0 ) ;
16- return ctx . getImageData ( 0 , 0 , imageBitmap . width , imageBitmap . height ) . data ;
17- } ;
20+ ` ;
1821
19- const SH_C0 = 0.28209479177387814 ;
22+ const reorderFS = /* glsl */ `
23+ uniform usampler2D orderTexture;
24+ uniform sampler2D sourceTexture;
25+ uniform highp uint numSplats;
26+
27+ void main(void) {
28+ uint w = uint(textureSize(sourceTexture, 0).x);
29+ uint idx = uint(gl_FragCoord.x) + uint(gl_FragCoord.y) * w;
30+ if (idx >= numSplats) discard;
31+
32+ // fetch the source index and calculate source uv
33+ uint sidx = texelFetch(orderTexture, ivec2(gl_FragCoord.xy), 0).x;
34+ uvec2 suv = uvec2(sidx % w, sidx / w);
35+
36+ // sample the source texture
37+ gl_FragColor = texelFetch(sourceTexture, ivec2(suv), 0);
38+ }
39+ ` ;
40+
41+ const resolve = ( scope , values ) => {
42+ for ( const key in values ) {
43+ scope . resolve ( key ) . setValue ( values [ key ] ) ;
44+ }
45+ } ;
2046
2147class GSplatSogsIterator {
2248 constructor ( data , p , r , s , c , sh ) {
@@ -25,16 +51,16 @@ class GSplatSogsIterator {
2551
2652 // extract means for centers
2753 const { meta } = data ;
28- const { means, quats , scales, opacities , sh0, shN } = meta ;
29- const means_l_data = p && readImageData ( data . means_l . _levels [ 0 ] ) ;
30- const means_u_data = p && readImageData ( data . means_u . _levels [ 0 ] ) ;
31- const quats_data = r && readImageData ( data . quats . _levels [ 0 ] ) ;
32- const scales_data = s && readImageData ( data . scales . _levels [ 0 ] ) ;
33- const opacities_data = c && readImageData ( data . opacities . _levels [ 0 ] ) ;
34- const sh0_data = c && readImageData ( data . sh0 . _levels [ 0 ] ) ;
35- const sh_labels_l_data = sh && readImageData ( data . sh_labels_l . _levels [ 0 ] ) ;
36- const sh_labels_u_data = sh && readImageData ( data . sh_labels_u . _levels [ 0 ] ) ;
37- const sh_centroids_data = sh && readImageData ( data . sh_centroids . _levels [ 0 ] ) ;
54+ const { means, scales, sh0, shN } = meta ;
55+ const means_l_data = p && data . means_l . _levels [ 0 ] ;
56+ const means_u_data = p && data . means_u . _levels [ 0 ] ;
57+ const quats_data = r && data . quats . _levels [ 0 ] ;
58+ const scales_data = s && data . scales . _levels [ 0 ] ;
59+ const sh0_data = c && data . sh0 . _levels [ 0 ] ;
60+ const sh_labels_data = sh && data . sh_labels . _levels [ 0 ] ;
61+ const sh_centroids_data = sh && data . sh_centroids . _levels [ 0 ] ;
62+
63+ const norm = 2.0 / Math . sqrt ( 2.0 ) ;
3864
3965 this . read = ( i ) => {
4066 if ( p ) {
@@ -48,11 +74,18 @@ class GSplatSogsIterator {
4874 }
4975
5076 if ( r ) {
51- const qx = lerp ( quats . mins [ 0 ] , quats . maxs [ 0 ] , quats_data [ i * 4 + 0 ] / 255 ) ;
52- const qy = lerp ( quats . mins [ 1 ] , quats . maxs [ 1 ] , quats_data [ i * 4 + 1 ] / 255 ) ;
53- const qz = lerp ( quats . mins [ 2 ] , quats . maxs [ 2 ] , quats_data [ i * 4 + 2 ] / 255 ) ;
54- const qw = Math . sqrt ( Math . max ( 0 , 1 - ( qx * qx + qy * qy + qz * qz ) ) ) ;
55- r . set ( qy , qz , qw , qx ) ;
77+ const a = ( quats_data [ i * 4 + 0 ] / 255 - 0.5 ) * norm ;
78+ const b = ( quats_data [ i * 4 + 1 ] / 255 - 0.5 ) * norm ;
79+ const c = ( quats_data [ i * 4 + 2 ] / 255 - 0.5 ) * norm ;
80+ const d = Math . sqrt ( Math . max ( 0 , 1 - ( a * a + b * b + c * c ) ) ) ;
81+ const mode = quats_data [ i * 4 + 3 ] - 252 ;
82+
83+ switch ( mode ) {
84+ case 0 : r . set ( a , b , c , d ) ; break ;
85+ case 1 : r . set ( d , b , c , a ) ; break ;
86+ case 2 : r . set ( b , d , c , a ) ; break ;
87+ case 3 : r . set ( b , c , d , a ) ; break ;
88+ }
5689 }
5790
5891 if ( s ) {
@@ -66,18 +99,18 @@ class GSplatSogsIterator {
6699 const r = lerp ( sh0 . mins [ 0 ] , sh0 . maxs [ 0 ] , sh0_data [ i * 4 + 0 ] / 255 ) ;
67100 const g = lerp ( sh0 . mins [ 1 ] , sh0 . maxs [ 1 ] , sh0_data [ i * 4 + 1 ] / 255 ) ;
68101 const b = lerp ( sh0 . mins [ 2 ] , sh0 . maxs [ 2 ] , sh0_data [ i * 4 + 2 ] / 255 ) ;
69- const a = lerp ( opacities . mins [ 0 ] , opacities . maxs [ 0 ] , opacities_data [ i * 4 + 0 ] / 255 ) ;
102+ const a = lerp ( sh0 . mins [ 3 ] , sh0 . maxs [ 3 ] , sh0_data [ i * 4 + 3 ] / 255 ) ;
70103
71104 c . set (
72105 0.5 + r * SH_C0 ,
73106 0.5 + g * SH_C0 ,
74107 0.5 + b * SH_C0 ,
75- 1.0 / ( 1.0 + Math . exp ( a * - 1.0 ) )
108+ 1.0 / ( 1.0 + Math . exp ( - a ) )
76109 ) ;
77110 }
78111
79112 if ( sh ) {
80- const n = sh_labels_l_data [ i * 4 + 0 ] + ( sh_labels_u_data [ i * 4 + 0 ] << 8 ) ;
113+ const n = sh_labels_data [ i * 4 + 0 ] + ( sh_labels_data [ i * 4 + 1 ] << 8 ) ;
81114 const u = ( n % 64 ) * 15 ;
82115 const v = Math . floor ( n / 64 ) ;
83116
@@ -106,15 +139,14 @@ class GSplatSogsData {
106139
107140 scales ;
108141
109- opacities ;
110-
111142 sh0 ;
112143
113144 sh_centroids ;
114145
115- sh_labels_l ;
146+ sh_labels ;
116147
117- sh_labels_u ;
148+ // if data is reordered at load, this texture stores the reorder indices.
149+ orderTexture ;
118150
119151 createIter ( p , r , s , c , sh ) {
120152 return new GSplatSogsIterator ( this , p , r , s , c , sh ) ;
@@ -141,9 +173,10 @@ class GSplatSogsData {
141173 getCenters ( result ) {
142174 const p = new Vec3 ( ) ;
143175 const iter = this . createIter ( p ) ;
176+ const order = this . orderTexture ?. _levels [ 0 ] ;
144177
145178 for ( let i = 0 ; i < this . numSplats ; i ++ ) {
146- iter . read ( i ) ;
179+ iter . read ( order ? order [ i ] : i ) ;
147180
148181 result [ i * 3 + 0 ] = p . x ;
149182 result [ i * 3 + 1 ] = p . y ;
@@ -234,6 +267,121 @@ class GSplatSogsData {
234267 } )
235268 } ] ) ;
236269 }
270+
271+ // reorder the sogs texture data in gpu memory given the ordering encoded in texture data
272+ reorderGpuMemory ( ) {
273+ const { orderTexture, numSplats } = this ;
274+ const { device, height, width } = orderTexture ;
275+ const { scope } = device ;
276+
277+ const shader = createShaderFromCode ( device , reorderVS , reorderFS , 'reorderShader' , {
278+ vertex_position : SEMANTIC_POSITION
279+ } ) ;
280+
281+ let targetTexture = new Texture ( device , {
282+ width : width ,
283+ height : height ,
284+ format : PIXELFORMAT_RGBA8 ,
285+ mipmaps : false
286+ } ) ;
287+
288+ const members = [ 'means_l' , 'means_u' , 'quats' , 'scales' , 'sh0' , 'sh_labels' ] ;
289+
290+ device . setBlendState ( BlendState . NOBLEND ) ;
291+ device . setCullMode ( CULLFACE_NONE ) ;
292+ device . setDepthState ( DepthState . NODEPTH ) ;
293+
294+ members . forEach ( ( member ) => {
295+ const sourceTexture = this [ member ] ;
296+
297+ const renderTarget = new RenderTarget ( {
298+ colorBuffer : targetTexture ,
299+ depth : false ,
300+ mipLevel : 0
301+ } ) ;
302+
303+ resolve ( scope , {
304+ orderTexture,
305+ sourceTexture,
306+ numSplats
307+ } ) ;
308+
309+ drawQuadWithShader ( device , renderTarget , shader ) ;
310+
311+ this [ member ] = targetTexture ;
312+ targetTexture . name = sourceTexture . name ;
313+ targetTexture . _levels = sourceTexture . _levels ;
314+ sourceTexture . _levels = [ ] ;
315+ targetTexture = sourceTexture ;
316+
317+ renderTarget . destroy ( ) ;
318+ } ) ;
319+
320+ targetTexture . destroy ( ) ;
321+ shader . destroy ( ) ;
322+ }
323+
324+ // construct an array containing the Morton order of the splats
325+ // returns an array of 32-bit unsigned integers
326+ calcMortonOrder ( ) {
327+ // https://fgiesen.wordpress.com/2009/12/13/decoding-morton-codes/
328+ const encodeMorton3 = ( x , y , z ) => {
329+ const Part1By2 = ( x ) => {
330+ x &= 0x000003ff ;
331+ x = ( x ^ ( x << 16 ) ) & 0xff0000ff ;
332+ x = ( x ^ ( x << 8 ) ) & 0x0300f00f ;
333+ x = ( x ^ ( x << 4 ) ) & 0x030c30c3 ;
334+ x = ( x ^ ( x << 2 ) ) & 0x09249249 ;
335+ return x ;
336+ } ;
337+
338+ return ( Part1By2 ( z ) << 2 ) + ( Part1By2 ( y ) << 1 ) + Part1By2 ( x ) ;
339+ } ;
340+
341+ const { means_l, means_u } = this ;
342+ const means_l_data = means_l . _levels [ 0 ] ;
343+ const means_u_data = means_u . _levels [ 0 ] ;
344+ const codes = new BigUint64Array ( this . numSplats ) ;
345+
346+ // generate Morton codes for each splat based on the means directly (i.e. the log-space coordinates)
347+ for ( let i = 0 ; i < this . numSplats ; ++ i ) {
348+ const ix = ( means_u_data [ i * 4 + 0 ] << 2 ) | ( means_l_data [ i * 4 + 0 ] >>> 6 ) ;
349+ const iy = ( means_u_data [ i * 4 + 1 ] << 2 ) | ( means_l_data [ i * 4 + 1 ] >>> 6 ) ;
350+ const iz = ( means_u_data [ i * 4 + 2 ] << 2 ) | ( means_l_data [ i * 4 + 2 ] >>> 6 ) ;
351+ codes [ i ] = BigInt ( encodeMorton3 ( ix , iy , iz ) ) << BigInt ( 32 ) | BigInt ( i ) ;
352+ }
353+
354+ codes . sort ( ) ;
355+
356+ // allocate data for the order buffer, but make it texture-memory sized
357+ const order = new Uint32Array ( means_l . width * means_l . height ) ;
358+ for ( let i = 0 ; i < this . numSplats ; ++ i ) {
359+ order [ i ] = Number ( codes [ i ] & BigInt ( 0xffffffff ) ) ;
360+ }
361+
362+ return order ;
363+ }
364+
365+ reorderData ( ) {
366+ if ( ! this . orderTexture ) {
367+ const { device, height, width } = this . means_l ;
368+
369+ this . orderTexture = new Texture ( device , {
370+ name : 'orderTexture' ,
371+ width,
372+ height,
373+ format : PIXELFORMAT_R32U ,
374+ mipmaps : false ,
375+ levels : [ this . calcMortonOrder ( ) ]
376+ } ) ;
377+
378+ device . on ( 'devicerestored' , ( ) => {
379+ this . reorderGpuMemory ( ) ;
380+ } ) ;
381+ }
382+
383+ this . reorderGpuMemory ( ) ;
384+ }
237385}
238386
239387export { GSplatSogsData } ;
0 commit comments