Skip to content

Commit 1ee2fca

Browse files
authored
WebGPU: WGSL Support and webgpu_compute example updated (mrdoob#22653)
* WebGPU: Hybrid language * webgpu_compute: Update of GLSL -> WGSL * highlight point color * update screenshot
1 parent 1150618 commit 1ee2fca

File tree

3 files changed

+101
-50
lines changed

3 files changed

+101
-50
lines changed

examples/jsm/renderers/webgpu/WebGPUProgrammableStage.js

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,24 @@ class WebGPUProgrammableStage {
1010
this.type = type;
1111
this.usedTimes = 0;
1212

13-
const byteCode = glslang.compileGLSL( code, type );
13+
let data = null;
14+
15+
if ( /^#version 450/.test( code ) === true ) {
16+
17+
// GLSL
18+
19+
data = glslang.compileGLSL( code, type );
20+
21+
} else {
22+
23+
// WGSL
24+
25+
data = code;
26+
27+
}
1428

1529
this.stage = {
16-
module: device.createShaderModule( { code: byteCode } ),
30+
module: device.createShaderModule( { code: data } ),
1731
entryPoint: 'main'
1832
};
1933

-82.7 KB
Loading

examples/webgpu_compute.html

Lines changed: 85 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,12 @@
6262
scene.background = new THREE.Color( 0x000000 );
6363

6464
const particleNum = 65000; // 16-bit limit
65-
const particleSize = 3;
65+
const particleSize = 4; // 16-byte stride align
6666

6767
const particleArray = new Float32Array( particleNum * particleSize );
6868
const velocityArray = new Float32Array( particleNum * particleSize );
6969

70-
for ( let i = 0; i < particleArray.length; i += 3 ) {
70+
for ( let i = 0; i < particleArray.length; i += particleSize ) {
7171

7272
const r = Math.random() * 0.01 + 0.0005;
7373
const degree = Math.random() * 360;
@@ -76,12 +76,12 @@
7676

7777
}
7878

79-
const particleBuffer = new WebGPUStorageBuffer( 'particle', new THREE.BufferAttribute( particleArray, 3 ) );
80-
const velocityBuffer = new WebGPUStorageBuffer( 'velocity', new THREE.BufferAttribute( velocityArray, 3 ) );
79+
const particleBuffer = new WebGPUStorageBuffer( 'particle', new THREE.BufferAttribute( particleArray, particleSize ) );
80+
const velocityBuffer = new WebGPUStorageBuffer( 'velocity', new THREE.BufferAttribute( velocityArray, particleSize ) );
8181

8282
const scaleUniformLength = WebGPUBufferUtils.getVectorLength( 2, 3 ); // two vector3 for array
8383

84-
scaleUniformBuffer = new WebGPUUniformBuffer( 'scaleUniform', new Float32Array( scaleUniformLength ) );
84+
scaleUniformBuffer = new WebGPUUniformBuffer( 'scaleUniform', new Float32Array( scaleUniformLength ) );
8585

8686
pointer = new THREE.Vector2( - 10.0, - 10.0 ); // Out of bounds first
8787

@@ -98,61 +98,98 @@
9898
pointerGroup
9999
];
100100

101-
const computeShader = /* glsl */`#version 450
102-
#define PARTICLE_NUM ${particleNum}
103-
#define PARTICLE_SIZE ${particleSize}
104-
#define ROOM_SIZE 1.0
105-
#define POINTER_SIZE 0.1
101+
const computeShader = `
106102
107-
// Limitation for now: the order should be the same as bindings order
103+
//
104+
// Buffer
105+
//
108106
109-
layout(set = 0, binding = 0) buffer Particle {
110-
float particle[ PARTICLE_NUM * PARTICLE_SIZE ];
111-
} particle;
107+
[[ block ]]
108+
struct Particle {
109+
value : array< vec4<f32> >;
110+
};
111+
[[ binding( 0 ), group( 0 ) ]]
112+
var<storage,read_write> particle : Particle;
112113
113-
layout(set = 0, binding = 1) buffer Velocity {
114-
float velocity[ PARTICLE_NUM * PARTICLE_SIZE ];
115-
} velocity;
114+
[[ block ]]
115+
struct Velocity {
116+
value : array< vec4<f32> >;
117+
};
118+
[[ binding( 1 ), group( 0 ) ]]
119+
var<storage,read_write> velocity : Velocity;
116120
117-
layout(set = 0, binding = 2) uniform Scale {
118-
vec3 value[2];
119-
} scaleUniform;
121+
//
122+
// Uniforms
123+
//
120124
121-
layout(set = 0, binding = 3) uniform MouseUniforms {
122-
vec2 pointer;
123-
} mouseUniforms;
125+
[[ block ]]
126+
struct Scale {
127+
value : array< vec3<f32>, 2 >;
128+
};
129+
[[ binding( 2 ), group( 0 ) ]]
130+
var<uniform> scaleUniform : Scale;
124131
125-
void main() {
126-
uint index = gl_GlobalInvocationID.x;
127-
if ( index >= PARTICLE_NUM ) { return; }
132+
[[block]]
133+
struct MouseUniforms {
134+
pointer : vec2<f32>;
135+
};
136+
[[ binding( 3 ), group( 0 ) ]]
137+
var<uniform> mouseUniforms : MouseUniforms;
128138
129-
vec3 position = vec3(
130-
particle.particle[ index * 3 + 0 ] + velocity.velocity[ index * 3 + 0 ],
131-
particle.particle[ index * 3 + 1 ] + velocity.velocity[ index * 3 + 1 ],
132-
particle.particle[ index * 3 + 2 ] + velocity.velocity[ index * 3 + 2 ]
133-
);
139+
[[ stage( compute ), workgroup_size( 64 ) ]]
140+
fn main( [[builtin(global_invocation_id)]] id : vec3<u32> ) {
134141
135-
if ( abs( position.x ) >= ROOM_SIZE ) {
142+
// get particle index
136143
137-
velocity.velocity[ index * 3 + 0 ] = - velocity.velocity[ index * 3 + 0 ];
144+
let index : u32 = id.x * 3u;
138145
139-
}
146+
// update speed
147+
148+
var position : vec4<f32> = particle.value[ index ] + velocity.value[ index ];
149+
150+
// update limit
151+
152+
let limit : vec2<f32> = scaleUniform.value[ 0 ].xy;
153+
154+
if ( abs( position.x ) >= limit.x ) {
155+
156+
if ( position.x > 0.0 ) {
157+
158+
position.x = limit.x;
159+
160+
} else {
161+
162+
position.x = -limit.x;
140163
141-
if ( abs( position.y ) >= ROOM_SIZE ) {
164+
}
142165
143-
velocity.velocity[ index * 3 + 1 ] = - velocity.velocity[ index * 3 + 1 ];
166+
velocity.value[ index ].x = - velocity.value[ index ].x;
144167
145168
}
146169
147-
if ( abs( position.z ) >= ROOM_SIZE ) {
170+
if ( abs( position.y ) >= limit.y ) {
148171
149-
velocity.velocity[ index * 3 + 2 ] = - velocity.velocity[ index * 3 + 2 ];
172+
if ( position.y > 0.0 ) {
173+
174+
position.y = limit.y;
175+
176+
} else {
177+
178+
position.y = -limit.y;
179+
180+
}
181+
182+
velocity.value[ index ].y = - velocity.value[ index ].y;
150183
151184
}
152185
153-
float dx = mouseUniforms.pointer.x - position.x;
154-
float dy = mouseUniforms.pointer.y - position.y;
155-
float distanceFromPointer = sqrt( dx * dx + dy * dy );
186+
// update mouse
187+
188+
let POINTER_SIZE : f32 = .1;
189+
190+
let dx : f32 = mouseUniforms.pointer.x - position.x;
191+
let dy : f32 = mouseUniforms.pointer.y - position.y;
192+
let distanceFromPointer : f32 = sqrt( dx * dx + dy * dy );
156193
157194
if ( distanceFromPointer <= POINTER_SIZE ) {
158195
@@ -162,11 +199,12 @@
162199
163200
}
164201
165-
particle.particle[ index * 3 + 0 ] = position.x * scaleUniform.value[0].x;
166-
particle.particle[ index * 3 + 1 ] = position.y * scaleUniform.value[0].y;
167-
particle.particle[ index * 3 + 2 ] = position.z * scaleUniform.value[0].z;
202+
// update buffer
203+
204+
particle.value[ index ] = position;
168205
169206
}
207+
170208
`;
171209

172210
computeParams.push( {
@@ -182,7 +220,7 @@
182220
);
183221

184222
const pointsMaterial = new Nodes.PointsNodeMaterial();
185-
pointsMaterial.colorNode = new Nodes.OperatorNode( '+', new Nodes.PositionNode(), new Nodes.ColorNode( new THREE.Color( 0x0000FF ) ) );
223+
pointsMaterial.colorNode = new Nodes.OperatorNode( '+', new Nodes.PositionNode(), new Nodes.ColorNode( new THREE.Color( 0xFFFFFF ) ) );
186224

187225
const mesh = new THREE.Points( pointsGeometry, pointsMaterial );
188226
scene.add( mesh );
@@ -199,9 +237,8 @@
199237

200238
const gui = new GUI();
201239

202-
gui.add( scaleVector, 'x', 0.9, 1.1, 0.01 );
203-
gui.add( scaleVector, 'y', 0.9, 1.1, 0.01 );
204-
gui.add( scaleVector, 'z', 0.9, 1.1, 0.01 );
240+
gui.add( scaleVector, 'x', 0, 1, 0.01 );
241+
gui.add( scaleVector, 'y', 0, 1, 0.01 );
205242

206243
return renderer.init();
207244

0 commit comments

Comments
 (0)