|
13 | 13 |
|
14 | 14 | from vllm.attention import AttentionMetadata, get_attn_backend |
15 | 15 | from vllm.config import VllmConfig |
16 | | -from vllm.forward_context import set_forward_context |
17 | 16 | from vllm.logger import init_logger |
18 | 17 | from vllm.model_executor.layers.sampler import SamplerOutput |
19 | 18 | from vllm.model_executor.model_loader import get_model |
@@ -272,9 +271,8 @@ def _dummy_run( |
272 | 271 | torch._dynamo.mark_dynamic(t, 0) |
273 | 272 | torch._dynamo.mark_dynamic(p, 0) |
274 | 273 | # Dummy run. |
275 | | - with set_forward_context(attn_metadata, self.vllm_config, 0): |
276 | | - self.model(token_ids, position_ids, attn_metadata, input_lens, t, |
277 | | - p, num_samples, kv_caches) |
| 274 | + self.model(token_ids, position_ids, attn_metadata, input_lens, t, p, |
| 275 | + num_samples, kv_caches) |
278 | 276 |
|
279 | 277 | def warmup_model( |
280 | 278 | self, |
@@ -673,13 +671,10 @@ def execute_model( |
673 | 671 | input_lens = model_input.input_lens[i:i + 1].to(self.device) |
674 | 672 | t = model_input.t[i:i + 1].to(self.device) |
675 | 673 | p = model_input.p[i:i + 1].to(self.device) |
676 | | - with set_forward_context(model_input.attn_metadata, |
677 | | - self.vllm_config, |
678 | | - model_input.virtual_engine): |
679 | | - output_token_ids = self.model(token_ids, position_ids, |
680 | | - attn_metadata, input_lens, t, |
681 | | - p, model_input.num_samples, |
682 | | - kv_caches) |
| 674 | + output_token_ids = self.model(token_ids, position_ids, |
| 675 | + attn_metadata, input_lens, t, p, |
| 676 | + model_input.num_samples, |
| 677 | + kv_caches) |
683 | 678 | next_token_ids.append(output_token_ids[0]) |
684 | 679 | start_idx = end_idx |
685 | 680 |
|
@@ -724,13 +719,10 @@ def execute_model( |
724 | 719 | input_lens = model_input.input_lens.to(self.device) |
725 | 720 | for i in range(num_steps): |
726 | 721 | slot_mapping = attn_metadata.slot_mapping |
727 | | - with set_forward_context(model_input.attn_metadata, |
728 | | - self.vllm_config, |
729 | | - model_input.virtual_engine): |
730 | | - output_token_ids = self.model(token_ids, position_ids, |
731 | | - attn_metadata, input_lens, t, |
732 | | - p, model_input.num_samples, |
733 | | - kv_caches) |
| 722 | + output_token_ids = self.model(token_ids, position_ids, |
| 723 | + attn_metadata, input_lens, t, p, |
| 724 | + model_input.num_samples, |
| 725 | + kv_caches) |
734 | 726 | self.cached_step_outputs.append(output_token_ids) |
735 | 727 |
|
736 | 728 | if i < num_steps - 1: |
|
0 commit comments