Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions vllm_hpu_extension/bucketing/exponential.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def get_prompt_buckets(self, max_num_prefill_seqs, block_size,
max_num_batched_tokens, max_model_len):
self.check_for_user_flags('prompt')
use_merged_prefill = get_config().merged_prefill
prefix_caching = get_config().prefix_caching
prefix_caching = get_config().prefix_caching or get_config().chunked_prefill
max_prompt_seq = max_model_len

# cfgs shape: [min, step, max, limit]
Expand Down Expand Up @@ -131,10 +131,16 @@ def generate_prompt_buckets(bs_bucket_config,
filtered_buckets = buckets
if max_num_batched_tokens is not None and max_model_len is not None:
# Remove buckets exceeding batch token budget
filtered_buckets = list(
filter(
lambda bucket: bucket[0] * (bucket[1] + bucket[2] * block_size) <= max_num_batched_tokens \
and bucket[1] <= max_model_len, buckets))
if get_config().chunked_prefill:
filtered_buckets = list(
filter(
lambda bucket: bucket[1] <= max_num_batched_tokens \
and bucket[1] <= max_model_len and bucket[0] == 1 , buckets))
else:
filtered_buckets = list(
filter(
lambda bucket: bucket[0] * (bucket[1] + bucket[2] * block_size) <= max_num_batched_tokens \
and bucket[1] <= max_model_len, buckets))

if len(filtered_buckets) == 0:
# we can handle this if we ignore max_num_batched_tokens
Expand Down
9 changes: 7 additions & 2 deletions vllm_hpu_extension/bucketing/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class LinearBucketingStrategy:
def get_prompt_buckets(self, max_num_prefill_seqs, block_size,
max_num_batched_tokens, max_model_len):
use_merged_prefill = get_config().merged_prefill
prefix_caching = get_config().prefix_caching
prefix_caching = get_config().prefix_caching or get_config().chunked_prefill

max_prompt_seq = max_model_len

Expand Down Expand Up @@ -176,12 +176,17 @@ def generate_prompt_buckets(bs_bucket_config,
filtered_buckets = buckets
if max_num_batched_tokens is not None:
# Remove buckets exceeding batch token budget
if prefix_caching:
if get_config().prefix_caching:
max_tokens = max_num_batched_tokens + context_bucket_step * block_size
filtered_buckets = list(
filter(
lambda bucket: bucket[0] * (bucket[1] + bucket[2] * block_size) <= max_tokens,
buckets))
elif get_config().chunked_prefill:
filtered_buckets = list(
filter(
lambda bucket: bucket[1] <= max_num_batched_tokens \
and bucket[0] == 1 , buckets))
else:
filtered_buckets = list(
filter(
Expand Down
2 changes: 2 additions & 0 deletions vllm_hpu_extension/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def set_vllm_config(cfg):
else:
_VLLM_VALUES['model_type'] = cfg.model_config.model_type
_VLLM_VALUES['prefix_caching'] = cfg.cache_config.enable_prefix_caching
_VLLM_VALUES['chunked_prefill'] = cfg.scheduler_config.enable_chunked_prefill

# t.compile is very picky about what functions we can call inside modules
# since this is the last step we can force recompilation of config to
Expand Down Expand Up @@ -89,5 +90,6 @@ def get_environment():
Value('bridge_mode', _get_pt_bridge_mode, env_var_type=choice('eager', 'lazy')),
VllmValue('model_type', str),
VllmValue('prefix_caching', boolean),
VllmValue('chunked_prefill', boolean),
]
return split_values_and_flags(values)