Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 79 additions & 107 deletions src/scene/gsplat-unified/gsplat-unified-sort-worker.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,76 @@ function UnifiedSortWorker() {

const myself = (typeof self !== 'undefined' && self) || (require('node:worker_threads').parentPort);

// let chunks;
// let sortSplatCount;

// cache of centers for each splat id
const centersMap = new Map();
let centersData;
let distances;
let countBuffer;

// could be increased, but this seems a good compromise between stability and performance
// const numBins = 32;
// const binCount = new Array(numBins).fill(0);
// const binBase = new Array(numBins).fill(0);
// const binDivider = new Array(numBins).fill(0);

// const binarySearch = (m, n, compare_fn) => {
// while (m <= n) {
// const k = (n + m) >> 1;
// const cmp = compare_fn(k);
// if (cmp > 0) {
// m = k + 1;
// } else if (cmp < 0) {
// n = k - 1;
// } else {
// return k;
// }
// }
// return ~m;
// };

const evaluateSortKeys = (sortParams, minDist, divider, distances, countBuffer, centersData, bucketCount) => {
// camera-relative bin-based precision optimization
const numBins = 32;
const binBase = new Array(numBins).fill(0);
const binDivider = new Array(numBins).fill(0);

// Weight tiers for camera-relative precision (distance from camera bin -> weight multiplier)
const weightTiers = [
{ maxDistance: 0, weight: 40.0 }, // Camera bin
{ maxDistance: 2, weight: 20.0 }, // Adjacent bins
{ maxDistance: 5, weight: 8.0 }, // Nearby bins
{ maxDistance: 10, weight: 3.0 }, // Medium distance
{ maxDistance: Infinity, weight: 1.0 } // Far bins
];

// Pre-calculate weight lookup table by distance from camera (constant)
const weightByDistance = new Array(numBins);
for (let dist = 0; dist < numBins; ++dist) {
let weight = 1.0;
for (let j = 0; j < weightTiers.length; ++j) {
if (dist <= weightTiers[j].maxDistance) {
weight = weightTiers[j].weight;
break;
}
}
weightByDistance[dist] = weight;
}

const setupCameraRelativeBins = (cameraBin, bucketCount) => {
const totalBudget = bucketCount;
const bitsPerBin = [];

// Assign weights to bins based on pre-calculated distance lookup
for (let i = 0; i < numBins; ++i) {
const distFromCamera = Math.abs(i - cameraBin);
bitsPerBin[i] = weightByDistance[distFromCamera];
}

// Normalize to fit within budget
const totalWeight = bitsPerBin.reduce((a, b) => a + b, 0);
let accumulated = 0;

for (let i = 0; i < numBins; ++i) {
binDivider[i] = Math.max(1, Math.floor((bitsPerBin[i] / totalWeight) * totalBudget));
binBase[i] = accumulated;
accumulated += binDivider[i];
}

// Adjust last bin to fit exactly
if (accumulated > bucketCount) {
const excess = accumulated - bucketCount;
binDivider[numBins - 1] = Math.max(1, binDivider[numBins - 1] - excess);
}

// Add safety entry for edge case where bin >= numBins due to floating point
binBase[numBins] = binBase[numBins - 1] + binDivider[numBins - 1];
binDivider[numBins] = 0;
};

const evaluateSortKeys = (sortParams, minDist, range, distances, countBuffer, centersData) => {
const { ids, lineStarts, padding, intervals, textureSize } = centersData;

// pre-calculate inverse bin range
const invBinRange = numBins / range;

// loop over all the splat placements
for (let paramIdx = 0; paramIdx < sortParams.length; paramIdx++) {

Expand All @@ -46,10 +83,10 @@ function UnifiedSortWorker() {
const dz = transformedDirection.z;

// pre-calculate camera related constants
const sdx = dx * scale * divider;
const sdy = dy * scale * divider;
const sdz = dz * scale * divider;
const add = (offset - minDist) * divider;
const sdx = dx * scale;
const sdy = dy * scale;
const sdz = dz * scale;
const add = offset - minDist;

// source centers
const id = ids[paramIdx];
Expand All @@ -75,7 +112,13 @@ function UnifiedSortWorker() {
const y = centers[srcIndex + 1];
const z = centers[srcIndex + 2];

const sortKey = Math.floor(x * sdx + y * sdy + z * sdz + add);
const dist = x * sdx + y * sdy + z * sdz + add;

// Bin-based mapping
const d = dist * invBinRange;
const bin = d >>> 0;
const sortKey = (binBase[bin] + binDivider[bin] * (d - bin)) >>> 0;

distances[targetIndex++] = sortKey;
countBuffer[sortKey]++;
}
Expand Down Expand Up @@ -173,88 +216,17 @@ function UnifiedSortWorker() {

const range = maxDist - minDist;

// // use chunks to calculate rough histogram of splats per distance
// const numChunks = chunks.length / 4;

// binCount.fill(0);
// for (let i = 0; i < numChunks; ++i) {
// const x = chunks[i * 4 + 0];
// const y = chunks[i * 4 + 1];
// const z = chunks[i * 4 + 2];
// const r = chunks[i * 4 + 3];
// const d = x * dx + y * dy + z * dz - minDist;

// const binMin = Math.max(0, Math.floor((d - r) * numBins / range));
// const binMax = Math.min(numBins, Math.ceil((d + r) * numBins / range));

// for (let j = binMin; j < binMax; ++j) {
// binCount[j]++;
// }
// }

// // count total number of histogram bin entries
// const binTotal = binCount.reduce((a, b) => a + b, 0);

// // calculate per-bin base and divider
// for (let i = 0; i < numBins; ++i) {
// binDivider[i] = (binCount[i] / binTotal * bucketCount) >>> 0;
// }
// for (let i = 0; i < numBins; ++i) {
// binBase[i] = i === 0 ? 0 : binBase[i - 1] + binDivider[i - 1];
// }

// // generate per vertex distance key using histogram to distribute bits
// const binRange = range / numBins;
// let ii = 0;
// for (let i = 0; i < numVertices; ++i) {
// const x = centers[ii++];
// const y = centers[ii++];
// const z = centers[ii++];
// const d = (x * dx + y * dy + z * dz - minDist) / binRange;
// const bin = d >>> 0;
// const sortKey = (binBase[bin] + binDivider[bin] * (d - bin)) >>> 0;

// distances[i] = sortKey;
// Set up camera-relative bin weighting for near-camera precision
const cameraOffsetFromRangeStart = 0 - minDist;
const cameraBinFloat = (cameraOffsetFromRangeStart / range) * numBins;
const cameraBin = Math.max(0, Math.min(numBins - 1, Math.floor(cameraBinFloat)));

// // count occurrences of each distance
// countBuffer[sortKey]++;
// }
setupCameraRelativeBins(cameraBin, bucketCount);


const divider = (range < 1e-6) ? 0 : (1 / range) * (2 ** compareBits);

// for (let i = 0; i < numVertices; ++i) {
// const istride = i * 3;
// const x = centers[istride + 0] - px;
// const y = centers[istride + 1] - py;
// const z = centers[istride + 2] - pz;
// const d = x * dx + y * dy + z * dz;
// const sortKey = Math.floor((d - minDist) * divider);

// distances[i] = sortKey;

// // count occurrences of each distance
// countBuffer[sortKey]++;
// }


evaluateSortKeys(sortParams, minDist, divider, distances, countBuffer, centersData, bucketCount);
evaluateSortKeys(sortParams, minDist, range, distances, countBuffer, centersData);

countingSort(bucketCount, countBuffer, numVertices, distances, order);


// // Find splat with distance 0 to limit rendering behind the camera
// const cameraDist = px * dx + py * dy + pz * dz;
// const dist = (i) => {
// let o = order[i] * 3;
// return centers[o++] * dx + centers[o++] * dy + centers[o] * dz - cameraDist;
// };
// const findZero = () => {
// const result = binarySearch(0, numVertices - 1, i => -dist(i));
// return Math.min(numVertices, Math.abs(result));
// };

// const count = dist(numVertices - 1) >= 0 ? findZero() : numVertices;
const count = numVertices;

// send results
Expand Down