|
1 | 1 | # SPDX-License-Identifier: Apache-2.0 |
2 | 2 |
|
| 3 | +import copy |
3 | 4 | from contextlib import contextmanager |
4 | 5 | from dataclasses import asdict, dataclass |
5 | 6 | from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Type |
@@ -125,7 +126,7 @@ def graph_clone(self, batch_size: int): |
125 | 126 | assert self._is_graph_capturing |
126 | 127 | state = self.__class__(self.runner) |
127 | 128 | state._workspace_buffer = self._graph_decode_workspace_buffer |
128 | | - state._decode_wrapper = self._graph_decode_wrapper |
| 129 | + state._decode_wrapper = copy.copy(self._graph_decode_wrapper) |
129 | 130 | return state |
130 | 131 |
|
131 | 132 | def graph_capture_get_metadata_for_batch( |
@@ -197,10 +198,12 @@ def begin_forward(self, model_input): |
197 | 198 | # In case of multistep chunked-prefill, there might be prefill requests |
198 | 199 | # scheduled while CUDA graph mode is enabled. We don't run graph in that |
199 | 200 | # case. |
| 201 | + print("begin_forward", model_input.input_tokens.shape[0]) |
200 | 202 | if use_cuda_graph and is_decode: |
201 | 203 | batch_size = model_input.input_tokens.shape[0] |
202 | 204 | state = (self.runner.graph_runners[model_input.virtual_engine] |
203 | 205 | [batch_size].attn_state) |
| 206 | + print("choosing decode_wrapper", batch_size) |
204 | 207 | model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper() |
205 | 208 | model_input.attn_metadata.begin_forward() |
206 | 209 |
|
@@ -421,9 +424,17 @@ def build(self, seq_lens: List[int], query_lens: List[int], |
421 | 424 | self.paged_kv_indptr.extend([self.paged_kv_indptr[-1]] * |
422 | 425 | cuda_graph_pad_size) |
423 | 426 | self.paged_kv_last_page_len.extend([0] * cuda_graph_pad_size) |
424 | | - query_start_loc_host = torch.functional.F.pad( |
425 | | - query_start_loc_host, (cuda_graph_pad_size + 1, ), |
426 | | - value=query_start_loc_host[-1].item()) |
| 427 | + |
| 428 | + print(cuda_graph_pad_size + 1 - query_start_loc_host.shape[0], |
| 429 | + cuda_graph_pad_size + 1, query_start_loc_host.shape[0]) |
| 430 | + if cuda_graph_pad_size + 1 > query_start_loc_host.shape[0]: |
| 431 | + query_start_loc_host = torch.cat( |
| 432 | + (query_start_loc_host, |
| 433 | + torch.full((cuda_graph_pad_size + 1 - |
| 434 | + query_start_loc_host.shape[0], ), |
| 435 | + fill_value=query_start_loc_host[-1].item(), |
| 436 | + dtype=torch.int32, |
| 437 | + device="cpu"))) |
427 | 438 |
|
428 | 439 | if len(self.paged_kv_indptr) > 0: |
429 | 440 | # extend to the maximum number of blocks as returned by the |
|
0 commit comments