From dc71b2e22979b190da145105155d72bf56ab3497 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 30 Aug 2024 13:24:35 -0400 Subject: [PATCH 1/3] [Kernel] Change interface to Mamba selective_state_update for continuous batching --- tests/kernels/test_mamba_ssm.py | 158 ++++++++++++++++++ .../layers/mamba/ops/mamba_ssm.py | 31 +++- 2 files changed, 186 insertions(+), 3 deletions(-) diff --git a/tests/kernels/test_mamba_ssm.py b/tests/kernels/test_mamba_ssm.py index d3cb0a8656a0..b8a8c4c4e62b 100644 --- a/tests/kernels/test_mamba_ssm.py +++ b/tests/kernels/test_mamba_ssm.py @@ -322,3 +322,161 @@ def test_selective_state_update(dim, dstate, has_z, itype): assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("itype", + [torch.float32, torch.float16, torch.bfloat16]) +# @pytest.mark.parametrize('itype', [torch.float16]) +@pytest.mark.parametrize("has_z", [False, True]) +# @pytest.mark.parametrize('has_z', [True]) +@pytest.mark.parametrize("dstate", [16, 32, 64]) +# @pytest.mark.parametrize("dstate", [16]) +@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) +# @pytest.mark.parametrize("dim", [2048]) +def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype): + device = "cuda" + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2) + if itype == torch.bfloat16: + rtol, atol = 6e-2, 6e-2 + if torch.version.hip: + atol *= 2 + # set seed + torch.random.manual_seed(0) + batch_size = 16 + + total_entries = 10 * batch_size + state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device) + state_indices = torch.randperm(total_entries)[:batch_size].to( + dtype=torch.int32, device=device) + + x = torch.randn(batch_size, dim, device=device, dtype=itype) + dt = torch.randn(batch_size, dim, device=device, dtype=itype) + dt_bias = torch.rand(dim, device=device) - 4.0 + A = -torch.rand(dim, dstate, device=device) - 1.0 + B = torch.randn(batch_size, dstate, device=device) + C = torch.randn(batch_size, dstate, device=device) + D = torch.randn(dim, device=device) + z = torch.randn_like(x) if has_z else None + state_ref = state[state_indices, :].detach().clone() + out = selective_state_update(state, + x, + dt, + A, + B, + C, + D=D, + z=z, + dt_bias=dt_bias, + dt_softplus=True, + state_batch_indices=state_indices) + out_ref = selective_state_update_ref(state_ref, + x, + dt, + A, + B, + C, + D=D, + z=z, + dt_bias=dt_bias, + dt_softplus=True) + + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + assert torch.allclose(state[state_indices, :], + state_ref, + rtol=rtol, + atol=atol) + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("itype", + [torch.float32, torch.float16, torch.bfloat16]) +#@pytest.mark.parametrize('itype', [torch.float32]) +@pytest.mark.parametrize("has_z", [False, True]) +# @pytest.mark.parametrize('has_z', [True]) +@pytest.mark.parametrize("tie_hdim", [False, True]) +# @pytest.mark.parametrize('tie_hdim', [True]) +@pytest.mark.parametrize("ngroups", [1, 2, 4]) +# @pytest.mark.parametrize("ngroups", [2]) +@pytest.mark.parametrize("dstate", [16, 32, 64]) +# @pytest.mark.parametrize("dstate", [16]) +@pytest.mark.parametrize("dim", [2048, 4096]) +# @pytest.mark.parametrize("dim", [2048]) +def test_selective_state_update_with_heads_with_batch_indices( + dim, dstate, ngroups, has_z, tie_hdim, itype): + device = "cuda" + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 3e-2) + if itype == torch.bfloat16: + rtol, atol = 1e-1, 1e-1 + # set seed + torch.random.manual_seed(0) + batch_size = 16 + headdim = 64 + nheads = dim // headdim + + total_entries = 10 * batch_size + state = torch.randn(total_entries, + nheads, + headdim, + dstate, + dtype=itype, + device=device) + state_indices = torch.randperm(total_entries)[:batch_size].to( + dtype=torch.int32, device=device) + + x = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype) + if not tie_hdim: + dt = torch.randn(batch_size, + nheads, + headdim, + device=device, + dtype=itype) + dt_bias = torch.rand(nheads, headdim, device=device) - 4.0 + A = -torch.rand(nheads, headdim, dstate, device=device) - 1.0 + D = torch.randn(nheads, headdim, device=device) + else: + dt = repeat(torch.randn(batch_size, nheads, device=device, + dtype=itype), + "b h -> b h p", + p=headdim) + dt_bias = repeat(torch.rand(nheads, device=device) - 4.0, + "h -> h p", + p=headdim) + A = repeat(-torch.rand(nheads, device=device) - 1.0, + "h -> h p n", + p=headdim, + n=dstate) + D = repeat(torch.randn(nheads, device=device), "h -> h p", p=headdim) + B = torch.randn(batch_size, ngroups, dstate, device=device) + C = torch.randn(batch_size, ngroups, dstate, device=device) + z = torch.randn_like(x) if has_z else None + state_ref = state[state_indices, :].detach().clone() + out = selective_state_update(state, + x, + dt, + A, + B, + C, + D=D, + z=z, + dt_bias=dt_bias, + dt_softplus=True, + state_batch_indices=state_indices) + out_ref = selective_state_update_ref(state_ref, + x, + dt, + A, + B, + C, + D=D, + z=z, + dt_bias=dt_bias, + dt_softplus=True) + + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + assert torch.allclose(state[state_indices, :], + state_ref, + rtol=rtol, + atol=atol) + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 869c69214caf..a0bed07ac619 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -1,4 +1,5 @@ # Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/selective_state_update.py import torch import triton @@ -27,6 +28,10 @@ def softplus(dt): {"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) @triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None}) @triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None}) +@triton.heuristics({ + "HAS_STATE_BATCH_INDICES": + lambda args: args["state_batch_indices_ptr"] is not None +}) @triton.heuristics( {"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}) @triton.jit @@ -42,6 +47,7 @@ def _selective_scan_update_kernel( D_ptr, z_ptr, out_ptr, + state_batch_indices_ptr, # Matrix dimensions batch, nheads, @@ -85,12 +91,24 @@ def _selective_scan_update_kernel( HAS_DT_BIAS: tl.constexpr, HAS_D: tl.constexpr, HAS_Z: tl.constexpr, + HAS_STATE_BATCH_INDICES: tl.constexpr, BLOCK_SIZE_DSTATE: tl.constexpr, ): pid_m = tl.program_id(axis=0) pid_b = tl.program_id(axis=1) pid_h = tl.program_id(axis=2) - state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head + + # If HAS_STATE_BATCH_INDICES is true, then the ssm state's batch coordinate + # is taken from the state_batch_indices_ptr Otherwise, the state coordinate + # is the same as the batch id. + if HAS_STATE_BATCH_INDICES: + state_batch_indices_ptr += pid_b + state_batch_idx = tl.load(state_batch_indices_ptr) + state_ptr += (state_batch_idx * stride_state_batch + + pid_h * stride_state_head) + else: + state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head + x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head if HAS_DT_BIAS: @@ -177,7 +195,8 @@ def selective_state_update(state, D=None, z=None, dt_bias=None, - dt_softplus=False): + dt_softplus=False, + state_batch_indices=None): """ Argument: state: (batch, dim, dstate) or (batch, nheads, dim, dstate) @@ -211,7 +230,10 @@ def selective_state_update(state, z = z.unsqueeze(1) if dt_bias is not None and dt_bias.dim() == 1: dt_bias = dt_bias.unsqueeze(0) - batch, nheads, dim, dstate = state.shape + + _, nheads, dim, dstate = state.shape + batch = x.shape[0] + assert x.shape == (batch, nheads, dim) assert dt.shape == x.shape assert A.shape == (nheads, dim, dstate) @@ -225,6 +247,8 @@ def selective_state_update(state, assert z.shape == x.shape if dt_bias is not None: assert dt_bias.shape == (nheads, dim) + if state_batch_indices is not None: + assert state_batch_indices.shape == (batch, ) out = torch.empty_like(x) grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads) z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else @@ -249,6 +273,7 @@ def selective_state_update(state, D, z, out, + state_batch_indices, batch, nheads, dim, From 37f4eefe0d569439afd001d0db025397ba28c389 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 30 Aug 2024 13:47:47 -0400 Subject: [PATCH 2/3] cruft --- tests/kernels/test_mamba_ssm.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/tests/kernels/test_mamba_ssm.py b/tests/kernels/test_mamba_ssm.py index b8a8c4c4e62b..326e95087587 100644 --- a/tests/kernels/test_mamba_ssm.py +++ b/tests/kernels/test_mamba_ssm.py @@ -326,13 +326,9 @@ def test_selective_state_update(dim, dstate, has_z, itype): @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) -# @pytest.mark.parametrize('itype', [torch.float16]) @pytest.mark.parametrize("has_z", [False, True]) -# @pytest.mark.parametrize('has_z', [True]) @pytest.mark.parametrize("dstate", [16, 32, 64]) -# @pytest.mark.parametrize("dstate", [16]) @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) -# @pytest.mark.parametrize("dim", [2048]) def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype): device = "cuda" rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2) @@ -391,17 +387,11 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype): @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) -#@pytest.mark.parametrize('itype', [torch.float32]) @pytest.mark.parametrize("has_z", [False, True]) -# @pytest.mark.parametrize('has_z', [True]) @pytest.mark.parametrize("tie_hdim", [False, True]) -# @pytest.mark.parametrize('tie_hdim', [True]) @pytest.mark.parametrize("ngroups", [1, 2, 4]) -# @pytest.mark.parametrize("ngroups", [2]) @pytest.mark.parametrize("dstate", [16, 32, 64]) -# @pytest.mark.parametrize("dstate", [16]) @pytest.mark.parametrize("dim", [2048, 4096]) -# @pytest.mark.parametrize("dim", [2048]) def test_selective_state_update_with_heads_with_batch_indices( dim, dstate, ngroups, has_z, tie_hdim, itype): device = "cuda" From 4b776da8e82170ac52a13805c750765cbfaac4b5 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Wed, 18 Sep 2024 14:03:36 -0400 Subject: [PATCH 3/3] remove cruft, and adjust toleraces --- tests/kernels/test_mamba_ssm.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/kernels/test_mamba_ssm.py b/tests/kernels/test_mamba_ssm.py index 326e95087587..79cbe5469bb5 100644 --- a/tests/kernels/test_mamba_ssm.py +++ b/tests/kernels/test_mamba_ssm.py @@ -333,7 +333,7 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype): device = "cuda" rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2) if itype == torch.bfloat16: - rtol, atol = 6e-2, 6e-2 + rtol, atol = 7e-2, 7e-2 if torch.version.hip: atol *= 2 # set seed @@ -376,8 +376,6 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype): dt_bias=dt_bias, dt_softplus=True) - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") assert torch.allclose(state[state_indices, :], state_ref, rtol=rtol,