Skip to content

Commit dca6719

Browse files
tianmu-lixuechendi
andauthored
Fully overlap model execution (#134)
Dependent on vllm-project/vllm#23569 --------- Signed-off-by: Tianmu Li <[email protected]> Co-authored-by: Chendi.Xue <[email protected]>
1 parent 1d3731b commit dca6719

File tree

5 files changed

+268
-56
lines changed

5 files changed

+268
-56
lines changed

tests/full_tests/ci_gsm8k_tests.sh

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,18 @@ if [ $? -ne 0 ]; then
140140
fi
141141
echo "Test with granite-8b passed"
142142

143+
# used to check asynchronous scheduling
144+
echo "Testing GSM8K on ganite-8b with async scheduling"
145+
echo VLLM_CONTIGUOUS_PA=False VLLM_SKIP_WARMUP=True PT_HPU_LAZY_MODE=1 VLLM_USE_V1=1 ASYNC_SCHEDULING=1 \
146+
pytest -v -s vllm-gaudi/tests/models/language/generation/test_common.py --model_card_path vllm-gaudi/tests/full_tests/model_cards/granite-8b.yaml
147+
VLLM_CONTIGUOUS_PA=False VLLM_SKIP_WARMUP=True PT_HPU_LAZY_MODE=1 VLLM_USE_V1=1 ASYNC_SCHEDULING=1 \
148+
pytest -v -s vllm-gaudi/tests/models/language/generation/test_common.py --model_card_path vllm-gaudi/tests/full_tests/model_cards/granite-8b.yaml
149+
if [ $? -ne 0 ]; then
150+
echo "Error: Test failed for granite-8b + async_scheduling" >&2
151+
exit -1
152+
fi
153+
echo "Test with granite-8b + async_scheduling passed"
154+
143155
# used to check MLA + MOE
144156
echo "Testing GSM8K on deepseek v2 lite"
145157
# deepseek-R1

tests/models/language/generation/test_common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,11 @@ def launch_lm_eval(eval_config):
2323
enforce_eager = os.environ.get('ENFORCE_EAGER', 'False').lower() in ['true', '1']
2424
kv_cache_dtype = os.environ.get('KV_CACHE_DTYPE', None)
2525
task = eval_config.get('tasks', 'gsm8k')
26+
async_scheduling = os.environ.get('ASYNC_SCHEDULING', 'False').lower() in ['true', '1']
2627
model_args = {
2728
'pretrained': eval_config['model_name'],
2829
'tensor_parallel_size': tp_size,
30+
'async_scheduling': async_scheduling,
2931
'enforce_eager': enforce_eager,
3032
'enable_prefix_caching': enable_apc,
3133
'add_bos_token': True,

vllm_gaudi/v1/worker/hpu_input_batch.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,11 @@ def __init__(
215215
self.sampling_metadata = self._make_sampling_metadata()
216216
self.pooling_params: dict[str, PoolingParams] = {}
217217

218+
# Cached reference to the GPU tensor of previously sampled tokens
219+
self.prev_sampled_token_ids: Optional[torch.Tensor] = None
220+
self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None
221+
self.prev_req_id_to_index: Optional[dict[str, int]] = None
222+
218223
self.req_type: dict[str, str] = {}
219224

220225
@property

0 commit comments

Comments
 (0)