Skip to content

Commit e1f6bae

Browse files
committed
implemented REPL_Template support and removed bug in unary operators kernel
1 parent 8c70b8f commit e1f6bae

File tree

3 files changed

+375
-341
lines changed

3 files changed

+375
-341
lines changed

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -865,7 +865,7 @@ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx,
865865
ggml_tensor * dst,
866866
webgpu_pipeline & pipeline,
867867
bool in_place,
868-
const std::vector<uint32_t> & xielu_params = {}) {
868+
const std::vector<uint32_t> & extra_params = {}) {
869869
uint32_t ne = (uint32_t) ggml_nelements(dst);
870870

871871
std::vector<uint32_t> params = {
@@ -881,7 +881,7 @@ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx,
881881
(uint32_t) dst->ne[1], (uint32_t) dst->ne[2]
882882
};
883883

884-
params.insert(params.end(), xielu_params.begin(), xielu_params.end());
884+
params.insert(params.end(), extra_params.begin(), extra_params.end());
885885

886886
std::vector<wgpu::BindGroupEntry> entries = {
887887
{ .binding = 0,

ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,14 @@ def parse_decls(decls_text):
1818
decls[name.strip()] = code.strip()
1919
return decls
2020

21+
def replace_repl_placeholders(variant, template_map):
22+
23+
for repl, code in variant["REPLS"].items():
24+
for key, val in template_map.items():
25+
# Match "key" and avoid matching subsequences using by using \b
26+
code = re.sub(rf'\b{re.escape(str(key))}\b', str(val), code)
27+
variant["REPLS"][repl] = code
28+
return variant
2129

2230
def replace_placeholders(shader_text, replacements):
2331
for key, val in replacements.items():
@@ -71,6 +79,10 @@ def generate_variants(fname, input_dir, output_dir, outfile):
7179
decls_map = parse_decls(extract_block(text, "DECLS"))
7280
except ValueError:
7381
decls_map = {}
82+
try:
83+
templates_map = ast.literal_eval(extract_block(text, "REPL_TEMPLATES"))
84+
except ValueError:
85+
templates_map = {}
7486

7587
with open(os.path.join(input_dir, "common_decls.tmpl"), "r", encoding="utf-8") as f:
7688
common_decls = f.read()
@@ -85,11 +97,15 @@ def generate_variants(fname, input_dir, output_dir, outfile):
8597
decls_code = ""
8698
for key in decls:
8799
if key not in decls_map:
100+
88101
raise ValueError(f"DECLS key '{key}' not found.")
89102
decls_code += decls_map[key] + "\n\n"
90103

91104
final_shader = re.sub(r'\bDECLS\b', decls_code, shader_template)
105+
92106
if "REPLS" in variant:
107+
variant = replace_repl_placeholders(variant, templates_map)
108+
final_shader = replace_placeholders(final_shader, variant["REPLS"])
93109
final_shader = replace_placeholders(final_shader, variant["REPLS"])
94110
final_shader = expand_includes(final_shader, input_dir)
95111

0 commit comments

Comments
 (0)