Skip to content

Commit 404ee91

Browse files
AffectionateCurrySokserey Sunsimonguozirui
authored
Toml-based Prompt Organization (#85)
* preliminary toml file change for prompt constructor. Will add tilelang * Cleaned up the toml file and added logic to use the toml * cleaned up prompt toml. Still need to add custom prompt logic * added custom prompt capabilities * small cleanup * finalize toml functionality * deprecate old prompt_constructor * prompt constructor bug fixes * syntax error * validate toml fix and verify_generation to merge sokserey's PR * validate generate run and evalwith intermediate prompt log, should make it shorter later --------- Co-authored-by: Sokserey Sun <sokserey@matx1.stanford.edu> Co-authored-by: Simon Guo <simonguo@stanford.edu>
1 parent 5c88b23 commit 404ee91

9 files changed

+904
-1133
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ If you don't have GPU available locally, you can set up [Modal](https://modal.co
9090

9191
## 🚀 Usage
9292
### Run on a single problem
93-
It is easier to get started with a single problem. This will fetch the problem, generate a sample, and evaluate the sample.
93+
It is easier to get started with a single problem. This will fetch the problem, generate a sample, and evaluate the sample.
9494

9595
```
9696
# for example, run level 2 problem 40 from huggingface
@@ -106,7 +106,7 @@ python3 scripts/generate_and_eval_single_sample.py dataset_src="huggingface" lev
106106
* **`precision`** - You can specify the precision of tensor by `precision=fp32`. Currently all of our reported results are `fp32` but we added support for `fp16` & `bf16`.
107107
* **`backend`** - We are also supporting other GPU programming languages beyond `cuda`. Simply specify `backend=triton`. For now we support DSLs: `cuda`, `triton`, `cute`, `tilelang`.
108108

109-
Check the config fields for comprehensive set of options.
109+
Check the config fields for comprehensive set of options. Note we provide the model with a one-shot example by default along with the minimum set of info; you can check out other prompt settings or construct your own in `src/prompt_constructor_toml.py`.
110110

111111
### Run on all problems
112112

scripts/generate_and_eval_single_sample.py

Lines changed: 73 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99

1010
from src.dataset import construct_kernelbench_dataset
1111
from src.eval import eval_kernel_against_ref
12-
from src.prompt_constructor import prompt_generate_custom_cuda_from_prompt_template
13-
from src.prompt_constructor_multilang import get_prompt_for_backend
12+
from src.prompt_constructor_toml import get_prompt_for_backend, get_custom_prompt
1413
from src.utils import (
1514
create_inference_server_from_presets,
1615
extract_first_code,
@@ -22,6 +21,9 @@
2221
"""
2322
Generate and evaluate a single sample
2423
Easiest way to get started, to test a single problem for experimentation or debugging
24+
25+
Example usage:
26+
python3 scripts/generate_and_eval_single_sample.py dataset_src=huggingface level=1 problem_id=1 eval_mode=local server_type=google model_name=gemini/gemini-2.5-flash max_tokens=8192 temperature=0.0
2527
"""
2628

2729
REPO_TOP_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
@@ -72,6 +74,12 @@ def __init__(self):
7274

7375
self.backend = "cuda"
7476

77+
# Prompt construction
78+
self.prompt_option = "one_shot" # choices: zero_shot, one_shot, few_shot
79+
self.include_hardware_info = False
80+
self.hardware_gpu_name = None
81+
self.custom_prompt_key = None
82+
7583
def verbose_logging(self):
7684
self.log = True
7785
self.log_prompt = True
@@ -86,6 +94,7 @@ def __repr__(self):
8694
def main(config: EvalConfig):
8795
"""
8896
Keep it simple: Generate and evaluate a single sample
97+
Note: will shorten code logic to make this as simple as possible
8998
"""
9099
from src.utils import SERVER_PRESETS
91100

@@ -129,6 +138,7 @@ def main(config: EvalConfig):
129138
config.problem_id <= num_problems
130139
), f"Problem ID {config.problem_id} out of range for Level {config.level}"
131140

141+
# TODO: refactor dataset fetching logic to be as clean as posisble.
132142
# 1. Fetch Problem
133143
if config.dataset_src == "huggingface":
134144

@@ -169,24 +179,70 @@ def main(config: EvalConfig):
169179
budget_tokens=config.budget_tokens,
170180
)
171181

182+
# Prompt Construction (Note: could be shortened in future PR)
183+
custom_prompt_key = getattr(config, "custom_prompt_key", None)
184+
if isinstance(custom_prompt_key, str):
185+
trimmed = custom_prompt_key.strip()
186+
if trimmed.lower() in {"", "none"}:
187+
custom_prompt_key = None
188+
else:
189+
custom_prompt_key = trimmed
190+
config.custom_prompt_key = custom_prompt_key
191+
172192
# Use appropriate prompt constructor based on backend
173-
if config.backend == "cuda":
174-
custom_prompt = prompt_generate_custom_cuda_from_prompt_template(ref_arch_src)
175-
elif config.backend in ["triton", "tilelang", "cute"]:
176-
custom_prompt = get_prompt_for_backend(ref_arch_src, config.backend)
177-
else:
193+
prompt_option = str(config.prompt_option).lower()
194+
valid_prompt_options = {"zero_shot", "one_shot", "few_shot"}
195+
include_hardware = config.include_hardware_info
196+
if isinstance(include_hardware, str):
197+
include_hardware = include_hardware.lower() in ["true", "1", "yes"]
198+
config.include_hardware_info = include_hardware
199+
200+
supported_backends = {"cuda", "triton", "tilelang", "cute"}
201+
backend = config.backend.lower()
202+
if backend not in supported_backends:
178203
raise ValueError(
179-
f"Unsupported backend: {config.backend}. Must be 'cuda', 'triton', 'tilelang', or 'cute'."
204+
f"Unsupported backend: {config.backend}. Must be one of {sorted(supported_backends)}."
180205
)
181206

207+
if backend == "tilelang":
208+
config.precision = "fp16" # tilelang only operates with fp16
209+
config.hardware_gpu_name = config.hardware_gpu_name or getattr(config, "gpu", None)
210+
211+
if not custom_prompt_key:
212+
if prompt_option not in valid_prompt_options:
213+
raise ValueError(
214+
f"Invalid prompt_option '{config.prompt_option}'. "
215+
f"Must be one of {sorted(valid_prompt_options)}."
216+
)
217+
if include_hardware and not config.hardware_gpu_name:
218+
raise ValueError(
219+
"include_hardware_info is True but hardware_gpu_name is not provided."
220+
)
221+
222+
if custom_prompt_key:
223+
custom_prompt = get_custom_prompt(
224+
custom_prompt_key,
225+
ref_arch_src=ref_arch_src,
226+
backend=backend,
227+
option=prompt_option,
228+
precision=config.precision,
229+
include_hardware=include_hardware,
230+
gpu_name=config.hardware_gpu_name,
231+
)
232+
else:
233+
custom_prompt = get_prompt_for_backend(
234+
ref_arch_src,
235+
backend,
236+
option=prompt_option,
237+
precision=config.precision,
238+
include_hardware=include_hardware,
239+
gpu_name=config.hardware_gpu_name,
240+
)
241+
242+
os.makedirs(config.logdir, exist_ok=True)
243+
182244
if config.log_prompt:
183-
with open(
184-
os.path.join(
185-
config.logdir,
186-
f"prompt_level_{config.level}_problem_{config.problem_id}.txt",
187-
),
188-
"w",
189-
) as f:
245+
with open(os.path.join(config.logdir, f"prompt_level_{config.level}_problem_{config.problem_id}.txt"), "w") as f:
190246
f.write(custom_prompt)
191247

192248
# Query server with constructed prompt
@@ -200,13 +256,7 @@ def main(config: EvalConfig):
200256

201257
# this should be optional
202258
if config.log:
203-
with open(
204-
os.path.join(
205-
config.logdir,
206-
f"generated_kernel_level_{config.level}_problem_{config.problem_id}.py",
207-
),
208-
"w",
209-
) as f:
259+
with open(os.path.join(config.logdir, f"generated_kernel_level_{config.level}_problem_{config.problem_id}.py"), "w") as f:
210260
f.write(custom_kernel)
211261

212262
# 3. Evaluate Kernel
@@ -228,13 +278,7 @@ def main(config: EvalConfig):
228278
)
229279

230280
if config.log:
231-
with open(
232-
os.path.join(
233-
config.logdir,
234-
f"eval_result_level_{config.level}_problem_{config.problem_id}.txt",
235-
),
236-
"a",
237-
) as f:
281+
with open(os.path.join(config.logdir, f"eval_result_level_{config.level}_problem_{config.problem_id}.txt"), "a",) as f:
238282
f.write(f"Problem Name: {problem_name}\n")
239283
f.write(str(kernel_exec_result))
240284

scripts/generate_and_eval_single_sample_modal.py

Lines changed: 62 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515

1616
#from src.dataset import construct_kernelbench_dataset
1717
from src.eval import eval_kernel_against_ref
18-
from src.prompt_constructor import prompt_generate_custom_cuda_from_prompt_template
19-
from src.prompt_constructor_multilang import get_prompt_for_backend
18+
from src.prompt_constructor_toml import get_prompt_for_backend, get_custom_prompt
2019
from src.utils import extract_first_code, query_server, set_gpu_arch, read_file, create_inference_server_from_presets
2120

2221
app = modal.App("eval_single_sample")
@@ -76,6 +75,11 @@ def __init__(self):
7675
self.log_eval_result = False
7776

7877
self.backend = "cuda"
78+
# Prompt generation settings
79+
self.prompt_option = "one_shot" # zero_shot, one_shot, few_shot
80+
self.include_hardware_info = False
81+
self.hardware_gpu_name = None
82+
self.custom_prompt_key = None
7983

8084
def verbose_logging(self):
8185
self.log = True
@@ -194,14 +198,64 @@ def main(config: EvalConfig):
194198
budget_tokens=config.budget_tokens)
195199

196200

201+
custom_prompt_key = getattr(config, "custom_prompt_key", None)
202+
if isinstance(custom_prompt_key, str):
203+
trimmed = custom_prompt_key.strip()
204+
if trimmed.lower() in {"", "none"}:
205+
custom_prompt_key = None
206+
else:
207+
custom_prompt_key = trimmed
208+
config.custom_prompt_key = custom_prompt_key
209+
210+
# Checks if user has inputted a valid argument for how many examples they want to give as context to the model
211+
prompt_option = str(config.prompt_option).lower()
212+
valid_prompt_options = {"zero_shot", "one_shot", "few_shot"}
213+
include_hardware = config.include_hardware_info
214+
if isinstance(include_hardware, str):
215+
include_hardware = include_hardware.lower() in ["true", "1", "yes"]
216+
config.include_hardware_info = include_hardware
217+
218+
supported_backends = {"cuda", "triton", "tilelang", "cute"}
219+
backend = config.backend.lower()
220+
if backend not in supported_backends:
221+
raise ValueError(
222+
f"Unsupported backend: {config.backend}. Must be one of {sorted(supported_backends)}."
223+
)
197224

198-
# Use appropriate prompt constructor based on backend
199-
if config.backend == "cuda":
200-
custom_prompt = prompt_generate_custom_cuda_from_prompt_template(ref_arch_src)
201-
elif config.backend in ["triton", "tilelang", "cute"]:
202-
custom_prompt = get_prompt_for_backend(ref_arch_src, config.backend)
225+
#tilelang only supports fp16 or bf16
226+
if backend == "tilelang":
227+
config.precision = "fp16"
228+
config.hardware_gpu_name = config.hardware_gpu_name or getattr(config, "gpu", None)
229+
230+
if not custom_prompt_key:
231+
if prompt_option not in valid_prompt_options:
232+
raise ValueError(
233+
f"Invalid prompt_option '{config.prompt_option}'. Must be one of {sorted(valid_prompt_options)}."
234+
)
235+
if include_hardware and not config.hardware_gpu_name:
236+
raise ValueError(
237+
"include_hardware_info is True but hardware_gpu_name is not provided."
238+
)
239+
240+
if custom_prompt_key:
241+
custom_prompt = get_custom_prompt(
242+
custom_prompt_key,
243+
ref_arch_src=ref_arch_src,
244+
backend=backend,
245+
option=prompt_option,
246+
precision=config.precision,
247+
include_hardware=include_hardware,
248+
gpu_name=config.hardware_gpu_name,
249+
)
203250
else:
204-
raise ValueError(f"Unsupported backend: {config.backend}. Must be 'cuda', 'triton', 'tilelang', or 'cute'.")
251+
custom_prompt = get_prompt_for_backend(
252+
ref_arch_src,
253+
backend,
254+
option=prompt_option,
255+
precision=config.precision,
256+
include_hardware=include_hardware,
257+
gpu_name=config.hardware_gpu_name,
258+
)
205259

206260
if config.log_prompt:
207261
with open(os.path.join(config.logdir, f"prompt_level_{config.level}_problem_{config.problem_id}.txt"), "w") as f:

scripts/generate_samples.py

Lines changed: 62 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010

1111
from src.dataset import construct_kernelbench_dataset
1212
from src.eval import eval_kernel_against_ref
13-
from src.prompt_constructor import prompt_generate_custom_cuda_from_prompt_template
14-
from src.prompt_constructor_multilang import get_prompt_for_backend
13+
from src.prompt_constructor_toml import get_prompt_for_backend, get_custom_prompt
1514
from src.utils import (
1615
create_inference_server_from_presets,
1716
extract_first_code,
@@ -80,6 +79,10 @@ def __init__(self):
8079
self.backend = "cuda"
8180

8281
self.precision = "fp32"
82+
self.prompt_option = "one_shot" # zero_shot, one_shot, few_shot
83+
self.include_hardware_info = False
84+
self.hardware_gpu_name = None
85+
self.custom_prompt_key = None
8386

8487
def greedy(self):
8588
# For greedy decoding, epsecially baseline eval
@@ -126,30 +129,38 @@ def generate_sample_single(
126129
problem_number == work.problem_id
127130
), f"Problem number in filename ({problem_number}) does not match config problem_id ({config.problem_id})"
128131

129-
# Construct Prompt
130-
if config.backend == "cuda":
131-
custom_cuda_prompt = prompt_generate_custom_cuda_from_prompt_template(
132-
ref_arch_src
132+
if config.custom_prompt_key:
133+
custom_prompt = get_custom_prompt(
134+
config.custom_prompt_key,
135+
ref_arch_src=ref_arch_src,
136+
backend=config.backend,
137+
option=config.prompt_option,
138+
precision=config.precision,
139+
include_hardware=config.include_hardware_info,
140+
gpu_name=config.hardware_gpu_name,
133141
)
134-
elif config.backend in ["triton", "cute", "tilelang"]:
135-
custom_cuda_prompt = get_prompt_for_backend(ref_arch_src, config.backend)
136142
else:
137-
raise ValueError(
138-
f"Unsupported backend: {config.backend}. Must be 'cuda', 'triton', 'cute', or 'tilelang'."
143+
custom_prompt = get_prompt_for_backend(
144+
ref_arch_src,
145+
config.backend,
146+
option=config.prompt_option,
147+
precision=config.precision,
148+
include_hardware=config.include_hardware_info,
149+
gpu_name=config.hardware_gpu_name,
139150
)
140151
if config.log_prompt:
141152
prompt_path = os.path.join(
142153
run_dir,
143154
f"level_{config.level}_problem_{work.problem_id}_sample_{work.sample_id}_prompt.txt",
144155
)
145156
with open(prompt_path, "w") as f:
146-
f.write(custom_cuda_prompt)
157+
f.write(custom_prompt)
147158

148159
# Query server with constructed prompt
149-
custom_cuda = inference_server(custom_cuda_prompt)
150-
custom_cuda = extract_first_code(custom_cuda, ["python", "cpp"])
160+
custom_kernel = inference_server(custom_prompt)
161+
custom_kernel = extract_first_code(custom_kernel, ["python", "cpp"])
151162
# check LLM is able to generate custom CUDA code
152-
assert custom_cuda is not None, "Custom CUDA code generation failed"
163+
assert custom_kernel is not None, "Custom CUDA code generation failed"
153164

154165
if config.verbose:
155166
print(
@@ -162,7 +173,7 @@ def generate_sample_single(
162173
f"level_{config.level}_problem_{work.problem_id}_sample_{work.sample_id}_kernel.py",
163174
)
164175
with open(kernel_path, "w") as f:
165-
f.write(custom_cuda)
176+
f.write(custom_kernel)
166177

167178
return True
168179

@@ -214,6 +225,42 @@ def main(config: GenerationConfig):
214225
if isinstance(config.is_reasoning_model, str):
215226
config.is_reasoning_model = config.is_reasoning_model.lower() in ['true', '1', 'yes']
216227

228+
custom_prompt_key = getattr(config, "custom_prompt_key", None)
229+
if isinstance(custom_prompt_key, str):
230+
trimmed = custom_prompt_key.strip()
231+
if trimmed.lower() in {"", "none"}:
232+
custom_prompt_key = None
233+
else:
234+
custom_prompt_key = trimmed
235+
config.custom_prompt_key = custom_prompt_key
236+
237+
include_hardware = config.include_hardware_info
238+
if isinstance(include_hardware, str):
239+
include_hardware = include_hardware.lower() in ["true", "1", "yes"]
240+
config.include_hardware_info = include_hardware
241+
242+
supported_backends = {"cuda", "triton", "cute", "tilelang"}
243+
backend = config.backend.lower()
244+
if backend not in supported_backends:
245+
raise ValueError(
246+
f"Unsupported backend: {config.backend}. Must be one of {sorted(supported_backends)}."
247+
)
248+
config.backend = backend
249+
if backend == "tilelang":
250+
config.precision = "fp16"
251+
252+
config.prompt_option = str(config.prompt_option).lower()
253+
valid_prompt_options = {"zero_shot", "one_shot", "few_shot"}
254+
if not config.custom_prompt_key:
255+
if config.prompt_option not in valid_prompt_options:
256+
raise ValueError(
257+
f"Invalid prompt_option '{config.prompt_option}'. Must be one of {sorted(valid_prompt_options)}."
258+
)
259+
if include_hardware and not config.hardware_gpu_name:
260+
raise ValueError(
261+
"include_hardware_info is True but hardware_gpu_name is not provided."
262+
)
263+
217264
print(f"Starting Batch Generation with config: {config}")
218265

219266
# Dataset Configurations

0 commit comments

Comments
 (0)