|
1 | 1 | """A layer that samples the next tokens from the model's outputs.""" |
2 | 2 | from typing import Dict, List, Optional, Tuple |
3 | 3 |
|
| 4 | +import time |
4 | 5 | import torch |
5 | 6 | import torch.nn as nn |
6 | 7 |
|
@@ -37,13 +38,20 @@ def forward( |
37 | 38 | hidden_states: torch.Tensor, |
38 | 39 | sampling_metadata: SamplingMetadata, |
39 | 40 | embedding_bias: Optional[torch.Tensor] = None, |
| 41 | + logits: Optional[torch.Tensor] = None, |
40 | 42 | ) -> Optional[SamplerOutput]: |
41 | | - # Get the hidden states that we use for sampling. |
42 | | - hidden_states = _prune_hidden_states(hidden_states, sampling_metadata) |
43 | | - |
44 | | - # Get the logits for the next tokens. |
45 | | - logits = _get_logits(hidden_states, embedding, embedding_bias, |
46 | | - self.vocab_size) |
| 43 | + if logits is None: |
| 44 | + # Get the hidden states that we use for sampling. |
| 45 | + hidden_states = _prune_hidden_states(hidden_states, sampling_metadata) |
| 46 | + |
| 47 | + start = time.time() |
| 48 | + # Get the logits for the next tokens. |
| 49 | + logits = _get_logits(hidden_states, embedding, embedding_bias, |
| 50 | + self.vocab_size) |
| 51 | + end = time.time() |
| 52 | + print(f'Out-of-model logits calculation (MatMul) took {(end - start)*1000} ms') |
| 53 | + else: |
| 54 | + logits = _prune_hidden_states(logits, sampling_metadata) |
47 | 55 |
|
48 | 56 | # Only perform sampling in the driver worker. |
49 | 57 | # Note: `_get_logits` is still distributed across TP workers because |
|
0 commit comments