Skip to content

Commit 64cc644

Browse files
authored
[core][torch.compile] discard the compile for profiling (#7796)
1 parent 39178c7 commit 64cc644

File tree

4 files changed

+43
-2
lines changed

4 files changed

+43
-2
lines changed

.buildkite/run-tpu-test.sh

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,4 @@ remove_docker_container
1212
# For HF_TOKEN.
1313
source /etc/environment
1414
# Run a simple end-to-end example.
15-
docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu \
16-
python3 /workspace/vllm/examples/offline_inference_tpu.py
15+
docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference_tpu.py"

tests/tpu/test_compilation.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import glob
2+
import os
3+
import runpy
4+
import tempfile
5+
6+
import depyf
7+
8+
temp_dir = tempfile.mkdtemp()
9+
with depyf.prepare_debug(temp_dir):
10+
cur_dir = os.path.dirname(__file__)
11+
parent_dir = os.path.dirname(cur_dir)
12+
root_dir = os.path.dirname(parent_dir)
13+
example_file = os.path.join(root_dir, "examples",
14+
"offline_inference_tpu.py")
15+
runpy.run_path(example_file)
16+
17+
compiled_code = sorted(
18+
glob.glob(os.path.join(temp_dir, "__transformed_code*.py")))
19+
full_code = glob.glob(os.path.join(temp_dir, "full_code*.py"))[0]
20+
# we should only trigger Dynamo compilation three times:
21+
# one for the profiling phase (and the compiled artifact will be discarded)
22+
# one for the prefill phase with symbolic shapes
23+
# one for the decode phase with symbolic shapes
24+
# and later calls should not trigger Dynamo compilation again.
25+
# NOTE: it might still trigger XLA compilation.
26+
27+
# check we have three compiled code
28+
assert len(compiled_code) == 3
29+
30+
# check the first compilation is discarded
31+
with open(full_code) as f:
32+
full_code_content = f.read()
33+
profile_function = compiled_code[0].split(".")[0]
34+
assert profile_function not in full_code_content

vllm/worker/model_runner.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1097,6 +1097,10 @@ def profile_run(self) -> None:
10971097
device=self.device)
10981098
self.execute_model(model_input, kv_caches, intermediate_tensors)
10991099
torch.cuda.synchronize()
1100+
1101+
# reset and discard the guard and compiled bytecode for profiling runs
1102+
torch._dynamo.reset()
1103+
11001104
return
11011105

11021106
def remove_all_loras(self):

vllm/worker/tpu_worker.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,10 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
143143
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
144144
block_size_bytes)
145145
num_cpu_blocks = (num_cpu_blocks // 8) * 8 # Round down to 8.
146+
147+
# reset and discard the guard and compiled bytecode for profiling runs
148+
torch._dynamo.reset()
149+
146150
return num_tpu_blocks, num_cpu_blocks
147151

148152
def initialize_cache(

0 commit comments

Comments
 (0)