Skip to content

Commit d10b6ad

Browse files
njhillAkshat-Tripathi
authored andcommitted
[V1][Sampler] Avoid an operation during temperature application (vllm-project#13587)
1 parent 761be24 commit d10b6ad

File tree

4 files changed

+18
-10
lines changed

4 files changed

+18
-10
lines changed

vllm/v1/sample/metadata.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
@dataclass
1010
class SamplingMetadata:
1111

12-
temperature: torch.Tensor
12+
temperature: Optional[torch.Tensor]
1313
all_greedy: bool
1414
all_random: bool
1515

vllm/v1/sample/sampler.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,8 @@ def apply_temperature(
7777
logits: torch.Tensor,
7878
temp: torch.Tensor,
7979
) -> torch.Tensor:
80-
# Avoid division by zero.
81-
temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
8280
# Use in-place division to avoid creating a new tensor.
83-
logits.div_(temp.unsqueeze(dim=1))
84-
return logits
81+
return logits.div_(temp.unsqueeze(dim=1))
8582

8683
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
8784
return logits.argmax(dim=-1).view(-1)
@@ -100,6 +97,8 @@ def sample(
10097
if sampling_metadata.all_greedy:
10198
return greedy_sampled
10299

100+
assert sampling_metadata.temperature is not None
101+
103102
# Apply temperature.
104103
logits = self.apply_temperature(logits, sampling_metadata.temperature)
105104

@@ -122,6 +121,7 @@ def sample(
122121
sampling_metadata.temperature < _SAMPLING_EPS,
123122
greedy_sampled,
124123
random_sampled,
124+
out=greedy_sampled, # Reuse tensor
125125
)
126126
return sampled
127127

vllm/v1/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,11 +191,13 @@ def bind_kv_cache(
191191

192192

193193
def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
194-
length: int) -> None:
194+
length: int) -> torch.Tensor:
195195
"""
196196
Copy the first length elements of a tensor into another tensor in a
197197
non-blocking manner.
198198
199199
Used to copy pinned CPU tensor data to pre-allocated GPU tensors.
200+
201+
Returns the sliced target tensor.
200202
"""
201-
to_tensor[:length].copy_(from_tensor[:length], non_blocking=True)
203+
return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True)

vllm/v1/worker/gpu_input_batch.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,10 +242,12 @@ def add_request(
242242
self.block_table.add_row(req_index, request.block_ids)
243243

244244
sampling_params = request.sampling_params
245-
self.temperature_cpu[req_index] = sampling_params.temperature
246245
if sampling_params.sampling_type == SamplingType.GREEDY:
246+
# Avoid later division by zero.
247+
self.temperature_cpu[req_index] = -1.0
247248
self.greedy_reqs.add(req_id)
248249
else:
250+
self.temperature_cpu[req_index] = sampling_params.temperature
249251
self.random_reqs.add(req_id)
250252

251253
self.top_p_cpu[req_index] = sampling_params.top_p
@@ -410,7 +412,11 @@ def refresh_sampling_metadata(self):
410412

411413
def _make_sampling_metadata(self) -> SamplingMetadata:
412414
num_reqs = self.num_reqs
413-
copy_slice(self.temperature_cpu_tensor, self.temperature, num_reqs)
415+
if not self.all_greedy:
416+
temperature = copy_slice(self.temperature_cpu_tensor,
417+
self.temperature, num_reqs)
418+
else:
419+
temperature = None
414420
if not self.no_top_p:
415421
copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
416422
if not self.no_top_k:
@@ -437,7 +443,7 @@ def _make_sampling_metadata(self) -> SamplingMetadata:
437443
prompt_token_ids = None
438444

439445
return SamplingMetadata(
440-
temperature=self.temperature[:num_reqs],
446+
temperature=temperature,
441447
all_greedy=self.all_greedy,
442448
all_random=self.all_random,
443449
top_p=None if self.no_top_p else self.top_p[:num_reqs],

0 commit comments

Comments
 (0)