Skip to content

Commit 77711e8

Browse files
committed
feat: add example + switch to the published versions of encase and shader-slang
1 parent de88bf0 commit 77711e8

File tree

6 files changed

+99
-4
lines changed

6 files changed

+99
-4
lines changed

Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,5 +31,5 @@ rust.unexpected_cfgs = { level = "warn", check-cfg = [
3131
opt-level = 'z'
3232

3333
[patch.crates-io]
34-
encase = { path = "../encase" }
35-
shader-slang = { path = "../slang-rs" }
34+
#encase = { path = "../encase" }
35+
#shader-slang = { path = "../slang-rs" }

crates/minislang/src/lib.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,9 @@ impl SlangCompiler {
9797

9898
let entry_points: Vec<_> = module
9999
.entry_points()
100-
.filter(|e| entry_point.is_none() || e.function_reflection().name() == entry_point)
100+
.filter(|e| {
101+
entry_point.is_none() || Some(e.function_reflection().name()) == entry_point
102+
})
101103
.map(|e| e.downcast().clone())
102104
.collect();
103105
let program = session

crates/slang-hal/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ futures-test = "0.3"
4949
serial_test = "3"
5050
approx = "0.5"
5151
async-std = { version = "1", features = ["attributes"] }
52+
slang-hal = { path = ".", features = ["derive"] }
5253

5354
[build-dependencies]
5455
dircpy = "0.3"

crates/slang-hal/examples/add.rs

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
use minislang::SlangCompiler;
2+
use slang_hal::backend::{Backend, Encoder, WebGpu};
3+
use slang_hal::function::GpuFunction;
4+
use slang_hal::{Shader, ShaderArgs, backend::Buffer};
5+
use wgpu::BufferUsages;
6+
7+
// Embed the shaders into the executable for simplicity.
8+
const SLANG_SRC_DIR: include_dir::Dir<'_> =
9+
include_dir::include_dir!("$CARGO_MANIFEST_DIR/examples/shaders");
10+
11+
#[derive(Shader)]
12+
#[shader(module = "add")]
13+
pub struct GpuAdd<B: Backend> {
14+
add_assign: GpuFunction<B>,
15+
}
16+
17+
#[derive(ShaderArgs)]
18+
pub struct AddArgs<'a, B: Backend> {
19+
a: &'a B::Buffer<f32>,
20+
b: &'a B::Buffer<f32>,
21+
}
22+
23+
impl<B: Backend> GpuAdd<B> {
24+
pub fn launch(
25+
&self,
26+
backend: &B,
27+
pass: &mut B::Pass,
28+
a: &B::Buffer<f32>,
29+
b: &B::Buffer<f32>,
30+
) -> Result<(), B::Error> {
31+
assert_eq!(a.len(), b.len());
32+
33+
let args = AddArgs { a, b };
34+
self.add_assign
35+
.launch(backend, pass, &args, [a.len() as u32, 1, 1])?;
36+
37+
Ok(())
38+
}
39+
}
40+
41+
#[async_std::main]
42+
async fn main() {
43+
// Initialize the backend and slang compiler.
44+
#[cfg(feature = "cuda")]
45+
let backend = Cuda::new().unwrap();
46+
#[cfg(not(feature = "cuda"))]
47+
let backend = WebGpu::default().await.unwrap();
48+
let mut compiler = SlangCompiler::new(vec![]);
49+
compiler.add_dir(SLANG_SRC_DIR);
50+
51+
// Run the operation and display the result.
52+
let a = (0..10000).map(|i| i as f32).collect::<Vec<_>>();
53+
let b = (0..10000).map(|i| i as f32 * 10.0).collect::<Vec<_>>();
54+
let result = compute_sum_on_gpu(&backend, &compiler, &a, &b)
55+
.await
56+
.unwrap();
57+
println!("Computed sum: {result:?}");
58+
}
59+
60+
async fn compute_sum_on_gpu<B: Backend>(
61+
backend: &B,
62+
compiler: &SlangCompiler,
63+
a: &[f32],
64+
b: &[f32],
65+
) -> Result<Vec<f32>, B::Error> {
66+
// Generate the GPU buffers.
67+
let a = backend.init_buffer(&a, BufferUsages::STORAGE | BufferUsages::COPY_SRC)?;
68+
let b = backend.init_buffer(&b, BufferUsages::STORAGE)?;
69+
70+
// Dispatch the operation on the gpu.
71+
let add = GpuAdd::from_backend(backend, compiler)?;
72+
let mut encoder = backend.begin_encoding();
73+
let mut pass = encoder.begin_pass();
74+
add.launch(backend, &mut pass, &a, &b)?;
75+
drop(pass);
76+
backend.submit(encoder)?;
77+
78+
// Read the result (slow but convenient version).
79+
backend.slow_read_vec(&a).await
80+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
[shader("compute")]
2+
[numthreads(64, 1, 1)]
3+
func add_assign(
4+
uint3 invocation_id: SV_DispatchThreadID,
5+
RWStructuredBuffer<float> a,
6+
StructuredBuffer<float> b,
7+
) {
8+
let thread_id = invocation_id.x;
9+
if (thread_id < a.getCount()) {
10+
a[thread_id] += b[thread_id];
11+
}
12+
}

crates/slang-hal/src/function.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ impl<B: Backend> GpuFunction<B> {
5353
buffers.push((
5454
param_var
5555
.name()
56-
.expect("unnamed parameters not supported yet")
56+
// .expect("unnamed parameters not supported yet")
5757
.to_string(),
5858
binding,
5959
));

0 commit comments

Comments
 (0)