Skip to content

Conversation

@youkaichao
Copy link
Member

manually fix the functionalization pass of torch.compile

without compile, after profiling the model, we have:

INFO 09-13 19:19:09 gpu_executor.py:122] # GPU blocks: 27975, # CPU blocks: 2048

after turning on torch compile, with this pr, we have:

INFO 09-13 19:20:44 gpu_executor.py:122] # GPU blocks: 27887, # CPU blocks: 2048

We lose about 88 GPU blocks, corresponding to 0.17GB memory. While it still costs 0.17GB memory, this is acceptable now.

Without this pr, if we naively turn on inductor, we have:

INFO 09-13 16:35:35 gpu_executor.py:122] # GPU blocks: 17753, # CPU blocks: 2048

It costs about 20 GB memory.

the pattern for reference:

# ============== start reference ==============


def pattern_rotary_embedding(mm, positions, cos_sin_cache):
    """This is the graph for rotary embedding, after post-grad 
    functionalization. We need to remove the `auto_functionalized`
    and `slice_scatter`, to avoid unnecessary copy.
    """
    # File: vllm/vllm/model_executor/models/llama.py:179 in forward, code: q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) # noqa
    split_with_sizes = torch.ops.aten.split_with_sizes.default(
        mm, [4096, 1024, 1024], -1)
    getitem_2 = split_with_sizes[0]
    getitem_3 = split_with_sizes[1]
    split_with_sizes = None

    # File: vllm/vllm/_custom_ops.py:139 in rotary_embedding, code: torch.ops._C.rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox=True) # noqa
    auto_functionalized_1 = torch._higher_order_ops.auto_functionalize.auto_functionalized( # noqa
        torch.ops._C.rotary_embedding.default,
        positions=positions,
        query=getitem_2,
        key=getitem_3,
        head_size=128,
        cos_sin_cache=cos_sin_cache,
        is_neox=True)
    getitem_2 = getitem_3 = None
    getitem_6 = auto_functionalized_1[1]
    getitem_7 = auto_functionalized_1[2]
    auto_functionalized_1 = None
    slice_scatter = torch.ops.aten.slice_scatter.default(
        mm, getitem_6, 1, 0, 4096)
    mm = getitem_6 = None
    slice_scatter_1 = torch.ops.aten.slice_scatter.default(
        slice_scatter, getitem_7, 1, 4096, 5120)
    slice_scatter = getitem_7 = None
    return slice_scatter_1


def replace_rotary_embedding(mm, positions, cos_sin_cache):
    """
    This is the ideal graph for rotary embedding, after post-grad functionalization.
    We want to replace the above pattern with a direct call to `torch.ops._C.rotary_embedding.default`.
    """
    split_with_sizes = torch.ops.aten.split_with_sizes.default(
        mm, [4096, 1024, 1024], -1)
    getitem_2 = split_with_sizes[0]
    getitem_3 = split_with_sizes[1]
    split_with_sizes = None

    torch.ops._C.rotary_embedding.default(positions=positions,
                                          query=getitem_2,
                                          key=getitem_3,
                                          head_size=128,
                                          cos_sin_cache=cos_sin_cache,
                                          is_neox=True)
    getitem_2 = getitem_3 = None
    return mm


def pattern_fused_add_rms_norm(input, residual, weight):
    """
    This is the graph for fused_add_rms_norm, after post-grad functionalization.
    We need to remove the `auto_functionalized`, because PyTorch will not be able to
    functionalize the fused operation well (it will create unnecessary copies).
    """
    # File: vllm/vllm/_custom_ops.py:161 in fused_add_rms_norm, code: torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
    auto_functionalized_3 = torch._higher_order_ops.auto_functionalize.auto_functionalized(
        torch.ops._C.fused_add_rms_norm.default,
        input=mm_1,
        residual=embedding,
        weight=arg4_1,
        epsilon=1e-05)
    mm_1 = embedding = arg4_1 = None
    getitem_26 = auto_functionalized_3[1]
    getitem_27 = auto_functionalized_3[2]
    auto_functionalized_3 = None
    return getitem_26, getitem_27


def replace_fused_add_rms_norm(input, residual, weight):
    """
    This is the ideal graph for fused_add_rms_norm, after post-grad functionalization.
    """
    torch.ops._C.fused_add_rms_norm.default(input=input,
                                            residual=residual,
                                            weight=weight,
                                            epsilon=1e-05)
    return input, residual


# ============== end reference ==============

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@youkaichao
Copy link
Member Author

In the main branch ci:

[2024-09-14T00:29:44Z] INFO 09-13 17:29:44 gpu_executor.py:122] # GPU blocks: 1678, # CPU blocks: 2048

in this commit ci:

[2024-09-14T03:05:39Z] INFO 09-13 20:05:39 gpu_executor.py:122] # GPU blocks: 1440, # CPU blocks: 2048

We lost 238 blocks, roughly 0.4 GB memory. We can investigate deeper to recover all the lost, but I think the majority issue is solved now.

@youkaichao
Copy link
Member Author

confirmed by @Chillee , merging.

@youkaichao youkaichao merged commit a36e070 into vllm-project:main Sep 14, 2024
@youkaichao youkaichao deleted the fix_functionalization branch September 14, 2024 16:46
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
garg-amit pushed a commit to garg-amit/vllm that referenced this pull request Oct 28, 2024
LeiWang1999 pushed a commit to LeiWang1999/vllm-bitblas that referenced this pull request Mar 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant