Skip to content

Commit bf0ece5

Browse files
mvaligurskyMartin Valigurskywilleastcott
authored
Generate SOG centers buffer on GPU instead of CPU (#7999)
* Generate SOG centers buffer on GPU instead of CPU * lint * Update src/scene/gsplat-unified/gsplat-unified-sorter.js Co-authored-by: Will Eastcott <[email protected]> --------- Co-authored-by: Martin Valigursky <[email protected]> Co-authored-by: Will Eastcott <[email protected]>
1 parent 0e437bf commit bf0ece5

File tree

7 files changed

+158
-51
lines changed

7 files changed

+158
-51
lines changed

src/scene/gsplat-unified/gsplat-unified-sorter.js

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ class GSplatUnifiedSorter extends EventHandler {
7878
if (!this.centersSet.has(id)) {
7979
this.centersSet.add(id);
8080

81-
// use the original buffer, as we do not need it on the main thread anymore
82-
const centersBuffer = centers.buffer;
81+
// clone centers buffer - required when multiple workers sort the same splat resource
82+
const centersBuffer = centers.buffer.slice();
8383

8484
// post centers to worker
8585
this.worker.postMessage({

src/scene/gsplat/gsplat-compressed-data.js

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,10 +199,12 @@ class GSplatCompressedData {
199199
}
200200

201201
/**
202-
* @param {Float32Array} result - Array containing the centers.
202+
* Returns a new Float32Array of centers (x, y, z per splat).
203+
* @returns {Float32Array} Centers buffer
203204
*/
204-
getCenters(result) {
205+
getCenters() {
205206
const { vertexData, chunkData, numChunks, chunkSize } = this;
207+
const result = new Float32Array(this.numSplats * 3);
206208

207209
let mx, my, mz, Mx, My, Mz;
208210

@@ -226,6 +228,8 @@ class GSplatCompressedData {
226228
result[i * 3 + 2] = (1 - pz) * mz + pz * Mz;
227229
}
228230
}
231+
232+
return result;
229233
}
230234

231235
getChunks(result) {

src/scene/gsplat/gsplat-data.js

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,9 @@ class GSplatData {
130130

131131
// access a named property
132132
getProp(name, elementName = 'vertex') {
133-
return this.getElement(elementName)?.properties.find(p => p.name === name)?.storage;
133+
const el = this.getElement(elementName);
134+
const storage = el && el.properties.find(p => p.name === name)?.storage;
135+
return /** @type {Float32Array} */ (storage ?? new Float32Array(0));
134136
}
135137

136138
// access the named element
@@ -259,18 +261,21 @@ class GSplatData {
259261
}
260262

261263
/**
262-
* @param {Float32Array} result - Array containing the centers.
264+
* Returns a new Float32Array of centers (x, y, z per splat).
265+
* @returns {Float32Array} Centers buffer
263266
*/
264-
getCenters(result) {
267+
getCenters() {
265268
const x = this.getProp('x');
266269
const y = this.getProp('y');
267270
const z = this.getProp('z');
268271

272+
const result = new Float32Array(this.numSplats * 3);
269273
for (let i = 0; i < this.numSplats; ++i) {
270274
result[i * 3 + 0] = x[i];
271275
result[i * 3 + 1] = y[i];
272276
result[i * 3 + 2] = z[i];
273277
}
278+
return result;
274279
}
275280

276281
/**

src/scene/gsplat/gsplat-resource-base.js

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,7 @@ class GSplatResourceBase {
9292
this.device = device;
9393
this.gsplatData = gsplatData;
9494

95-
this.centers = new Float32Array(gsplatData.numSplats * 3);
96-
gsplatData.getCenters(this.centers);
95+
this.centers = gsplatData.getCenters();
9796

9897
this.aabb = new BoundingBox();
9998
gsplatData.calcAabb(this.aabb);

src/scene/gsplat/gsplat-sogs-data.js

Lines changed: 84 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import { Debug } from '../../core/debug.js';
12
import { Quat } from '../../core/math/quat.js';
23
import { Vec3 } from '../../core/math/vec3.js';
34
import { Vec4 } from '../../core/math/vec4.js';
@@ -18,6 +19,9 @@ import glslGsplatPackingPS from '../shader-lib/glsl/chunks/gsplat/frag/gsplatPac
1819
import wgslGsplatSogsReorderSH from '../shader-lib/wgsl/chunks/gsplat/frag/gsplatSogsReorderSh.js';
1920
import wgslGsplatPackingPS from '../shader-lib/wgsl/chunks/gsplat/frag/gsplatPacking.js';
2021

22+
import glslSogsCentersPS from '../shader-lib/glsl/chunks/gsplat/frag/gsplatSogsCenters.js';
23+
import wgslSogsCentersPS from '../shader-lib/wgsl/chunks/gsplat/frag/gsplatSogsCenters.js';
24+
2125
const SH_C0 = 0.28209479177387814;
2226

2327
const readImageDataAsync = (texture) => {
@@ -169,6 +173,14 @@ class GSplatSogsData {
169173

170174
packedShN;
171175

176+
/**
177+
* Cached centers array (x, y, z per splat), length = numSplats * 3.
178+
*
179+
* @type {Float32Array | null}
180+
* @private
181+
*/
182+
_centers = null;
183+
172184
// Marked when resource is destroyed, to abort any in-flight async preparation
173185
destroyed = false;
174186

@@ -212,41 +224,12 @@ class GSplatSogsData {
212224
);
213225
}
214226

215-
getCenters(result) {
216-
const { meta, means_l, means_u, numSplats } = this;
217-
const { means } = meta;
218-
219-
const means_u_data = new Uint32Array(means_u._levels[0].buffer);
220-
const means_l_data = new Uint32Array(means_l._levels[0].buffer);
221-
222-
const mx = means.mins[0] / 65535;
223-
const my = means.mins[1] / 65535;
224-
const mz = means.mins[2] / 65535;
225-
const Mx = means.maxs[0] / 65535;
226-
const My = means.maxs[1] / 65535;
227-
const Mz = means.maxs[2] / 65535;
228-
229-
for (let i = 0; i < numSplats; i++) {
230-
const idx = i;
231-
232-
const means_u = means_u_data[idx];
233-
const means_l = means_l_data[idx];
234-
235-
const wx = ((means_u << 8) & 0xff00) | (means_l & 0xff);
236-
const wy = (means_u & 0xff00) | ((means_l >>> 8) & 0xff);
237-
const wz = ((means_u >>> 8) & 0xff00) | ((means_l >>> 16) & 0xff);
238-
239-
const nx = mx * (65535 - wx) + Mx * wx;
240-
const ny = my * (65535 - wy) + My * wy;
241-
const nz = mz * (65535 - wz) + Mz * wz;
242-
243-
const ax = nx < 0 ? -nx : nx;
244-
const ay = ny < 0 ? -ny : ny;
245-
const az = nz < 0 ? -nz : nz;
246-
result[i * 3] = (nx < 0 ? -1 : 1) * (Math.exp(ax) - 1);
247-
result[i * 3 + 1] = (ny < 0 ? -1 : 1) * (Math.exp(ay) - 1);
248-
result[i * 3 + 2] = (nz < 0 ? -1 : 1) * (Math.exp(az) - 1);
249-
}
227+
getCenters() {
228+
// centers can be only copied once to avoid making copies.
229+
Debug.assert(this._centers);
230+
const centers = /** @type {Float32Array} */ this._centers;
231+
this._centers = null;
232+
return centers;
250233
}
251234

252235
// use bound center for focal point
@@ -363,6 +346,69 @@ class GSplatSogsData {
363346
}]);
364347
}
365348

349+
async generateCenters() {
350+
const { device, width, height } = this.means_l;
351+
const { scope } = device;
352+
353+
// create a temporary texture to render centers into
354+
const centersTexture = new Texture(device, {
355+
name: 'sogsCentersTexture',
356+
width,
357+
height,
358+
format: PIXELFORMAT_RGBA32U,
359+
mipmaps: false
360+
});
361+
362+
const shader = ShaderUtils.createShader(device, {
363+
uniqueName: 'GsplatSogsCentersShader',
364+
attributes: { vertex_position: SEMANTIC_POSITION },
365+
vertexChunk: 'fullscreenQuadVS',
366+
fragmentGLSL: glslSogsCentersPS,
367+
fragmentWGSL: wgslSogsCentersPS,
368+
fragmentOutputTypes: ['uvec4'],
369+
fragmentIncludes: new Map([['gsplatPackingPS', device.isWebGPU ? wgslGsplatPackingPS : glslGsplatPackingPS]])
370+
});
371+
372+
const renderTarget = new RenderTarget({
373+
colorBuffer: centersTexture,
374+
depth: false,
375+
mipLevel: 0
376+
});
377+
378+
device.setCullMode(CULLFACE_NONE);
379+
device.setBlendState(BlendState.NOBLEND);
380+
device.setDepthState(DepthState.NODEPTH);
381+
382+
resolve(scope, {
383+
means_l: this.means_l,
384+
means_u: this.means_u,
385+
numSplats: this.numSplats,
386+
means_mins: this.meta.means.mins,
387+
means_maxs: this.meta.means.maxs
388+
});
389+
390+
drawQuadWithShader(device, renderTarget, shader);
391+
392+
renderTarget.destroy();
393+
394+
const u32 = await readImageDataAsync(centersTexture);
395+
if (this.destroyed || device._destroyed) {
396+
centersTexture.destroy();
397+
return;
398+
}
399+
400+
const asFloat = new Float32Array(u32.buffer);
401+
const result = new Float32Array(this.numSplats * 3);
402+
for (let i = 0; i < this.numSplats; i++) {
403+
const base = i * 4;
404+
result[i * 3 + 0] = asFloat[base + 0];
405+
result[i * 3 + 1] = asFloat[base + 1];
406+
result[i * 3 + 2] = asFloat[base + 2];
407+
}
408+
this._centers = result;
409+
centersTexture.destroy();
410+
}
411+
366412
// pack the means, quats, scales and sh_labels data into one RGBA32U texture
367413
packGpuMemory() {
368414
const { meta, means_l, means_u, quats, scales, sh0, sh_labels, numSplats } = this;
@@ -456,13 +502,6 @@ class GSplatSogsData {
456502
async prepareGpuData() {
457503
const { device, height, width } = this.means_l;
458504

459-
// copy back means_l and means_u data so cpu reorder has access to it
460-
if (this.destroyed || device._destroyed) return; // skip the rest if the resource was destroyed
461-
this.means_l._levels[0] = await readImageDataAsync(this.means_l);
462-
463-
if (this.destroyed || device._destroyed) return; // skip the rest if the resource was destroyed
464-
this.means_u._levels[0] = await readImageDataAsync(this.means_u);
465-
466505
if (this.destroyed || device._destroyed) return; // skip the rest if the resource was destroyed
467506
this.packedTexture = new Texture(device, {
468507
name: 'sogsPackedTexture',
@@ -495,6 +534,9 @@ class GSplatSogsData {
495534
}
496535
});
497536

537+
if (this.destroyed || device._destroyed) return; // skip the rest if the resource was destroyed
538+
await this.generateCenters();
539+
498540
if (this.destroyed || device._destroyed) return; // skip the rest if the resource was destroyed
499541
this.packGpuMemory();
500542
if (this.packedShN) {
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
export default /* glsl */`
2+
#include "gsplatPackingPS"
3+
4+
uniform highp sampler2D means_l;
5+
uniform highp sampler2D means_u;
6+
7+
uniform highp uint numSplats;
8+
uniform highp vec3 means_mins;
9+
uniform highp vec3 means_maxs;
10+
11+
void main(void) {
12+
int w = int(textureSize(means_l, 0).x);
13+
ivec2 uv = ivec2(gl_FragCoord.xy);
14+
if (uint(uv.x + uv.y * w) >= numSplats) {
15+
discard;
16+
}
17+
18+
vec3 l = texelFetch(means_l, uv, 0).xyz;
19+
vec3 u = texelFetch(means_u, uv, 0).xyz;
20+
vec3 n = (l * 255.0 + u * 255.0 * 256.0) / 65535.0;
21+
vec3 v = mix(means_mins, means_maxs, n);
22+
vec3 center = sign(v) * (exp(abs(v)) - 1.0);
23+
24+
// store float bits into u32 RGBA, alpha unused
25+
pcFragColor0 = uvec4(floatBitsToUint(center), 0u);
26+
}
27+
`;
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
export default /* wgsl */`
2+
var means_l: texture_2d<f32>;
3+
var means_u: texture_2d<f32>;
4+
5+
uniform numSplats: u32;
6+
uniform means_mins: vec3f;
7+
uniform means_maxs: vec3f;
8+
9+
@fragment
10+
fn fragmentMain(input: FragmentInput) -> FragmentOutput {
11+
var output: FragmentOutput;
12+
13+
let w: u32 = textureDimensions(means_l, 0).x;
14+
let uv: vec2<i32> = vec2<i32>(input.position.xy);
15+
if (u32(uv.x + uv.y * i32(w)) >= uniform.numSplats) {
16+
discard;
17+
return output;
18+
}
19+
20+
let l: vec3f = textureLoad(means_l, uv, 0).xyz;
21+
let u: vec3f = textureLoad(means_u, uv, 0).xyz;
22+
let n: vec3f = (l * 255.0 + u * 255.0 * 256.0) / 65535.0;
23+
let v: vec3f = mix(uniform.means_mins, uniform.means_maxs, n);
24+
let center: vec3f = sign(v) * (exp(abs(v)) - 1.0);
25+
26+
let packed: vec4<u32> = bitcast<vec4<u32>>(vec4f(center, 0.0));
27+
output.color = packed;
28+
return output;
29+
}
30+
`;

0 commit comments

Comments
 (0)