Skip to content

Commit f0d271b

Browse files
committed
Sogs updates (#7662)
1 parent d0d4f36 commit f0d271b

File tree

10 files changed

+258
-114
lines changed

10 files changed

+258
-114
lines changed

src/framework/parsers/sogs.js

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,14 @@ import { GSplatSogsData } from '../../scene/gsplat/gsplat-sogs-data.js';
88
* @import { ResourceHandlerCallback } from '../handlers/handler.js'
99
*/
1010

11+
const readImageDataAsync = (texture) => {
12+
return texture.read(0, 0, texture.width, texture.height, {
13+
mipLevel: 0,
14+
face: 0,
15+
immediate: true
16+
});
17+
};
18+
1119
class SogsParser {
1220
/** @type {AppBase} */
1321
app;
@@ -27,7 +35,7 @@ class SogsParser {
2735
async loadTextures(url, callback, asset, meta) {
2836
const { assets } = this.app;
2937

30-
const subs = ['means', 'opacities', 'quats', 'scales', 'sh0', 'shN'];
38+
const subs = ['means', 'quats', 'scales', 'sh0', 'shN'];
3139

3240
const textures = {};
3341
const promises = [];
@@ -72,15 +80,34 @@ class SogsParser {
7280
data.shBands = widths[textures.shN?.[0]?.resource?.width] ?? 0;
7381
data.means_l = textures.means[0].resource;
7482
data.means_u = textures.means[1].resource;
75-
data.opacities = textures.opacities[0].resource;
7683
data.quats = textures.quats[0].resource;
7784
data.scales = textures.scales[0].resource;
7885
data.sh0 = textures.sh0[0].resource;
7986
data.sh_centroids = textures.shN?.[0]?.resource;
80-
data.sh_labels_l = textures.shN?.[1]?.resource;
81-
data.sh_labels_u = textures.shN?.[2]?.resource;
87+
data.sh_labels = textures.shN?.[1]?.resource;
88+
89+
if (asset.data?.reorder ?? true) {
90+
// copy back means_l and means_u data from gpu so cpu reorder has access to it
91+
data.means_l._levels[0] = await readImageDataAsync(data.means_l);
92+
data.means_u._levels[0] = await readImageDataAsync(data.means_u);
93+
data.reorderData();
94+
}
8295

83-
const resource = new GSplatResource(this.app, (asset.data?.decompress) ? data.decompress() : data, []);
96+
let resource;
97+
if (asset.data?.decompress) {
98+
// copy back gpu texture data so cpu iterator has access to it
99+
const { means_l, means_u, quats, scales, sh0, sh_labels, sh_centroids } = data;
100+
means_l._levels[0] = await readImageDataAsync(means_l);
101+
means_u._levels[0] = await readImageDataAsync(means_u);
102+
quats._levels[0] = await readImageDataAsync(quats);
103+
scales._levels[0] = await readImageDataAsync(scales);
104+
sh0._levels[0] = await readImageDataAsync(sh0);
105+
sh_labels._levels[0] = await readImageDataAsync(sh_labels);
106+
sh_centroids._levels[0] = await readImageDataAsync(sh_centroids);
107+
resource = new GSplatResource(this.app, data.decompress(), []);
108+
} else {
109+
resource = new GSplatResource(this.app, data, []);
110+
}
84111

85112
callback(null, resource);
86113
}

src/scene/gsplat/gsplat-data.js

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -392,9 +392,9 @@ class GSplatData {
392392

393393
const codes = new Map();
394394
for (let i = 0; i < this.numSplats; i++) {
395-
const ix = Math.floor((x[i] - minX) * sizeX);
396-
const iy = Math.floor((y[i] - minY) * sizeY);
397-
const iz = Math.floor((z[i] - minZ) * sizeZ);
395+
const ix = Math.min(1023, Math.floor((x[i] - minX) * sizeX));
396+
const iy = Math.min(1023, Math.floor((y[i] - minY) * sizeY));
397+
const iz = Math.min(1023, Math.floor((z[i] - minZ) * sizeZ));
398398
const code = encodeMorton3(ix, iy, iz);
399399

400400
const val = codes.get(code);

src/scene/gsplat/gsplat-sogs-data.js

Lines changed: 182 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,47 @@ import { Quat } from '../../core/math/quat.js';
22
import { Vec3 } from '../../core/math/vec3.js';
33
import { Vec4 } from '../../core/math/vec4.js';
44
import { GSplatData } from './gsplat-data.js';
5+
import { BlendState } from '../../platform/graphics/blend-state.js';
6+
import { DepthState } from '../../platform/graphics/depth-state.js';
7+
import { RenderTarget } from '../../platform/graphics/render-target.js';
8+
import { Texture } from '../../platform/graphics/texture.js';
9+
import { CULLFACE_NONE, PIXELFORMAT_R32U, PIXELFORMAT_RGBA8, SEMANTIC_POSITION } from '../../platform/graphics/constants.js';
10+
import { drawQuadWithShader } from '../../scene/graphics/quad-render-utils.js';
11+
import { createShaderFromCode } from '../shader-lib/shader-utils.js';
512

6-
let offscreen = null;
7-
let ctx = null;
13+
const SH_C0 = 0.28209479177387814;
814

9-
const readImageData = (imageBitmap) => {
10-
if (!offscreen || offscreen.width !== imageBitmap.width || offscreen.height !== imageBitmap.height) {
11-
offscreen = new OffscreenCanvas(imageBitmap.width, imageBitmap.height);
12-
ctx = offscreen.getContext('2d');
13-
ctx.globalCompositeOperation = 'copy';
15+
const reorderVS = /* glsl */`
16+
attribute vec2 vertex_position;
17+
void main(void) {
18+
gl_Position = vec4(vertex_position, 0.0, 1.0);
1419
}
15-
ctx.drawImage(imageBitmap, 0, 0);
16-
return ctx.getImageData(0, 0, imageBitmap.width, imageBitmap.height).data;
17-
};
20+
`;
1821

19-
const SH_C0 = 0.28209479177387814;
22+
const reorderFS = /* glsl */`
23+
uniform usampler2D orderTexture;
24+
uniform sampler2D sourceTexture;
25+
uniform highp uint numSplats;
26+
27+
void main(void) {
28+
uint w = uint(textureSize(sourceTexture, 0).x);
29+
uint idx = uint(gl_FragCoord.x) + uint(gl_FragCoord.y) * w;
30+
if (idx >= numSplats) discard;
31+
32+
// fetch the source index and calculate source uv
33+
uint sidx = texelFetch(orderTexture, ivec2(gl_FragCoord.xy), 0).x;
34+
uvec2 suv = uvec2(sidx % w, sidx / w);
35+
36+
// sample the source texture
37+
gl_FragColor = texelFetch(sourceTexture, ivec2(suv), 0);
38+
}
39+
`;
40+
41+
const resolve = (scope, values) => {
42+
for (const key in values) {
43+
scope.resolve(key).setValue(values[key]);
44+
}
45+
};
2046

2147
class GSplatSogsIterator {
2248
constructor(data, p, r, s, c, sh) {
@@ -25,16 +51,16 @@ class GSplatSogsIterator {
2551

2652
// extract means for centers
2753
const { meta } = data;
28-
const { means, quats, scales, opacities, sh0, shN } = meta;
29-
const means_l_data = p && readImageData(data.means_l._levels[0]);
30-
const means_u_data = p && readImageData(data.means_u._levels[0]);
31-
const quats_data = r && readImageData(data.quats._levels[0]);
32-
const scales_data = s && readImageData(data.scales._levels[0]);
33-
const opacities_data = c && readImageData(data.opacities._levels[0]);
34-
const sh0_data = c && readImageData(data.sh0._levels[0]);
35-
const sh_labels_l_data = sh && readImageData(data.sh_labels_l._levels[0]);
36-
const sh_labels_u_data = sh && readImageData(data.sh_labels_u._levels[0]);
37-
const sh_centroids_data = sh && readImageData(data.sh_centroids._levels[0]);
54+
const { means, scales, sh0, shN } = meta;
55+
const means_l_data = p && data.means_l._levels[0];
56+
const means_u_data = p && data.means_u._levels[0];
57+
const quats_data = r && data.quats._levels[0];
58+
const scales_data = s && data.scales._levels[0];
59+
const sh0_data = c && data.sh0._levels[0];
60+
const sh_labels_data = sh && data.sh_labels._levels[0];
61+
const sh_centroids_data = sh && data.sh_centroids._levels[0];
62+
63+
const norm = 2.0 / Math.sqrt(2.0);
3864

3965
this.read = (i) => {
4066
if (p) {
@@ -48,11 +74,18 @@ class GSplatSogsIterator {
4874
}
4975

5076
if (r) {
51-
const qx = lerp(quats.mins[0], quats.maxs[0], quats_data[i * 4 + 0] / 255);
52-
const qy = lerp(quats.mins[1], quats.maxs[1], quats_data[i * 4 + 1] / 255);
53-
const qz = lerp(quats.mins[2], quats.maxs[2], quats_data[i * 4 + 2] / 255);
54-
const qw = Math.sqrt(Math.max(0, 1 - (qx * qx + qy * qy + qz * qz)));
55-
r.set(qy, qz, qw, qx);
77+
const a = (quats_data[i * 4 + 0] / 255 - 0.5) * norm;
78+
const b = (quats_data[i * 4 + 1] / 255 - 0.5) * norm;
79+
const c = (quats_data[i * 4 + 2] / 255 - 0.5) * norm;
80+
const d = Math.sqrt(Math.max(0, 1 - (a * a + b * b + c * c)));
81+
const mode = quats_data[i * 4 + 3] - 252;
82+
83+
switch (mode) {
84+
case 0: r.set(a, b, c, d); break;
85+
case 1: r.set(d, b, c, a); break;
86+
case 2: r.set(b, d, c, a); break;
87+
case 3: r.set(b, c, d, a); break;
88+
}
5689
}
5790

5891
if (s) {
@@ -66,18 +99,18 @@ class GSplatSogsIterator {
6699
const r = lerp(sh0.mins[0], sh0.maxs[0], sh0_data[i * 4 + 0] / 255);
67100
const g = lerp(sh0.mins[1], sh0.maxs[1], sh0_data[i * 4 + 1] / 255);
68101
const b = lerp(sh0.mins[2], sh0.maxs[2], sh0_data[i * 4 + 2] / 255);
69-
const a = lerp(opacities.mins[0], opacities.maxs[0], opacities_data[i * 4 + 0] / 255);
102+
const a = lerp(sh0.mins[3], sh0.maxs[3], sh0_data[i * 4 + 3] / 255);
70103

71104
c.set(
72105
0.5 + r * SH_C0,
73106
0.5 + g * SH_C0,
74107
0.5 + b * SH_C0,
75-
1.0 / (1.0 + Math.exp(a * -1.0))
108+
1.0 / (1.0 + Math.exp(-a))
76109
);
77110
}
78111

79112
if (sh) {
80-
const n = sh_labels_l_data[i * 4 + 0] + (sh_labels_u_data[i * 4 + 0] << 8);
113+
const n = sh_labels_data[i * 4 + 0] + (sh_labels_data[i * 4 + 1] << 8);
81114
const u = (n % 64) * 15;
82115
const v = Math.floor(n / 64);
83116

@@ -106,15 +139,14 @@ class GSplatSogsData {
106139

107140
scales;
108141

109-
opacities;
110-
111142
sh0;
112143

113144
sh_centroids;
114145

115-
sh_labels_l;
146+
sh_labels;
116147

117-
sh_labels_u;
148+
// if data is reordered at load, this texture stores the reorder indices.
149+
orderTexture;
118150

119151
createIter(p, r, s, c, sh) {
120152
return new GSplatSogsIterator(this, p, r, s, c, sh);
@@ -141,9 +173,10 @@ class GSplatSogsData {
141173
getCenters(result) {
142174
const p = new Vec3();
143175
const iter = this.createIter(p);
176+
const order = this.orderTexture?._levels[0];
144177

145178
for (let i = 0; i < this.numSplats; i++) {
146-
iter.read(i);
179+
iter.read(order ? order[i] : i);
147180

148181
result[i * 3 + 0] = p.x;
149182
result[i * 3 + 1] = p.y;
@@ -234,6 +267,121 @@ class GSplatSogsData {
234267
})
235268
}]);
236269
}
270+
271+
// reorder the sogs texture data in gpu memory given the ordering encoded in texture data
272+
reorderGpuMemory() {
273+
const { orderTexture, numSplats } = this;
274+
const { device, height, width } = orderTexture;
275+
const { scope } = device;
276+
277+
const shader = createShaderFromCode(device, reorderVS, reorderFS, 'reorderShader', {
278+
vertex_position: SEMANTIC_POSITION
279+
});
280+
281+
let targetTexture = new Texture(device, {
282+
width: width,
283+
height: height,
284+
format: PIXELFORMAT_RGBA8,
285+
mipmaps: false
286+
});
287+
288+
const members = ['means_l', 'means_u', 'quats', 'scales', 'sh0', 'sh_labels'];
289+
290+
device.setBlendState(BlendState.NOBLEND);
291+
device.setCullMode(CULLFACE_NONE);
292+
device.setDepthState(DepthState.NODEPTH);
293+
294+
members.forEach((member) => {
295+
const sourceTexture = this[member];
296+
297+
const renderTarget = new RenderTarget({
298+
colorBuffer: targetTexture,
299+
depth: false,
300+
mipLevel: 0
301+
});
302+
303+
resolve(scope, {
304+
orderTexture,
305+
sourceTexture,
306+
numSplats
307+
});
308+
309+
drawQuadWithShader(device, renderTarget, shader);
310+
311+
this[member] = targetTexture;
312+
targetTexture.name = sourceTexture.name;
313+
targetTexture._levels = sourceTexture._levels;
314+
sourceTexture._levels = [];
315+
targetTexture = sourceTexture;
316+
317+
renderTarget.destroy();
318+
});
319+
320+
targetTexture.destroy();
321+
shader.destroy();
322+
}
323+
324+
// construct an array containing the Morton order of the splats
325+
// returns an array of 32-bit unsigned integers
326+
calcMortonOrder() {
327+
// https://fgiesen.wordpress.com/2009/12/13/decoding-morton-codes/
328+
const encodeMorton3 = (x, y, z) => {
329+
const Part1By2 = (x) => {
330+
x &= 0x000003ff;
331+
x = (x ^ (x << 16)) & 0xff0000ff;
332+
x = (x ^ (x << 8)) & 0x0300f00f;
333+
x = (x ^ (x << 4)) & 0x030c30c3;
334+
x = (x ^ (x << 2)) & 0x09249249;
335+
return x;
336+
};
337+
338+
return (Part1By2(z) << 2) + (Part1By2(y) << 1) + Part1By2(x);
339+
};
340+
341+
const { means_l, means_u } = this;
342+
const means_l_data = means_l._levels[0];
343+
const means_u_data = means_u._levels[0];
344+
const codes = new BigUint64Array(this.numSplats);
345+
346+
// generate Morton codes for each splat based on the means directly (i.e. the log-space coordinates)
347+
for (let i = 0; i < this.numSplats; ++i) {
348+
const ix = (means_u_data[i * 4 + 0] << 2) | (means_l_data[i * 4 + 0] >>> 6);
349+
const iy = (means_u_data[i * 4 + 1] << 2) | (means_l_data[i * 4 + 1] >>> 6);
350+
const iz = (means_u_data[i * 4 + 2] << 2) | (means_l_data[i * 4 + 2] >>> 6);
351+
codes[i] = BigInt(encodeMorton3(ix, iy, iz)) << BigInt(32) | BigInt(i);
352+
}
353+
354+
codes.sort();
355+
356+
// allocate data for the order buffer, but make it texture-memory sized
357+
const order = new Uint32Array(means_l.width * means_l.height);
358+
for (let i = 0; i < this.numSplats; ++i) {
359+
order[i] = Number(codes[i] & BigInt(0xffffffff));
360+
}
361+
362+
return order;
363+
}
364+
365+
reorderData() {
366+
if (!this.orderTexture) {
367+
const { device, height, width } = this.means_l;
368+
369+
this.orderTexture = new Texture(device, {
370+
name: 'orderTexture',
371+
width,
372+
height,
373+
format: PIXELFORMAT_R32U,
374+
mipmaps: false,
375+
levels: [this.calcMortonOrder()]
376+
});
377+
378+
device.on('devicerestored', () => {
379+
this.reorderGpuMemory();
380+
});
381+
}
382+
383+
this.reorderGpuMemory();
384+
}
237385
}
238386

239387
export { GSplatSogsData };

src/scene/gsplat/gsplat-sogs.js

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,11 @@ class GSplatSogs {
4242
result.setDefine('GSPLAT_SOGS_DATA', true);
4343
result.setDefine('SH_BANDS', this.gsplatData.shBands);
4444

45-
['means_l', 'means_u', 'quats', 'scales', 'opacities', 'sh0', 'sh_centroids', 'sh_labels_u', 'sh_labels_l'].forEach((name) => {
45+
['means_l', 'means_u', 'quats', 'scales', 'sh0', 'sh_centroids', 'sh_labels'].forEach((name) => {
4646
result.setParameter(name, gsplatData[name]);
4747
});
4848

49-
['means', 'quats', 'scales', 'opacities', 'sh0', 'shN'].forEach((name) => {
49+
['means', 'scales', 'sh0', 'shN'].forEach((name) => {
5050
const v = gsplatData.meta[name];
5151
if (v) {
5252
result.setParameter(`${name}_mins`, v.mins);

0 commit comments

Comments
 (0)