diff --git a/src/framework/parsers/sogs.js b/src/framework/parsers/sogs.js index 73b25bd3e66..0747d56a3d2 100644 --- a/src/framework/parsers/sogs.js +++ b/src/framework/parsers/sogs.js @@ -8,6 +8,14 @@ import { GSplatSogsData } from '../../scene/gsplat/gsplat-sogs-data.js'; * @import { ResourceHandlerCallback } from '../handlers/handler.js' */ +const readImageDataAsync = (texture) => { + return texture.read(0, 0, texture.width, texture.height, { + mipLevel: 0, + face: 0, + immediate: true + }); +}; + class SogsParser { /** @type {AppBase} */ app; @@ -27,7 +35,7 @@ class SogsParser { async loadTextures(url, callback, asset, meta) { const { assets } = this.app; - const subs = ['means', 'opacities', 'quats', 'scales', 'sh0', 'shN']; + const subs = ['means', 'quats', 'scales', 'sh0', 'shN']; const textures = {}; const promises = []; @@ -72,15 +80,34 @@ class SogsParser { data.shBands = widths[textures.shN?.[0]?.resource?.width] ?? 0; data.means_l = textures.means[0].resource; data.means_u = textures.means[1].resource; - data.opacities = textures.opacities[0].resource; data.quats = textures.quats[0].resource; data.scales = textures.scales[0].resource; data.sh0 = textures.sh0[0].resource; data.sh_centroids = textures.shN?.[0]?.resource; - data.sh_labels_l = textures.shN?.[1]?.resource; - data.sh_labels_u = textures.shN?.[2]?.resource; + data.sh_labels = textures.shN?.[1]?.resource; + + if (asset.data?.reorder ?? true) { + // copy back means_l and means_u data from gpu so cpu reorder has access to it + data.means_l._levels[0] = await readImageDataAsync(data.means_l); + data.means_u._levels[0] = await readImageDataAsync(data.means_u); + data.reorderData(); + } - const resource = new GSplatResource(this.app, (asset.data?.decompress) ? data.decompress() : data, []); + let resource; + if (asset.data?.decompress) { + // copy back gpu texture data so cpu iterator has access to it + const { means_l, means_u, quats, scales, sh0, sh_labels, sh_centroids } = data; + means_l._levels[0] = await readImageDataAsync(means_l); + means_u._levels[0] = await readImageDataAsync(means_u); + quats._levels[0] = await readImageDataAsync(quats); + scales._levels[0] = await readImageDataAsync(scales); + sh0._levels[0] = await readImageDataAsync(sh0); + sh_labels._levels[0] = await readImageDataAsync(sh_labels); + sh_centroids._levels[0] = await readImageDataAsync(sh_centroids); + resource = new GSplatResource(this.app, data.decompress(), []); + } else { + resource = new GSplatResource(this.app, data, []); + } callback(null, resource); } diff --git a/src/scene/gsplat/gsplat-data.js b/src/scene/gsplat/gsplat-data.js index 7b92ca995b4..0eee795fc95 100644 --- a/src/scene/gsplat/gsplat-data.js +++ b/src/scene/gsplat/gsplat-data.js @@ -392,9 +392,9 @@ class GSplatData { const codes = new Map(); for (let i = 0; i < this.numSplats; i++) { - const ix = Math.floor((x[i] - minX) * sizeX); - const iy = Math.floor((y[i] - minY) * sizeY); - const iz = Math.floor((z[i] - minZ) * sizeZ); + const ix = Math.min(1023, Math.floor((x[i] - minX) * sizeX)); + const iy = Math.min(1023, Math.floor((y[i] - minY) * sizeY)); + const iz = Math.min(1023, Math.floor((z[i] - minZ) * sizeZ)); const code = encodeMorton3(ix, iy, iz); const val = codes.get(code); diff --git a/src/scene/gsplat/gsplat-sogs-data.js b/src/scene/gsplat/gsplat-sogs-data.js index bd4c1c3bcf9..946219f2b8e 100644 --- a/src/scene/gsplat/gsplat-sogs-data.js +++ b/src/scene/gsplat/gsplat-sogs-data.js @@ -2,21 +2,47 @@ import { Quat } from '../../core/math/quat.js'; import { Vec3 } from '../../core/math/vec3.js'; import { Vec4 } from '../../core/math/vec4.js'; import { GSplatData } from './gsplat-data.js'; +import { BlendState } from '../../platform/graphics/blend-state.js'; +import { DepthState } from '../../platform/graphics/depth-state.js'; +import { RenderTarget } from '../../platform/graphics/render-target.js'; +import { Texture } from '../../platform/graphics/texture.js'; +import { CULLFACE_NONE, PIXELFORMAT_R32U, PIXELFORMAT_RGBA8, SEMANTIC_POSITION } from '../../platform/graphics/constants.js'; +import { drawQuadWithShader } from '../../scene/graphics/quad-render-utils.js'; +import { createShaderFromCode } from '../shader-lib/shader-utils.js'; -let offscreen = null; -let ctx = null; +const SH_C0 = 0.28209479177387814; -const readImageData = (imageBitmap) => { - if (!offscreen || offscreen.width !== imageBitmap.width || offscreen.height !== imageBitmap.height) { - offscreen = new OffscreenCanvas(imageBitmap.width, imageBitmap.height); - ctx = offscreen.getContext('2d'); - ctx.globalCompositeOperation = 'copy'; +const reorderVS = /* glsl */` + attribute vec2 vertex_position; + void main(void) { + gl_Position = vec4(vertex_position, 0.0, 1.0); } - ctx.drawImage(imageBitmap, 0, 0); - return ctx.getImageData(0, 0, imageBitmap.width, imageBitmap.height).data; -}; +`; -const SH_C0 = 0.28209479177387814; +const reorderFS = /* glsl */` + uniform usampler2D orderTexture; + uniform sampler2D sourceTexture; + uniform highp uint numSplats; + + void main(void) { + uint w = uint(textureSize(sourceTexture, 0).x); + uint idx = uint(gl_FragCoord.x) + uint(gl_FragCoord.y) * w; + if (idx >= numSplats) discard; + + // fetch the source index and calculate source uv + uint sidx = texelFetch(orderTexture, ivec2(gl_FragCoord.xy), 0).x; + uvec2 suv = uvec2(sidx % w, sidx / w); + + // sample the source texture + gl_FragColor = texelFetch(sourceTexture, ivec2(suv), 0); + } +`; + +const resolve = (scope, values) => { + for (const key in values) { + scope.resolve(key).setValue(values[key]); + } +}; class GSplatSogsIterator { constructor(data, p, r, s, c, sh) { @@ -25,16 +51,16 @@ class GSplatSogsIterator { // extract means for centers const { meta } = data; - const { means, quats, scales, opacities, sh0, shN } = meta; - const means_l_data = p && readImageData(data.means_l._levels[0]); - const means_u_data = p && readImageData(data.means_u._levels[0]); - const quats_data = r && readImageData(data.quats._levels[0]); - const scales_data = s && readImageData(data.scales._levels[0]); - const opacities_data = c && readImageData(data.opacities._levels[0]); - const sh0_data = c && readImageData(data.sh0._levels[0]); - const sh_labels_l_data = sh && readImageData(data.sh_labels_l._levels[0]); - const sh_labels_u_data = sh && readImageData(data.sh_labels_u._levels[0]); - const sh_centroids_data = sh && readImageData(data.sh_centroids._levels[0]); + const { means, scales, sh0, shN } = meta; + const means_l_data = p && data.means_l._levels[0]; + const means_u_data = p && data.means_u._levels[0]; + const quats_data = r && data.quats._levels[0]; + const scales_data = s && data.scales._levels[0]; + const sh0_data = c && data.sh0._levels[0]; + const sh_labels_data = sh && data.sh_labels._levels[0]; + const sh_centroids_data = sh && data.sh_centroids._levels[0]; + + const norm = 2.0 / Math.sqrt(2.0); this.read = (i) => { if (p) { @@ -48,11 +74,18 @@ class GSplatSogsIterator { } if (r) { - const qx = lerp(quats.mins[0], quats.maxs[0], quats_data[i * 4 + 0] / 255); - const qy = lerp(quats.mins[1], quats.maxs[1], quats_data[i * 4 + 1] / 255); - const qz = lerp(quats.mins[2], quats.maxs[2], quats_data[i * 4 + 2] / 255); - const qw = Math.sqrt(Math.max(0, 1 - (qx * qx + qy * qy + qz * qz))); - r.set(qy, qz, qw, qx); + const a = (quats_data[i * 4 + 0] / 255 - 0.5) * norm; + const b = (quats_data[i * 4 + 1] / 255 - 0.5) * norm; + const c = (quats_data[i * 4 + 2] / 255 - 0.5) * norm; + const d = Math.sqrt(Math.max(0, 1 - (a * a + b * b + c * c))); + const mode = quats_data[i * 4 + 3] - 252; + + switch (mode) { + case 0: r.set(a, b, c, d); break; + case 1: r.set(d, b, c, a); break; + case 2: r.set(b, d, c, a); break; + case 3: r.set(b, c, d, a); break; + } } if (s) { @@ -66,18 +99,18 @@ class GSplatSogsIterator { const r = lerp(sh0.mins[0], sh0.maxs[0], sh0_data[i * 4 + 0] / 255); const g = lerp(sh0.mins[1], sh0.maxs[1], sh0_data[i * 4 + 1] / 255); const b = lerp(sh0.mins[2], sh0.maxs[2], sh0_data[i * 4 + 2] / 255); - const a = lerp(opacities.mins[0], opacities.maxs[0], opacities_data[i * 4 + 0] / 255); + const a = lerp(sh0.mins[3], sh0.maxs[3], sh0_data[i * 4 + 3] / 255); c.set( 0.5 + r * SH_C0, 0.5 + g * SH_C0, 0.5 + b * SH_C0, - 1.0 / (1.0 + Math.exp(a * -1.0)) + 1.0 / (1.0 + Math.exp(-a)) ); } if (sh) { - const n = sh_labels_l_data[i * 4 + 0] + (sh_labels_u_data[i * 4 + 0] << 8); + const n = sh_labels_data[i * 4 + 0] + (sh_labels_data[i * 4 + 1] << 8); const u = (n % 64) * 15; const v = Math.floor(n / 64); @@ -106,15 +139,14 @@ class GSplatSogsData { scales; - opacities; - sh0; sh_centroids; - sh_labels_l; + sh_labels; - sh_labels_u; + // if data is reordered at load, this texture stores the reorder indices. + orderTexture; createIter(p, r, s, c, sh) { return new GSplatSogsIterator(this, p, r, s, c, sh); @@ -141,9 +173,10 @@ class GSplatSogsData { getCenters(result) { const p = new Vec3(); const iter = this.createIter(p); + const order = this.orderTexture?._levels[0]; for (let i = 0; i < this.numSplats; i++) { - iter.read(i); + iter.read(order ? order[i] : i); result[i * 3 + 0] = p.x; result[i * 3 + 1] = p.y; @@ -234,6 +267,121 @@ class GSplatSogsData { }) }]); } + + // reorder the sogs texture data in gpu memory given the ordering encoded in texture data + reorderGpuMemory() { + const { orderTexture, numSplats } = this; + const { device, height, width } = orderTexture; + const { scope } = device; + + const shader = createShaderFromCode(device, reorderVS, reorderFS, 'reorderShader', { + vertex_position: SEMANTIC_POSITION + }); + + let targetTexture = new Texture(device, { + width: width, + height: height, + format: PIXELFORMAT_RGBA8, + mipmaps: false + }); + + const members = ['means_l', 'means_u', 'quats', 'scales', 'sh0', 'sh_labels']; + + device.setBlendState(BlendState.NOBLEND); + device.setCullMode(CULLFACE_NONE); + device.setDepthState(DepthState.NODEPTH); + + members.forEach((member) => { + const sourceTexture = this[member]; + + const renderTarget = new RenderTarget({ + colorBuffer: targetTexture, + depth: false, + mipLevel: 0 + }); + + resolve(scope, { + orderTexture, + sourceTexture, + numSplats + }); + + drawQuadWithShader(device, renderTarget, shader); + + this[member] = targetTexture; + targetTexture.name = sourceTexture.name; + targetTexture._levels = sourceTexture._levels; + sourceTexture._levels = []; + targetTexture = sourceTexture; + + renderTarget.destroy(); + }); + + targetTexture.destroy(); + shader.destroy(); + } + + // construct an array containing the Morton order of the splats + // returns an array of 32-bit unsigned integers + calcMortonOrder() { + // https://fgiesen.wordpress.com/2009/12/13/decoding-morton-codes/ + const encodeMorton3 = (x, y, z) => { + const Part1By2 = (x) => { + x &= 0x000003ff; + x = (x ^ (x << 16)) & 0xff0000ff; + x = (x ^ (x << 8)) & 0x0300f00f; + x = (x ^ (x << 4)) & 0x030c30c3; + x = (x ^ (x << 2)) & 0x09249249; + return x; + }; + + return (Part1By2(z) << 2) + (Part1By2(y) << 1) + Part1By2(x); + }; + + const { means_l, means_u } = this; + const means_l_data = means_l._levels[0]; + const means_u_data = means_u._levels[0]; + const codes = new BigUint64Array(this.numSplats); + + // generate Morton codes for each splat based on the means directly (i.e. the log-space coordinates) + for (let i = 0; i < this.numSplats; ++i) { + const ix = (means_u_data[i * 4 + 0] << 2) | (means_l_data[i * 4 + 0] >>> 6); + const iy = (means_u_data[i * 4 + 1] << 2) | (means_l_data[i * 4 + 1] >>> 6); + const iz = (means_u_data[i * 4 + 2] << 2) | (means_l_data[i * 4 + 2] >>> 6); + codes[i] = BigInt(encodeMorton3(ix, iy, iz)) << BigInt(32) | BigInt(i); + } + + codes.sort(); + + // allocate data for the order buffer, but make it texture-memory sized + const order = new Uint32Array(means_l.width * means_l.height); + for (let i = 0; i < this.numSplats; ++i) { + order[i] = Number(codes[i] & BigInt(0xffffffff)); + } + + return order; + } + + reorderData() { + if (!this.orderTexture) { + const { device, height, width } = this.means_l; + + this.orderTexture = new Texture(device, { + name: 'orderTexture', + width, + height, + format: PIXELFORMAT_R32U, + mipmaps: false, + levels: [this.calcMortonOrder()] + }); + + device.on('devicerestored', () => { + this.reorderGpuMemory(); + }); + } + + this.reorderGpuMemory(); + } } export { GSplatSogsData }; diff --git a/src/scene/gsplat/gsplat-sogs.js b/src/scene/gsplat/gsplat-sogs.js index cac46a3252c..2486b51334f 100644 --- a/src/scene/gsplat/gsplat-sogs.js +++ b/src/scene/gsplat/gsplat-sogs.js @@ -42,11 +42,11 @@ class GSplatSogs { result.setDefine('GSPLAT_SOGS_DATA', true); result.setDefine('SH_BANDS', this.gsplatData.shBands); - ['means_l', 'means_u', 'quats', 'scales', 'opacities', 'sh0', 'sh_centroids', 'sh_labels_u', 'sh_labels_l'].forEach((name) => { + ['means_l', 'means_u', 'quats', 'scales', 'sh0', 'sh_centroids', 'sh_labels'].forEach((name) => { result.setParameter(name, gsplatData[name]); }); - ['means', 'quats', 'scales', 'opacities', 'sh0', 'shN'].forEach((name) => { + ['means', 'scales', 'sh0', 'shN'].forEach((name) => { const v = gsplatData.meta[name]; if (v) { result.setParameter(`${name}_mins`, v.mins); diff --git a/src/scene/shader-lib/chunks-glsl/gsplat/vert/gsplatCommon.js b/src/scene/shader-lib/chunks-glsl/gsplat/vert/gsplatCommon.js index 1a7322693d8..54bd741ef32 100644 --- a/src/scene/shader-lib/chunks-glsl/gsplat/vert/gsplatCommon.js +++ b/src/scene/shader-lib/chunks-glsl/gsplat/vert/gsplatCommon.js @@ -13,9 +13,29 @@ struct SplatCenter { vec3 view; // center in view space vec4 proj; // center in clip space mat4 modelView; // model-view matrix - float projMat00; // elememt [0][0] of the projection matrix + float projMat00; // element [0][0] of the projection matrix }; +mat3 quatToMat3(vec4 R) { + vec4 R2 = R + R; + float X = R2.x * R.w; + vec4 Y = R2.y * R; + vec4 Z = R2.z * R; + float W = R2.w * R.w; + + return mat3( + 1.0 - Z.z - W, + Y.z + X, + Y.w - Z.x, + Y.z - X, + 1.0 - Y.y - W, + Z.w + Y.x, + Y.w + Z.x, + Z.w - Y.x, + 1.0 - Y.y - Z.z + ); +} + // stores the offset from center for the current gaussian struct SplatCorner { vec2 offset; // corner offset from center in clip space diff --git a/src/scene/shader-lib/chunks-glsl/gsplat/vert/gsplatCompressedData.js b/src/scene/shader-lib/chunks-glsl/gsplat/vert/gsplatCompressedData.js index d6b765ed687..089fcb9fa7d 100644 --- a/src/scene/shader-lib/chunks-glsl/gsplat/vert/gsplatCompressedData.js +++ b/src/scene/shader-lib/chunks-glsl/gsplat/vert/gsplatCompressedData.js @@ -42,23 +42,6 @@ vec4 unpackRotation(uint bits) { return vec4(a, b, c, m); } -mat3 quatToMat3(vec4 R) { - float x = R.x; - float y = R.y; - float z = R.z; - float w = R.w; - return mat3( - 1.0 - 2.0 * (z * z + w * w), - 2.0 * (y * z + x * w), - 2.0 * (y * w - x * z), - 2.0 * (y * z - x * w), - 1.0 - 2.0 * (y * y + w * w), - 2.0 * (z * w + x * y), - 2.0 * (y * w + x * z), - 2.0 * (z * w - x * y), - 1.0 - 2.0 * (y * y + z * z) - ); -} // read center vec3 readCenter(SplatSource source) { diff --git a/src/scene/shader-lib/chunks-glsl/gsplat/vert/gsplatData.js b/src/scene/shader-lib/chunks-glsl/gsplat/vert/gsplatData.js index 65310871240..e44d5f104a2 100644 --- a/src/scene/shader-lib/chunks-glsl/gsplat/vert/gsplatData.js +++ b/src/scene/shader-lib/chunks-glsl/gsplat/vert/gsplatData.js @@ -13,24 +13,6 @@ vec3 readCenter(SplatSource source) { return uintBitsToFloat(tA.xyz); } -mat3 quatToMat3(vec4 R) { - float x = R.w; - float y = R.x; - float z = R.y; - float w = R.z; - return mat3( - 1.0 - 2.0 * (z * z + w * w), - 2.0 * (y * z + x * w), - 2.0 * (y * w - x * z), - 2.0 * (y * z - x * w), - 1.0 - 2.0 * (y * y + w * w), - 2.0 * (z * w + x * y), - 2.0 * (y * w + x * z), - 2.0 * (z * w - x * y), - 1.0 - 2.0 * (y * y + z * z) - ); -} - vec4 unpackRotation(vec3 packed) { return vec4(packed.xyz, sqrt(max(0.0, 1.0 - dot(packed, packed)))); } @@ -39,7 +21,7 @@ vec4 unpackRotation(vec3 packed) { void readCovariance(in SplatSource source, out vec3 covA, out vec3 covB) { vec4 tB = texelFetch(transformB, source.uv, 0); - mat3 rot = quatToMat3(unpackRotation(vec3(unpackHalf2x16(tAw), tB.w))); + mat3 rot = quatToMat3(unpackRotation(vec3(unpackHalf2x16(tAw), tB.w)).wxyz); vec3 scale = tB.xyz; // M = S * R diff --git a/src/scene/shader-lib/chunks-glsl/gsplat/vert/gsplatSogsColor.js b/src/scene/shader-lib/chunks-glsl/gsplat/vert/gsplatSogsColor.js index f610c52e6d9..2d4d280485c 100644 --- a/src/scene/shader-lib/chunks-glsl/gsplat/vert/gsplatSogsColor.js +++ b/src/scene/shader-lib/chunks-glsl/gsplat/vert/gsplatSogsColor.js @@ -1,19 +1,13 @@ export default /* glsl */` uniform mediump sampler2D sh0; -uniform mediump sampler2D opacities; -uniform vec3 sh0_mins; -uniform vec3 sh0_maxs; - -uniform float opacities_mins; -uniform float opacities_maxs; +uniform vec4 sh0_mins; +uniform vec4 sh0_maxs; float SH_C0 = 0.28209479177387814; vec4 readColor(in SplatSource source) { - vec3 clr = mix(sh0_mins, sh0_maxs, texelFetch(sh0, source.uv, 0).xyz); - float opacity = mix(opacities_mins, opacities_maxs, texelFetch(opacities, source.uv, 0).x); - - return vec4(vec3(0.5) + clr * SH_C0, 1.0 / (1.0 + exp(opacity * -1.0))); + vec4 clr = mix(sh0_mins, sh0_maxs, texelFetch(sh0, source.uv, 0)); + return vec4(vec3(0.5) + clr.xyz * SH_C0, 1.0 / (1.0 + exp(-clr.w))); } `; diff --git a/src/scene/shader-lib/chunks-glsl/gsplat/vert/gsplatSogsData.js b/src/scene/shader-lib/chunks-glsl/gsplat/vert/gsplatSogsData.js index e7c21be9805..1f00d31d125 100644 --- a/src/scene/shader-lib/chunks-glsl/gsplat/vert/gsplatSogsData.js +++ b/src/scene/shader-lib/chunks-glsl/gsplat/vert/gsplatSogsData.js @@ -23,28 +23,20 @@ vec3 readCenter(SplatSource source) { return sign(v) * (exp(abs(v)) - 1.0); } -mat3 quatToMat3(vec3 R) { - float x = R.x; - float y = R.y; - float z = R.z; - float w2 = clamp(1.0 - (x*x + y*y + z*z), 0.0, 1.0); - float w = sqrt(w2); - return mat3( - 1.0 - 2.0 * (z * z + w2), - 2.0 * (y * z + x * w), - 2.0 * (y * w - x * z), - 2.0 * (y * z - x * w), - 1.0 - 2.0 * (y * y + w2), - 2.0 * (z * w + x * y), - 2.0 * (y * w + x * z), - 2.0 * (z * w - x * y), - 1.0 - 2.0 * (y * y + z * z) - ); -} +const float norm = 2.0 / sqrt(2.0); // sample covariance vectors void readCovariance(in SplatSource source, out vec3 covA, out vec3 covB) { - vec3 quat = mix(quats_mins, quats_maxs, texelFetch(quats, source.uv, 0).xyz); + vec4 qdata = texelFetch(quats, source.uv, 0); + vec3 abc = (qdata.xyz - 0.5) * norm; + float d = sqrt(max(0.0, 1.0 - dot(abc, abc))); + + uint mode = uint(qdata.w * 255.0 + 0.5) - 252u; + + vec4 quat = (mode == 0u) ? vec4(d, abc) : + ((mode == 1u) ? vec4(abc.x, d, abc.yz) : + ((mode == 2u) ? vec4(abc.xy, d, abc.z) : vec4(abc, d))); + mat3 rot = quatToMat3(quat); vec3 scale = exp(mix(scales_mins, scales_maxs, texelFetch(scales, source.uv, 0).xyz)); diff --git a/src/scene/shader-lib/chunks-glsl/gsplat/vert/gsplatSogsSH.js b/src/scene/shader-lib/chunks-glsl/gsplat/vert/gsplatSogsSH.js index 0ab7d35b3ba..cb7383a9e85 100644 --- a/src/scene/shader-lib/chunks-glsl/gsplat/vert/gsplatSogsSH.js +++ b/src/scene/shader-lib/chunks-glsl/gsplat/vert/gsplatSogsSH.js @@ -1,6 +1,5 @@ export default /* glsl */` -uniform highp sampler2D sh_labels_u; -uniform highp sampler2D sh_labels_l; +uniform highp sampler2D sh_labels; uniform highp sampler2D sh_centroids; uniform float shN_mins; @@ -8,9 +7,8 @@ uniform float shN_maxs; void readSHData(in SplatSource source, out vec3 sh[15], out float scale) { // extract spherical harmonics palette index - int tu = int(texelFetch(sh_labels_u, source.uv, 0).x * 255.0 * 256.0); - int tl = int(texelFetch(sh_labels_l, source.uv, 0).x * 255.0); - int n = tu + tl; + ivec2 t = ivec2(texelFetch(sh_labels, source.uv, 0).xy * 255.0); + int n = t.x + t.y * 256; int u = (n % 64) * 15; int v = n / 64;