Skip to content

Conversation

@jeongin601
Copy link
Contributor

@jeongin601 jeongin601 commented Nov 10, 2024

FIX #9834 (link existing issues this PR will resolve)

Problem

The current BatchExpansionTop1Scorer implements a speculative scoring mechanism that uses batch expansion to estimate the probabilities of speculative tokens based on the scoring model. However, in the existing setup, SequenceGroupMetadata applies default sampling parameters (top_p=1.0, temperature=1.0, repetition_penalty=1.0) when generating target probabilities. According to comments in the code, this choice seems to be made since the sampled tokens are not used directly.

Modification

Although we do not directly sample tokens from the target model while scoring, I believe applying consistent sampling parameters to both draft and target probabilities is essential for accurate rejection sampling. The current implementation uses draft probabilities influenced by sampling (filtered by top_p), while target probabilities are not, leading to a mismatch that could affect scoring accuracy. Because the unsampled target probabilities don’t represent actual usage probabilities, I modified the code to apply the same sampling parameters to both draft and target probabilities for consistency in rejection sampling.

In my experiment, this change resulted in a significant difference in the acceptance rate, as shown in the figures below.

Experiment

Setting

  • Target Model / Draft Model: llama3-70B / llama3-8B
  • TP: 4
  • Devices: A100 * 4
  • Total number of requests: 500
  • input length / output length: 1024 / 128
  • sampling parameter: repetition_penalty=1.0, temperature=0.6, top_p=0.9, top_k=-1
  • dataset: c4
  • batch size: 1
  • K: # of speculative tokens)

As-Is

K acceptance rate system efficiency
1 65.1 82.5
2 63.3 68.3
3 62.4 57.4

To-be (applied in this PR)

K acceptance rate system efficiency
1 81.9 91.0
2 80.5 82.3
3 80.2 75.4

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@llsj14
Copy link
Contributor

llsj14 commented Nov 11, 2024

@sroy745 @LiuXiaoxuanPKU @njhill
Would you please check this PR related to the sampling process?

Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @jeongin601 this looks like a very nice finding!

We may still want to make and use a (shallow) copy of the sampling parameters with the seed removed in the case a seed is set, to avoid doing seeded sampling for the non-final tokens.

@llsj14
Copy link
Contributor

llsj14 commented Nov 12, 2024

We may still want to make and use a (shallow) copy of the sampling parameters with the seed removed in the case a seed is set, to avoid doing seeded sampling for the non-final tokens.

@njhill, I'm curious about the reason why the seed should be removed, especially if it is used for the target model sampling and affects the output token selection when proposals are rejected.

@jeongin601 jeongin601 force-pushed the main branch 2 times, most recently from 289341d to a54d83e Compare November 12, 2024 15:11
Signed-off-by: jeongin601 <[email protected]>
Signed-off-by: jeongin601 <[email protected]>
@njhill
Copy link
Member

njhill commented Nov 12, 2024

We may still want to make and use a (shallow) copy of the sampling parameters with the seed removed in the case a seed is set, to avoid doing seeded sampling for the non-final tokens.

@njhill, I'm curious about the reason why the seed should be removed, especially if it is used for the target model sampling and affects the output token selection when proposals are rejected.

@joennlae ah sorry, perhaps I misremembered the logic, I didn't think those sampled tokens could end up getting used. I'll check it again but if you're right then makes sense to ignore that seed optimization.

@sroy745 sroy745 added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 15, 2024
@sroy745
Copy link
Collaborator

sroy745 commented Nov 15, 2024

Adding /ready to kick off the tests and verify nothing else fails from this

@llsj14
Copy link
Contributor

llsj14 commented Nov 15, 2024

We may still want to make and use a (shallow) copy of the sampling parameters with the seed removed in the case a seed is set, to avoid doing seeded sampling for the non-final tokens.

@njhill, I'm curious about the reason why the seed should be removed, especially if it is used for the target model sampling and affects the output token selection when proposals are rejected.
@joennlae ah sorry, perhaps I misremembered the logic, I didn't think those sampled tokens could end up getting used. I'll check it again but if you're right then makes sense to ignore that seed optimization.

@njhill
Yeah, I also needed to double-check. I think in this part, we might need to use seeds, but I haven't examined seeded_seqs in detail yet.

# NOTE: the recovered_probs are overwritten by this method.
recovered_token_ids = _multinomial(
recovered_probs,
num_samples=1,
k=k,
seeded_seqs=seeded_seqs or {},
).reshape(batch_size, k)

@sroy745
Copy link
Collaborator

sroy745 commented Nov 18, 2024

We may still want to make and use a (shallow) copy of the sampling parameters with the seed removed in the case a seed is set, to avoid doing seeded sampling for the non-final tokens.

@njhill, I'm curious about the reason why the seed should be removed, especially if it is used for the target model sampling and affects the output token selection when proposals are rejected.
@joennlae ah sorry, perhaps I misremembered the logic, I didn't think those sampled tokens could end up getting used. I'll check it again but if you're right then makes sense to ignore that seed optimization.

@njhill Yeah, I also needed to double-check. I think in this part, we might need to use seeds, but I haven't examined seeded_seqs in detail yet.

# NOTE: the recovered_probs are overwritten by this method.
recovered_token_ids = _multinomial(
recovered_probs,
num_samples=1,
k=k,
seeded_seqs=seeded_seqs or {},
).reshape(batch_size, k)

Hi,
I think this change should not impact the per request seed handling logic in the RejectionSampler. The per request seeds are set here which remains unchanged, hence I am wondering if this should be fine.

cc: @tdoublep who made the change for respecting per request seed in spec-decode worker. @tdoublep can you PTAL and see if this change impacts the per request seeding logic or not.

@jeongin601 there is one test failure in the spec_decoding tests (test_many_k[1-32-2-test_llm_kwargs3-baseline_llm_kwargs0-per_test_common_llm_kwargs0-common_llm_kwargs0]). I ran the test locally and it passes. Also from the failure logs it seems transient. Can you please trigger the tests once to see if it passes or not?

@llsj14
Copy link
Contributor

llsj14 commented Nov 18, 2024

Thank you @sroy745, I was able to check correctly after your comments.

I found out that this PR also corrects the seed for 'non_spec_token_ids'. Although I haven't used 'non_spec_token_ids' while utilizing spec decode, if it is used, the seed should be set to match that of 'seq_group_metadata'.

I also confirmed that this section remains unchanged by this PR and is already using the correct sampling parameters. This PR cannot affect the 'seq_group_metadata'(which has per request sampling parameters) as the 'target_seq_group_metadata_list' is simply generated from 'seq_group_metadata.'

Signed-off-by: jeongin601 <[email protected]>
Signed-off-by: jeongin601 <[email protected]>
Signed-off-by: jeongin601 <[email protected]>
@llsj14
Copy link
Contributor

llsj14 commented Nov 19, 2024

test_many_k passed, but test_mlp_e2e_seeded_correctness failed(it didn't raise assertion). I think there shouldn't be any issue with the seed, but we need to check. @jeongin601 will rerun the test first.

What I suspect is that the number of attempts to sample with the same seed may have changed due to this PR. This could affect the output because it causes the generator to use a different part of the random values. If that's the case, I believe the outcome is not incorrect, but we need to verify it.

Signed-off-by: jeongin601 <[email protected]>
Signed-off-by: jeongin601 <[email protected]>
@jeongin601
Copy link
Contributor Author

jeongin601 commented Nov 19, 2024

@sroy745 I retriggered the tests, but they still fail on a seeded speculative decoding test. It seems this happens because my code reproduces the same results when the sampling seed is set to 'None.' But if the seed value is set to 'None,' sampling should use the default seed value. Then why should it produce different results?

@llsj14
Copy link
Contributor

llsj14 commented Nov 19, 2024

@sroy745 I retriggered the tests, but they still fail on a seeded speculative decoding test. It seems this happens because my code reproduces the same results when the sampling seed is set to 'None.' But if the seed value is set to 'None,' sampling should use the default seed value. Then why should it produce different results?

This test was added in this PR.
@njhill Would you please help us to understand the test results? We are confused due to the test failure in 'test_mlp_e2e_seeded_correctness'.

@sroy745
Copy link
Collaborator

sroy745 commented Nov 20, 2024

@sroy745 I retriggered the tests, but they still fail on a seeded speculative decoding test. It seems this happens because my code reproduces the same results when the sampling seed is set to 'None.' But if the seed value is set to 'None,' sampling should use the default seed value. Then why should it produce different results?

This test was added in this PR. @njhill Would you please help us to understand the test results? We are confused due to the test failure in 'test_mlp_e2e_seeded_correctness'.

Hi,
I added some more logs to debug the test failure. The test that fails is for temperature 0.1. My understanding is the following.

The output is determined by the probability distribution of the target model. Prior to this change the temperature of the target model would be set to 1 (https://sourcegraph.com/github.com/vllm-project/vllm/-/blob/vllm/spec_decode/batch_expansion.py?L317) and hence the probability distribution would be uniform as shown in the logs below. When the output distribution is uniform having a seed matters in-order to guarantee a deterministic output.

=====================================
sorted_target_with_bonus_probs tensor([[[1.3449e-01, 1.2113e-01, 6.5448e-02,  ..., 5.9314e-14,
          5.9040e-14, 2.1329e-14],
         [4.2238e-01, 1.8298e-01, 6.3147e-02,  ..., 3.3977e-15,
          2.9729e-15, 4.5849e-16],
         [5.4143e-01, 1.7530e-01, 5.0799e-02,  ..., 4.8536e-15,
          3.9750e-15, 2.7892e-15],
         [9.9988e-01, 1.2378e-04, 3.6513e-07,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00]]], device='cuda:0')
sorted_indices_target_with_bonus_probs tensor([[[11794,  2965,  1307,  ..., 28566, 21585, 13066],
         [29979,  4214, 29991,  ..., 25436, 20044, 26365],
         [29991,  4214, 29973,  ..., 26365, 16539, 16731],
         [29991,   349,   382,  ..., 31997, 31998, 31999]]], device='cuda:0')
=====================================
==================
accepted tensor([[False, False, False]], device='cuda:0')
substitute_token_ids tensor([[ 3904,  4214, 29973]], device='cuda:0')
==================
=====================================
sorted_target_with_bonus_probs tensor([[[1.4940e-01, 1.0690e-01, 6.3709e-02,  ..., 2.3085e-15,
          7.9265e-16, 3.0244e-16],
         [3.5589e-01, 2.6260e-01, 1.1180e-01,  ..., 4.5457e-16,
          3.7591e-16, 1.1146e-16],
         [7.1558e-01, 1.0897e-01, 3.6871e-02,  ..., 9.7023e-16,
          8.0093e-16, 3.1187e-16],
         [9.9512e-01, 2.5934e-03, 1.1197e-03,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00]]], device='cuda:0')
sorted_indices_target_with_bonus_probs tensor([[[29889,  1806,  2965,  ..., 28856, 17419, 25459],
         [18474,  4214, 29991,  ..., 26812, 24449, 26365],
         [29991, 29889, 29973,  ..., 22759, 27037, 26365],
         [  349, 29991,   323,  ..., 31997, 31998, 31999]]], device='cuda:0')
=====================================
==================
accepted tensor([[False, False,  True]], device='cuda:0')
substitute_token_ids tensor([[29889,  3352, 29973]], device='cuda:0')
==================
=====================================
sorted_target_with_bonus_probs tensor([[[4.3602e-01, 6.8332e-02, 4.5080e-02,  ..., 1.2287e-13,
          1.1384e-13, 3.2875e-14],
         [5.7436e-02, 3.2391e-02, 2.4813e-02,  ..., 3.1412e-12,
          3.0167e-12, 1.2250e-12],
         [6.1436e-02, 5.5466e-02, 4.5348e-02,  ..., 3.1619e-13,
          2.9519e-13, 2.5927e-13],
         [9.1123e-01, 8.2438e-02, 2.1776e-03,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00]]], device='cuda:0')
sorted_indices_target_with_bonus_probs tensor([[[  349, 29991,   323,  ..., 27949, 18728, 29771],
         [29902,  1576, 29911,  ..., 13546, 24151, 28276],
         [  349,   350,   399,  ..., 28493,  3319, 12223],
         [ 3904,  1525,  1299,  ..., 31997, 31998, 31999]]], device='cuda:0')
=====================================
==================
accepted tensor([[False, False, False]], device='cuda:0')
substitute_token_ids tensor([[2672, 2776,  390]], device='cuda:0')
==================

However after your change, for the failure case you are now setting the sampling temperature for the target model to 0.1. This means that the probability distribution of the sampled tokens is no longer uniform (highly skewed). When that happens I think it does not matter if there is a seed or not. We will always sample the token with prob 1.0

=====================================
sorted_target_with_bonus_probs tensor([[[1.0000e+00, 3.2860e-13, 9.0046e-16,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [1.0000e+00, 5.1224e-14, 1.1282e-16,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [1.0000e+00, 5.2276e-07, 8.0635e-08,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [1.0000e+00, 1.0203e-08, 6.8024e-10,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00]]], device='cuda:0')
sorted_indices_target_with_bonus_probs tensor([[[29889, 29892, 29991,  ..., 31997, 31998, 31999],
         [  306,    13,     0,  ..., 31997, 31998, 31999],
         [29902,  1576,  4013,  ..., 31997, 31998, 31999],
         [  626,   505,  5360,  ..., 31997, 31998, 31999]]], device='cuda:0')
=====================================
==================
accepted tensor([[ True, False,  True]], device='cuda:0')
substitute_token_ids tensor([[29892,   306, 29902]], device='cuda:0')
==================
=====================================
sorted_target_with_bonus_probs tensor([[[1.0000e+00, 1.4068e-11, 9.6391e-14,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [1.0000e+00, 1.6998e-08, 1.4874e-14,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [1.0000e+00, 1.7947e-14, 4.7633e-18,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [9.9995e-01, 3.7569e-05, 9.5598e-06,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00]]], device='cuda:0')
sorted_indices_target_with_bonus_probs tensor([[[  626,   338,  5360,  ..., 31997, 31998, 31999],
         [  263, 29889, 29892,  ..., 31997, 31998, 31999],
         [  304, 29889, 29892,  ..., 31997, 31998, 31999],
         [  367,   304,   748,  ..., 31997, 31998, 31999]]], device='cuda:0')
=====================================
==================
accepted tensor([[ True, False,  True]], device='cuda:0')
substitute_token_ids tensor([[  626,   263, 29889]], device='cuda:0')
==================
=====================================
sorted_target_with_bonus_probs tensor([[[1.0000e+00, 1.1414e-14, 1.0582e-18,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [9.9957e-01, 4.3066e-04, 5.2830e-12,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [1.0000e+00, 9.7951e-11, 4.8004e-12,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [1.0000e+00, 8.7923e-14, 1.4424e-14,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00]]], device='cuda:0')
sorted_indices_target_with_bonus_probs tensor([[[ 7826,  8023, 29889,  ..., 31997, 31998, 31999],
         [  865,   262,  1253,  ..., 31997, 31998, 31999],
         [29889, 12456, 29892,  ..., 31997, 31998, 31999],
         [  306, 22063,  1334,  ..., 31997, 31998, 31999]]], device='cuda:0')
=====================================
==================
accepted tensor([[False, False,  True]], device='cuda:0')
substitute_token_ids tensor([[ 7826,   865, 12456]], device='cuda:0')

I tried running the test with temperature 0.6 instead of 0.1 and it passes.

Not sure if this is correct. @tdoublep / @njhill ptal.

@tdoublep
Copy link
Member

I was always wondering why we use different sampling parameters for the speculative model vs. main model. I reviewed the paper referenced in the rejection sampling code and it doesn't explicitly say either way. I guess one has the freedom to sample the speculative model however one likes, but the results in this PR certainly suggest it makes sense to use the same sampling params as the main model. Really cool! @jeongin601

temperature of the target model would be set to 1 and hence the probability distribution would be uniform

I don't think that is correct: temperature=1.0 doesn't imply a uniform distribution, the distribution is determined by the logits and the temperature is just a scaling factor in the softmax that transforms the logits to probabilities.

However after your change, for the failure case you are now setting the sampling temperature for the target model to 0.1. This means that the probability distribution of the sampled tokens is no longer uniform (highly skewed). When that happens I think it does not matter if there is a seed or not. We will always sample the token with prob 1.0

Yes, setting the temperature to 0.1 will make the test more deterministic than it would be with temperature 1.0. However, it definitely doesn't mean we will always sample the same token with probability exactly 1.0, but rather some high probability.

@sroy745 I think your explanation makes sense: the failing test checks that we get different output each time when not applying the seed. This is not guaranteed to happen when using a low temperature (e.g., approaching the greedy limit), so the test is not really well-defined. The changes in this PR will make everything more deterministic at low temperature, which explains why it was passing before (e.g., because we were using temperature=1.0 to sample the speculative model). I would recommend we modify that test to run only when the temperature is >=1.0.

@jeongin601
Copy link
Contributor Author

jeongin601 commented Nov 21, 2024

Thank you for all the reviews and analysis. I understood why my code kept failing the seeded correctness test. As you all pointed out, I also believe we need to test at higher temperatures. To address this, I modified the test to run at temperatures of 0.6, 1.0, and 1.2.

Signed-off-by: jeongin601 <[email protected]>
Copy link
Collaborator

@sroy745 sroy745 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the pr. Left one comment about one test. otherwise lgtm

@pytest.mark.parametrize("output_len", [64])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("temperature", [0.1, 1.0])
@pytest.mark.parametrize("temperature", [0.6, 1.0, 1.2])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - wondering if it would be ok to run it only for temperature 1.0?

Copy link
Contributor Author

@jeongin601 jeongin601 Nov 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated it ! :) What do you think about the A100 distribution test failure? ..

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi I think the failures in distributed-tests-a100 may not be related to this pr (I see errors related CustomAllreduce in the failure logs). I see this test fail in some other prs as well. The other 2 failures seem related to timeouts.

Copy link
Collaborator

@LiuXiaoxuanPKU LiuXiaoxuanPKU left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Sorry for the late reply here, yes, using the same sampling params for the draft and target model makes a lot of sense to me.

@LiuXiaoxuanPKU LiuXiaoxuanPKU enabled auto-merge (squash) November 24, 2024 04:18
@sroy745
Copy link
Collaborator

sroy745 commented Nov 25, 2024

Hi @simon-mo,
The PR has been LG'ed but there is one test failing (distributed-tests-a100). The failure doesn't seem related to this pr. Also confirmed in the #sig-ci channel that this failure is happening in head(https://vllm-dev.slack.com/archives/C07R5PAL2L9/p1732571600143719). If it looks fine I am wondering if we can get this merged?

cc: @LiuXiaoxuanPKU

@njhill
Copy link
Member

njhill commented Nov 27, 2024

Thanks @jeongin601 @llsj14 @sroy745 @tdoublep for all of the analysis here!

@jeongin601 could you merge in the latest main which should hopefully address the test failure?

@LiuXiaoxuanPKU LiuXiaoxuanPKU merged commit 1bf905d into vllm-project:main Nov 27, 2024
45 checks passed
afeldman-nm pushed a commit to neuralmagic/vllm that referenced this pull request Dec 2, 2024
…s for consistency in rejection sampling. (vllm-project#10198)

Signed-off-by: jeongin601 <[email protected]>
Signed-off-by: jeong_in.bae <[email protected]>
Signed-off-by: Andrew Feldman <[email protected]>
sleepwalker2017 pushed a commit to sleepwalker2017/vllm that referenced this pull request Dec 13, 2024
…s for consistency in rejection sampling. (vllm-project#10198)

Signed-off-by: jeongin601 <[email protected]>
Signed-off-by: jeong_in.bae <[email protected]>
anko-intel pushed a commit to HabanaAI/vllm-fork that referenced this pull request Feb 12, 2025
…s for consistency in rejection sampling. (vllm-project#10198)

Signed-off-by: jeongin601 <[email protected]>
Signed-off-by: jeong_in.bae <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: Sampling parameter fixed issue while doing speculative sampling verification step

6 participants