Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 32 additions & 5 deletions src/framework/parsers/sogs.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@ import { GSplatSogsData } from '../../scene/gsplat/gsplat-sogs-data.js';
* @import { ResourceHandlerCallback } from '../handlers/handler.js'
*/

const readImageDataAsync = (texture) => {
return texture.read(0, 0, texture.width, texture.height, {
mipLevel: 0,
face: 0,
immediate: true
});
};

class SogsParser {
/** @type {AppBase} */
app;
Expand All @@ -27,7 +35,7 @@ class SogsParser {
async loadTextures(url, callback, asset, meta) {
const { assets } = this.app;

const subs = ['means', 'opacities', 'quats', 'scales', 'sh0', 'shN'];
const subs = ['means', 'quats', 'scales', 'sh0', 'shN'];

const textures = {};
const promises = [];
Expand Down Expand Up @@ -72,15 +80,34 @@ class SogsParser {
data.shBands = widths[textures.shN?.[0]?.resource?.width] ?? 0;
data.means_l = textures.means[0].resource;
data.means_u = textures.means[1].resource;
data.opacities = textures.opacities[0].resource;
data.quats = textures.quats[0].resource;
data.scales = textures.scales[0].resource;
data.sh0 = textures.sh0[0].resource;
data.sh_centroids = textures.shN?.[0]?.resource;
data.sh_labels_l = textures.shN?.[1]?.resource;
data.sh_labels_u = textures.shN?.[2]?.resource;
data.sh_labels = textures.shN?.[1]?.resource;

if (asset.data?.reorder ?? true) {
// copy back means_l and means_u data from gpu so cpu reorder has access to it
data.means_l._levels[0] = await readImageDataAsync(data.means_l);
data.means_u._levels[0] = await readImageDataAsync(data.means_u);
data.reorderData();
}

const resource = new GSplatResource(this.app, (asset.data?.decompress) ? data.decompress() : data, []);
let resource;
if (asset.data?.decompress) {
// copy back gpu texture data so cpu iterator has access to it
const { means_l, means_u, quats, scales, sh0, sh_labels, sh_centroids } = data;
means_l._levels[0] = await readImageDataAsync(means_l);
means_u._levels[0] = await readImageDataAsync(means_u);
quats._levels[0] = await readImageDataAsync(quats);
scales._levels[0] = await readImageDataAsync(scales);
sh0._levels[0] = await readImageDataAsync(sh0);
sh_labels._levels[0] = await readImageDataAsync(sh_labels);
sh_centroids._levels[0] = await readImageDataAsync(sh_centroids);
resource = new GSplatResource(this.app, data.decompress(), []);
} else {
resource = new GSplatResource(this.app, data, []);
}

callback(null, resource);
}
Expand Down
6 changes: 3 additions & 3 deletions src/scene/gsplat/gsplat-data.js
Original file line number Diff line number Diff line change
Expand Up @@ -392,9 +392,9 @@ class GSplatData {

const codes = new Map();
for (let i = 0; i < this.numSplats; i++) {
const ix = Math.floor((x[i] - minX) * sizeX);
const iy = Math.floor((y[i] - minY) * sizeY);
const iz = Math.floor((z[i] - minZ) * sizeZ);
const ix = Math.min(1023, Math.floor((x[i] - minX) * sizeX));
const iy = Math.min(1023, Math.floor((y[i] - minY) * sizeY));
const iz = Math.min(1023, Math.floor((z[i] - minZ) * sizeZ));
const code = encodeMorton3(ix, iy, iz);

const val = codes.get(code);
Expand Down
216 changes: 182 additions & 34 deletions src/scene/gsplat/gsplat-sogs-data.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,47 @@ import { Quat } from '../../core/math/quat.js';
import { Vec3 } from '../../core/math/vec3.js';
import { Vec4 } from '../../core/math/vec4.js';
import { GSplatData } from './gsplat-data.js';
import { BlendState } from '../../platform/graphics/blend-state.js';
import { DepthState } from '../../platform/graphics/depth-state.js';
import { RenderTarget } from '../../platform/graphics/render-target.js';
import { Texture } from '../../platform/graphics/texture.js';
import { CULLFACE_NONE, PIXELFORMAT_R32U, PIXELFORMAT_RGBA8, SEMANTIC_POSITION } from '../../platform/graphics/constants.js';
import { drawQuadWithShader } from '../../scene/graphics/quad-render-utils.js';
import { createShaderFromCode } from '../shader-lib/shader-utils.js';

let offscreen = null;
let ctx = null;
const SH_C0 = 0.28209479177387814;

const readImageData = (imageBitmap) => {
if (!offscreen || offscreen.width !== imageBitmap.width || offscreen.height !== imageBitmap.height) {
offscreen = new OffscreenCanvas(imageBitmap.width, imageBitmap.height);
ctx = offscreen.getContext('2d');
ctx.globalCompositeOperation = 'copy';
const reorderVS = /* glsl */`
attribute vec2 vertex_position;
void main(void) {
gl_Position = vec4(vertex_position, 0.0, 1.0);
}
ctx.drawImage(imageBitmap, 0, 0);
return ctx.getImageData(0, 0, imageBitmap.width, imageBitmap.height).data;
};
`;

const SH_C0 = 0.28209479177387814;
const reorderFS = /* glsl */`
uniform usampler2D orderTexture;
uniform sampler2D sourceTexture;
uniform highp uint numSplats;

void main(void) {
uint w = uint(textureSize(sourceTexture, 0).x);
uint idx = uint(gl_FragCoord.x) + uint(gl_FragCoord.y) * w;
if (idx >= numSplats) discard;

// fetch the source index and calculate source uv
uint sidx = texelFetch(orderTexture, ivec2(gl_FragCoord.xy), 0).x;
uvec2 suv = uvec2(sidx % w, sidx / w);

// sample the source texture
gl_FragColor = texelFetch(sourceTexture, ivec2(suv), 0);
}
`;

const resolve = (scope, values) => {
for (const key in values) {
scope.resolve(key).setValue(values[key]);
}
};

class GSplatSogsIterator {
constructor(data, p, r, s, c, sh) {
Expand All @@ -25,16 +51,16 @@ class GSplatSogsIterator {

// extract means for centers
const { meta } = data;
const { means, quats, scales, opacities, sh0, shN } = meta;
const means_l_data = p && readImageData(data.means_l._levels[0]);
const means_u_data = p && readImageData(data.means_u._levels[0]);
const quats_data = r && readImageData(data.quats._levels[0]);
const scales_data = s && readImageData(data.scales._levels[0]);
const opacities_data = c && readImageData(data.opacities._levels[0]);
const sh0_data = c && readImageData(data.sh0._levels[0]);
const sh_labels_l_data = sh && readImageData(data.sh_labels_l._levels[0]);
const sh_labels_u_data = sh && readImageData(data.sh_labels_u._levels[0]);
const sh_centroids_data = sh && readImageData(data.sh_centroids._levels[0]);
const { means, scales, sh0, shN } = meta;
const means_l_data = p && data.means_l._levels[0];
const means_u_data = p && data.means_u._levels[0];
const quats_data = r && data.quats._levels[0];
const scales_data = s && data.scales._levels[0];
const sh0_data = c && data.sh0._levels[0];
const sh_labels_data = sh && data.sh_labels._levels[0];
const sh_centroids_data = sh && data.sh_centroids._levels[0];

const norm = 2.0 / Math.sqrt(2.0);

this.read = (i) => {
if (p) {
Expand All @@ -48,11 +74,18 @@ class GSplatSogsIterator {
}

if (r) {
const qx = lerp(quats.mins[0], quats.maxs[0], quats_data[i * 4 + 0] / 255);
const qy = lerp(quats.mins[1], quats.maxs[1], quats_data[i * 4 + 1] / 255);
const qz = lerp(quats.mins[2], quats.maxs[2], quats_data[i * 4 + 2] / 255);
const qw = Math.sqrt(Math.max(0, 1 - (qx * qx + qy * qy + qz * qz)));
r.set(qy, qz, qw, qx);
const a = (quats_data[i * 4 + 0] / 255 - 0.5) * norm;
const b = (quats_data[i * 4 + 1] / 255 - 0.5) * norm;
const c = (quats_data[i * 4 + 2] / 255 - 0.5) * norm;
const d = Math.sqrt(Math.max(0, 1 - (a * a + b * b + c * c)));
const mode = quats_data[i * 4 + 3] - 252;

switch (mode) {
case 0: r.set(a, b, c, d); break;
case 1: r.set(d, b, c, a); break;
case 2: r.set(b, d, c, a); break;
case 3: r.set(b, c, d, a); break;
}
}

if (s) {
Expand All @@ -66,18 +99,18 @@ class GSplatSogsIterator {
const r = lerp(sh0.mins[0], sh0.maxs[0], sh0_data[i * 4 + 0] / 255);
const g = lerp(sh0.mins[1], sh0.maxs[1], sh0_data[i * 4 + 1] / 255);
const b = lerp(sh0.mins[2], sh0.maxs[2], sh0_data[i * 4 + 2] / 255);
const a = lerp(opacities.mins[0], opacities.maxs[0], opacities_data[i * 4 + 0] / 255);
const a = lerp(sh0.mins[3], sh0.maxs[3], sh0_data[i * 4 + 3] / 255);

c.set(
0.5 + r * SH_C0,
0.5 + g * SH_C0,
0.5 + b * SH_C0,
1.0 / (1.0 + Math.exp(a * -1.0))
1.0 / (1.0 + Math.exp(-a))
);
}

if (sh) {
const n = sh_labels_l_data[i * 4 + 0] + (sh_labels_u_data[i * 4 + 0] << 8);
const n = sh_labels_data[i * 4 + 0] + (sh_labels_data[i * 4 + 1] << 8);
const u = (n % 64) * 15;
const v = Math.floor(n / 64);

Expand Down Expand Up @@ -106,15 +139,14 @@ class GSplatSogsData {

scales;

opacities;

sh0;

sh_centroids;

sh_labels_l;
sh_labels;

sh_labels_u;
// if data is reordered at load, this texture stores the reorder indices.
orderTexture;

createIter(p, r, s, c, sh) {
return new GSplatSogsIterator(this, p, r, s, c, sh);
Expand All @@ -141,9 +173,10 @@ class GSplatSogsData {
getCenters(result) {
const p = new Vec3();
const iter = this.createIter(p);
const order = this.orderTexture?._levels[0];

for (let i = 0; i < this.numSplats; i++) {
iter.read(i);
iter.read(order ? order[i] : i);

result[i * 3 + 0] = p.x;
result[i * 3 + 1] = p.y;
Expand Down Expand Up @@ -234,6 +267,121 @@ class GSplatSogsData {
})
}]);
}

// reorder the sogs texture data in gpu memory given the ordering encoded in texture data
reorderGpuMemory() {
const { orderTexture, numSplats } = this;
const { device, height, width } = orderTexture;
const { scope } = device;

const shader = createShaderFromCode(device, reorderVS, reorderFS, 'reorderShader', {
vertex_position: SEMANTIC_POSITION
});

let targetTexture = new Texture(device, {
width: width,
height: height,
format: PIXELFORMAT_RGBA8,
mipmaps: false
});

const members = ['means_l', 'means_u', 'quats', 'scales', 'sh0', 'sh_labels'];

device.setBlendState(BlendState.NOBLEND);
device.setCullMode(CULLFACE_NONE);
device.setDepthState(DepthState.NODEPTH);

members.forEach((member) => {
const sourceTexture = this[member];

const renderTarget = new RenderTarget({
colorBuffer: targetTexture,
depth: false,
mipLevel: 0
});

resolve(scope, {
orderTexture,
sourceTexture,
numSplats
});

drawQuadWithShader(device, renderTarget, shader);

this[member] = targetTexture;
targetTexture.name = sourceTexture.name;
targetTexture._levels = sourceTexture._levels;
sourceTexture._levels = [];
targetTexture = sourceTexture;

renderTarget.destroy();
});

targetTexture.destroy();
shader.destroy();
}

// construct an array containing the Morton order of the splats
// returns an array of 32-bit unsigned integers
calcMortonOrder() {
// https://fgiesen.wordpress.com/2009/12/13/decoding-morton-codes/
const encodeMorton3 = (x, y, z) => {
const Part1By2 = (x) => {
x &= 0x000003ff;
x = (x ^ (x << 16)) & 0xff0000ff;
x = (x ^ (x << 8)) & 0x0300f00f;
x = (x ^ (x << 4)) & 0x030c30c3;
x = (x ^ (x << 2)) & 0x09249249;
return x;
};

return (Part1By2(z) << 2) + (Part1By2(y) << 1) + Part1By2(x);
};

const { means_l, means_u } = this;
const means_l_data = means_l._levels[0];
const means_u_data = means_u._levels[0];
const codes = new BigUint64Array(this.numSplats);

// generate Morton codes for each splat based on the means directly (i.e. the log-space coordinates)
for (let i = 0; i < this.numSplats; ++i) {
const ix = (means_u_data[i * 4 + 0] << 2) | (means_l_data[i * 4 + 0] >>> 6);
const iy = (means_u_data[i * 4 + 1] << 2) | (means_l_data[i * 4 + 1] >>> 6);
const iz = (means_u_data[i * 4 + 2] << 2) | (means_l_data[i * 4 + 2] >>> 6);
codes[i] = BigInt(encodeMorton3(ix, iy, iz)) << BigInt(32) | BigInt(i);
}

codes.sort();

// allocate data for the order buffer, but make it texture-memory sized
const order = new Uint32Array(means_l.width * means_l.height);
for (let i = 0; i < this.numSplats; ++i) {
order[i] = Number(codes[i] & BigInt(0xffffffff));
}

return order;
}

reorderData() {
if (!this.orderTexture) {
const { device, height, width } = this.means_l;

this.orderTexture = new Texture(device, {
name: 'orderTexture',
width,
height,
format: PIXELFORMAT_R32U,
mipmaps: false,
levels: [this.calcMortonOrder()]
});

device.on('devicerestored', () => {
this.reorderGpuMemory();
});
}

this.reorderGpuMemory();
}
}

export { GSplatSogsData };
4 changes: 2 additions & 2 deletions src/scene/gsplat/gsplat-sogs.js
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ class GSplatSogs {
result.setDefine('GSPLAT_SOGS_DATA', true);
result.setDefine('SH_BANDS', this.gsplatData.shBands);

['means_l', 'means_u', 'quats', 'scales', 'opacities', 'sh0', 'sh_centroids', 'sh_labels_u', 'sh_labels_l'].forEach((name) => {
['means_l', 'means_u', 'quats', 'scales', 'sh0', 'sh_centroids', 'sh_labels'].forEach((name) => {
result.setParameter(name, gsplatData[name]);
});

['means', 'quats', 'scales', 'opacities', 'sh0', 'shN'].forEach((name) => {
['means', 'scales', 'sh0', 'shN'].forEach((name) => {
const v = gsplatData.meta[name];
if (v) {
result.setParameter(`${name}_mins`, v.mins);
Expand Down
Loading