Skip to content

Conversation

@rkooo567
Copy link
Collaborator

@rkooo567 rkooo567 commented Apr 4, 2024

This is a part of the RFC. #3130

This PR

  • fixes a regression caused from [3/N] Refactor scheduler for chunked prefill scheduling #3550 where RUNNING num_batched tokens are added before running prefill. It is fixed & regression test is added test_scheduler_prefill_prioritized.
  • Update SchedulingBudget API to include request_ids to dedup add/subtract budgets. It is to make APIs more order agonistic.
  • From this PR, running and swapped both now can include prefill requests (chunked prefill). To figure out if the seq group is prefill vs decode, the stage is added to SequenceData class.
  • enable_chunk to support chunking prefills for _schedule_prefills, schedule_running, and _schedule_swapped.
  • Support chunked prefill scheduling algorithm. The algorithm is as follow;
1. Schedule all decodes
2. Schedule all chunked prefills (running prefills)
3. Schedule swapped
4. Schedule new prefills
  • Since chunked prefill requests are now a part of running, preemption & swapping works with chunked prefill. There are test cases covering it in a new test file test_chunked_prefill_scheduler.py.

PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

@simon-mo simon-mo requested review from simon-mo and zhuohan123 April 4, 2024 21:53
Copy link
Collaborator

@simon-mo simon-mo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this mean we can do end to end test now?


for scheduled_seq_group, outputs in zip(scheduled_seq_groups, output):
seq_group = scheduled_seq_group.seq_group
token_chunk_size = scheduled_seq_group.token_chunk_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why removed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made it handled in scheduler.py. Basically as soon as we confirm it is scheduled, we update num computed tokens (which is consistent to moving waiting item to running within a scheduler).

seq_group=seq_group,
token_chunk_size=num_running_tokens))
else:
assert num_running_tokens == 1, num_running_tokens
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would this break spec decode because the rejection sampling takes multiple token @cadedaniel

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think when the feature is working e2e, we can remove the assert?

@rkooo567
Copy link
Collaborator Author

rkooo567 commented Apr 4, 2024

does this mean we can do end to end test now?

We will need attention level change which will come in the next PR. That'd be the last one

def remaining_token_budget(self):
return self.token_budget - self.num_batched_tokens

def add_num_batched_tokens(self, seq_group: SequenceGroup,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's required so that we can avoid having duplicated budget add/subtract

)
# Make sure we include num running seqs before scheduling prefill,
# so that we don't schedule beyond max_num_seqs for prefill.
for seq_group in self.running:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was a bug (we should not include num_batched tokens in the beginning). it is fixed & regression tests are added.


for scheduled_seq_group, outputs in zip(scheduled_seq_groups, output):
seq_group = scheduled_seq_group.seq_group
token_chunk_size = scheduled_seq_group.token_chunk_size
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made it handled in scheduler.py. Basically as soon as we confirm it is scheduled, we update num computed tokens (which is consistent to moving waiting item to running within a scheduler).

vllm/sequence.py Outdated
"""
num_uncomputed_tokens = self.data.get_num_uncomputed_tokens()
if self.data.stage == SequenceStage.DECODE:
assert num_uncomputed_tokens == 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is this set? get_num_uncomputed_tokens in RequestMetrics only returns return self.get_len() - self.get_num_computed_tokens()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm yeah, this assumes that a new output token is added to sequence (so that get_len is incremented by 1) before update happens. I just removed the assert here

Copy link
Collaborator

@simon-mo simon-mo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving. The PR is mostly additive and some slight refactoring over existing code path to enable the chunked prefill path.

self._requeset_ids_num_curr_seqs.add(req_id)
self._num_curr_seqs += num_curr_seqs

def subtract_num_seqs(self, seq_group: SequenceGroup, num_curr_seqs: int):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think the two add/substract methods should be combined

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unfortunately, for existing logic, we only add budget for num_seqs. So we need a separate method here...

def add_num_batched_tokens(self, seq_group: SequenceGroup,
num_batched_tokens: int):
req_id = seq_group.request_id
if req_id in self._requeset_ids_num_batched_tokens:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when would this happen? I think this should raise an exception

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same case as #3853 (comment)

seq_group, SequenceStatus.RUNNING, enable_chunking, budget)
# We can have up to 1 running prefill at any given time in running
# queue, which means we can guarantee chunk size is at least 1.
assert num_running_tokens != 0

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wha if budget is zero?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can't happen now. (because for normal case, if prefill is scheduled, this is not called. In chunked prefill, decoding is called first. I don't think we need to handle num_max_batched_tokens=0). So I will leave it in a current way and not handle this case (we can remove assert if there's other cases that this can happen)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants