-
Notifications
You must be signed in to change notification settings - Fork 48
Draft: Proper chunked prefill bucketing #295
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -28,11 +28,15 @@ def check_for_user_flags(self, phase): | |||||
|
|
||||||
|
|
||||||
| def get_prompt_buckets(self, max_num_prefill_seqs, block_size, | ||||||
| max_num_batched_tokens, max_model_len): | ||||||
| max_num_batched_tokens, max_model_len, max_num_blocks): | ||||||
| self.check_for_user_flags('prompt') | ||||||
| use_merged_prefill = get_config().merged_prefill | ||||||
| use_merged_prefill = get_config().merged_prefill | ||||||
| prefix_caching = get_config().prefix_caching | ||||||
| max_prompt_seq = max_model_len | ||||||
| # NOTE(kzawora): v1 requires chunked prefill, | ||||||
| # and we assume it is not going to be supported in v0 hpu code | ||||||
| enable_chunked_prefill = get_config().engine_version == 'v1' | ||||||
| # NOTE(kzawora): Chunked prefill scenarios will never exceed upper boundary of max_num_batched_tokens, regardless of max_model_len | ||||||
| max_prompt_seq = max_model_len if not enable_chunked_prefill else max_num_batched_tokens | ||||||
|
|
||||||
| # cfgs shape: [min, step, max, limit] | ||||||
| prompt_bs_limit = math.ceil(math.log2(max_num_prefill_seqs)) + 1 | ||||||
|
|
@@ -54,8 +58,10 @@ def get_prompt_buckets(self, max_num_prefill_seqs, block_size, | |||||
| prompt_seq_bucket_cfg, | ||||||
| block_size, | ||||||
| prefix_caching, | ||||||
| enable_chunked_prefill, | ||||||
| max_num_batched_tokens, | ||||||
| max_model_len) | ||||||
| max_model_len, | ||||||
| max_num_blocks) | ||||||
|
|
||||||
| return sorted(prompt_buckets) | ||||||
|
|
||||||
|
|
@@ -89,8 +95,10 @@ def generate_prompt_buckets(bs_bucket_config, | |||||
| seq_bucket_config, | ||||||
| block_size, | ||||||
| prefix_caching, | ||||||
| enable_chunked_prefill, | ||||||
| max_num_batched_tokens=None, | ||||||
| max_model_len=None): | ||||||
| max_model_len=None, | ||||||
| max_num_blocks=None): | ||||||
| _, _, bmax, _ = seq_bucket_config | ||||||
| batch_size_buckets = warmup_range_with_limit(bs_bucket_config) | ||||||
| long_context = False | ||||||
|
|
@@ -103,7 +111,7 @@ def generate_prompt_buckets(bs_bucket_config, | |||||
| for bs in batch_size_buckets: | ||||||
| for b in seq_bucket_config: | ||||||
| buckets_3d.append((bs, b, 0)) | ||||||
| max_blocks_range = (bmax - b) // block_size | ||||||
| max_blocks_range = (bmax - b) // block_size if not max_num_blocks else max_num_blocks | ||||||
| if max_blocks_range == 0: | ||||||
| continue | ||||||
| else: | ||||||
|
|
@@ -131,10 +139,36 @@ def generate_prompt_buckets(bs_bucket_config, | |||||
| filtered_buckets = buckets | ||||||
| if max_num_batched_tokens is not None and max_model_len is not None: | ||||||
| # Remove buckets exceeding batch token budget | ||||||
| filtered_buckets = list( | ||||||
| filter( | ||||||
| lambda bucket: bucket[0] * (bucket[1] + bucket[2] * block_size) <= max_num_batched_tokens \ | ||||||
| and bucket[1] <= max_model_len, buckets)) | ||||||
| if not enable_chunked_prefill: | ||||||
| filtered_buckets = list( | ||||||
| filter( | ||||||
| lambda bucket: bucket[0] * (bucket[1] + bucket[2] * block_size) <= max_num_batched_tokens \ | ||||||
| and bucket[1] <= max_model_len, buckets)) | ||||||
| else: | ||||||
| def filter_fn(bucket): | ||||||
| # NOTE(kzawora): Chunked prefill scenarios will never exceed upper boundary of max_num_batched_tokens, regardless of max_model_len | ||||||
| _, seq, block = bucket | ||||||
| is_seq_in_bounds = seq <= max_num_batched_tokens | ||||||
| is_block_in_bounds = block <= max_num_blocks | ||||||
| # New logic: allow all buckets up to and including the first that exceeds max_model_len, then filter the rest | ||||||
| return is_seq_in_bounds and is_block_in_bounds | ||||||
| # Find the first bucket that exceeds max_model_len | ||||||
| # For each (bs, seq), keep all buckets that do not exceed model len, and the first that does | ||||||
| from collections import defaultdict | ||||||
|
||||||
| from collections import defaultdict |
Copilot
AI
Jul 17, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] This complex nested lambda expression reduces readability. Consider using a list comprehension or separating into multiple steps for better clarity.
| filtered_buckets = list(map(lambda x: x[1], filter(keep_bucket, enumerate(buckets)))) | |
| filtered_buckets = [bucket for _, bucket in enumerate(buckets) if keep_bucket((_, bucket))] |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -11,9 +11,10 @@ | |||||
|
|
||||||
| class LinearBucketingStrategy: | ||||||
| def get_prompt_buckets(self, max_num_prefill_seqs, block_size, | ||||||
| max_num_batched_tokens, max_model_len): | ||||||
| max_num_batched_tokens, max_model_len, max_num_blocks): | ||||||
| use_merged_prefill = get_config().merged_prefill | ||||||
| prefix_caching = get_config().prefix_caching | ||||||
| chunked_prefill = get_config().engine_version == 'v1' | ||||||
|
|
||||||
| max_prompt_seq = max_model_len | ||||||
|
|
||||||
|
|
@@ -50,7 +51,10 @@ def get_prompt_buckets(self, max_num_prefill_seqs, block_size, | |||||
| prompt_seq_bucket_cfg, | ||||||
| block_size, | ||||||
| prefix_caching, | ||||||
| max_num_batched_tokens) | ||||||
| chunked_prefill, | ||||||
| max_num_batched_tokens, | ||||||
| max_model_len, | ||||||
| max_num_blocks) | ||||||
|
|
||||||
| return sorted(prompt_buckets) | ||||||
|
|
||||||
|
|
@@ -129,7 +133,10 @@ def generate_prompt_buckets(bs_bucket_config, | |||||
| seq_bucket_config, | ||||||
| block_size, | ||||||
| prefix_caching, | ||||||
| max_num_batched_tokens=None): | ||||||
| enable_chunked_prefill, | ||||||
| max_num_batched_tokens=None, | ||||||
| max_model_len=None, | ||||||
| max_num_blocks=None): | ||||||
| _, _, bmax = seq_bucket_config | ||||||
| batch_size_buckets = warmup_range(bs_bucket_config) | ||||||
| seq_bucket_config = warmup_range(seq_bucket_config) | ||||||
|
|
@@ -157,10 +164,37 @@ def generate_prompt_buckets(bs_bucket_config, | |||||
| filtered_buckets = buckets | ||||||
| if max_num_batched_tokens is not None: | ||||||
| # Remove buckets exceeding batch token budget | ||||||
| filtered_buckets = list( | ||||||
| filter( | ||||||
| lambda bucket: bucket[0] * (bucket[1] + bucket[2] * block_size) <= max_num_batched_tokens, | ||||||
| buckets)) | ||||||
| if not enable_chunked_prefill: | ||||||
| filtered_buckets = list( | ||||||
| filter( | ||||||
| lambda bucket: bucket[0] * (bucket[1] + bucket[2] * block_size) <= max_num_batched_tokens, | ||||||
| buckets)) | ||||||
| else: | ||||||
| def filter_fn(bucket): | ||||||
| # NOTE(kzawora): Chunked prefill scenarios will never exceed upper boundary of max_num_batched_tokens, regardless of max_model_len | ||||||
| _, seq, block = bucket | ||||||
| is_seq_in_bounds = seq <= max_num_batched_tokens | ||||||
| is_block_in_bounds = block <= max_num_blocks | ||||||
| # New logic: allow all buckets up to and including the first that exceeds max_model_len, then filter the rest | ||||||
| return is_seq_in_bounds and is_block_in_bounds | ||||||
| # Find the first bucket that exceeds max_model_len | ||||||
| # For each (bs, seq), keep all buckets that do not exceed model len, and the first that does | ||||||
| from collections import defaultdict | ||||||
|
||||||
| from collections import defaultdict |
Copilot
AI
Jul 17, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] This complex nested lambda expression reduces readability. Consider using a list comprehension or separating into multiple steps for better clarity.
| filtered_buckets = list(map(lambda x: x[1], filter(keep_bucket, enumerate(buckets)))) | |
| filtered_buckets = [bucket for _, bucket in enumerate(buckets) if keep_bucket((_, bucket))] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The conditional logic is unclear. Consider using 'max_num_blocks if max_num_blocks is not None else (bmax - b) // block_size' to be more explicit about None checking.