Skip to content

[Bug]: OLMo 2 does not split qkv correctly for grouped query attention #13686

@2015aroras

Description

@2015aroras

Your current environment

I do not have access to the environment anymore, but the bug and its fix are straightforward.

🐛 Describe the bug

OLMo 2 does not correctly do attention when the number of heads is not the same as the number of kv heads (i.e. GQA or MQA is used instead of MHA). Specifically, it splits qkv into equal chunks rather than chunks based on q, k, v size. The fix is a 1-liner.

I don't have a minimal repro, but below is the stack trace caused by using OLMo 2 for GQA.

Exception in worker VllmWorkerProcess while processing method determine_num_available_blocks.
Traceback (most recent call last):
  File "/opt/conda/lib/python3.11/site-packages/vllm/executor/multiproc_worker_utils.py", line 236, in _run_worker_process
    output = run_method(worker, method, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/vllm/utils.py", line 2220, in run_method
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/vllm/worker/worker.py", line 229, in determine_num_available_blocks
    self.model_runner.profile_run()
  File "/opt/conda/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/vllm/worker/model_runner.py", line 1235, in profile_run
    self._dummy_run(max_num_batched_tokens, max_num_seqs)
  File "/opt/conda/lib/python3.11/site-packages/vllm/worker/model_runner.py", line 1346, in _dummy_run
    self.execute_model(model_input, kv_caches, intermediate_tensors)
  File "/opt/conda/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/vllm/worker/model_runner.py", line 1719, in execute_model
    hidden_or_intermediate_states = model_executable(
                                    ^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/vllm/model_executor/models/olmo2.py", line 364, in forward
    hidden_states = self.model(
                    ^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/vllm/model_executor/models/olmo2.py", line 312, in forward
    hidden_states = self.layers[i](
                    ^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/vllm/model_executor/models/olmo2.py", line 247, in forward
    hidden_states = self.self_attn(positions, hidden_states, kv_cache,
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/vllm/model_executor/models/olmo2.py", line 161, in forward
    q, k = self._apply_qk_norm(q, k)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/vllm/model_executor/models/olmo2.py", line 143, in _apply_qk_norm
    q = self.q_norm.forward_native(q)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/vllm/model_executor/layers/layernorm.py", line 52, in forward_native
    raise ValueError("Expected hidden_size to be "
ValueError: Expected hidden_size to be 5120, but found: 2392

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions