|
1 | 1 | import contextlib |
| 2 | +import dataclasses |
2 | 3 | import time |
3 | 4 | from typing import Dict, List, Optional, Tuple, Set, Union |
4 | 5 |
|
@@ -521,45 +522,27 @@ def prepare_input_tensors( |
521 | 522 | metadata_dict = { |
522 | 523 | "input_tokens": input_tokens, |
523 | 524 | "input_positions": input_positions, |
524 | | - "is_prompt": input_metadata.is_prompt, |
525 | | - "slot_mapping": input_metadata.slot_mapping, |
526 | | - "prompt_lens": input_metadata.prompt_lens, |
527 | | - "max_seq_len": input_metadata.max_seq_len, |
528 | | - "start_loc": input_metadata.start_loc, |
529 | | - "max_context_len": input_metadata.max_context_len, |
530 | | - "context_lens": input_metadata.context_lens, |
531 | | - "block_tables": input_metadata.block_tables, |
532 | | - "use_cuda_graph": input_metadata.use_cuda_graph, |
533 | | - "kv_cache_dtype": input_metadata.kv_cache_dtype, |
534 | 525 | "selected_token_indices": |
535 | 526 | sampling_metadata.selected_token_indices, |
536 | 527 | "lora_requests": lora_requests, |
537 | 528 | "lora_mapping": lora_mapping, |
538 | 529 | } |
| 530 | + metadata_dict.update(dataclasses.asdict(input_metadata)) |
539 | 531 | broadcast_tensor_dict(metadata_dict, src=0) |
540 | 532 | else: |
541 | 533 | metadata_dict = broadcast_tensor_dict(src=0) |
542 | | - input_tokens = metadata_dict["input_tokens"] |
543 | | - input_positions = metadata_dict["input_positions"] |
544 | | - lora_mapping = metadata_dict["lora_mapping"] |
545 | | - lora_requests = metadata_dict["lora_requests"] |
546 | | - input_metadata = InputMetadata( |
547 | | - is_prompt=metadata_dict["is_prompt"], |
548 | | - slot_mapping=metadata_dict["slot_mapping"], |
549 | | - prompt_lens=metadata_dict["prompt_lens"], |
550 | | - max_seq_len=metadata_dict["max_seq_len"], |
551 | | - start_loc=metadata_dict["start_loc"], |
552 | | - max_context_len=metadata_dict["max_context_len"], |
553 | | - context_lens=metadata_dict["context_lens"], |
554 | | - block_tables=metadata_dict["block_tables"], |
555 | | - use_cuda_graph=metadata_dict["use_cuda_graph"], |
556 | | - kv_cache_dtype=metadata_dict["kv_cache_dtype"], |
557 | | - ) |
| 534 | + input_tokens = metadata_dict.pop("input_tokens") |
| 535 | + input_positions = metadata_dict.pop("input_positions") |
| 536 | + selected_token_indices = metadata_dict.pop( |
| 537 | + "selected_token_indices") |
| 538 | + lora_mapping = metadata_dict.pop("lora_mapping") |
| 539 | + lora_requests = metadata_dict.pop("lora_requests") |
| 540 | + input_metadata = InputMetadata(**metadata_dict) |
558 | 541 | sampling_metadata = SamplingMetadata( |
559 | 542 | seq_groups=None, |
560 | 543 | seq_data=None, |
561 | 544 | prompt_lens=None, |
562 | | - selected_token_indices=metadata_dict["selected_token_indices"], |
| 545 | + selected_token_indices=selected_token_indices, |
563 | 546 | categorized_sample_indices=None, |
564 | 547 | generators=None, |
565 | 548 | perform_sampling=False, |
|
0 commit comments